-
Notifications
You must be signed in to change notification settings - Fork 386
Llama 4 issue tracking #1118
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
I wanna contribute. Can you assign this to me ? |
@Wodlfvllf there's a list of issues / items, do you have preference on what to work on? |
Have you tried this with cuda12.8? I was able to run the debug model for 1000 iterations on torch nightly with cuda 12.8, using AdamW. Edit: Same for Full AC and Selective AC, int or op.
Successfully run 1000 iters debug model with AdamW+(F/S)AC with cuda 12.8. |
A deeper dive on the torch group gemm hanging issue finds that the root cause is if an expert is assigned zero tokens...this will cause a hang in the backwards pass. (big thanks to @rciocoiu for the initial isolation work and tracking this core issue down). The choice of Adam vs AdamW does matter b/c with Adam, an expert has to get to zero tokens assigned before hanging, vs with AdamW, it only has to get to the min_align_m setting (default is 16) and thus that happens sooner, matching users reporting that it hangs faster with AdamW. have opened a Pytorch issue on this here: pytorch/pytorch#152668 and there's a min repro case setup to hang or not hang by simply having a zero offset (tokens assigned) expert. |
I already raised this issue with cutlass, they indeed don't support K=0 today, they promise a fix soon though. The best thing would be to indeed pad the group to minimum supported tokens (8 for bf16, 16 for fp8) |
…to avoid hangs with torch_group_gemm (#1166) This PR updates generate_permute_indices to enable 'auto padding' for experts that have zero tokens assigned and resolves the hang that was being encountered with llama4 titan and group gemm. This autopadding is vital to ensure that torch group gemm is able to process the backwards pass, as zero token experts currently cause a hang. (see #1118 and pytorch/pytorch#152668) Further, because we now track the total_tokens_per_expert, this PR adds in 'skip logic' in the triton kernel based on being able to jump over experts with zero tokens. Usage: no user change is needed. We simply auto-pad zero token experts to alignment size tokens. Testing: a - ran to 2K iters with expert load balancing disabled (as this forces zero token expert scenario) successfully with AdamW. AdamW hangs faster previously and Adam and AdamW both would hang if an expert had zero tokens. b - added unit test for zero token expert in the indices.py as part of the fast simple testing (and verified passing). c - verified can run inference with same torch group gemm. (previous PR I had with auto-padding would crash so that is a key test). Screenshot - forced zero experts by removing load balancing of experts..note the many zero token experts but successfully running to 2K: <img width="1040" alt="Screenshot 2025-05-05 at 7 55 00 PM" src="https://github.com/user-attachments/assets/85ac21b0-2b2a-4916-a318-2ebb4530e3b9" />
The PR to fix the hang for groupedMM (torch._grouped_mm) has landed and that should resolve the issues regarding hanging with groupedMM, regardless of optimizer. |
…to avoid hangs with torch_group_gemm (#1166) This PR updates generate_permute_indices to enable 'auto padding' for experts that have zero tokens assigned and resolves the hang that was being encountered with llama4 titan and group gemm. This autopadding is vital to ensure that torch group gemm is able to process the backwards pass, as zero token experts currently cause a hang. (see #1118 and pytorch/pytorch#152668) Further, because we now track the total_tokens_per_expert, this PR adds in 'skip logic' in the triton kernel based on being able to jump over experts with zero tokens. Usage: no user change is needed. We simply auto-pad zero token experts to alignment size tokens. Testing: a - ran to 2K iters with expert load balancing disabled (as this forces zero token expert scenario) successfully with AdamW. AdamW hangs faster previously and Adam and AdamW both would hang if an expert had zero tokens. b - added unit test for zero token expert in the indices.py as part of the fast simple testing (and verified passing). c - verified can run inference with same torch group gemm. (previous PR I had with auto-padding would crash so that is a key test). Screenshot - forced zero experts by removing load balancing of experts..note the many zero token experts but successfully running to 2K: <img width="1040" alt="Screenshot 2025-05-05 at 7 55 00 PM" src="https://github.com/user-attachments/assets/85ac21b0-2b2a-4916-a318-2ebb4530e3b9" />
…to avoid hangs with torch_group_gemm (#1166) This PR updates generate_permute_indices to enable 'auto padding' for experts that have zero tokens assigned and resolves the hang that was being encountered with llama4 titan and group gemm. This autopadding is vital to ensure that torch group gemm is able to process the backwards pass, as zero token experts currently cause a hang. (see #1118 and pytorch/pytorch#152668) Further, because we now track the total_tokens_per_expert, this PR adds in 'skip logic' in the triton kernel based on being able to jump over experts with zero tokens. Usage: no user change is needed. We simply auto-pad zero token experts to alignment size tokens. Testing: a - ran to 2K iters with expert load balancing disabled (as this forces zero token expert scenario) successfully with AdamW. AdamW hangs faster previously and Adam and AdamW both would hang if an expert had zero tokens. b - added unit test for zero token expert in the indices.py as part of the fast simple testing (and verified passing). c - verified can run inference with same torch group gemm. (previous PR I had with auto-padding would crash so that is a key test). Screenshot - forced zero experts by removing load balancing of experts..note the many zero token experts but successfully running to 2K: <img width="1040" alt="Screenshot 2025-05-05 at 7 55 00 PM" src="https://github.com/user-attachments/assets/85ac21b0-2b2a-4916-a318-2ebb4530e3b9" />
Uh oh!
There was an error while loading. Please reload this page.
High priority
torch._grouped_mm
(and the triton kernel for aligning indices)Not high priority for now
Not llama4 specific
The text was updated successfully, but these errors were encountered: