Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SmolLM (smollm2) #9354

Merged
merged 5 commits into from
Mar 20, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
@@ -95,6 +95,7 @@
"static_llama",
"qwen2_5",
"phi-4-mini",
"smollm2",
]
TORCHTUNE_DEFINED_MODELS = ["llama3_2_vision"]

14 changes: 14 additions & 0 deletions examples/models/smollm2/135M_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
{
"dim": 576,
"ffn_dim_multiplier": 1,
"hidden_dim": 1536,
"n_heads": 9,
"n_kv_heads": 3,
"n_layers": 30,
"norm_eps": 1e-05,
"rope_theta": 10000.0,
"use_scaled_rope": false,
"vocab_size": 49152,
"use_hf_rope": false,
"attention_qkv_bias": false
}
14 changes: 14 additions & 0 deletions examples/models/smollm2/__init__
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from executorch.example.models.llama.model import Llama2Model


class SmolLM2Model(Llama2Model):
def __init__(self, **kwargs):
super().__init__(**kwargs)


__all__ = [
"SmolLM2Model",
]
80 changes: 80 additions & 0 deletions examples/models/smollm2/convert_weights.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import argparse
from typing import Dict

import torch

from torchtune.models.convert_weights import get_mapped_key

from torchtune.training import FullModelHFCheckpointer

# Standard _FROM_META weight mapping of Meta weights to TorchTune + additional bias weight mappings.
_SMOLLM_FROM_META = {
"tok_embeddings.weight": "tok_embeddings.weight",
"norm.weight": "norm.scale",
"output.weight": "output.weight",
"layers.{}.attention.wk.weight": "layers.{}.attn.k_proj.weight",
"layers.{}.attention.wq.weight": "layers.{}.attn.q_proj.weight",
"layers.{}.attention.wv.weight": "layers.{}.attn.v_proj.weight",
"layers.{}.attention.wo.weight": "layers.{}.attn.output_proj.weight",
"layers.{}.attention_norm.weight": "layers.{}.sa_norm.scale",
"layers.{}.ffn_norm.weight": "layers.{}.mlp_norm.scale",
"layers.{}.feed_forward.w1.weight": "layers.{}.mlp.w1.weight",
"layers.{}.feed_forward.w2.weight": "layers.{}.mlp.w2.weight",
"layers.{}.feed_forward.w3.weight": "layers.{}.mlp.w3.weight",
}


def smollm_tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
Convert a state dict from torchtune's format to Meta's format. This function
doesn't handle any sharding or splitting of state dicts. It follows the
state_dict IN -> state_dict OUT pattern.

Args:
state_dict (Dict[str, torch.Tensor]): State dict in torchtune's format.

Returns:
Dict[str, torch.Tensor]: State dict in Meta's format.
"""
converted_state_dict = {}
inverted_mapping_dict = {v: k for k, v in _SMOLLM_FROM_META.items()}
for key, value in state_dict.items():
new_key = get_mapped_key(key, inverted_mapping_dict)
converted_state_dict[new_key] = value

return converted_state_dict


def main():
parser = argparse.ArgumentParser(
description="Convert SmolLM weights to Meta format."
)
parser.add_argument(
"input_dir",
type=str,
help="Path to directory containing checkpoint files",
)
parser.add_argument("output", type=str, help="Path to the output checkpoint")

args = parser.parse_args()

# Don't necessarily need to use TorchTune checkpointer, can just aggregate checkpoint files by ourselves.
checkpointer = FullModelHFCheckpointer(
checkpoint_dir=args.input_dir,
checkpoint_files=["model.safetensors"],
output_dir=".",
model_type="LLAMA",
)

print("Loading checkpoint...")
sd = checkpointer.load_checkpoint()

print("Converting checkpoint...")
sd = smollm_tune_to_meta(sd["model"])

torch.save(sd, args.output)
print(f"Checkpoint saved to {args.output}")


if __name__ == "__main__":
main()