Skip to content

Conversation

@aleozlx
Copy link
Collaborator

@aleozlx aleozlx commented Dec 6, 2025

📌 Description

Fix sm110 moe (cutlass backend) functional regression : no viable configs

Found dispatch logic changed and extra cluster constraint from #2020 / #1925 trying to support sm103 sm120 new features. The logic haven't been tested on sm110 due to a current lack of CI resources

(to be tested)

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • Chores
    • Optimized GPU compute capability handling for improved architecture-specific performance across supported hardware configurations.
    • Adjusted fallback mechanisms and resource validation logic for enhanced compatibility across GPU generations.

✏️ Tip: You can customize this high-level summary in your review settings.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Dec 6, 2025

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

📝 Walkthrough

Walkthrough

This change modifies the MOE GEMM kernel dispatch logic for NVIDIA TensorRT-LLM, specifically narrowing the compute capability window for SM100 dispatch, expanding SM100+ fallback conditions to include SM110, and adding arch-specific gating for tile shape support checks to differentiate SM110 from other SM100 configurations.

Changes

Cohort / File(s) Summary
MOE GEMM TMA Dispatch Logic
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h
Narrowed SM100 compute capability window from <120 to <110 in dispatchMoeGemmFinalDispatchTmaWarpSpecialized; expanded fallback path to include SM100/SM110/SM120/SM90 with arch-specific condition for kMinComputeCapability == 110; added conditional gating to exclude runtime-cluster-shape checks for SM110 in are_tile_shapes_supported_sm100

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Possibly related PRs

  • Fix for moe on sm110 #2190: Makes coordinated changes to MOE GEMM dispatch for SM110, adjusting SM110 handling and TMA epilogue/dispatch logic in parallel.

Suggested reviewers

  • aleozlx
  • djmmoss
  • wenscarl

Poem

🐰 A kernel fine-tunes its dispatch dance,
SM110 gets special circumstance,
Tile shapes supported with arch-aware care,
Compute paths narrowed with precision rare,
GEMM's mixture of experts now shines bright! ✨

Pre-merge checks and finishing touches

❌ Failed checks (1 warning, 1 inconclusive)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
Title check ❓ Inconclusive The title is vague and generic, using non-descriptive terms like 'Fix/moe_sm110' and '(to be tested)' that lack sufficient detail about the specific problem being fixed. Revise the title to be more specific about the fix, e.g., 'Fix SM110 MoE dispatch regression by adjusting cluster constraints' to clearly convey the main change.
✅ Passed checks (1 passed)
Check name Status Explanation
Description check ✅ Passed The description includes the required template sections with adequate explanation of the issue, root cause attribution, and checklist completion; however, the PR is explicitly marked 'to be tested' with unverified status.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

📜 Recent review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between ad3f26b and 18428e5.

📒 Files selected for processing (1)
  • csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.

Applied to files:

  • csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h
🧬 Code graph analysis (1)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h (2)
csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h (1)
  • ClusterShape (246-373)
include/flashinfer/gemm/cutlass_gemm_configs.h (1)
  • ClusterShape (249-386)
🔇 Additional comments (2)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h (2)

179-179: LGTM - SM100 dispatch boundary correctly narrowed.

The condition change from < 120 to < 110 correctly routes SM110 away from the dynamic cluster shape path that was causing the regression. SM100 and SM103 remain in this path as intended.


207-222: LGTM - SM110 correctly added to static cluster shape dispatch path.

SM110 is now grouped with SM90 and SM120+ for the fallback path using static cluster shapes (dynamic_cga = false). This aligns SM110's dispatch behavior with architectures that don't use runtime cluster shapes.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @aleozlx, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a functional regression in the Mixture-of-Experts (MoE) Cutlass backend for SM110 architectures, which previously failed to find viable configurations. The fix involves adjusting the kernel dispatch logic and relaxing certain cluster shape constraints specifically for SM110, ensuring it correctly utilizes the MoE kernels. These changes were necessary due to recent updates intended for newer architectures that inadvertently affected SM110.

Highlights

  • Dispatch Logic Adjustment: The dispatch logic for Arch::kMinComputeCapability has been refined. The range [100, 120) was narrowed to [100, 110), excluding SM110 from a specific dispatch path.
  • SM110 Dispatch Path Rerouting: SM110 is now explicitly included in a different dispatch path, aligning its behavior with SM120 and SM90 architectures for certain kernel configurations.
  • Cluster Shape Constraint Relaxation: A specific cluster shape constraint within the are_tile_shapes_supported_sm100 function has been conditionally disabled for SM110, addressing the 'no viable configs' issue.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a functional regression for SM110 in the MoE CUTLASS backend by adjusting the dispatch logic. The changes correctly move SM110 to a different dispatch path, aligning it with SM90 and SM120+ instead of SM100/103. A corresponding change is made to the tile shape validation to remove a failing constraint for SM110.

While the changes appear to correctly fix the issue, the patch to the validation logic reveals a structural inconsistency between how architectures are grouped for dispatch versus validation. I've left a comment with a suggestion to address this for better long-term maintainability.

Comment on lines +229 to 235
if constexpr (Arch::kMinComputeCapability != 110) {
// We use a runtime cluster shape for SM100, so we only support 1x1x1 and 2x1x1 cluster shapes.
if (cute::size<0>(ClusterShape{}) > 2 || cute::size<1>(ClusterShape{}) != 1 ||
cute::size<2>(ClusterShape{}) != 1) {
return false;
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

While this change fixes the regression for SM110, special-casing it within are_tile_shapes_supported_sm100 highlights a structural inconsistency that could affect maintainability.

The changes in dispatchMoeGemmFinalDispatchTmaWarpSpecialized correctly group SM110 with SM90 and SM120+, separating it from the SM100/103 path. However, the validation logic in the calling function are_tile_shapes_supported still groups SM110 with SM100/103, which necessitates this patch.

This discrepancy makes the code harder to reason about, as the logic for SM110 is fragmented. For better long-term code health, the validation paths in are_tile_shapes_supported should be refactored to align with the dispatch paths. I recommend creating a follow-up technical debt issue to address this.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed with Gemini's suggestion

@yzh119 yzh119 marked this pull request as ready for review December 28, 2025 07:47
@yzh119
Copy link
Collaborator

yzh119 commented Dec 28, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !220 has been created, and the CI pipeline #40898746 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #40898746: 1/20 passed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants