Skip to content

Question about fixed std=0.02 initialization of w1 in moe.py #1257

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
trestad opened this issue Jun 3, 2025 · 1 comment
Open

Question about fixed std=0.02 initialization of w1 in moe.py #1257

trestad opened this issue Jun 3, 2025 · 1 comment
Labels
question Further information is requested

Comments

@trestad
Copy link
Contributor

trestad commented Jun 3, 2025

Hi torchtitan team,

Thanks for the great work on this project! I had a question regarding a detail in the code at moe.py#L92

nn.init.trunc_normal_(self.w1, mean=0.0, std=0.02)

I noticed that w1 is initialized with a fixed standard deviation of 0.02, whereas w2 and w3 are initialized using a configurable init_std parameter. I’m wondering if this discrepancy is intentional, and if so, what the reasoning is behind using a hardcoded value for w1.

Would greatly appreciate any insights you could share!

Thanks again!

@tianyu-l
Copy link
Contributor

tianyu-l commented Jun 3, 2025

I copied this from Llama 3 FFN init code at https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama3/model.py#L305

@lessw2020 do you have more context on the choice of std for w1, w2, w3?

@tianyu-l tianyu-l added the question Further information is requested label Jun 3, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants