-
Notifications
You must be signed in to change notification settings - Fork 28.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: support grad clipping for TP through replicating non-sharded modules #36132
base: main
Are you sure you want to change the base?
Conversation
289af7a
to
2192a35
Compare
Can you have a look @kwen2501 ? |
Signed-off-by: Mehant Kammakomati <[email protected]>
@weifengpy @mori360 do you mind having a look at the two concerns here? Thanks! |
|
@weifengpy Do you recommend to use |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing gradient clipping with TP.
Functionality wise, the code change looks reasonable.
I am consulting with my colleagues to see if there is a must to explicitly annotate non-sharded modules as Replicated.
"layers.*.self_attn.o_proj": "rowwise_output_dtensor", | ||
"layers.*.mlp.gate_proj": "colwise", | ||
"layers.*.mlp.up_proj": "colwise", | ||
"layers.*.mlp.down_proj": "rowwise", | ||
"layers.*.mlp.down_proj": "rowwise_output_dtensor", | ||
"embed_tokens": "replicateparallel_output_dtensor", | ||
"layers.*.post_attention_layernorm": "replicateparallel_output_dtensor", | ||
"layers.*.input_layernorm": "replicateparallel_output_dtensor", | ||
"norm": "replicateparallel_output_dtensor", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for extending the configs here.
I wonder if some of these settings would be more interesting to training than to inference?
(On the other hand, I don't know much about HF's user profile -- training more or inference more?)
If some of the settings are specific to training, is it possible to separate them out? Or, shall we make the config somehow customizable at run time?
class ReplicateParallel(ParallelStyle): | ||
""" | ||
Replicate a nn.Module. | ||
Users can compose it together with other parallel styles like RowwiseParallel to achieve a fully distributed model. | ||
Fully distributed model is needed for gradient clipping. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@weifengpy @wz337 @tianyu-l
I wonder if there is anything we can do on DTensor side so that users don't have to annotate the entire model to perform gradient clipping?
# TODO need to add the __repr__ that shows that it is a colwise parallel | ||
# See https://github.com/pytorch/pytorch/issues/145726 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: keep this TODO?
What does this PR do?
torch.nn.utils.clip_grad_norm_
does not support heterogenous set of parameters having a mix of DTensors and Tensors. This PR allows for gradient clipping by distributing non-sharded modules that are not involved in TP. We replicate all such modules across the device mesh.The PR also adds new parallel style
ReplicateParallel
so that the existing TP APIs can be used as is for this module replication operation. We could think of contributing this back to PyTorch if it makes sense (cc: @kwen2501) otherwise we can maintain it in transformers.⭐ Note : We would rebase this PR once #34194 is merged some of the workflow changes that you see here would disappear once the PR is merged.
fixes: #36296
Concerns
Concern 1
When we do two TP runs with gradient clipping with exact same training settings we dont reproduce exact loss parity between the runs though both the runs converge eventually. I am worried if

Replicate
sharding has something to do here.Concern 2
Grad norms are not same on each rank, I would assume in TP training the grad norms should come out to be same across the ranks however thats not the case
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@ArthurZucker @muellerzr and @SunMarc
@kwen2501 from PyTorch