-
Notifications
You must be signed in to change notification settings - Fork 385
[Bug] Potential bugs in "_grouped_mm" in Llama4 MoE codes #1237
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
thanks for spotting this @raymin0223! |
Thanks for taking a look into this! I'm explicitly setting |
I see what I think is a related issue: using llama4 debug model with FSDP=2, the loss becomes NaN at step 2 then all tokens are routed to the first expert, and then grouped gemm kernel cannot handle the case where experts are assigned 0 tokens. |
@lessw2020 Is it possible for the update: hmm could such dynamic shapes cause difficulties for torch.compile? |
Thanks @lessw2020, I tested this on Llama4 with the same settings , and it seems like this PR resolves the issue. No more nan loss!
|
thanks very much for your help here @raymin0223 ! |
…rt needed, instead of max_len (#1254) This PR switches the generate_permute_indices to move to using exact sizes per expert needed, instead of max_len. Thus, we now return a tensor of size sum(m_sizes) instead of max_len. This may resolve the current issue [here](#1237). Testing: Ran both unit testing with dynamic padding, both pass. Verified resolves Nans in running in llama4 (credit @raymin0223). #1237 (comment) ~~~ permuted_indices_gpu=tensor([ 0, 1, 2, 3, 16, 17, 18, 19, 32, 33, 34, 35, 48, 49, 50, 51, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 4, 5, 6, 7, 20, 21, 22, 23, 36, 37, 38, 39, 52, 53, 54, 55, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 8, 9, 10, 11, 24, 25, 26, 27, 40, 41, 42, 43, 56, 57, 58, 59, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 12, 13, 14, 15, 28, 29, 30, 31, 44, 45, 46, 47, 60, 61, 62, 63, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], device='cuda:0', dtype=torch.int32), permuted_indices_cpu=tensor([ 0, 1, 2, 3, 16, 17, 18, 19, 32, 33, 34, 35, 48, 49, 50, 51, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 4, 5, 6, 7, 20, 21, 22, 23, 36, 37, 38, 39, 52, 53, 54, 55, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 8, 9, 10, 11, 24, 25, 26, 27, 40, 41, 42, 43, 56, 57, 58, 59, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 12, 13, 14, 15, 28, 29, 30, 31, 44, 45, 46, 47, 60, 61, 62, 63, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], dtype=torch.int32) m_sizes=tensor([32, 32, 32, 32], device='cuda:0', dtype=torch.int32) Success tokens_per_expert_group = tensor([4, 0, 2, 3, 1, 0, 0, 5], device='cuda:0', dtype=torch.int32) total_tokens_per_expert = tensor([5, 0, 2, 8], device='cuda:0') m_sizes = tensor([8, 8, 8, 8], device='cuda:0', dtype=torch.int32) m_offsets = tensor([ 8, 16, 24, 32], device='cuda:0', dtype=torch.int32) permuted_indices = tensor([ 0, 1, 2, 3, 9, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 4, 5, -1, -1, -1, -1, -1, -1, 6, 7, 8, 10, 11, 12, 13, 14], device='cuda:0', dtype=torch.int32) Expert 1 has zero tokens and 8 slots with all -1 All tests passed successfully! ~~~
@lessw2020 My worry is that in current Llama 4 TP, this will make DTensor sharding propagation cache miss, as the shapes of the input could change a lot from iteration to iteration. I have one more question/request: is it possible to still always return |
Bug description
Descriptions for Bugs.
I encountered NaN loss values when running Llama 4 MoE experimental codes.
The errors come from here.
Afaik
offsets
are defined astorch.cumsum(num_local_tokens_per_expert)
andx
(routed_input
) is permuted with the shape oforiginal_shape + num_experts * ALIGN_SIZE_M
.Thus, there was a difference between
x.shape[0]
andoffsets[-1]
.I'm not sure which expert will be allocated for those redundant tensors in x in
grouped_mm
.I believe the expected behavior would be the outputs from them should always be 0, because they are filled with 0 values.
But
_grouped_mm
sometimes results in large values, which first index of outputs getsinf
elements (here).How to Reproduce?
debug_model.toml
, but with different batch size and seq_len in 1 H200 GPU. Here is the running script:x = x.to(torch.bfloat16)
and..., dtype=torch.bfloat16)
forself.w1
,self.w2
, andself.w3
, since 1 GPU will automatically use torch.float32 in the code and_grouped_mm
requires tensors are in GPU.pdb
to get intermediate outputs one by one.Results and Expected Behaviors.
Routed outputs sometimes show the following results (at the first step or a few steps later):
We expect that tensors, where the sequence positions are from 2096 to 2176, should be always zero.
This causes to hidden states to have nan values, and nan values of loss eventually.
Versions
Python 3.13 with the following packages:
The text was updated successfully, but these errors were encountered: