-
Notifications
You must be signed in to change notification settings - Fork 490
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add SmolLM (smollm2) #9354
Add SmolLM (smollm2) #9354
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -94,6 +94,7 @@ | |
"static_llama", | ||
"qwen2_5", | ||
"phi-4-mini", | ||
"smollm", | ||
] | ||
TORCHTUNE_DEFINED_MODELS = ["llama3_2_vision"] | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
{ | ||
"dim": 576, | ||
"ffn_dim_multiplier": 1, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There's some size mismatch error during quantization
I'm not very sure about the definiation of dim and ffn_dim_multiplier here, looks like some wrong value here? Would you mind provide some pointer/context on this? Appreciate it! @jackzhxng The model structure is below LlamaForCausalLM( |
||
"hidden_dim": 576, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you! I got it mixed with the hidden_size 😄 |
||
"n_heads": 9, | ||
"n_kv_heads": 3, | ||
"n_layers": 30, | ||
"norm_eps": 1e-05, | ||
"rope_theta": 10000.0, | ||
"use_scaled_rope": false, | ||
"vocab_size": 49152, | ||
"use_hf_rope": true, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this should be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you!! Updated |
||
"attention_qkv_bias": false | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from executorch.example.models.llama.model import Llama2Model | ||
|
||
|
||
class SmolLMModel(Llama2Model): | ||
def __init__(self, **kwargs): | ||
super().__init__(**kwargs) | ||
|
||
|
||
__all__ = [ | ||
"SmolLMModel", | ||
] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
import argparse | ||
from typing import Dict | ||
|
||
import torch | ||
|
||
from torchtune.models.convert_weights import get_mapped_key | ||
|
||
from torchtune.training import FullModelHFCheckpointer | ||
|
||
# Standard _FROM_META weight mapping of Meta weights to TorchTune + additional bias weight mappings. | ||
_SMOLLM_FROM_META = { | ||
"tok_embeddings.weight": "tok_embeddings.weight", | ||
"norm.weight": "norm.scale", | ||
"output.weight": "output.weight", | ||
"layers.{}.attention.wk.weight": "layers.{}.attn.k_proj.weight", | ||
"layers.{}.attention.wq.weight": "layers.{}.attn.q_proj.weight", | ||
"layers.{}.attention.wv.weight": "layers.{}.attn.v_proj.weight", | ||
"layers.{}.attention.wo.weight": "layers.{}.attn.output_proj.weight", | ||
"layers.{}.attention_norm.weight": "layers.{}.sa_norm.scale", | ||
"layers.{}.ffn_norm.weight": "layers.{}.mlp_norm.scale", | ||
"layers.{}.feed_forward.w1.weight": "layers.{}.mlp.w1.weight", | ||
"layers.{}.feed_forward.w2.weight": "layers.{}.mlp.w2.weight", | ||
"layers.{}.feed_forward.w3.weight": "layers.{}.mlp.w3.weight", | ||
} | ||
|
||
|
||
def smollm_tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: | ||
""" | ||
Convert a state dict from torchtune's format to Meta's format. This function | ||
doesn't handle any sharding or splitting of state dicts. It follows the | ||
state_dict IN -> state_dict OUT pattern. | ||
|
||
Args: | ||
state_dict (Dict[str, torch.Tensor]): State dict in torchtune's format. | ||
|
||
Returns: | ||
Dict[str, torch.Tensor]: State dict in Meta's format. | ||
""" | ||
converted_state_dict = {} | ||
inverted_mapping_dict = {v: k for k, v in _SMOLLM_FROM_META.items()} | ||
for key, value in state_dict.items(): | ||
new_key = get_mapped_key(key, inverted_mapping_dict) | ||
converted_state_dict[new_key] = value | ||
|
||
# Input and output embeddings are tied. | ||
converted_state_dict["output.weight"] = converted_state_dict[ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Might be because of this, input and output embeddings are not shared for Llama which this model is based off of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Make sense, removed this |
||
"tok_embeddings.weight" | ||
] | ||
|
||
return converted_state_dict | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser( | ||
description="Convert SmolLM weights to Meta format." | ||
) | ||
parser.add_argument( | ||
"input_dir", | ||
type=str, | ||
help="Path to directory containing checkpoint files", | ||
) | ||
parser.add_argument("output", type=str, help="Path to the output checkpoint") | ||
|
||
args = parser.parse_args() | ||
|
||
# Don't necessarily need to use TorchTune checkpointer, can just aggregate checkpoint files by ourselves. | ||
checkpointer = FullModelHFCheckpointer( | ||
checkpoint_dir=args.input_dir, | ||
checkpoint_files=["model.safetensors"], | ||
output_dir=".", | ||
model_type="MISTRAL", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Change to Llama There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you updated~ |
||
) | ||
|
||
print("Loading checkpoint...") | ||
sd = checkpointer.load_checkpoint() | ||
|
||
print("Converting checkpoint...") | ||
sd = smollm_tune_to_meta(sd["model"]) | ||
|
||
torch.save(sd, args.output) | ||
print(f"Checkpoint saved to {args.output}") | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rename this and directory to smolllm2
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Should it be smollm2 or smolllm2?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah - it should be smollm2*