Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Full DPO Distributed #2275

Merged
merged 39 commits into from
Feb 7, 2025
Merged

Full DPO Distributed #2275

merged 39 commits into from
Feb 7, 2025

Conversation

sam-pi
Copy link
Contributor

@sam-pi sam-pi commented Jan 17, 2025

Context

Adapted from the great work in #1966

What is the purpose of this PR? Is it to

  • add a new feature

Please link to any issues this PR addresses: relates to #2082

Changelog

What are the changes made in this PR?

  • Adds full DPO distributed training configs and recipes, adapting from the lora DPO training
  • Includes integration tests
  • Includes configs for llama3.1 8B and 70B models

Test plan

Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
  • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

UX

If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example

  • I did not change any public API

Commands and Sample Outputs

Full DPO Config

output_dir: .../Meta-Llama-3.1-8B-Instruct/full_dpo
model:
  _component_: torchtune.models.llama3_1.llama3_1_8b
tokenizer:
  _component_: torchtune.models.llama3.llama3_tokenizer
  path: .../Meta-Llama-3.1-8B-Instruct/original/tokenizer.model
  max_seq_len: 1024
checkpointer:
  _component_: torchtune.training.FullModelHFCheckpointer
  checkpoint_dir: .../Meta-Llama-3.1-8B-Instruct
  checkpoint_files:
  - model-00001-of-00004.safetensors
  - model-00002-of-00004.safetensors
  - model-00003-of-00004.safetensors
  - model-00004-of-00004.safetensors
  recipe_checkpoint: null
  output_dir: ${output_dir}
  model_type: LLAMA3
resume_from_checkpoint: false
ref_checkpointer:
  _component_: torchtune.training.FullModelHFCheckpointer
  checkpoint_dir: .../Meta-Llama-3.1-8B-Instruct
  checkpoint_files:
  - model-00001-of-00004.safetensors
  - model-00002-of-00004.safetensors
  - model-00003-of-00004.safetensors
  - model-00004-of-00004.safetensors
  recipe_checkpoint: null
  output_dir: ${output_dir}
  model_type: LLAMA3
dataset:
  _component_: torchtune.datasets.stack_exchange_paired_dataset
seed: null
shuffle: true
batch_size: 4
optimizer:
  _component_: torch.optim.AdamW
  fused: true
  weight_decay: 0.05
  lr: 1.0e-06
lr_scheduler:
  _component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup
  num_warmup_steps: 100
loss:
  _component_: torchtune.rlhf.loss.DPOLoss
  beta: 0.05
  label_smoothing: 0
epochs: 1
max_steps_per_epoch: 2000
gradient_accumulation_steps: 4
compile: false
metric_logger:
  _component_: torchtune.training.metric_logging.WandBLogger
  log_dir: ${output_dir}/logs
  project: torchtune
  name: llama3.1-8B-dpo_3605
log_every_n_steps: 1
log_peak_memory_stats: true
device: cuda
dtype: bf16
enable_activation_checkpointing: true
enable_activation_offloading: false

Lora DPO Config

output_dir: .../Meta-Llama-3.1-8B-Instruct/lora_dpo
model:
  _component_: torchtune.models.llama3_1.lora_llama3_1_8b
  lora_attn_modules:
  - q_proj
  - v_proj
  - output_proj
  apply_lora_to_mlp: true
  apply_lora_to_output: false
  lora_rank: 256
  lora_alpha: 256
  lora_dropout: 0.0
tokenizer:
  _component_: torchtune.models.llama3.llama3_tokenizer
  path: .../Meta-Llama-3.1-8B-Instruct/original/tokenizer.model
  max_seq_len: 1024
checkpointer:
  _component_: torchtune.training.FullModelHFCheckpointer
  checkpoint_dir: .../Meta-Llama-3.1-8B-Instruct
  checkpoint_files:
  - model-00001-of-00004.safetensors
  - model-00002-of-00004.safetensors
  - model-00003-of-00004.safetensors
  - model-00004-of-00004.safetensors
  recipe_checkpoint: null
  output_dir: ${output_dir}
  model_type: LLAMA3
resume_from_checkpoint: false
save_adapter_weights_only: false
dataset:
  _component_: torchtune.datasets.stack_exchange_paired_dataset
seed: null
shuffle: true
batch_size: 4
optimizer:
  _component_: torch.optim.AdamW
  fused: true
  weight_decay: 0.05
  lr: 1.0e-05
lr_scheduler:
  _component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup
  num_warmup_steps: 100
loss:
  _component_: torchtune.rlhf.loss.DPOLoss
  beta: 0.1
  label_smoothing: 0
epochs: 1
max_steps_per_epoch: 100
gradient_accumulation_steps: 4
compile: false
metric_logger:
  _component_: torchtune.training.metric_logging.WandBLogger
  log_dir: ${output_dir}/logs
  project: torchtune
  name: llama3.1-8Blora-dpo_3603
log_every_n_steps: 1
log_peak_memory_stats: true
device: cuda
dtype: bf16
enable_activation_checkpointing: true
enable_activation_offloading: false
Screenshot 2025-01-16 at 12 39 23 PM

Copy link

pytorch-bot bot commented Jan 17, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2275

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit d463e70 with merge base 23896c3 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jan 17, 2025
@sam-pi
Copy link
Contributor Author

sam-pi commented Jan 17, 2025

@joecummings Please take a look and let me know if you have feedback!

@SalmanMohammadi
Copy link
Collaborator

SalmanMohammadi commented Jan 20, 2025

Hey @sam-pi! Thanks so much for adding this. I had a quick skim through and it looked good to me. I'll have a closer look soon. First, a couple of high level points.

Did you manage to train using these configs? If so, could you attach some evidence of successful runs (e.g. WandB links)?

I'm particularly interested in the hardware requirements for the 70B config. We may want to think about offering some additional memory performance improvements for this recipe in particular, such as different parallelization configurations for the reference model (which doesn't need gradients to be sharded), offloading the entire reference model to CPU, etc.

@sam-pi
Copy link
Contributor Author

sam-pi commented Jan 21, 2025

@SalmanMohammadi Please take a look at my training run screenshots and configs at the bottom of the PR summary (I tried re-uploading the screenshot of my WandB run). I tried showing a comparison of a rank/alpha 256 lora dpo run against a full dpo run (only 100 iterations).
For Llama3.1-70B-Instruct, I was able to run using 2 nodes with 8x H100 GPUs (I think this is just 2x the HW requirements for running a single non-quantized 70B).

@RdoubleA RdoubleA mentioned this pull request Jan 21, 2025
@EugenHotaj
Copy link
Contributor

Any updates on merging this to main? Really excited to use it 😄

_,
) = self.concatenated_forward(self._ref_model, batch)

loss, chosen_rewards, rejected_rewards = self._loss_fn(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another heads up: we log these below but we're not taking GAS into account.

(lmk if these comments are unhelpful btw and I'll stop 🙂 -- just trying to get this PR to run / verify on our setup and commenting as I find discrepancies)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(lmk if these comments are unhelpful btw and I'll stop 🙂 -- just trying to get this PR to run / verify on our setup and commenting as I find discrepancies)

Not at all, your comments are incredibly helpful and more than welcome! Thanks for taking the time to help review.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another heads up: we log these below but we're not taking GAS into account.

noob q: what's GAS?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gradient Accumulation Steps

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah you're totally right. We should update to correct for gradient accumulation steps.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case I assume the same holds for the LoRA DPO recipe too, right?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bogdansalyp I understand that we're going to leave this as a follow up?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@EugenHotaj
Copy link
Contributor

With the changes I mentioned in the comments I was able to get parity with NeMo's DPO using the same data / hparams. E.g. here's the loss curves:

Screenshot 2025-01-29 at 5 06 40 PM

Really awesome work! Pretty excited to use this.

@SalmanMohammadi
Copy link
Collaborator

SalmanMohammadi commented Jan 30, 2025

Any updates on merging this to main? Really excited to use it 😄

I'm going to try out investigate some alternative sharding strategies for the reference model, and see if I can get single-node training working for 70B. Will update soon. @sam-pi would you be up for looking into @EugenHotaj's comments above?

@SalmanMohammadi
Copy link
Collaborator

OK so we're not blocking this PR I'm going to leave exploring different parallelism strategies for a follow-up. Let's make the necessary fixes to this recipe and bring it in line with our other distributed recipes.

@sam-pi If the 70B config doesn't work on a single node, I'd also suggest we remove it for now and add it back in after patching in the changes from #2301. What do you think?

@sam-pi
Copy link
Contributor Author

sam-pi commented Jan 30, 2025

OK so we're not blocking this PR I'm going to leave exploring different parallelism strategies for a follow-up. Let's make the necessary fixes to this recipe and bring it in line with our other distributed recipes.

@sam-pi If the 70B config doesn't work on a single node, I'd also suggest we remove it for now and add it back in after patching in the changes from #2301. What do you think?

Thanks, I will look into all these fixes today and also remove the 70B config for now

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is looking great, thanks so much for adding this @sam-pi! Aside from my inline comments it'd be good to confirm that various features like compile, optimizer-in-backward, etc are working and doing what we'd expect (we can even add e.g. compile to the recipe test)

recipes/configs/llama3_1/70B_full_dpo.yaml Outdated Show resolved Hide resolved
recipes/configs/llama3_1/70B_full_dpo.yaml Outdated Show resolved Hide resolved
tests/recipes/test_full_dpo_distributed.py Outdated Show resolved Hide resolved
tests/recipes/test_full_dpo_distributed.py Outdated Show resolved Hide resolved
tests/recipes/test_full_dpo_distributed.py Outdated Show resolved Hide resolved
recipes/full_dpo_distributed.py Show resolved Hide resolved
recipes/full_dpo_distributed.py Show resolved Hide resolved
_,
) = self.concatenated_forward(self._ref_model, batch)

loss, chosen_rewards, rejected_rewards = self._loss_fn(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case I assume the same holds for the LoRA DPO recipe too, right?

recipes/full_dpo_distributed.py Show resolved Hide resolved
recipes/full_dpo_distributed.py Outdated Show resolved Hide resolved
@codecov-commenter
Copy link

codecov-commenter commented Feb 6, 2025

Codecov Report

Attention: Patch coverage is 3.93120% with 391 lines in your changes missing coverage. Please review.

Project coverage is 63.23%. Comparing base (5764650) to head (893398e).
Report is 18 commits behind head on main.

Files with missing lines Patch % Lines
recipes/full_dpo_distributed.py 0.00% 346 Missing ⚠️
tests/recipes/test_full_dpo_distributed.py 38.09% 26 Missing ⚠️
recipes/lora_dpo_distributed.py 0.00% 17 Missing ⚠️
recipes/full_finetune_distributed.py 0.00% 1 Missing ⚠️
recipes/lora_finetune_distributed.py 0.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2275      +/-   ##
==========================================
- Coverage   66.60%   63.23%   -3.37%     
==========================================
  Files         353      358       +5     
  Lines       20645    21370     +725     
==========================================
- Hits        13750    13514     -236     
- Misses       6895     7856     +961     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@EugenHotaj
Copy link
Contributor

The loss starting at 5.5 seems off? We might not be normalizing correctly somewhere.

Also wow, 2x improvement in tps with compile! Is this pretty common / expected? I thought compile would give maybe 10-20% extra boost but 2x is awesome!

@EugenHotaj
Copy link
Contributor

The rewards being different between compile and non-compile is also a bit suspicious. I'm surprised we're not able to get fully deterministic results between compile / non-compile. Maybe we're forgetting to set the seed somewhere?

@ebsmothers
Copy link
Contributor

@EugenHotaj compile is awesome 😃. 2x is definitely not unheard of, in some cases I've seen gains even larger than that. I agree that the large delta in rewards is suspicious. Compile tends to not achieve exact numerical parity, but we should also rule out something like #2335. Once we verify the delta is actually coming from compile, we can loop in some compiler folks to help out here.

@sam-pi
Copy link
Contributor Author

sam-pi commented Feb 6, 2025

The loss starting at 5.5 seems off? We might not be normalizing correctly somewhere.

Also wow, 2x improvement in tps with compile! Is this pretty common / expected? I thought compile would give maybe 10-20% extra boost but 2x is awesome!

Note we were running these with 8xH100 GPUs - otherwise the same default config. With 4xH100 GPUs the loss starts closer to 2.7 and with 2xH100 the loss starts around 1.4.

@EugenHotaj
Copy link
Contributor

EugenHotaj commented Feb 6, 2025

Note we were running these with 8xH100 GPUs - otherwise the same default config. With 4xH100 GPUs the loss starts closer to 2.7 and with 2xH100 the loss starts around 1.4.

Yea so we probably forgot to normalize by dp rank. Loss should not change with the number of ranks we have.

@sam-pi
Copy link
Contributor Author

sam-pi commented Feb 6, 2025

Note we were running these with 8xH100 GPUs - otherwise the same default config. With 4xH100 GPUs the loss starts closer to 2.7 and with 2xH100 the loss starts around 1.4.

Yea so we probably forgot to normalize by dp rank. Loss should not change with the number of ranks we have.

Ok I added in all_reduce with average, and it works! I can see starting loss of 0.69 for 2x and 8x H100s

Copy link
Collaborator

@SalmanMohammadi SalmanMohammadi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. I'll let @ebsmothers have one final look - please seek his approval before merging.

Thanks so much for this contribution and your patience through the review process @sam-pi @EugenHotaj @bogdansalyp!

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @sam-pi, @bogdansalyp and others for your patience on this one! I do have a few more questions but I think we're close here. The two main ones are making sure we're applying optimizer-in-backward and activation offloading in the right way. Especially for optimizer-in-backward I'm not sure the loss is being calculated in a consistent way. I mentioned this in an inline comment, but open to dropping this from the recipe altogether if you feel it will make things more straightforward. Modulo these last few comments I think we are in good shape here.

@pytest.mark.integration_test
@pytest.mark.parametrize("optimizer_in_bwd", [False])
# @pytest.mark.parametrize("optimizer_in_bwd", [False, True])
# TODO: whomever fixes opt in bwd checkpointing without async, please fix this test
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this. I just created #2360 for figuring this out, so can update there

Comment on lines +978 to +998
running_metrics["rewards/chosen"] += (
scaling_factor * chosen_rewards.mean()
)
running_metrics["rewards/rejected"] += (
scaling_factor * rejected_rewards.mean()
)
running_metrics["rewards/accuracies"] += (
scaling_factor * reward_accuracies.mean()
)
running_metrics["log_probs/chosen"] += (
scaling_factor * policy_chosen_log_probs.detach().mean()
)
running_metrics["log_probs/rejected"] += (
scaling_factor * policy_rejected_log_probs.detach().mean()
)
running_metrics["logits/chosen"] += (
scaling_factor * policy_chosen_logits_mean
)
running_metrics["logits/rejected"] += (
scaling_factor * policy_rejected_logits_mean
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should definitely figure out a utility to reduce all this logging boilerplate

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have something like

batch_trajectory = Trajectory(
in the PPO recipe (and we've adopted something similar in the new GRPO recipe) - e.g. map(partial(torch.mul, scaling_factor ...

That being said, someone on twitter called it "machine learning engineer bs" so up to you if it's readable

break

# batch is input_ids, labels
with self.activations_handling_ctx:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like we wrap more logic in activations_handling_ctx than we do in the LoRA DPO recipe, why is this (there I believe we just use it for the model forward pass, which intuitively makes more sense to me)? Also running with activation offloading I see about 20-25% slowdown, do you see this as well? If so, we should figure out why, because in @janeyx99's runs with SFT the slowdown was usually only 1-2%. (Admittedly I am running on A100, not H100, so actually my tokens/sec seems quite a bit lower than yours in general)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah good point, that should be updated I think. I can add that now

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated this because if we place it inside concatenated_forward the context will be fired for both the reference and policy model. We only want it for the policy model, and the hooks have been registered on the policy model. I'm not sure what behaviour we'd get if we tried to run the reference model in the context manager setup for the policy model.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just updated in my latest commit so activations_handling_ctx is optionally applied with an additional optional bool arg. This appears to show ~20% speedup in tokens per second @ebsmothers

running_metrics[key], op=torch.distributed.ReduceOp.AVG
)
# We multiply by world_size to undo FSDP2 gradient normalization.
loss = loss * (world_size / num_tokens)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I'm missing something here, but it seems that when optimizer-in-backward is enabled, we are normalizing the loss by the number of tokens, but when it's not enabled, we aren't?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not to drop too much of a bomb, but a separate question is whether it's worth it to onboard optimizer-in-backward to this recipe at all. For a feature like this I want to make sure value >> added complexity. Optimizer-in-backward (especially in distributed recipes) has relatively high complexity, so would only add it if there is a clear need for it. Would be interested to hear your thoughts on that though

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note I reverted commits 893398e b293ce3 9c71083 4fd18bf 6781894 46c59ec eef1b01 716efca

This removes optimizer_in_bwd updates for now, along with a few minor update to use log_rank_zero.

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks so much @sam-pi, @bogdansalyp, @SalmanMohammadi for the team effort on this one! Once CI is green I think we're good to go here.

@ebsmothers ebsmothers merged commit fb52557 into pytorch:main Feb 7, 2025
17 checks passed
@SalmanMohammadi SalmanMohammadi mentioned this pull request Feb 12, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants