Skip to content

[Bug] Potential bugs in "_grouped_mm" in Llama4 MoE codes #1237

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
Tracked by #1118
raymin0223 opened this issue May 29, 2025 · 8 comments
Open
Tracked by #1118

[Bug] Potential bugs in "_grouped_mm" in Llama4 MoE codes #1237

raymin0223 opened this issue May 29, 2025 · 8 comments
Assignees

Comments

@raymin0223
Copy link

Bug description

Descriptions for Bugs.

I encountered NaN loss values when running Llama 4 MoE experimental codes.
The errors come from here.

Afaik offsets are defined as torch.cumsum(num_local_tokens_per_expert) and x (routed_input) is permuted with the shape of original_shape + num_experts * ALIGN_SIZE_M.
Thus, there was a difference between x.shape[0] and offsets[-1].

I'm not sure which expert will be allocated for those redundant tensors in x in grouped_mm.
I believe the expected behavior would be the outputs from them should always be 0, because they are filled with 0 values.
But _grouped_mm sometimes results in large values, which first index of outputs gets inf elements (here).

How to Reproduce?

  1. I used Llama-3.2-1B tokenizer.
  2. I used debug_model.toml, but with different batch size and seq_len in 1 H200 GPU. Here is the running script:
torchrun --nnodes 1 --nproc_per_node 1  ./torchtitan/train.py  \
--job.config_file ./torchtitan/experiments/llama4/train_configs/debug_model.toml --job.dump_folder ./outputs/250528_grouped_mm_debug  \
--profiling.save_traces_folder profile_trace --comm.trace_buf_size 0  --checkpoint.folder ./checkpoints/250528_grouped_mm_debug --checkpoint.interval 13000   \
--training.steps 114440 --training.batch_size 1 --training.seq_len 2048   \
--metrics.log_freq 100 --lr_scheduler.warmup_steps 1000 --optimizer.lr 6e-4  \
--parallelism.data_parallel_shard_degree 1 --parallelism.tensor_parallel_degree 1
  1. Add x = x.to(torch.bfloat16) and ..., dtype=torch.bfloat16) for self.w1, self.w2, and self.w3, since 1 GPU will automatically use torch.float32 in the code and _grouped_mm requires tensors are in GPU.
  2. I used pdb to get intermediate outputs one by one.

Results and Expected Behaviors.

Routed outputs sometimes show the following results (at the first step or a few steps later):

offsets : tensor([ 176,  416,  736,  992, 1296, 1584, 1840, 2096], device='cuda:0', dtype=torch.int32)

x.shape : torch.Size([2176, 256])

h = F.silu(torch._grouped_mm(x, self.w1, offs=offsets)) :
tensor([[ 3.7598e-02, -9.3262e-02,  1.3965e-01,  ..., -1.7822e-02,
         -2.2949e-02,  2.0020e-02],
        [ 1.1572e-01,  2.2461e-01,  3.1641e-01,  ...,  8.6060e-03,
         -5.3711e-02, -2.7100e-02],
        [ 1.4551e-01,  2.1973e-02,  1.3086e-01,  ..., -2.5269e-02,
          3.7354e-02, -1.5503e-02],
        ...,
        [-0.0000e+00,  2.9297e-02, -0.0000e+00,  ...,  5.2246e-02,
          7.7462e+18, -1.8066e-02],
        [ 2.8531e+26,  5.1025e-02, -0.0000e+00,  ...,  1.1670e-01,
          3.2028e-28,  1.5076e-02],
        [ 6.3348e+26,  3.8818e-02,  4.0250e+01,  ..., -2.8229e-03,
          2.4844e-32, -8.6670e-03]], device='cuda:0', dtype=torch.bfloat16,
       grad_fn=<SiluBackward0>)

h = h * torch._grouped_mm(x, self.w3, offs=offsets)
tensor([[-1.8692e-03, -2.8992e-03,  1.6327e-03,  ..., -1.5564e-03,
         -1.0681e-02,  5.1022e-05],
        [-5.5237e-03,  6.0425e-03,  1.0864e-02,  ...,  9.8419e-04,
          3.0396e-02, -4.2152e-04],
        [-1.6785e-03, -4.5776e-04, -2.0142e-03,  ...,  1.0193e-02,
         -4.6082e-03, -1.3733e-04],
        ...,
        [ 0.0000e+00,  1.2054e-03, -0.0000e+00,  ..., -2.5177e-03,
          3.5863e+11, -1.7548e-03],
        [       -inf,  6.3705e-04,  0.0000e+00,  ...,  9.5825e-03,
         -2.9000e+02,  3.2234e-04],
        [ 8.4410e+07,  4.0588e-03, -1.0379e+31,  ...,  3.7432e-05,
          1.2387e-07, -1.3733e-03]], device='cuda:0', dtype=torch.bfloat16,
       grad_fn=<MulBackward0>)

out = torch._grouped_mm(h, self.w2, offs=offsets)
tensor([[ 6.3782e-03,  4.0894e-03, -1.3672e-02,  ..., -8.4839e-03,
         -2.8229e-03, -3.9978e-03],
        [-1.9379e-03, -4.6387e-03,  8.5449e-03,  ..., -4.8523e-03,
         -4.4861e-03, -1.4114e-03],
        [-3.1128e-03, -2.5177e-03, -3.4332e-03,  ...,  1.3062e-02,
         -6.7139e-03, -7.6904e-03],
        ...,
        [-1.6251e-03, -1.3279e-10, -7.3787e+19,  ..., -5.1659e-10,
         -3.8780e+34, -3.5834e-10],
        [ 4.7055e+34, -1.6735e-09,  6.0889e+18,  ..., -1.1205e-09,
          7.1024e+24,  3.1287e-10],
        [-2.4087e-21, -2.1682e-09,  3.0898e+20,  ...,  2.9831e-09,
          2.4898e-30,  5.5297e-10]], device='cuda:0', dtype=torch.bfloat16,
       grad_fn=<GroupedMmBackward0>)

We expect that tensors, where the sequence positions are from 2096 to 2176, should be always zero.
This causes to hidden states to have nan values, and nan values of loss eventually.

Versions

Python 3.13 with the following packages:

absl-py==2.2.2
aiohappyeyeballs==2.6.1
aiohttp==3.11.18
aiosignal==1.3.2
annotated-types==0.7.0
asttokens @ file:///home/conda/feedstock_root/build_artifacts/asttokens_1733250440834/work
attrs==25.3.0
beautifulsoup4==4.13.4
bleach==6.2.0
blessed==1.21.0
blobfile==3.0.0
certifi==2025.4.26
charset-normalizer==3.4.2
click==8.2.0
comm @ file:///home/conda/feedstock_root/build_artifacts/comm_1733502965406/work
contourpy==1.3.2
cycler==0.12.1
datasets==3.6.0
debugpy @ file:///croot/debugpy_1736267418885/work
decorator @ file:///home/conda/feedstock_root/build_artifacts/decorator_1740384970518/work
defusedxml==0.7.1
dill==0.3.8
docker-pycreds==0.4.0
docstring_parser==0.16
exceptiongroup @ file:///home/conda/feedstock_root/build_artifacts/exceptiongroup_1746947292760/work
executing @ file:///home/conda/feedstock_root/build_artifacts/executing_1745502089858/work
fastjsonschema==2.21.1
filelock==3.16.1
fonttools==4.58.0
frozenlist==1.6.0
fsspec==2024.10.0
gitdb==4.0.12
GitPython==3.1.44
gpustat==1.1.1
grpcio==1.71.0
huggingface-hub==0.31.4
idna==3.10
importlib_metadata @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_importlib-metadata_1747934053/work
ipykernel @ file:///home/conda/feedstock_root/build_artifacts/ipykernel_1719845459717/work
ipython @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_ipython_1745672166/work
ipython_pygments_lexers @ file:///home/conda/feedstock_root/build_artifacts/ipython_pygments_lexers_1737123620466/work
jedi @ file:///home/conda/feedstock_root/build_artifacts/jedi_1733300866624/work
Jinja2==3.1.4
jsonschema==4.24.0
jsonschema-specifications==2025.4.1
jupyter_client @ file:///home/conda/feedstock_root/build_artifacts/jupyter_client_1733440914442/work
jupyter_core @ file:///home/conda/feedstock_root/build_artifacts/jupyter_core_1727163409502/work
jupyterlab_pygments==0.3.0
kiwisolver==1.4.8
lxml==5.4.0
Markdown==3.8
markdown-it-py==3.0.0
MarkupSafe==2.1.5
matplotlib==3.10.3
matplotlib-inline @ file:///home/conda/feedstock_root/build_artifacts/matplotlib-inline_1733416936468/work
mdurl==0.1.2
mistune==3.1.3
mpmath==1.3.0
multidict==6.4.4
multiprocess==0.70.16
nbclient==0.10.2
nbconvert==7.16.6
nbformat==5.10.4
nest_asyncio @ file:///home/conda/feedstock_root/build_artifacts/nest-asyncio_1733325553580/work
networkx==3.4.2
numpy==2.2.6
nvidia-cublas-cu12==12.6.4.1
nvidia-cuda-cupti-cu12==12.6.80
nvidia-cuda-nvrtc-cu12==12.6.77
nvidia-cuda-runtime-cu12==12.6.77
nvidia-cudnn-cu12==9.5.1.17
nvidia-cufft-cu12==11.3.0.4
nvidia-cufile-cu12==1.11.1.6
nvidia-curand-cu12==10.3.7.77
nvidia-cusolver-cu12==11.7.1.2
nvidia-cusparse-cu12==12.5.4.2
nvidia-cusparselt-cu12==0.6.3
nvidia-ml-py==12.575.51
nvidia-nccl-cu12==2.26.5
nvidia-nvjitlink-cu12==12.6.85
nvidia-nvtx-cu12==12.6.77
packaging==25.0
pandas==2.2.3
pandocfilters==1.5.1
parso @ file:///home/conda/feedstock_root/build_artifacts/parso_1733271261340/work
pexpect @ file:///home/conda/feedstock_root/build_artifacts/pexpect_1733301927746/work
pickleshare @ file:///home/conda/feedstock_root/build_artifacts/pickleshare_1733327343728/work
pillow==11.2.1
platformdirs==4.3.8
prompt_toolkit @ file:///home/conda/feedstock_root/build_artifacts/prompt-toolkit_1744724089886/work
propcache==0.3.1
protobuf==6.31.0
psutil==7.0.0
ptyprocess @ file:///home/conda/feedstock_root/build_artifacts/ptyprocess_1733302279685/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl#sha256=92c32ff62b5fd8cf325bec5ab90d7be3d2a8ca8c8a3813ff487a8d2002630d1f
pure_eval @ file:///home/conda/feedstock_root/build_artifacts/pure_eval_1733569405015/work
pyarrow==20.0.0
pycryptodomex==3.23.0
pydantic==2.11.4
pydantic_core==2.33.2
Pygments==2.19.1
pyparsing==3.2.3
python-dateutil==2.9.0.post0
pytorch-triton==3.3.0+git96316ce5
pytz==2025.2
PyYAML==6.0.2
pyzmq @ file:///croot/pyzmq_1734687138743/work
referencing==0.36.2
regex==2024.11.6
requests==2.32.3
rich==14.0.0
rpds-py==0.25.1
safetensors==0.5.3
sentry-sdk==2.29.1
setproctitle==1.3.6
setuptools==70.2.0
shtab==1.7.2
six==1.17.0
smmap==5.0.2
soupsieve==2.7
stack_data @ file:///home/conda/feedstock_root/build_artifacts/stack_data_1733569443808/work
sympy==1.13.3
tabulate==0.9.0
tensorboard==2.19.0
tensorboard-data-server==0.7.2
tiktoken==0.9.0
tinycss2==1.4.0
tokenizers==0.21.1
torch==2.8.0.dev20250519+cu126
torchdata==0.11.0
tornado @ file:///croot/tornado_1747918059467/work
tqdm==4.67.1
traitlets @ file:///home/conda/feedstock_root/build_artifacts/traitlets_1733367359838/work
transformers==4.52.1
triton==3.3.0
typeguard==4.4.2
typing-inspection==0.4.0
typing_extensions @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_typing_extensions_1744302253/work
tyro==0.9.20
tzdata==2025.2
urllib3==2.4.0
wandb==0.19.11
wcwidth @ file:///home/conda/feedstock_root/build_artifacts/wcwidth_1733231326287/work
webencodings==0.5.1
Werkzeug==3.1.3
wheel==0.45.1
xxhash==3.5.0
yarl==1.20.0
zipp @ file:///home/conda/feedstock_root/build_artifacts/zipp_1732827521216/work
@tianyu-l tianyu-l self-assigned this May 29, 2025
@lessw2020
Copy link
Contributor

thanks for spotting this @raymin0223!
I think this same issue is being discussed / addressed here:
pytorch/pytorch#154557
but let's keep this open until this is fully resolved.

@raymin0223
Copy link
Author

Thanks for taking a look into this! I'm explicitly setting offsets[-1] = x.shape[0], i.e., padding values will go through the last expert, but end up being 0 as we expected.

@danielvegamyhre
Copy link
Contributor

I see what I think is a related issue: using llama4 debug model with FSDP=2, the loss becomes NaN at step 2 then all tokens are routed to the first expert, and then grouped gemm kernel cannot handle the case where experts are assigned 0 tokens.

@tianyu-l
Copy link
Contributor

tianyu-l commented Jun 2, 2025

@lessw2020
It looks that it's the redundant padding of size x.shape[0] - offsets[-1] which causes difficulty.
The padding to make each expert deal with multiple of ALIGN_SIZE_M is fine because they are initialized as 0's and will be multiplied by some expert in torch._grouped_mm.

Is it possible for the generate_permute_indices kernel not returning a tensor of fixed size of max_len, but of size sum(num_local_tokens_per_expert) depending on how much padding the input needs at the minimum?

update: hmm could such dynamic shapes cause difficulties for torch.compile?

@lessw2020
Copy link
Contributor

@tianyu-l - thanks for spotting this issue.
PR with exact sizing is here:
#1254
passes unit testing but need to run with actual models next.

@raymin0223
Copy link
Author

raymin0223 commented Jun 4, 2025

Thanks @lessw2020,

I tested this on Llama4 with the same settings , and it seems like this PR resolves the issue. No more nan loss!
Here are the intermediate logs while testing.

print(num_local_tokens_per_expert, sum(num_local_tokens_per_expert))
tensor([299, 254, 158, 278, 167, 343, 263, 286], device='cuda:0') tensor(2048, device='cuda:0')

print(x.shape)
torch.Size([17408, 256])

print(offsets)
tensor([ 2176,  4352,  6528,  8704, 10880, 13056, 15232, 17408], device='cuda:0', dtype=torch.int32)

print(h[299:2176])  # First expert's redundant parts 
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0', dtype=torch.bfloat16,
       grad_fn=<SliceBackward0>)

print(out[299:2176])  # First expert's redundant parts
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0', dtype=torch.bfloat16,
       grad_fn=<SliceBackward0>)

[titan] 2025-06-04 23:24:47,344 - root - INFO - step: 4200  loss:  5.0034  memory:  4.87GiB(3.48%)  tps: 40,365  tflops: 10.86  mfu: 1.10%
[titan] 2025-06-04 23:24:49,832 - root - INFO - [GC] Peforming periodical GC collection. 0.00 seconds.
[titan] 2025-06-04 23:24:52,367 - root - INFO - [GC] Peforming periodical GC collection. 0.00 seconds.
[titan] 2025-06-04 23:24:52,418 - root - INFO - step: 4300  loss:  5.0769  memory:  4.87GiB(3.48%)  tps: 40,364  tflops: 10.86  mfu: 1.10%
[titan] 2025-06-04 23:24:54,904 - root - INFO - [GC] Peforming periodical GC collection. 0.00 seconds.
[titan] 2025-06-04 23:24:57,438 - root - INFO - [GC] Peforming periodical GC collection. 0.00 seconds.
[titan] 2025-06-04 23:24:57,489 - root - INFO - step: 4400  loss:  5.0065  memory:  4.87GiB(3.48%)  tps: 40,385  tflops: 10.86  mfu: 1.10%
[titan] 2025-06-04 23:24:59,976 - root - INFO - [GC] Peforming periodical GC collection. 0.00 seconds.
[titan] 2025-06-04 23:25:02,511 - root - INFO - [GC] Peforming periodical GC collection. 0.00 seconds.

@lessw2020
Copy link
Contributor

thanks very much for your help here @raymin0223 !

lessw2020 added a commit that referenced this issue Jun 5, 2025
…rt needed, instead of max_len (#1254)

This PR switches the generate_permute_indices to move to using exact
sizes per expert needed, instead of max_len.
Thus, we now return a tensor of size sum(m_sizes) instead of max_len. 
This may resolve the current issue
[here](#1237).

Testing:
Ran both unit testing with dynamic padding, both pass.
Verified resolves Nans in running in llama4 (credit @raymin0223).
#1237 (comment)

~~~
permuted_indices_gpu=tensor([ 0, 1, 2, 3, 16, 17, 18, 19, 32, 33, 34,
35, 48, 49, 50, 51, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 4, 5, 6, 7,
20, 21, 22, 23, 36, 37, 38, 39, 52, 53, 54, 55, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 8, 9, 10, 11, 24, 25, 26, 27,
40, 41, 42, 43, 56, 57, 58, 59, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, 12, 13, 14, 15, 28, 29, 30, 31, 44, 45, 46, 47,
60, 61, 62, 63, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
        -1, -1], device='cuda:0', dtype=torch.int32), 
permuted_indices_cpu=tensor([ 0, 1, 2, 3, 16, 17, 18, 19, 32, 33, 34,
35, 48, 49, 50, 51, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 4, 5, 6, 7,
20, 21, 22, 23, 36, 37, 38, 39, 52, 53, 54, 55, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 8, 9, 10, 11, 24, 25, 26, 27,
40, 41, 42, 43, 56, 57, 58, 59, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, 12, 13, 14, 15, 28, 29, 30, 31, 44, 45, 46, 47,
60, 61, 62, 63, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
        -1, -1], dtype=torch.int32)
m_sizes=tensor([32, 32, 32, 32], device='cuda:0', dtype=torch.int32)
Success
tokens_per_expert_group = tensor([4, 0, 2, 3, 1, 0, 0, 5],
device='cuda:0', dtype=torch.int32)
total_tokens_per_expert = tensor([5, 0, 2, 8], device='cuda:0')
m_sizes = tensor([8, 8, 8, 8], device='cuda:0', dtype=torch.int32)
m_offsets = tensor([ 8, 16, 24, 32], device='cuda:0', dtype=torch.int32)
permuted_indices = tensor([ 0, 1, 2, 3, 9, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, 4, 5,
        -1, -1, -1, -1, -1, -1,  6,  7,  8, 10, 11, 12, 13, 14],
       device='cuda:0', dtype=torch.int32)
Expert 1 has zero tokens and 8 slots with all -1
All tests passed successfully!
~~~
@tianyu-l
Copy link
Contributor

tianyu-l commented Jun 5, 2025

@lessw2020
Sorry I didn't get a chance to further review #1254 before it gets merged.

My worry is that in current Llama 4 TP, this will make DTensor sharding propagation cache miss, as the shapes of the input could change a lot from iteration to iteration.

I have one more question/request: is it possible to still always return permuted_indices of size original_shape + num_experts * ALIGN_SIZE_M as before, but let the last element of num_local_tokens_per_expert covering the remaining 0's trimmed in the PR?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants