Skip to content

Commit 8419a74

Browse files
author
Danqing Wang (MPK)
committed
Add SmolLM
1 parent ce612b8 commit 8419a74

File tree

5 files changed

+114
-0
lines changed

5 files changed

+114
-0
lines changed

Diff for: examples/models/llama/export_llama_lib.py

+1
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@
9494
"static_llama",
9595
"qwen2_5",
9696
"phi-4-mini",
97+
"smollm",
9798
]
9899
TORCHTUNE_DEFINED_MODELS = ["llama3_2_vision"]
99100

Diff for: examples/models/smollm/135M_config.json

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
{
2+
"dim": 576,
3+
"ffn_dim_multiplier": 1,
4+
"hidden_dim": 576,
5+
"n_heads": 9,
6+
"n_kv_heads": 3,
7+
"n_layers": 30,
8+
"norm_eps": 1e-05,
9+
"rope_theta": 10000.0,
10+
"use_scaled_rope": false,
11+
"vocab_size": 49152,
12+
"use_hf_rope": true,
13+
"attention_qkv_bias": false
14+
}

Diff for: examples/models/smollm/README.md

Whitespace-only changes.

Diff for: examples/models/smollm/__init__

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# This source code is licensed under the BSD-style license found in the
2+
# LICENSE file in the root directory of this source tree.
3+
4+
from executorch.example.models.llama.model import Llama2Model
5+
6+
7+
class SmolLMModel(Llama2Model):
8+
def __init__(self, **kwargs):
9+
super().__init__(**kwargs)
10+
11+
12+
__all__ = [
13+
"SmolLMModel",
14+
]

Diff for: examples/models/smollm/convert_weights.py

+85
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
import argparse
2+
from typing import Dict
3+
4+
import torch
5+
6+
from torchtune.models.convert_weights import get_mapped_key
7+
8+
from torchtune.training import FullModelHFCheckpointer
9+
10+
# Standard _FROM_META weight mapping of Meta weights to TorchTune + additional bias weight mappings.
11+
_SMOLLM_FROM_META = {
12+
"tok_embeddings.weight": "tok_embeddings.weight",
13+
"norm.weight": "norm.scale",
14+
"output.weight": "output.weight",
15+
"layers.{}.attention.wk.weight": "layers.{}.attn.k_proj.weight",
16+
"layers.{}.attention.wq.weight": "layers.{}.attn.q_proj.weight",
17+
"layers.{}.attention.wv.weight": "layers.{}.attn.v_proj.weight",
18+
"layers.{}.attention.wo.weight": "layers.{}.attn.output_proj.weight",
19+
"layers.{}.attention_norm.weight": "layers.{}.sa_norm.scale",
20+
"layers.{}.ffn_norm.weight": "layers.{}.mlp_norm.scale",
21+
"layers.{}.feed_forward.w1.weight": "layers.{}.mlp.w1.weight",
22+
"layers.{}.feed_forward.w2.weight": "layers.{}.mlp.w2.weight",
23+
"layers.{}.feed_forward.w3.weight": "layers.{}.mlp.w3.weight",
24+
}
25+
26+
27+
def smollm_tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
28+
"""
29+
Convert a state dict from torchtune's format to Meta's format. This function
30+
doesn't handle any sharding or splitting of state dicts. It follows the
31+
state_dict IN -> state_dict OUT pattern.
32+
33+
Args:
34+
state_dict (Dict[str, torch.Tensor]): State dict in torchtune's format.
35+
36+
Returns:
37+
Dict[str, torch.Tensor]: State dict in Meta's format.
38+
"""
39+
converted_state_dict = {}
40+
inverted_mapping_dict = {v: k for k, v in _SMOLLM_FROM_META.items()}
41+
for key, value in state_dict.items():
42+
new_key = get_mapped_key(key, inverted_mapping_dict)
43+
converted_state_dict[new_key] = value
44+
45+
# Input and output embeddings are tied.
46+
converted_state_dict["output.weight"] = converted_state_dict[
47+
"tok_embeddings.weight"
48+
]
49+
50+
return converted_state_dict
51+
52+
53+
def main():
54+
parser = argparse.ArgumentParser(
55+
description="Convert SmolLM weights to Meta format."
56+
)
57+
parser.add_argument(
58+
"input_dir",
59+
type=str,
60+
help="Path to directory containing checkpoint files",
61+
)
62+
parser.add_argument("output", type=str, help="Path to the output checkpoint")
63+
64+
args = parser.parse_args()
65+
66+
# Don't necessarily need to use TorchTune checkpointer, can just aggregate checkpoint files by ourselves.
67+
checkpointer = FullModelHFCheckpointer(
68+
checkpoint_dir=args.input_dir,
69+
checkpoint_files=["model.safetensors"],
70+
output_dir=".",
71+
model_type="MISTRAL",
72+
)
73+
74+
print("Loading checkpoint...")
75+
sd = checkpointer.load_checkpoint()
76+
77+
print("Converting checkpoint...")
78+
sd = smollm_tune_to_meta(sd["model"])
79+
80+
torch.save(sd, args.output)
81+
print(f"Checkpoint saved to {args.output}")
82+
83+
84+
if __name__ == "__main__":
85+
main()

0 commit comments

Comments
 (0)