diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index ff2b82f6c65..cb887252571 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -95,6 +95,7 @@ "static_llama", "qwen2_5", "phi-4-mini", + "smollm2", ] TORCHTUNE_DEFINED_MODELS = ["llama3_2_vision"] diff --git a/examples/models/smollm2/135M_config.json b/examples/models/smollm2/135M_config.json new file mode 100644 index 00000000000..604c7e94ab5 --- /dev/null +++ b/examples/models/smollm2/135M_config.json @@ -0,0 +1,14 @@ +{ + "dim": 576, + "ffn_dim_multiplier": 1, + "hidden_dim": 1536, + "n_heads": 9, + "n_kv_heads": 3, + "n_layers": 30, + "norm_eps": 1e-05, + "rope_theta": 10000.0, + "use_scaled_rope": false, + "vocab_size": 49152, + "use_hf_rope": false, + "attention_qkv_bias": false + } diff --git a/examples/models/smollm2/__init__ b/examples/models/smollm2/__init__ new file mode 100644 index 00000000000..3d01bf9eb42 --- /dev/null +++ b/examples/models/smollm2/__init__ @@ -0,0 +1,14 @@ +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from executorch.example.models.llama.model import Llama2Model + + +class SmolLM2Model(Llama2Model): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + +__all__ = [ + "SmolLM2Model", +] diff --git a/examples/models/smollm2/convert_weights.py b/examples/models/smollm2/convert_weights.py new file mode 100644 index 00000000000..db80bd47b8c --- /dev/null +++ b/examples/models/smollm2/convert_weights.py @@ -0,0 +1,80 @@ +import argparse +from typing import Dict + +import torch + +from torchtune.models.convert_weights import get_mapped_key + +from torchtune.training import FullModelHFCheckpointer + +# Standard _FROM_META weight mapping of Meta weights to TorchTune + additional bias weight mappings. +_SMOLLM_FROM_META = { + "tok_embeddings.weight": "tok_embeddings.weight", + "norm.weight": "norm.scale", + "output.weight": "output.weight", + "layers.{}.attention.wk.weight": "layers.{}.attn.k_proj.weight", + "layers.{}.attention.wq.weight": "layers.{}.attn.q_proj.weight", + "layers.{}.attention.wv.weight": "layers.{}.attn.v_proj.weight", + "layers.{}.attention.wo.weight": "layers.{}.attn.output_proj.weight", + "layers.{}.attention_norm.weight": "layers.{}.sa_norm.scale", + "layers.{}.ffn_norm.weight": "layers.{}.mlp_norm.scale", + "layers.{}.feed_forward.w1.weight": "layers.{}.mlp.w1.weight", + "layers.{}.feed_forward.w2.weight": "layers.{}.mlp.w2.weight", + "layers.{}.feed_forward.w3.weight": "layers.{}.mlp.w3.weight", +} + + +def smollm_tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """ + Convert a state dict from torchtune's format to Meta's format. This function + doesn't handle any sharding or splitting of state dicts. It follows the + state_dict IN -> state_dict OUT pattern. + + Args: + state_dict (Dict[str, torch.Tensor]): State dict in torchtune's format. + + Returns: + Dict[str, torch.Tensor]: State dict in Meta's format. + """ + converted_state_dict = {} + inverted_mapping_dict = {v: k for k, v in _SMOLLM_FROM_META.items()} + for key, value in state_dict.items(): + new_key = get_mapped_key(key, inverted_mapping_dict) + converted_state_dict[new_key] = value + + return converted_state_dict + + +def main(): + parser = argparse.ArgumentParser( + description="Convert SmolLM weights to Meta format." + ) + parser.add_argument( + "input_dir", + type=str, + help="Path to directory containing checkpoint files", + ) + parser.add_argument("output", type=str, help="Path to the output checkpoint") + + args = parser.parse_args() + + # Don't necessarily need to use TorchTune checkpointer, can just aggregate checkpoint files by ourselves. + checkpointer = FullModelHFCheckpointer( + checkpoint_dir=args.input_dir, + checkpoint_files=["model.safetensors"], + output_dir=".", + model_type="LLAMA", + ) + + print("Loading checkpoint...") + sd = checkpointer.load_checkpoint() + + print("Converting checkpoint...") + sd = smollm_tune_to_meta(sd["model"]) + + torch.save(sd, args.output) + print(f"Checkpoint saved to {args.output}") + + +if __name__ == "__main__": + main()