Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: support grad clipping for TP through replicating non-sharded modules #36132

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

kmehant
Copy link
Contributor

@kmehant kmehant commented Feb 11, 2025

What does this PR do?

torch.nn.utils.clip_grad_norm_ does not support heterogenous set of parameters having a mix of DTensors and Tensors. This PR allows for gradient clipping by distributing non-sharded modules that are not involved in TP. We replicate all such modules across the device mesh.

The PR also adds new parallel style ReplicateParallel so that the existing TP APIs can be used as is for this module replication operation. We could think of contributing this back to PyTorch if it makes sense (cc: @kwen2501) otherwise we can maintain it in transformers.

⭐ Note : We would rebase this PR once #34194 is merged some of the workflow changes that you see here would disappear once the PR is merged.

fixes: #36296

Concerns

Concern 1

When we do two TP runs with gradient clipping with exact same training settings we dont reproduce exact loss parity between the runs though both the runs converge eventually. I am worried if Replicate sharding has something to do here.
Screenshot 2025-02-11 at 6 51 37 PM

Concern 2

Grad norms are not same on each rank, I would assume in TP training the grad norms should come out to be same across the ranks however thats not the case

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@ArthurZucker @muellerzr and @SunMarc
@kwen2501 from PyTorch

@SunMarc
Copy link
Member

SunMarc commented Feb 20, 2025

Can you have a look @kwen2501 ?

Signed-off-by: Mehant Kammakomati <[email protected]>
@kwen2501
Copy link
Contributor

@weifengpy @mori360 do you mind having a look at the two concerns here? Thanks!

@weifengpy
Copy link

@weifengpy @mori360 do you mind having a look at the two concerns here? Thanks!

does not support heterogenous set of parameters having a mix of DTensors and Tensors

implicit_replication is invented to mix DTensor with plain tensors. Maybe it's cleaner here.

from torch.distributed._tensor.experimental import implicit_replication
with implicit_replication():
   # call gradient clipping here

code pointer: https://github.com/pytorch/pytorch/blob/8b818ab58f635f999de2c8a5bf8e6c01d0c122ed/test/distributed/tensor/parallel/test_tp_examples.py#L262-L264

@kmehant
Copy link
Contributor Author

kmehant commented Feb 21, 2025

@weifengpy Do you recommend to use implicit_replication instead?

Copy link
Contributor

@kwen2501 kwen2501 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing gradient clipping with TP.
Functionality wise, the code change looks reasonable.
I am consulting with my colleagues to see if there is a must to explicitly annotate non-sharded modules as Replicated.

Comment on lines +120 to +127
"layers.*.self_attn.o_proj": "rowwise_output_dtensor",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
"layers.*.mlp.down_proj": "rowwise_output_dtensor",
"embed_tokens": "replicateparallel_output_dtensor",
"layers.*.post_attention_layernorm": "replicateparallel_output_dtensor",
"layers.*.input_layernorm": "replicateparallel_output_dtensor",
"norm": "replicateparallel_output_dtensor",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for extending the configs here.
I wonder if some of these settings would be more interesting to training than to inference?
(On the other hand, I don't know much about HF's user profile -- training more or inference more?)
If some of the settings are specific to training, is it possible to separate them out? Or, shall we make the config somehow customizable at run time?

Comment on lines +354 to +358
class ReplicateParallel(ParallelStyle):
"""
Replicate a nn.Module.
Users can compose it together with other parallel styles like RowwiseParallel to achieve a fully distributed model.
Fully distributed model is needed for gradient clipping.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@weifengpy @wz337 @tianyu-l
I wonder if there is anything we can do on DTensor side so that users don't have to annotate the entire model to perform gradient clipping?

Comment on lines -347 to -348
# TODO need to add the __repr__ that shows that it is a colwise parallel
# See https://github.com/pytorch/pytorch/issues/145726
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: keep this TODO?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

tensor parallel training bug
4 participants