From 8419a7459ef1fea55e75a866f7558fb08286f189 Mon Sep 17 00:00:00 2001 From: "Danqing Wang (MPK)" Date: Tue, 18 Mar 2025 00:36:20 -0700 Subject: [PATCH 1/4] Add SmolLM --- examples/models/llama/export_llama_lib.py | 1 + examples/models/smollm/135M_config.json | 14 ++++ examples/models/smollm/README.md | 0 examples/models/smollm/__init__ | 14 ++++ examples/models/smollm/convert_weights.py | 85 +++++++++++++++++++++++ 5 files changed, 114 insertions(+) create mode 100644 examples/models/smollm/135M_config.json create mode 100644 examples/models/smollm/README.md create mode 100644 examples/models/smollm/__init__ create mode 100644 examples/models/smollm/convert_weights.py diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 2319ec0c6a7..246a6561d50 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -94,6 +94,7 @@ "static_llama", "qwen2_5", "phi-4-mini", + "smollm", ] TORCHTUNE_DEFINED_MODELS = ["llama3_2_vision"] diff --git a/examples/models/smollm/135M_config.json b/examples/models/smollm/135M_config.json new file mode 100644 index 00000000000..7f7a6526bda --- /dev/null +++ b/examples/models/smollm/135M_config.json @@ -0,0 +1,14 @@ +{ + "dim": 576, + "ffn_dim_multiplier": 1, + "hidden_dim": 576, + "n_heads": 9, + "n_kv_heads": 3, + "n_layers": 30, + "norm_eps": 1e-05, + "rope_theta": 10000.0, + "use_scaled_rope": false, + "vocab_size": 49152, + "use_hf_rope": true, + "attention_qkv_bias": false + } diff --git a/examples/models/smollm/README.md b/examples/models/smollm/README.md new file mode 100644 index 00000000000..e69de29bb2d diff --git a/examples/models/smollm/__init__ b/examples/models/smollm/__init__ new file mode 100644 index 00000000000..745540ebb56 --- /dev/null +++ b/examples/models/smollm/__init__ @@ -0,0 +1,14 @@ +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from executorch.example.models.llama.model import Llama2Model + + +class SmolLMModel(Llama2Model): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + +__all__ = [ + "SmolLMModel", +] diff --git a/examples/models/smollm/convert_weights.py b/examples/models/smollm/convert_weights.py new file mode 100644 index 00000000000..a9dfe08a4b6 --- /dev/null +++ b/examples/models/smollm/convert_weights.py @@ -0,0 +1,85 @@ +import argparse +from typing import Dict + +import torch + +from torchtune.models.convert_weights import get_mapped_key + +from torchtune.training import FullModelHFCheckpointer + +# Standard _FROM_META weight mapping of Meta weights to TorchTune + additional bias weight mappings. +_SMOLLM_FROM_META = { + "tok_embeddings.weight": "tok_embeddings.weight", + "norm.weight": "norm.scale", + "output.weight": "output.weight", + "layers.{}.attention.wk.weight": "layers.{}.attn.k_proj.weight", + "layers.{}.attention.wq.weight": "layers.{}.attn.q_proj.weight", + "layers.{}.attention.wv.weight": "layers.{}.attn.v_proj.weight", + "layers.{}.attention.wo.weight": "layers.{}.attn.output_proj.weight", + "layers.{}.attention_norm.weight": "layers.{}.sa_norm.scale", + "layers.{}.ffn_norm.weight": "layers.{}.mlp_norm.scale", + "layers.{}.feed_forward.w1.weight": "layers.{}.mlp.w1.weight", + "layers.{}.feed_forward.w2.weight": "layers.{}.mlp.w2.weight", + "layers.{}.feed_forward.w3.weight": "layers.{}.mlp.w3.weight", +} + + +def smollm_tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """ + Convert a state dict from torchtune's format to Meta's format. This function + doesn't handle any sharding or splitting of state dicts. It follows the + state_dict IN -> state_dict OUT pattern. + + Args: + state_dict (Dict[str, torch.Tensor]): State dict in torchtune's format. + + Returns: + Dict[str, torch.Tensor]: State dict in Meta's format. + """ + converted_state_dict = {} + inverted_mapping_dict = {v: k for k, v in _SMOLLM_FROM_META.items()} + for key, value in state_dict.items(): + new_key = get_mapped_key(key, inverted_mapping_dict) + converted_state_dict[new_key] = value + + # Input and output embeddings are tied. + converted_state_dict["output.weight"] = converted_state_dict[ + "tok_embeddings.weight" + ] + + return converted_state_dict + + +def main(): + parser = argparse.ArgumentParser( + description="Convert SmolLM weights to Meta format." + ) + parser.add_argument( + "input_dir", + type=str, + help="Path to directory containing checkpoint files", + ) + parser.add_argument("output", type=str, help="Path to the output checkpoint") + + args = parser.parse_args() + + # Don't necessarily need to use TorchTune checkpointer, can just aggregate checkpoint files by ourselves. + checkpointer = FullModelHFCheckpointer( + checkpoint_dir=args.input_dir, + checkpoint_files=["model.safetensors"], + output_dir=".", + model_type="MISTRAL", + ) + + print("Loading checkpoint...") + sd = checkpointer.load_checkpoint() + + print("Converting checkpoint...") + sd = smollm_tune_to_meta(sd["model"]) + + torch.save(sd, args.output) + print(f"Checkpoint saved to {args.output}") + + +if __name__ == "__main__": + main() From 34c5dee313e94df14196c203d9c4e797a066e13b Mon Sep 17 00:00:00 2001 From: "Danqing Wang (MPK)" Date: Tue, 18 Mar 2025 16:05:32 -0700 Subject: [PATCH 2/4] address comment --- examples/models/llama/export_llama_lib.py | 2 +- examples/models/smollm/convert_weights.py | 7 +------ 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 246a6561d50..7c30120a282 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -94,7 +94,7 @@ "static_llama", "qwen2_5", "phi-4-mini", - "smollm", + "smolllm2", ] TORCHTUNE_DEFINED_MODELS = ["llama3_2_vision"] diff --git a/examples/models/smollm/convert_weights.py b/examples/models/smollm/convert_weights.py index a9dfe08a4b6..db80bd47b8c 100644 --- a/examples/models/smollm/convert_weights.py +++ b/examples/models/smollm/convert_weights.py @@ -42,11 +42,6 @@ def smollm_tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch. new_key = get_mapped_key(key, inverted_mapping_dict) converted_state_dict[new_key] = value - # Input and output embeddings are tied. - converted_state_dict["output.weight"] = converted_state_dict[ - "tok_embeddings.weight" - ] - return converted_state_dict @@ -68,7 +63,7 @@ def main(): checkpoint_dir=args.input_dir, checkpoint_files=["model.safetensors"], output_dir=".", - model_type="MISTRAL", + model_type="LLAMA", ) print("Loading checkpoint...") From 91e5dd27cb4b710f8f9da27b3102b22a9ac94731 Mon Sep 17 00:00:00 2001 From: "Danqing Wang (MPK)" Date: Tue, 18 Mar 2025 17:39:24 -0700 Subject: [PATCH 3/4] change use_hf_rope to False --- examples/models/smollm/135M_config.json | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/models/smollm/135M_config.json b/examples/models/smollm/135M_config.json index 7f7a6526bda..604c7e94ab5 100644 --- a/examples/models/smollm/135M_config.json +++ b/examples/models/smollm/135M_config.json @@ -1,7 +1,7 @@ { "dim": 576, "ffn_dim_multiplier": 1, - "hidden_dim": 576, + "hidden_dim": 1536, "n_heads": 9, "n_kv_heads": 3, "n_layers": 30, @@ -9,6 +9,6 @@ "rope_theta": 10000.0, "use_scaled_rope": false, "vocab_size": 49152, - "use_hf_rope": true, + "use_hf_rope": false, "attention_qkv_bias": false } From 14d3ca78051dfc135a93bb4c0f64d9334b09a8c1 Mon Sep 17 00:00:00 2001 From: "Danqing Wang (MPK)" Date: Tue, 18 Mar 2025 22:08:44 -0700 Subject: [PATCH 4/4] update model name to smollm2 --- examples/models/llama/export_llama_lib.py | 2 +- examples/models/smollm/README.md | 0 examples/models/{smollm => smollm2}/135M_config.json | 0 examples/models/{smollm => smollm2}/__init__ | 4 ++-- examples/models/{smollm => smollm2}/convert_weights.py | 0 5 files changed, 3 insertions(+), 3 deletions(-) delete mode 100644 examples/models/smollm/README.md rename examples/models/{smollm => smollm2}/135M_config.json (100%) rename examples/models/{smollm => smollm2}/__init__ (84%) rename examples/models/{smollm => smollm2}/convert_weights.py (100%) diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 7c30120a282..6a32b99f1de 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -94,7 +94,7 @@ "static_llama", "qwen2_5", "phi-4-mini", - "smolllm2", + "smollm2", ] TORCHTUNE_DEFINED_MODELS = ["llama3_2_vision"] diff --git a/examples/models/smollm/README.md b/examples/models/smollm/README.md deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/examples/models/smollm/135M_config.json b/examples/models/smollm2/135M_config.json similarity index 100% rename from examples/models/smollm/135M_config.json rename to examples/models/smollm2/135M_config.json diff --git a/examples/models/smollm/__init__ b/examples/models/smollm2/__init__ similarity index 84% rename from examples/models/smollm/__init__ rename to examples/models/smollm2/__init__ index 745540ebb56..3d01bf9eb42 100644 --- a/examples/models/smollm/__init__ +++ b/examples/models/smollm2/__init__ @@ -4,11 +4,11 @@ from executorch.example.models.llama.model import Llama2Model -class SmolLMModel(Llama2Model): +class SmolLM2Model(Llama2Model): def __init__(self, **kwargs): super().__init__(**kwargs) __all__ = [ - "SmolLMModel", + "SmolLM2Model", ] diff --git a/examples/models/smollm/convert_weights.py b/examples/models/smollm2/convert_weights.py similarity index 100% rename from examples/models/smollm/convert_weights.py rename to examples/models/smollm2/convert_weights.py