Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SmolLM (smollm2) #9354

Merged
merged 5 commits into from
Mar 20, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@
"static_llama",
"qwen2_5",
"phi-4-mini",
"smollm",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rename this and directory to smolllm2

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Should it be smollm2 or smolllm2?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah - it should be smollm2*

]
TORCHTUNE_DEFINED_MODELS = ["llama3_2_vision"]

Expand Down
14 changes: 14 additions & 0 deletions examples/models/smollm/135M_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
{
"dim": 576,
"ffn_dim_multiplier": 1,
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's some size mismatch error during quantization

	size mismatch for layers.0.feed_forward.w1.weight: copying a param with shape torch.Size([1536, 576]) from checkpoint, the shape in current model is torch.Size([576, 576]).
	size mismatch for layers.0.feed_forward.w2.weight: copying a param with shape torch.Size([576, 1536]) from checkpoint, the shape in current model is torch.Size([576, 576]).
	size mismatch for layers.0.feed_forward.w3.weight: copying a param with shape torch.Size([1536, 576]) from checkpoint, the shape in current model is torch.Size([576, 576]).

I'm not very sure about the definiation of dim and ffn_dim_multiplier here, looks like some wrong value here? Would you mind provide some pointer/context on this? Appreciate it! @jackzhxng

The model structure is below

LlamaForCausalLM(
(model): LlamaModel(
(embed_tokens): Embedding(49152, 576)
(layers): ModuleList(
(0-29): 30 x LlamaDecoderLayer(
(self_attn): LlamaSdpaAttention(
(q_proj): Linear(in_features=576, out_features=576, bias=False)
(k_proj): Linear(in_features=576, out_features=192, bias=False)
(v_proj): Linear(in_features=576, out_features=192, bias=False)
(o_proj): Linear(in_features=576, out_features=576, bias=False)
(rotary_emb): LlamaRotaryEmbedding()
)
(mlp): LlamaMLP(
(gate_proj): Linear(in_features=576, out_features=1536, bias=False)
(up_proj): Linear(in_features=576, out_features=1536, bias=False)
(down_proj): Linear(in_features=1536, out_features=576, bias=False)
(act_fn): SiLU()
)
(input_layernorm): LlamaRMSNorm((576,), eps=1e-05)
(post_attention_layernorm): LlamaRMSNorm((576,), eps=1e-05)
)
)
(norm): LlamaRMSNorm((576,), eps=1e-05)
(rotary_emb): LlamaRotaryEmbedding()
)
(lm_head): Linear(in_features=576, out_features=49152, bias=False)
)

"hidden_dim": 576,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! I got it mixed with the hidden_size 😄

"n_heads": 9,
"n_kv_heads": 3,
"n_layers": 30,
"norm_eps": 1e-05,
"rope_theta": 10000.0,
"use_scaled_rope": false,
"vocab_size": 49152,
"use_hf_rope": true,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be false

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!! Updated

"attention_qkv_bias": false
}
Empty file.
14 changes: 14 additions & 0 deletions examples/models/smollm/__init__
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from executorch.example.models.llama.model import Llama2Model


class SmolLMModel(Llama2Model):
def __init__(self, **kwargs):
super().__init__(**kwargs)


__all__ = [
"SmolLMModel",
]
85 changes: 85 additions & 0 deletions examples/models/smollm/convert_weights.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import argparse
from typing import Dict

import torch

from torchtune.models.convert_weights import get_mapped_key

from torchtune.training import FullModelHFCheckpointer

# Standard _FROM_META weight mapping of Meta weights to TorchTune + additional bias weight mappings.
_SMOLLM_FROM_META = {
"tok_embeddings.weight": "tok_embeddings.weight",
"norm.weight": "norm.scale",
"output.weight": "output.weight",
"layers.{}.attention.wk.weight": "layers.{}.attn.k_proj.weight",
"layers.{}.attention.wq.weight": "layers.{}.attn.q_proj.weight",
"layers.{}.attention.wv.weight": "layers.{}.attn.v_proj.weight",
"layers.{}.attention.wo.weight": "layers.{}.attn.output_proj.weight",
"layers.{}.attention_norm.weight": "layers.{}.sa_norm.scale",
"layers.{}.ffn_norm.weight": "layers.{}.mlp_norm.scale",
"layers.{}.feed_forward.w1.weight": "layers.{}.mlp.w1.weight",
"layers.{}.feed_forward.w2.weight": "layers.{}.mlp.w2.weight",
"layers.{}.feed_forward.w3.weight": "layers.{}.mlp.w3.weight",
}


def smollm_tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
Convert a state dict from torchtune's format to Meta's format. This function
doesn't handle any sharding or splitting of state dicts. It follows the
state_dict IN -> state_dict OUT pattern.

Args:
state_dict (Dict[str, torch.Tensor]): State dict in torchtune's format.

Returns:
Dict[str, torch.Tensor]: State dict in Meta's format.
"""
converted_state_dict = {}
inverted_mapping_dict = {v: k for k, v in _SMOLLM_FROM_META.items()}
for key, value in state_dict.items():
new_key = get_mapped_key(key, inverted_mapping_dict)
converted_state_dict[new_key] = value

# Input and output embeddings are tied.
converted_state_dict["output.weight"] = converted_state_dict[
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be because of this, input and output embeddings are not shared for Llama which this model is based off of

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sense, removed this

"tok_embeddings.weight"
]

return converted_state_dict


def main():
parser = argparse.ArgumentParser(
description="Convert SmolLM weights to Meta format."
)
parser.add_argument(
"input_dir",
type=str,
help="Path to directory containing checkpoint files",
)
parser.add_argument("output", type=str, help="Path to the output checkpoint")

args = parser.parse_args()

# Don't necessarily need to use TorchTune checkpointer, can just aggregate checkpoint files by ourselves.
checkpointer = FullModelHFCheckpointer(
checkpoint_dir=args.input_dir,
checkpoint_files=["model.safetensors"],
output_dir=".",
model_type="MISTRAL",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change to Llama

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you updated~

)

print("Loading checkpoint...")
sd = checkpointer.load_checkpoint()

print("Converting checkpoint...")
sd = smollm_tune_to_meta(sd["model"])

torch.save(sd, args.output)
print(f"Checkpoint saved to {args.output}")


if __name__ == "__main__":
main()