Skip to content

Commit

Permalink
Implement embedding name inference
Browse files Browse the repository at this point in the history
It's now possible to let the adapter decide which is the input embedding layer based on the output
of `model.get_input_embeddings()`. If that fails, the default is still `embed_tokens`.
  • Loading branch information
nemo committed Feb 27, 2025
1 parent 7d2a715 commit 66b8078
Show file tree
Hide file tree
Showing 7 changed files with 157 additions and 22 deletions.
6 changes: 3 additions & 3 deletions src/peft/peft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
PeftType,
TaskType,
_get_batch_size,
_get_input_embeddings_name,
_prepare_prompt_learning_config,
_set_adapter,
_set_trainable,
Expand Down Expand Up @@ -958,9 +959,8 @@ def set_additional_trainable_modules(self, peft_config, adapter_name):
if isinstance(peft_config.trainable_token_indices, dict):
target_layers = peft_config.trainable_token_indices
else:
target_layers = {"embed_tokens": peft_config.trainable_token_indices}

# TODO viable to use model.get_input_embeddings() to find the correct name?
layer_name = _get_input_embeddings_name(self.model) or "embed_tokens"
target_layers = {layer_name: peft_config.trainable_token_indices}

if self.modules_to_save:
for target_layer in target_layers:
Expand Down
24 changes: 13 additions & 11 deletions src/peft/tuners/lora/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,13 +274,14 @@ class LoraConfig(PeftConfig):
megatron_core (`Optional[str]`):
The core module from Megatron to use, defaults to `"megatron.core"`.
trainable_token_indices (`Optional[Union[List[int], dict[str, List[int]]]]`)
Lets you specify which token indices to selectively fine-tune without requiring to re-train the whole
embedding matrix using the `peft.TrainableTokensModel` method. You can either specify a list of indices
which will then target the `embed_tokens` layer, or, if your model is using a different layer for
embedding, you can specify a dictionary where the key is the name of the embedding module and the values
are the list of token indices, e.g. `{'embed_tokens': [0, 1, ...]}`. Note that training with FSDP/DeepSpeed
might not yet be fully supported with this option enabled. Also note that models using weight-tying are
currently not supported.
Lets you specify which token indices to selectively fine-tune without requiring to re-train the
whole embedding matrix using the `peft.TrainableTokensModel` method. You can specify token indices
in two ways. Either you specify a list of indices which will then target the model's input embedding
layer (or, if not found, `embed_tokens`). Alternatively, you can specify a dictionary where the key
is the name of the embedding module and the values are the list of token indices, e.g.
`{'embed_tokens': [0, 1, ...]}`.
Note that training with FSDP/DeepSpeed might not yet be fully supported with this option enabled.
Also note that models using weight-tying are currently not supported.
loftq_config (`Optional[LoftQConfig]`):
The configuration of LoftQ. If this is not None, then LoftQ will be used to quantize the backbone weights
and initialize Lora layers. Also pass `init_lora_weights='loftq'`. Note that you should not pass a
Expand Down Expand Up @@ -444,10 +445,11 @@ class LoraConfig(PeftConfig):
metadata={
"help": (
"Lets you specify which token indices to selectively fine-tune without requiring to re-train the "
"whole embedding matrix using the `peft.TrainableTokensModel` method. You can either specify a list "
"of indices which will then target the `embed_tokens` layer, or, if your model is using a different "
"layer for embedding, you can specify a dictionary where the key is the name of the embedding module "
"and the values are the list of token indices, e.g. `{'embed_tokens': [0, 1, ...]}`. "
"whole embedding matrix using the `peft.TrainableTokensModel` method. You can specify token indices "
"in two ways. Either you specify a list of indices which will then target the model's input embedding "
"layer (or, if not found, `embed_tokens`). Alternatively, you can specify a dictionary where the key "
"is the name of the embedding module and the values are the list of token indices, e.g. "
"`{'embed_tokens': [0, 1, ...]}`. "
"Note that training with FSDP/DeepSpeed might not yet be fully supported with this option enabled. "
"Also note that models using weight-tying are currently not supported."
)
Expand Down
15 changes: 9 additions & 6 deletions src/peft/tuners/trainable_tokens/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,11 @@ class TrainableTokensConfig(PeftConfig):
token with a tokenizer, you can tokenize the string and look at the returned `input_ids`. The closer the
amount of indices is to the total amount of tokens, the less efficient this method gets.
target_modules (`Optional[Union[list[str], str]]`):
List of module names or regex expression of the module names to replace with our `TrainableTokensLayer`.
This is by default the `embed_tokens` layer. But could be multiple embedding-like layers, such as
`embedding`, `encoder.embeddings` or `decoder.embeddings`.
List of module names or regex expression of the module names to replace with our
`TrainableTokensLayer`. If not defined, it will attempt to get the model's input embedding layer if
the model has a `get_input_embeddings` method (transformer models usually do), if that fails the
default is 'embed_tokens'. Other example targets are `embedding`, `encoder.embeddings` or
`decoder.embeddings`.
init_weights (`bool`):
By default the new token weights are initialized to be the same as the respective token embeddings. This
makes TrainableTokens a no-op when not trained. If set to `False` the weights will be random values. Do not
Expand All @@ -61,12 +63,13 @@ class TrainableTokensConfig(PeftConfig):
},
)
target_modules: Optional[Union[list[str], str]] = field(
default_factory=lambda: ["embed_tokens"],
default=None,
metadata={
"help": (
"List of module names or regex expression of the module names to replace with our "
"`TrainableTokensLayer`. This is by default the `embed_tokens` layer. "
"But could be multiple embedding-like layers, such as `embedding`, `encoder.embeddings` or "
"`TrainableTokensLayer`. If not defined, it will default to the model's input embedding layer if "
"the model has a `get_input_embeddings` method (transformer models usually do), if that fails the "
"default is 'embed_tokens'. Other example targets could be `embedding`, `encoder.embeddings` or "
"`decoder.embeddings`."
),
},
Expand Down
6 changes: 5 additions & 1 deletion src/peft/tuners/trainable_tokens/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from peft.config import PeftConfig
from peft.tuners.tuners_utils import BaseTuner, BaseTunerLayer, check_target_module_exists, onload_layer
from peft.utils import AuxiliaryTrainingWrapper, _get_submodules
from peft.utils import AuxiliaryTrainingWrapper, _get_submodules, _get_input_embeddings_name

from .layer import TrainableTokensLayer

Expand All @@ -42,6 +42,10 @@ def __getattr__(self, name: str):
return getattr(self.model, name)

def _prepare_adapter_config(self, peft_config, model_config):
# target_modules can be none which prompts us to infer the embedding layer name ourselves.
if peft_config.target_modules is None:
peft_config.target_modules = _get_input_embeddings_name(self.model) or ["embed_tokens"]

return peft_config

def inject_adapter(
Expand Down
1 change: 1 addition & 0 deletions src/peft/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
_freeze_adapter,
_get_batch_size,
_get_submodules,
_get_input_embeddings_name,
_is_valid_match,
_prepare_prompt_learning_config,
_set_adapter,
Expand Down
12 changes: 12 additions & 0 deletions src/peft/utils/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,6 +665,18 @@ def unload_and_optionally_merge_module(
return self.token_adapter.get_base_layer()


def _get_input_embeddings_name(model):
if not hasattr(model, 'get_input_embeddings'):
return None

input_embeddings = model.get_input_embeddings()
for name, module in model.named_modules():
if module == input_embeddings:
return name

return None


def _get_submodules(model, key):
parent = model.get_submodule(".".join(key.split(".")[:-1]))
target_name = key.split(".")[-1]
Expand Down
115 changes: 114 additions & 1 deletion tests/test_trainable_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,51 @@
from peft import AutoPeftModel, LoraConfig, PeftModel, TrainableTokensConfig, get_peft_model
from peft.tuners.trainable_tokens.layer import TrainableTokensLayer
from peft.utils import get_peft_model_state_dict
from peft.utils.other import TrainableTokensWrapper


class ModelEmb(torch.nn.Module):
def __init__(self):
super().__init__()
self.emb = torch.nn.Embedding(100, 10)
self.lin0 = torch.nn.Linear(10, 1)
def forward(self, x):
return self.lin0(self.emb(x))
def get_input_embeddings(self):
return self.emb


class ModelEmbedIn(torch.nn.Module):
def __init__(self):
super().__init__()
self.embed_in = torch.nn.Embedding(100, 10)
self.lin0 = torch.nn.Linear(10, 1)
def forward(self, x):
return self.lin0(self.embed_in(x))
def get_input_embeddings(self):
return self.embed_in



class ModelEmbedMultiple(torch.nn.Module):
def __init__(self):
super().__init__()
self.embed_in = torch.nn.Embedding(100, 10)
self.embed_in_2 = torch.nn.Embedding(100, 10)
self.lin0 = torch.nn.Linear(10, 1)
def forward(self, x):
return self.lin0(self.embed_in(x) + self.embed_in_2(x))
def get_input_embeddings(self):
return self.embed_in


class ModelEmbedInNoGet(torch.nn.Module):
def __init__(self):
super().__init__()
self.embed_in = torch.nn.Embedding(100, 10)
self.lin0 = torch.nn.Linear(10, 1)
def forward(self, x):
return self.lin0(self.embed_in(x))


class TestTrainableTokens:
Expand Down Expand Up @@ -675,7 +720,6 @@ def test_weight_tying_applied_when_model_is_tied_encoder_decoder(self, peft_conf
assert merged_model.encoder.embed_tokens.weight.data_ptr() == merged_model.lm_head.weight.data_ptr()
assert merged_model.encoder.embed_tokens.weight.data_ptr() == merged_model.decoder.embed_tokens.weight.data_ptr()


@pytest.mark.parametrize(
"peft_config",
[
Expand Down Expand Up @@ -735,3 +779,72 @@ def test_original_module_not_in_state_dict(self, model):

state_dict = peft_model.state_dict()
assert not [k for k in state_dict if ".original_module.weight" in k]

@pytest.fixture
def model_emb(self):
return ModelEmb()

@pytest.fixture
def model_embed_in(self):
return ModelEmbedIn()

@pytest.fixture
def model_embed_in_no_get(self):
return ModelEmbedInNoGet()

@pytest.fixture
def model_embed_multiple(self):
return ModelEmbedMultiple()

@pytest.mark.parametrize('model_fixture_name, getter', [
('model_emb', lambda model: model.emb),
('model_embed_in', lambda model: model.embed_in),
('model', lambda model: model.model.model.embed_tokens),
])
def test_default_embedding_name_is_inferred_standalone(self, model_fixture_name, getter, request):
# make sure that the auto targeting works when `target_module=None`
base_model = request.getfixturevalue(model_fixture_name)

peft_config = TrainableTokensConfig(target_modules=None, token_indices=[0, 1, 3])
peft_model = get_peft_model(base_model, peft_config)

assert isinstance(getter(peft_model), TrainableTokensLayer)

@pytest.mark.parametrize('model_fixture_name, getter', [
('model_emb', lambda model: model.emb),
('model_embed_in', lambda model: model.embed_in),
('model', lambda model: model.model.model.embed_tokens),
])
def test_default_embedding_name_is_inferred_combined(self, model_fixture_name, getter, request):
# make sure that the auto targeting works when `target_module=None`
base_model = request.getfixturevalue(model_fixture_name)

peft_config = LoraConfig(target_modules='all-linear', trainable_token_indices=[0, 1, 3])
peft_model = get_peft_model(base_model, peft_config)

assert isinstance(getter(peft_model), TrainableTokensWrapper)

def test_default_embedding_name_cannot_be_inferred(self, model_embed_in_no_get):
# should default to default value `embed_tokens` which is not present in this model
base_model = model_embed_in_no_get

peft_config = TrainableTokensConfig(target_modules=None, token_indices=[0, 1, 3])

with pytest.raises(ValueError) as e:
peft_model = get_peft_model(base_model, peft_config)

assert "Target modules ['embed_tokens'] not found in the base model." in str(e)

def test_embedding_name_is_used_when_given_standalone(self, model_embed_multiple):
peft_config = TrainableTokensConfig(target_modules="embed_in_2", token_indices=[0, 1, 3])
peft_model = get_peft_model(model_embed_multiple, peft_config)

assert isinstance(peft_model.model.embed_in_2, TrainableTokensLayer)
assert not isinstance(peft_model.model.embed_in, TrainableTokensLayer)

def test_embedding_name_is_used_when_given_combined(self, model_embed_multiple):
peft_config = LoraConfig(target_modules="all-linear", trainable_token_indices={'embed_in_2': [0, 1, 3]})
peft_model = get_peft_model(model_embed_multiple, peft_config)

assert isinstance(peft_model.model.embed_in_2, TrainableTokensWrapper)
assert not isinstance(peft_model.model.embed_in, TrainableTokensWrapper)

0 comments on commit 66b8078

Please sign in to comment.