Skip to content

Commit 1eb2ad6

Browse files
Inklingdqjackzhxng
authored andcommitted
Add SmolLM (smollm2) (#9354)
1 parent 8cd1b93 commit 1eb2ad6

File tree

4 files changed

+109
-0
lines changed

4 files changed

+109
-0
lines changed

examples/models/llama/export_llama_lib.py

+1
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@
9696
"static_llama",
9797
"qwen2_5",
9898
"phi-4-mini",
99+
"smollm2",
99100
]
100101
TORCHTUNE_DEFINED_MODELS = ["llama3_2_vision"]
101102

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
{
2+
"dim": 576,
3+
"ffn_dim_multiplier": 1,
4+
"hidden_dim": 1536,
5+
"n_heads": 9,
6+
"n_kv_heads": 3,
7+
"n_layers": 30,
8+
"norm_eps": 1e-05,
9+
"rope_theta": 10000.0,
10+
"use_scaled_rope": false,
11+
"vocab_size": 49152,
12+
"use_hf_rope": false,
13+
"attention_qkv_bias": false
14+
}

examples/models/smollm2/__init__

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# This source code is licensed under the BSD-style license found in the
2+
# LICENSE file in the root directory of this source tree.
3+
4+
from executorch.example.models.llama.model import Llama2Model
5+
6+
7+
class SmolLM2Model(Llama2Model):
8+
def __init__(self, **kwargs):
9+
super().__init__(**kwargs)
10+
11+
12+
__all__ = [
13+
"SmolLM2Model",
14+
]
+80
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import argparse
2+
from typing import Dict
3+
4+
import torch
5+
6+
from torchtune.models.convert_weights import get_mapped_key
7+
8+
from torchtune.training import FullModelHFCheckpointer
9+
10+
# Standard _FROM_META weight mapping of Meta weights to TorchTune + additional bias weight mappings.
11+
_SMOLLM_FROM_META = {
12+
"tok_embeddings.weight": "tok_embeddings.weight",
13+
"norm.weight": "norm.scale",
14+
"output.weight": "output.weight",
15+
"layers.{}.attention.wk.weight": "layers.{}.attn.k_proj.weight",
16+
"layers.{}.attention.wq.weight": "layers.{}.attn.q_proj.weight",
17+
"layers.{}.attention.wv.weight": "layers.{}.attn.v_proj.weight",
18+
"layers.{}.attention.wo.weight": "layers.{}.attn.output_proj.weight",
19+
"layers.{}.attention_norm.weight": "layers.{}.sa_norm.scale",
20+
"layers.{}.ffn_norm.weight": "layers.{}.mlp_norm.scale",
21+
"layers.{}.feed_forward.w1.weight": "layers.{}.mlp.w1.weight",
22+
"layers.{}.feed_forward.w2.weight": "layers.{}.mlp.w2.weight",
23+
"layers.{}.feed_forward.w3.weight": "layers.{}.mlp.w3.weight",
24+
}
25+
26+
27+
def smollm_tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
28+
"""
29+
Convert a state dict from torchtune's format to Meta's format. This function
30+
doesn't handle any sharding or splitting of state dicts. It follows the
31+
state_dict IN -> state_dict OUT pattern.
32+
33+
Args:
34+
state_dict (Dict[str, torch.Tensor]): State dict in torchtune's format.
35+
36+
Returns:
37+
Dict[str, torch.Tensor]: State dict in Meta's format.
38+
"""
39+
converted_state_dict = {}
40+
inverted_mapping_dict = {v: k for k, v in _SMOLLM_FROM_META.items()}
41+
for key, value in state_dict.items():
42+
new_key = get_mapped_key(key, inverted_mapping_dict)
43+
converted_state_dict[new_key] = value
44+
45+
return converted_state_dict
46+
47+
48+
def main():
49+
parser = argparse.ArgumentParser(
50+
description="Convert SmolLM weights to Meta format."
51+
)
52+
parser.add_argument(
53+
"input_dir",
54+
type=str,
55+
help="Path to directory containing checkpoint files",
56+
)
57+
parser.add_argument("output", type=str, help="Path to the output checkpoint")
58+
59+
args = parser.parse_args()
60+
61+
# Don't necessarily need to use TorchTune checkpointer, can just aggregate checkpoint files by ourselves.
62+
checkpointer = FullModelHFCheckpointer(
63+
checkpoint_dir=args.input_dir,
64+
checkpoint_files=["model.safetensors"],
65+
output_dir=".",
66+
model_type="LLAMA",
67+
)
68+
69+
print("Loading checkpoint...")
70+
sd = checkpointer.load_checkpoint()
71+
72+
print("Converting checkpoint...")
73+
sd = smollm_tune_to_meta(sd["model"])
74+
75+
torch.save(sd, args.output)
76+
print(f"Checkpoint saved to {args.output}")
77+
78+
79+
if __name__ == "__main__":
80+
main()

0 commit comments

Comments
 (0)