Skip to content

Llama 4 issue tracking #1118

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
3 of 14 tasks
tianyu-l opened this issue Apr 17, 2025 · 6 comments
Open
3 of 14 tasks

Llama 4 issue tracking #1118

tianyu-l opened this issue Apr 17, 2025 · 6 comments

Comments

@tianyu-l
Copy link
Contributor

tianyu-l commented Apr 17, 2025

High priority

Not high priority for now

  • for-loop implementation of MoE
    • with DTensor TP: sharding propagation overhead due to dynamic shapes
      • need to lift cache hit criteria in DTensor sharding prop
      • may be needed by Loss Parallel for per-sequence loss as well
    • with torch.compile: branching on “unbacked” symbolic ints
      • static padding of DTensor may solve this

Not llama4 specific

@Wodlfvllf
Copy link

I wanna contribute. Can you assign this to me ?

@tianyu-l
Copy link
Contributor Author

@Wodlfvllf there's a list of issues / items, do you have preference on what to work on?

@lkhphuc
Copy link
Contributor

lkhphuc commented Apr 30, 2025

[ ]
with AdamW

* gets stuck after a couple of iterations

Have you tried this with cuda12.8? I was able to run the debug model for 1000 iterations on torch nightly with cuda 12.8, using AdamW.

Edit: Same for Full AC and Selective AC, int or op.

with Activation Checkpointing
gets stuck after a couple of iterations

Successfully run 1000 iters debug model with AdamW+(F/S)AC with cuda 12.8.

@lessw2020
Copy link
Contributor

lessw2020 commented May 2, 2025

A deeper dive on the torch group gemm hanging issue finds that the root cause is if an expert is assigned zero tokens...this will cause a hang in the backwards pass. (big thanks to @rciocoiu for the initial isolation work and tracking this core issue down).

The choice of Adam vs AdamW does matter b/c with Adam, an expert has to get to zero tokens assigned before hanging, vs with AdamW, it only has to get to the min_align_m setting (default is 16) and thus that happens sooner, matching users reporting that it hangs faster with AdamW.

have opened a Pytorch issue on this here: pytorch/pytorch#152668 and there's a min repro case setup to hang or not hang by simply having a zero offset (tokens assigned) expert.
Will investigate if we can safely pad out any expert assignments that are zero with min_aligned_tokens as the likely workaround/fix. This may be also fixable directly in the kernel but I think the fastest fix is to autopad.

@ngimel
Copy link

ngimel commented May 2, 2025

I already raised this issue with cutlass, they indeed don't support K=0 today, they promise a fix soon though. The best thing would be to indeed pad the group to minimum supported tokens (8 for bf16, 16 for fp8)

lessw2020 added a commit that referenced this issue May 7, 2025
…to avoid hangs with torch_group_gemm (#1166)

This PR updates generate_permute_indices to enable 'auto padding' for
experts that have zero tokens assigned and resolves the hang that was
being encountered with llama4 titan and group gemm.

This autopadding is vital to ensure that torch group gemm is able to
process the backwards pass, as zero token experts currently cause a
hang. (see #1118 and
pytorch/pytorch#152668)

Further, because we now track the total_tokens_per_expert, this PR adds
in 'skip logic' in the triton kernel based on being able to jump over
experts with zero tokens.

Usage:
no user change is needed. We simply auto-pad zero token experts to
alignment size tokens.

Testing:
a - ran to 2K iters with expert load balancing disabled (as this forces
zero token expert scenario) successfully with AdamW. AdamW hangs faster
previously and Adam and AdamW both would hang if an expert had zero
tokens.
b - added unit test for zero token expert in the indices.py as part of
the fast simple testing (and verified passing).
c - verified can run inference with same torch group gemm. (previous PR
I had with auto-padding would crash so that is a key test).

Screenshot - forced zero experts by removing load balancing of
experts..note the many zero token experts but successfully running to
2K:
<img width="1040" alt="Screenshot 2025-05-05 at 7 55 00 PM"
src="https://github.com/user-attachments/assets/85ac21b0-2b2a-4916-a318-2ebb4530e3b9"
/>
@lessw2020
Copy link
Contributor

The PR to fix the hang for groupedMM (torch._grouped_mm) has landed and that should resolve the issues regarding hanging with groupedMM, regardless of optimizer.
I ran 2K iters with no issue, with the expert load balancing disabled to force lots of zero token experts (zero token experts was the core issue wrt to the hang).
@tianyu-l
#1166

wwwjn pushed a commit that referenced this issue May 9, 2025
…to avoid hangs with torch_group_gemm (#1166)

This PR updates generate_permute_indices to enable 'auto padding' for
experts that have zero tokens assigned and resolves the hang that was
being encountered with llama4 titan and group gemm.

This autopadding is vital to ensure that torch group gemm is able to
process the backwards pass, as zero token experts currently cause a
hang. (see #1118 and
pytorch/pytorch#152668)

Further, because we now track the total_tokens_per_expert, this PR adds
in 'skip logic' in the triton kernel based on being able to jump over
experts with zero tokens.

Usage:
no user change is needed. We simply auto-pad zero token experts to
alignment size tokens.

Testing:
a - ran to 2K iters with expert load balancing disabled (as this forces
zero token expert scenario) successfully with AdamW. AdamW hangs faster
previously and Adam and AdamW both would hang if an expert had zero
tokens.
b - added unit test for zero token expert in the indices.py as part of
the fast simple testing (and verified passing).
c - verified can run inference with same torch group gemm. (previous PR
I had with auto-padding would crash so that is a key test).

Screenshot - forced zero experts by removing load balancing of
experts..note the many zero token experts but successfully running to
2K:
<img width="1040" alt="Screenshot 2025-05-05 at 7 55 00 PM"
src="https://github.com/user-attachments/assets/85ac21b0-2b2a-4916-a318-2ebb4530e3b9"
/>
wwwjn pushed a commit that referenced this issue May 16, 2025
…to avoid hangs with torch_group_gemm (#1166)

This PR updates generate_permute_indices to enable 'auto padding' for
experts that have zero tokens assigned and resolves the hang that was
being encountered with llama4 titan and group gemm.

This autopadding is vital to ensure that torch group gemm is able to
process the backwards pass, as zero token experts currently cause a
hang. (see #1118 and
pytorch/pytorch#152668)

Further, because we now track the total_tokens_per_expert, this PR adds
in 'skip logic' in the triton kernel based on being able to jump over
experts with zero tokens.

Usage:
no user change is needed. We simply auto-pad zero token experts to
alignment size tokens.

Testing:
a - ran to 2K iters with expert load balancing disabled (as this forces
zero token expert scenario) successfully with AdamW. AdamW hangs faster
previously and Adam and AdamW both would hang if an expert had zero
tokens.
b - added unit test for zero token expert in the indices.py as part of
the fast simple testing (and verified passing).
c - verified can run inference with same torch group gemm. (previous PR
I had with auto-padding would crash so that is a key test).

Screenshot - forced zero experts by removing load balancing of
experts..note the many zero token experts but successfully running to
2K:
<img width="1040" alt="Screenshot 2025-05-05 at 7 55 00 PM"
src="https://github.com/user-attachments/assets/85ac21b0-2b2a-4916-a318-2ebb4530e3b9"
/>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

5 participants