-
Notifications
You must be signed in to change notification settings - Fork 526
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Full DPO Distributed #2275
Full DPO Distributed #2275
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2275
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit d463e70 with merge base 23896c3 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@joecummings Please take a look and let me know if you have feedback! |
Hey @sam-pi! Thanks so much for adding this. I had a quick skim through and it looked good to me. I'll have a closer look soon. First, a couple of high level points. Did you manage to train using these configs? If so, could you attach some evidence of successful runs (e.g. WandB links)? I'm particularly interested in the hardware requirements for the 70B config. We may want to think about offering some additional memory performance improvements for this recipe in particular, such as different parallelization configurations for the reference model (which doesn't need gradients to be sharded), offloading the entire reference model to CPU, etc. |
@SalmanMohammadi Please take a look at my training run screenshots and configs at the bottom of the PR summary (I tried re-uploading the screenshot of my WandB run). I tried showing a comparison of a rank/alpha 256 lora dpo run against a full dpo run (only 100 iterations). |
Any updates on merging this to main? Really excited to use it 😄 |
_, | ||
) = self.concatenated_forward(self._ref_model, batch) | ||
|
||
loss, chosen_rewards, rejected_rewards = self._loss_fn( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another heads up: we log these below but we're not taking GAS into account.
(lmk if these comments are unhelpful btw and I'll stop 🙂 -- just trying to get this PR to run / verify on our setup and commenting as I find discrepancies)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(lmk if these comments are unhelpful btw and I'll stop 🙂 -- just trying to get this PR to run / verify on our setup and commenting as I find discrepancies)
Not at all, your comments are incredibly helpful and more than welcome! Thanks for taking the time to help review.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another heads up: we log these below but we're not taking GAS into account.
noob q: what's GAS?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Gradient Accumulation Steps
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah you're totally right. We should update to correct for gradient accumulation steps.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In that case I assume the same holds for the LoRA DPO recipe too, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@bogdansalyp I understand that we're going to leave this as a follow up?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@SalmanMohammadi wait but it's already accounted for
metrics are scaled for grad_acc steps -> https://github.com/pytorch/torchtune/pull/2275/files#diff-ca64e06ed8af5f389cf3ecfc1fa0eda295e21c4248d44bdfb69992d57368033cR728
and loss is also averaged out -> https://github.com/pytorch/torchtune/pull/2275/files#diff-ca64e06ed8af5f389cf3ecfc1fa0eda295e21c4248d44bdfb69992d57368033cR724
With the changes I mentioned in the comments I was able to get parity with NeMo's DPO using the same data / hparams. E.g. here's the loss curves: ![]() Really awesome work! Pretty excited to use this. |
I'm going to try out investigate some alternative sharding strategies for the reference model, and see if I can get single-node training working for 70B. Will update soon. @sam-pi would you be up for looking into @EugenHotaj's comments above? |
OK so we're not blocking this PR I'm going to leave exploring different parallelism strategies for a follow-up. Let's make the necessary fixes to this recipe and bring it in line with our other distributed recipes. @sam-pi If the 70B config doesn't work on a single node, I'd also suggest we remove it for now and add it back in after patching in the changes from #2301. What do you think? |
Thanks, I will look into all these fixes today and also remove the 70B config for now |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is looking great, thanks so much for adding this @sam-pi! Aside from my inline comments it'd be good to confirm that various features like compile, optimizer-in-backward, etc are working and doing what we'd expect (we can even add e.g. compile to the recipe test)
_, | ||
) = self.concatenated_forward(self._ref_model, batch) | ||
|
||
loss, chosen_rewards, rejected_rewards = self._loss_fn( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In that case I assume the same holds for the LoRA DPO recipe too, right?
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2275 +/- ##
==========================================
- Coverage 66.60% 63.23% -3.37%
==========================================
Files 353 358 +5
Lines 20645 21370 +725
==========================================
- Hits 13750 13514 -236
- Misses 6895 7856 +961 ☔ View full report in Codecov by Sentry. |
The loss starting at 5.5 seems off? We might not be normalizing correctly somewhere. Also wow, 2x improvement in tps with compile! Is this pretty common / expected? I thought compile would give maybe 10-20% extra boost but 2x is awesome! |
The rewards being different between compile and non-compile is also a bit suspicious. I'm surprised we're not able to get fully deterministic results between compile / non-compile. Maybe we're forgetting to set the seed somewhere? |
@EugenHotaj compile is awesome 😃. 2x is definitely not unheard of, in some cases I've seen gains even larger than that. I agree that the large delta in rewards is suspicious. Compile tends to not achieve exact numerical parity, but we should also rule out something like #2335. Once we verify the delta is actually coming from compile, we can loop in some compiler folks to help out here. |
Note we were running these with 8xH100 GPUs - otherwise the same default config. With 4xH100 GPUs the loss starts closer to 2.7 and with 2xH100 the loss starts around 1.4. |
Yea so we probably forgot to normalize by dp rank. Loss should not change with the number of ranks we have. |
Ok I added in all_reduce with average, and it works! I can see starting loss of 0.69 for 2x and 8x H100s |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. I'll let @ebsmothers have one final look - please seek his approval before merging.
Thanks so much for this contribution and your patience through the review process @sam-pi @EugenHotaj @bogdansalyp!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @sam-pi, @bogdansalyp and others for your patience on this one! I do have a few more questions but I think we're close here. The two main ones are making sure we're applying optimizer-in-backward and activation offloading in the right way. Especially for optimizer-in-backward I'm not sure the loss is being calculated in a consistent way. I mentioned this in an inline comment, but open to dropping this from the recipe altogether if you feel it will make things more straightforward. Modulo these last few comments I think we are in good shape here.
@pytest.mark.integration_test | ||
@pytest.mark.parametrize("optimizer_in_bwd", [False]) | ||
# @pytest.mark.parametrize("optimizer_in_bwd", [False, True]) | ||
# TODO: whomever fixes opt in bwd checkpointing without async, please fix this test |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this. I just created #2360 for figuring this out, so can update there
running_metrics["rewards/chosen"] += ( | ||
scaling_factor * chosen_rewards.mean() | ||
) | ||
running_metrics["rewards/rejected"] += ( | ||
scaling_factor * rejected_rewards.mean() | ||
) | ||
running_metrics["rewards/accuracies"] += ( | ||
scaling_factor * reward_accuracies.mean() | ||
) | ||
running_metrics["log_probs/chosen"] += ( | ||
scaling_factor * policy_chosen_log_probs.detach().mean() | ||
) | ||
running_metrics["log_probs/rejected"] += ( | ||
scaling_factor * policy_rejected_log_probs.detach().mean() | ||
) | ||
running_metrics["logits/chosen"] += ( | ||
scaling_factor * policy_chosen_logits_mean | ||
) | ||
running_metrics["logits/rejected"] += ( | ||
scaling_factor * policy_rejected_logits_mean | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should definitely figure out a utility to reduce all this logging boilerplate
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have something like
batch_trajectory = Trajectory( |
map(partial(torch.mul, scaling_factor ...
That being said, someone on twitter called it "machine learning engineer bs" so up to you if it's readable
recipes/full_dpo_distributed.py
Outdated
break | ||
|
||
# batch is input_ids, labels | ||
with self.activations_handling_ctx: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems like we wrap more logic in activations_handling_ctx
than we do in the LoRA DPO recipe, why is this (there I believe we just use it for the model forward pass, which intuitively makes more sense to me)? Also running with activation offloading I see about 20-25% slowdown, do you see this as well? If so, we should figure out why, because in @janeyx99's runs with SFT the slowdown was usually only 1-2%. (Admittedly I am running on A100, not H100, so actually my tokens/sec seems quite a bit lower than yours in general)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah good point, that should be updated I think. I can add that now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I updated this because if we place it inside concatenated_forward
the context will be fired for both the reference and policy model. We only want it for the policy model, and the hooks have been registered on the policy model. I'm not sure what behaviour we'd get if we tried to run the reference model in the context manager setup for the policy model.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just updated in my latest commit so activations_handling_ctx is optionally applied with an additional optional bool arg. This appears to show ~20% speedup in tokens per second @ebsmothers
recipes/full_dpo_distributed.py
Outdated
running_metrics[key], op=torch.distributed.ReduceOp.AVG | ||
) | ||
# We multiply by world_size to undo FSDP2 gradient normalization. | ||
loss = loss * (world_size / num_tokens) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe I'm missing something here, but it seems that when optimizer-in-backward is enabled, we are normalizing the loss by the number of tokens, but when it's not enabled, we aren't?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not to drop too much of a bomb, but a separate question is whether it's worth it to onboard optimizer-in-backward to this recipe at all. For a feature like this I want to make sure value >> added complexity. Optimizer-in-backward (especially in distributed recipes) has relatively high complexity, so would only add it if there is a clear need for it. Would be interested to hear your thoughts on that though
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks so much @sam-pi, @bogdansalyp, @SalmanMohammadi for the team effort on this one! Once CI is green I think we're good to go here.
Context
Adapted from the great work in #1966
What is the purpose of this PR? Is it to
Please link to any issues this PR addresses: relates to #2082
Changelog
What are the changes made in this PR?
Test plan
Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.
pre-commit install
)pytest tests
pytest tests -m integration_test
UX
If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example
Commands and Sample Outputs
Full DPO Config
Lora DPO Config