forked from pytorch/executorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconvert_weights.py
85 lines (66 loc) · 2.82 KB
/
convert_weights.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import argparse
from typing import Dict
import torch
from torchtune.models.convert_weights import get_mapped_key
from torchtune.training import FullModelHFCheckpointer
# Standard _FROM_META weight mapping of Meta weights to TorchTune + additional bias weight mappings.
_SMOLLM_FROM_META = {
"tok_embeddings.weight": "tok_embeddings.weight",
"norm.weight": "norm.scale",
"output.weight": "output.weight",
"layers.{}.attention.wk.weight": "layers.{}.attn.k_proj.weight",
"layers.{}.attention.wq.weight": "layers.{}.attn.q_proj.weight",
"layers.{}.attention.wv.weight": "layers.{}.attn.v_proj.weight",
"layers.{}.attention.wo.weight": "layers.{}.attn.output_proj.weight",
"layers.{}.attention_norm.weight": "layers.{}.sa_norm.scale",
"layers.{}.ffn_norm.weight": "layers.{}.mlp_norm.scale",
"layers.{}.feed_forward.w1.weight": "layers.{}.mlp.w1.weight",
"layers.{}.feed_forward.w2.weight": "layers.{}.mlp.w2.weight",
"layers.{}.feed_forward.w3.weight": "layers.{}.mlp.w3.weight",
}
def smollm_tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
Convert a state dict from torchtune's format to Meta's format. This function
doesn't handle any sharding or splitting of state dicts. It follows the
state_dict IN -> state_dict OUT pattern.
Args:
state_dict (Dict[str, torch.Tensor]): State dict in torchtune's format.
Returns:
Dict[str, torch.Tensor]: State dict in Meta's format.
"""
converted_state_dict = {}
inverted_mapping_dict = {v: k for k, v in _SMOLLM_FROM_META.items()}
for key, value in state_dict.items():
new_key = get_mapped_key(key, inverted_mapping_dict)
converted_state_dict[new_key] = value
# Input and output embeddings are tied.
converted_state_dict["output.weight"] = converted_state_dict[
"tok_embeddings.weight"
]
return converted_state_dict
def main():
parser = argparse.ArgumentParser(
description="Convert SmolLM weights to Meta format."
)
parser.add_argument(
"input_dir",
type=str,
help="Path to directory containing checkpoint files",
)
parser.add_argument("output", type=str, help="Path to the output checkpoint")
args = parser.parse_args()
# Don't necessarily need to use TorchTune checkpointer, can just aggregate checkpoint files by ourselves.
checkpointer = FullModelHFCheckpointer(
checkpoint_dir=args.input_dir,
checkpoint_files=["model.safetensors"],
output_dir=".",
model_type="MISTRAL",
)
print("Loading checkpoint...")
sd = checkpointer.load_checkpoint()
print("Converting checkpoint...")
sd = smollm_tune_to_meta(sd["model"])
torch.save(sd, args.output)
print(f"Checkpoint saved to {args.output}")
if __name__ == "__main__":
main()