Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trainable Tokens: Support for Weight Tying #2399

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

githubnemo
Copy link
Collaborator

@githubnemo githubnemo commented Feb 25, 2025

This is a follow-up PR of #2376 to add support for weight-tying. Do not merge before the other is not merged.

What is this

Some models, such as gpt2, tie the weights between the LM head and the input embeddings for various reasons. If we use the trainable tokens adapter, we're changing the result of the forward() of the input embeddings but we do not change the weights (unless we merge()). This means that the changes are not reflected in the tied weights, such as the LM head, leading to wrong results when training.

How it is solved

The current approach is searching for tied layers and putting TrainableTokensLayer adapters on them as well but initialized to use the parameters from the embedding layer's TrainableTokensLayer. This is done via the tied_adapter argument of TrailableTokensLayer.__init__().

What needs to be done

  • encoder-decoder model tests
  • support for standalone TrainableTokens adapter
  • more tests

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@githubnemo githubnemo force-pushed the feature/custom-token-tuner-weight-tying branch from 69948b9 to ac70db6 Compare February 26, 2025 16:00
nemo added 5 commits February 26, 2025 17:21
Notably we are removing the duplication filter of `named_modules` when searching for
the (tied) target modules since tied weights are by definition duplicates.
It's now possible to let the adapter decide which is the input embedding layer based on the output
of `model.get_input_embeddings()`. If that fails, the default is still `embed_tokens`.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants