Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix training #774

Open
wants to merge 24 commits into
base: main
Choose a base branch
from
Open

Fix training #774

wants to merge 24 commits into from

Conversation

michaelbenayoun
Copy link
Member

@michaelbenayoun michaelbenayoun commented Jan 31, 2025

What does this PR do?

  • Updates the NeuronTrainer to match Transformers version
  • Fixes the way mixed-precision training is handled. It was breaking multiple trainings such as Llama. We do not use XLA_USE_BF16 and XLA_DOWNCAST_BF16 flags anymore following instructions from here
  • Fixes support for gradient clipping: now it always happens between the reduction of the gradients across the devices and the optimizer step. The gradient norm is also now always reported when logging.
  • Adds support for ignore_index in the parallel_cross_entropy_loss. There was a big issue in training when using TP, the model was not learning. After investigating, it was linked to the input being padded and the vanilla parallel_cross_entropy from neuronx_distributed not supporting ignore_index:
    • First, loss.mean() does not work in this case because the loss for the ignored tokens is not zeroed.
    • Second, the ignored tokens contributed to the gradient, which effectively destroys training.

For now DP + TP can lead to compilation issues with SDK 2.20, but they seem to be gone with SDK 2.21.

Tests performed

  • Llama (HuggingFaceTB/SmolLM2-135M-Instruct) can overfit with dp=1 tp=1
  • Llama (HuggingFaceTB/SmolLM2-135M-Instruct) + LoRA can overfit with dp=1 tp=1
  • Llama (meta-llama/Llama-3.2-1B) can overfit with dp=1 tp=2
  • Llama + LoRA (meta-llama/Llama-3.2-1B) can overfit with dp=1 tp=2
  • Llama (meta-llama/Llama-3.2-1B) can overfit with dp=4 tp=2 (Only tested on SDK 2.21, otherwise compiler error)
  • Actual training of Llama (meta-llama/Llama-3.2-1B) dp=4 tp=2 on SDK 2.21 and compared to GPUs

W B Chart 20_02_2025 17_35_18

W B Chart 20_02_2025 17_35_38

To be done in following PRs

  • Add AdamW_FP32Params for a more stable training in mixed-precision

Copy link
Collaborator

@tengomucho tengomucho left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find difficult to understand what your changes have accomplished, can you give more details about that please?
Also, would you mind pointing to a test (or example) that now works after your changes?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@michaelbenayoun michaelbenayoun changed the title Fixes and updates training code for Transformers 4.48.1 Fix training Feb 20, 2025
Copy link
Collaborator

@dacorvo dacorvo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Several failures in the CI:

  • style check,
  • errors in the training code directly related to some of the changes (_PARALLEL_CROSS_ENTROPY_SHOULD_PRESERVE_INPUT is not found).

Maybe this should be rebased on the 2.21.1 branch once it is merged.

Copy link
Collaborator

@JingyaHuang JingyaHuang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not too much to comment about, kudos for finding the fix!!

(will approve when the CIs pass)

@@ -426,6 +434,37 @@ def _peft_tuner_embedding_to_parallel_embedding(
return parent, parallel_linear


class ParallelEmbeddingsFixed(layers.ParallelEmbedding):
# TODO: remove when updating to neuronx_distributed >= 0.10.0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about raising an error when the version>= 0.10.0 (to not forget removing it).

@michaelbenayoun
Copy link
Member Author

Several failures in the CI:

  • style check,
  • errors in the training code directly related to some of the changes (_PARALLEL_CROSS_ENTROPY_SHOULD_PRESERVE_INPUT is not found).

Maybe this should be rebased on the 2.21.1 branch once it is merged.

I am working on fixing the CI failures.
If you want you can merge the 2.21.1 branch, then I will rebase and adapt: some changes in the PR are linked to 2.20, so if we are going to move to 2.21 very soon it does not make sense to add them.

@dacorvo
Copy link
Collaborator

dacorvo commented Feb 21, 2025

I bumped the SDK version to 2.21.1: you can now rebase your branch and drop the 2.20 specifics.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants