-
Notifications
You must be signed in to change notification settings - Fork 489
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add SmolLM (smollm2) #9354
Add SmolLM (smollm2) #9354
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
{ | ||
"dim": 576, | ||
"ffn_dim_multiplier": 1, | ||
"hidden_dim": 576, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you! I got it mixed with the hidden_size 😄 |
||
"n_heads": 9, | ||
"n_kv_heads": 3, | ||
"n_layers": 30, | ||
"norm_eps": 1e-05, | ||
"rope_theta": 10000.0, | ||
"use_scaled_rope": false, | ||
"vocab_size": 49152, | ||
"use_hf_rope": true, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this should be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you!! Updated |
||
"attention_qkv_bias": false | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from executorch.example.models.llama.model import Llama2Model | ||
|
||
|
||
class SmolLMModel(Llama2Model): | ||
def __init__(self, **kwargs): | ||
super().__init__(**kwargs) | ||
|
||
|
||
__all__ = [ | ||
"SmolLMModel", | ||
] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
import argparse | ||
from typing import Dict | ||
|
||
import torch | ||
|
||
from torchtune.models.convert_weights import get_mapped_key | ||
|
||
from torchtune.training import FullModelHFCheckpointer | ||
|
||
# Standard _FROM_META weight mapping of Meta weights to TorchTune + additional bias weight mappings. | ||
_SMOLLM_FROM_META = { | ||
"tok_embeddings.weight": "tok_embeddings.weight", | ||
"norm.weight": "norm.scale", | ||
"output.weight": "output.weight", | ||
"layers.{}.attention.wk.weight": "layers.{}.attn.k_proj.weight", | ||
"layers.{}.attention.wq.weight": "layers.{}.attn.q_proj.weight", | ||
"layers.{}.attention.wv.weight": "layers.{}.attn.v_proj.weight", | ||
"layers.{}.attention.wo.weight": "layers.{}.attn.output_proj.weight", | ||
"layers.{}.attention_norm.weight": "layers.{}.sa_norm.scale", | ||
"layers.{}.ffn_norm.weight": "layers.{}.mlp_norm.scale", | ||
"layers.{}.feed_forward.w1.weight": "layers.{}.mlp.w1.weight", | ||
"layers.{}.feed_forward.w2.weight": "layers.{}.mlp.w2.weight", | ||
"layers.{}.feed_forward.w3.weight": "layers.{}.mlp.w3.weight", | ||
} | ||
|
||
|
||
def smollm_tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: | ||
""" | ||
Convert a state dict from torchtune's format to Meta's format. This function | ||
doesn't handle any sharding or splitting of state dicts. It follows the | ||
state_dict IN -> state_dict OUT pattern. | ||
|
||
Args: | ||
state_dict (Dict[str, torch.Tensor]): State dict in torchtune's format. | ||
|
||
Returns: | ||
Dict[str, torch.Tensor]: State dict in Meta's format. | ||
""" | ||
converted_state_dict = {} | ||
inverted_mapping_dict = {v: k for k, v in _SMOLLM_FROM_META.items()} | ||
for key, value in state_dict.items(): | ||
new_key = get_mapped_key(key, inverted_mapping_dict) | ||
converted_state_dict[new_key] = value | ||
|
||
return converted_state_dict | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser( | ||
description="Convert SmolLM weights to Meta format." | ||
) | ||
parser.add_argument( | ||
"input_dir", | ||
type=str, | ||
help="Path to directory containing checkpoint files", | ||
) | ||
parser.add_argument("output", type=str, help="Path to the output checkpoint") | ||
|
||
args = parser.parse_args() | ||
|
||
# Don't necessarily need to use TorchTune checkpointer, can just aggregate checkpoint files by ourselves. | ||
checkpointer = FullModelHFCheckpointer( | ||
checkpoint_dir=args.input_dir, | ||
checkpoint_files=["model.safetensors"], | ||
output_dir=".", | ||
model_type="LLAMA", | ||
) | ||
|
||
print("Loading checkpoint...") | ||
sd = checkpointer.load_checkpoint() | ||
|
||
print("Converting checkpoint...") | ||
sd = smollm_tune_to_meta(sd["model"]) | ||
|
||
torch.save(sd, args.output) | ||
print(f"Checkpoint saved to {args.output}") | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's some size mismatch error during quantization
I'm not very sure about the definiation of dim and ffn_dim_multiplier here, looks like some wrong value here? Would you mind provide some pointer/context on this? Appreciate it! @jackzhxng
The model structure is below
LlamaForCausalLM(
(model): LlamaModel(
(embed_tokens): Embedding(49152, 576)
(layers): ModuleList(
(0-29): 30 x LlamaDecoderLayer(
(self_attn): LlamaSdpaAttention(
(q_proj): Linear(in_features=576, out_features=576, bias=False)
(k_proj): Linear(in_features=576, out_features=192, bias=False)
(v_proj): Linear(in_features=576, out_features=192, bias=False)
(o_proj): Linear(in_features=576, out_features=576, bias=False)
(rotary_emb): LlamaRotaryEmbedding()
)
(mlp): LlamaMLP(
(gate_proj): Linear(in_features=576, out_features=1536, bias=False)
(up_proj): Linear(in_features=576, out_features=1536, bias=False)
(down_proj): Linear(in_features=1536, out_features=576, bias=False)
(act_fn): SiLU()
)
(input_layernorm): LlamaRMSNorm((576,), eps=1e-05)
(post_attention_layernorm): LlamaRMSNorm((576,), eps=1e-05)
)
)
(norm): LlamaRMSNorm((576,), eps=1e-05)
(rotary_emb): LlamaRotaryEmbedding()
)
(lm_head): Linear(in_features=576, out_features=49152, bias=False)
)