Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SmolLM (smollm2) #9354

Merged
merged 5 commits into from
Mar 20, 2025
Merged

Conversation

Inklingdq
Copy link
Contributor

@Inklingdq Inklingdq commented Mar 18, 2025

Summary

Add SmolLM 135M model (smollm2) for issue #9324

python -m examples.models.llama.runner.native --model smollm2 \
--pte smollm2.pte  \
--tokenizer /Users/danqingwang/tmp/snapshots/1d461723eec654e65efdc40cf49301c89c0c92f4/tokenizer.json \
--tokenizer_config /Users/danqingwang/tmp/snapshots/1d461723eec654e65efdc40cf49301c89c0c92f4/tokenizer_config.json \
--prompt "What ingredients are in a California roll?" \
--params examples/models/smollm2/135M_config.json --max_len 64 \
--temperature 0 -kv
import error: No module named 'triton'
W0318 22:08:14.883045 53671 site-packages/torch/distributed/elastic/multiprocessing/redirects.py:29] NOTE: Redirects are currently not supported in Windows or MacOs.
[program.cpp:136] InternalConsistency verification requested but not available

California rolls are made with a variety of ingredients, including wheat, corn, rice, and beans. The ingredients are combined in a special way to create a unique texture and flavor.
What is the difference between a California roll and a roll?
The main difference between a

Prefill time: 0.03980898857116699
Generation tok/s: 76.69897827254086

output agrees with the eager model output
image

Test plan

Convert to meta format

python examples/models/smollm/convert_weights.py /Users/danqingwang/tmp/snapshots/1d461723eec654e65efdc40cf49301c89c0c92f4/ /Users/danqingwang/smollm.pth

Run export

./install_executorch.sh --pybind xnnpack && python -m examples.models.llama.export_llama   \
--model smollm2 --params examples/models/smollm2/135M_config.json  \
--checkpoint /Users/danqingwang/smollm.pth -kv --use_sdpa_with_kv_cache  \
-X -d fp32 --metadata '{"get_bos_id":[11191, 12870], "get_eos_ids":[29070, 25853]}'  \
--output_name="smollm2.pte" --verbose

Run test
python -m examples.models.llama.runner.native --model smollm2
--pte smollm2.pte
--tokenizer /Users/danqingwang/tmp/snapshots/1d461723eec654e65efdc40cf49301c89c0c92f4/tokenizer.json
--tokenizer_config /Users/danqingwang/tmp/snapshots/1d461723eec654e65efdc40cf49301c89c0c92f4/tokenizer_config.json
--prompt "What ingredients are in a California roll?"
--params examples/models/smollm2/135M_config.json --max_len 64
--temperature 0 -kv

Copy link

pytorch-bot bot commented Mar 18, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/9354

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 New Failures, 1 Unrelated Failure

As of commit 4b07e95 with merge base 01a22b6 (image):

NEW FAILURES - The following jobs have failed:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot
Copy link
Contributor

Hi @Inklingdq!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at [email protected]. Thanks!

@@ -0,0 +1,14 @@
{
"dim": 576,
"ffn_dim_multiplier": 1,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's some size mismatch error during quantization

	size mismatch for layers.0.feed_forward.w1.weight: copying a param with shape torch.Size([1536, 576]) from checkpoint, the shape in current model is torch.Size([576, 576]).
	size mismatch for layers.0.feed_forward.w2.weight: copying a param with shape torch.Size([576, 1536]) from checkpoint, the shape in current model is torch.Size([576, 576]).
	size mismatch for layers.0.feed_forward.w3.weight: copying a param with shape torch.Size([1536, 576]) from checkpoint, the shape in current model is torch.Size([576, 576]).

I'm not very sure about the definiation of dim and ffn_dim_multiplier here, looks like some wrong value here? Would you mind provide some pointer/context on this? Appreciate it! @jackzhxng

The model structure is below

LlamaForCausalLM(
(model): LlamaModel(
(embed_tokens): Embedding(49152, 576)
(layers): ModuleList(
(0-29): 30 x LlamaDecoderLayer(
(self_attn): LlamaSdpaAttention(
(q_proj): Linear(in_features=576, out_features=576, bias=False)
(k_proj): Linear(in_features=576, out_features=192, bias=False)
(v_proj): Linear(in_features=576, out_features=192, bias=False)
(o_proj): Linear(in_features=576, out_features=576, bias=False)
(rotary_emb): LlamaRotaryEmbedding()
)
(mlp): LlamaMLP(
(gate_proj): Linear(in_features=576, out_features=1536, bias=False)
(up_proj): Linear(in_features=576, out_features=1536, bias=False)
(down_proj): Linear(in_features=1536, out_features=576, bias=False)
(act_fn): SiLU()
)
(input_layernorm): LlamaRMSNorm((576,), eps=1e-05)
(post_attention_layernorm): LlamaRMSNorm((576,), eps=1e-05)
)
)
(norm): LlamaRMSNorm((576,), eps=1e-05)
(rotary_emb): LlamaRotaryEmbedding()
)
(lm_head): Linear(in_features=576, out_features=49152, bias=False)
)

checkpoint_dir=args.input_dir,
checkpoint_files=["model.safetensors"],
output_dir=".",
model_type="MISTRAL",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change to Llama

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you updated~

converted_state_dict[new_key] = value

# Input and output embeddings are tied.
converted_state_dict["output.weight"] = converted_state_dict[
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be because of this, input and output embeddings are not shared for Llama which this model is based off of

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sense, removed this

@@ -94,6 +94,7 @@
"static_llama",
"qwen2_5",
"phi-4-mini",
"smollm",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rename this and directory to smolllm2

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Should it be smollm2 or smolllm2?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah - it should be smollm2*

{
"dim": 576,
"ffn_dim_multiplier": 1,
"hidden_dim": 576,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! I got it mixed with the hidden_size 😄

"rope_theta": 10000.0,
"use_scaled_rope": false,
"vocab_size": 49152,
"use_hf_rope": true,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be false

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!! Updated

@Inklingdq Inklingdq marked this pull request as ready for review March 19, 2025 05:15
@Inklingdq Inklingdq requested a review from lucylq as a code owner March 19, 2025 05:15
@Inklingdq Inklingdq changed the title Add SmolLM Add SmolLM (smollm2) Mar 19, 2025
Copy link
Contributor

@jackzhxng jackzhxng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for helping with this!

@Inklingdq
Copy link
Contributor Author

hi @jackzhxng, it looks like the 3 failing checks are not related to my change, anyway to rerun the test?

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 20, 2025
@facebook-github-bot
Copy link
Contributor

Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks!

@jackzhxng jackzhxng merged commit d60173b into pytorch:viable/strict Mar 20, 2025
75 of 78 checks passed
jackzhxng added a commit that referenced this pull request Mar 24, 2025
jackzhxng pushed a commit that referenced this pull request Mar 24, 2025
@jackzhxng jackzhxng mentioned this pull request Mar 24, 2025
jackzhxng added a commit that referenced this pull request Mar 24, 2025
This reverts commit d60173b (PR -
#9354) which was accidentally
merged directly into `viable/extrict` instead of `main`.

---------

Co-authored-by: Inkling <[email protected]>
jackzhxng added a commit that referenced this pull request Mar 24, 2025
Add #9354 by
[Inklingdq](https://github.com/Inklingdq) which was accidentally merged
to `viable/strict` to `main`.

Co-authored-by: Inkling <[email protected]>
@larryliu0820
Copy link
Contributor

@jackzhxng can you please pick this into main branch not viable/strict? We don't merge into viable/strict directly (there's a job that does that)

@jackzhxng
Copy link
Contributor

@larryliu0820 yeah I'm not sure how this got into viable/strict, I think I just didn't see it, but as you can see above it's been reverted and picked back into main already

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: not user facing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants