Skip to content

Conversation

@PerkzZheng
Copy link
Contributor

@PerkzZheng PerkzZheng commented Dec 24, 2025

📌 Description

This MR adds the optimized decode attention kernels for high throughput (large batch size) + speculative decoding (seqlen_q > 1).

See below for speedups (collected by benchmarks/flashinfer_benchmark.py). The seqlenKv is 16K for all cases.

test case median_time_ms median_time_ms (opt) speedup
Qwen3-480B-fp8_e4m3-batchSize8-seqLenQ2 0.057 0.046 1.24
Qwen3-480B-fp8_e4m3-batchSize16-seqLenQ2 0.11 0.083 1.33
Qwen3-480B-fp8_e4m3-batchSize32-seqLenQ2 0.213 0.168 1.27
Qwen3-480B-fp8_e4m3-batchSize40-seqLenQ2 0.266 0.241 1.10
Qwen3-480B-fp8_e4m3-batchSize64-seqLenQ2 0.432 0.336 1.29
Qwen3-480B-fp8_e4m3-batchSize8-seqLenQ4 0.109 0.048 2.27
Qwen3-480B-fp8_e4m3-batchSize16-seqLenQ4 0.212 0.083 2.55
Qwen3-480B-fp8_e4m3-batchSize32-seqLenQ4 0.371 0.168 2.21
Qwen3-480B-fp8_e4m3-batchSize40-seqLenQ4 0.472 0.245 1.93
Qwen3-480B-fp8_e4m3-batchSize64-seqLenQ4 0.736 0.348 2.11
Qwen3-480B-fp8_e4m3-batchSize8-seqLenQ8 0.212 0.061 3.48
Qwen3-480B-fp8_e4m3-batchSize16-seqLenQ8 0.37 0.106 3.49
Qwen3-480B-fp8_e4m3-batchSize32-seqLenQ8 0.732 0.239 3.06
Qwen3-480B-fp8_e4m3-batchSize40-seqLenQ8 0.937 0.321 2.92
Qwen3-480B-fp8_e4m3-batchSize64-seqLenQ8 1.456 0.484 3.01
GPT-OSS-fp8_e4m3-batchSize8-seqLenQ2 0.051 0.03 1.70
GPT-OSS-fp8_e4m3-batchSize16-seqLenQ2 0.098 0.054 1.81
GPT-OSS-fp8_e4m3-batchSize32-seqLenQ2 0.188 0.104 1.81
GPT-OSS-fp8_e4m3-batchSize40-seqLenQ2 0.234 0.15 1.56
GPT-OSS-fp8_e4m3-batchSize64-seqLenQ2 0.332 0.199 1.67
GPT-OSS-fp8_e4m3-batchSize8-seqLenQ4 0.099 0.038 2.61
GPT-OSS-fp8_e4m3-batchSize16-seqLenQ4 0.188 0.07 2.69
GPT-OSS-fp8_e4m3-batchSize32-seqLenQ4 0.332 0.136 2.44
GPT-OSS-fp8_e4m3-batchSize40-seqLenQ4 0.418 0.2 2.09
GPT-OSS-fp8_e4m3-batchSize64-seqLenQ4 0.647 0.265 2.44
GPT-OSS-fp8_e4m3-batchSize8-seqLenQ8 0.188 0.039 4.82
GPT-OSS-fp8_e4m3-batchSize16-seqLenQ8 0.332 0.065 5.11
GPT-OSS-fp8_e4m3-batchSize32-seqLenQ8 0.647 0.126 5.13
GPT-OSS-fp8_e4m3-batchSize40-seqLenQ8 0.83 0.185 4.49
GPT-OSS-fp8_e4m3-batchSize64-seqLenQ8 1.29 0.245 5.27

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • Bug Fixes

    • Fixed generation attention masking to ensure correct causal behavior during token generation.
  • Performance / Refactor

    • Improved kernel selection, on-demand loading, and launch/configuration logic for more robust and efficient attention execution across devices and SM variants.
  • Chores

    • Updated artifact paths and checksums for FMHA kernels.
  • Tests

    • Expanded parameterized tests to cover larger batch decoding scenarios.

✏️ Tip: You can customize this high-level summary in your review settings.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Dec 24, 2025

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

📝 Walkthrough

Walkthrough

Refactors FMHA kernel selection/launch to on-demand cubin loading and new CtaLaunchParams, extends kernel hashing/selection surface and TMA token grouping logic, switches generation path maskType to Causal, updates artifact checksum/path, and adds head_dim=256 batch=32 tests.

Changes

Cohort / File(s) Summary
FMHA Kernel Selection & Launching Refactor
include/flashinfer/trtllm/fmha/fmhaKernels.cuh
Added isFamilySpecificSMPair; introduced public CtaLaunchParams and type aliases; changed signatures for hashID, hashFromRunnerParams, loadKernel, computeCtaAndClusterConfig; added kernel-selection helpers; implemented on-demand module/cubin loading and updated grid/cluster/launch logic.
Kernel Parameter Infrastructure
include/flashinfer/trtllm/fmha/kernelParams.h
Added FastModDivInt32; added mInflateMax, mNumTokensPerCtaQ, and mNumHeadsQPerKvDivisor to KernelParams; changed makeTmaShapeStrideQ to accept groupsTokensHeadsQ and return numTokensPerCtaQ; updated callers and setKernelParams.
Kernel Selection Parameters
include/flashinfer/trtllm/fmha/fmhaRunnerParams.h
Added mNumTokensPerPage and mTileSizeQ to TllmGenSelectKernelParams and initialized them (default mTileSizeQ=128).
Kernel Launcher Masking
csrc/trtllm_fmha_kernel_launcher.cu
In trtllm_paged_attention_launcher generation branch, changed mMaskType from Dense to Causal; added clarifying comment that naming uses dense mask for perf while causal/dense behave equivalently per-CTA.
Artifacts
flashinfer/artifacts.py
Updated ArtifactPath.TRTLLM_GEN_FMHA and CheckSumHash.TRTLLM_GEN_FMHA constant values.
Tests
tests/attention/test_trtllm_gen_attention.py
Added three parameterized test cases for head_dim=256 with batch_size=32 across page_size/num_kv_heads/head_grp_size combinations.
sequenceDiagram
  autonumber
  participant Host as Host (Runner)
  participant Selector as SelectKernelParams
  participant Loader as KernelLoader (module cache)
  participant CUDA as CUDA Driver
  participant GPU as GPU Kernel

  Host->>Selector: prepare RunnerParams & selection params
  Note right of Selector `#bfe9e0`: decide tileSizeQ, scheduler,\nmaskType, kernelType, numTokensPerPage
  Selector->>Loader: request kernel (hash from params)
  alt cached
    Loader-->>Selector: return CUfunction + KernelMeta (cached)
  else not cached
    Loader->>CUDA: load module/cubin
    CUDA-->>Loader: CUfunction + KernelMeta
    Loader-->>Selector: return CUfunction + KernelMeta
  end
  Host->>CUDA: configure grid/cluster using CtaLaunchParams
  CUDA->>GPU: launch kernel (CUfunction) with smem from KernelMeta
  GPU-->>Host: execution complete / results
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related issues

Possibly related PRs

Suggested reviewers

  • joker-eph
  • aleozlx
  • cyx-6
  • wenscarl
  • djmmoss
  • yzh119
  • jiahanc

Poem

🐰 I hopped through kernels, hashes, and queues,
CtaLaunchParams snug in my paws like new shoes.
Causal masks whispered while modules load fast,
Tokens and tiles arranged to hold fast.
Hop—FMHA refined, now let’s run to the grass! 🥕

Pre-merge checks and finishing touches

✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed The PR title clearly summarizes the main change: adding optimized trtllm-gen decode kernels for high throughput and speculative decoding scenarios.
Description check ✅ Passed The PR description includes a comprehensive explanation of changes, detailed benchmark results, completed pre-commit checks, and confirmation that tests were added/updated and passing.
Docstring Coverage ✅ Passed Docstring coverage is 80.00% which is sufficient. The required threshold is 80.00%.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

📜 Recent review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between e7f9ba0 and e2734cd.

📒 Files selected for processing (1)
  • include/flashinfer/trtllm/fmha/kernelParams.h
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.

Applied to files:

  • include/flashinfer/trtllm/fmha/kernelParams.h
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (7)
include/flashinfer/trtllm/fmha/kernelParams.h (7)

44-62: LGTM - FastModDivInt32 implementation correctly addresses prior feedback.

The edge case for divisor == 1 is now handled via std::max(ceilLog2(mDivisor) - 1, 0), and the code reuses flashinfer::ceil_div instead of a local duplicate. The mAdd field is correctly left at 0 per the CCCL fast modulo division algorithm for signed types.


148-150: LGTM - New inflation parameter is well-documented.

The mInflateMax field has a clear comment explaining its purpose. The field will be zero-initialized via memset in setKernelParams.


173-180: LGTM - New kernel parameters for token/head grouping.

The mNumHeadsQPerKvDivisor is correctly typed as FastModDivInt32{1} for fast modulo operations, and mNumTokensPerCtaQ supports the new grouping kernel feature.


205-207: LGTM - Clean API extension.

The new groupsTokensHeadsQ parameter cleanly extends the function signature to support the new token/head grouping mode.


589-591: Helpful debugging guidance added.

The updated comment clarifying that errors may originate from previous kernels and suggesting CUDA_LAUNCH_BLOCKING or cuda-gdb is a useful addition for debugging TMA initialization failures.


639-641: LGTM - Call site correctly updated.

The structured binding now captures the new numTokensPerCtaQ return value, and kernelMeta.mGroupsTokensHeadsQ is passed as the new parameter.


796-798: LGTM - Kernel parameters correctly assigned.

The mNumHeadsQPerKvDivisor is initialized from options.mNumHeadsQPerKv, and mNumTokensPerCtaQ is propagated from the makeTmaShapeStrideQ return value.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @PerkzZheng, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the performance of decode attention kernels, particularly for high-throughput scenarios involving large batch sizes and speculative decoding. By introducing optimized kernels and a refined dynamic kernel selection strategy, the changes aim to improve efficiency and reduce latency in generation-phase operations. The update also includes adjustments to mask handling, expanded kernel parameterization, and new test cases to ensure robustness and validate the reported speedups.

Highlights

  • Optimized Attention Kernels: Introduced new decode attention kernels specifically designed for high-throughput (large batch size) and speculative decoding (seqlen_q > 1), leading to significant speedups as demonstrated in the PR description benchmarks.
  • Dynamic Kernel Selection Logic: Implemented a more sophisticated kernel selection mechanism for generation-phase attention, including heuristics for MLA and GQA kernels, and an experimental kernel-timing model for grouping tokens and heads into a single CTA to find the optimal tileSizeQ.
  • Mask Type Refinement: Changed the default attention mask type for generation kernels from Dense to Causal in the launcher, with clarification that for single-token queries, a dense mask behaves like a causal mask for performance.
  • Kernel Parameterization Enhancements: Extended kernel parameters to include tileSizeQ and numTokensPerCtaQ, and introduced a FastModDivInt32 utility for efficient division operations within kernel parameters.
  • Artifact Updates: Updated compiled kernel artifacts and their checksums to reflect the new optimizations and kernel changes.
  • Test Coverage: Added new test cases to validate the performance and correctness of the batch decode attention with various configurations, specifically for head_dim_256.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@PerkzZheng PerkzZheng changed the title [TRTLLM-Gen Fmha] update trtllm-gen to support groups tokens and headsQ [TRTLLM-Gen Fmha] add optimized trtllm-gen decode kernels for high throughput + speculative decoding Dec 24, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces significant optimizations for speculative decoding in TRT-LLM's FMHA kernels by adding support for grouping tokens and heads. The changes are extensive, involving a major refactoring of the kernel selection logic, hash calculation, and kernel parameter structures. The introduction of a more modular, heuristic-based kernel selection mechanism is a notable improvement. My review focuses on a potential bug in the new performance heuristic and a minor improvement in a math utility function.

Comment on lines 559 to 655
void selectTileSizeQForGqaGeneration(RunnerParams const& params,
SelectKernelParams& selectKernelParams) const {
// Define the per-tile mainloop cost model for different tileSizeQ choices.
std::unordered_map<int, float> kernelMainloopCost = {
{128, 2.2}, // Cost factor when tileSizeQ = 128
{64, 1.68}, // Cost factor when tileSizeQ = 64
{32, 1.48}, // Cost factor when tileSizeQ = 32
{16, 1.2}, // Cost factor when tileSizeQ = 16
{8, 1.0} // Cost factor when tileSizeQ = 8
};

// Define the per-tile reduction cost model for different tileSizeQ choices.
std::unordered_map<int, float> kernelReductionCost = {
{128, 1.32}, // Reduction cost factor when tileSizeQ = 128
{64, 1.2}, // Reduction cost factor when tileSizeQ = 64
{32, 1.08}, // Reduction cost factor when tileSizeQ = 32
{16, 1.03}, // Reduction cost factor when tileSizeQ = 16
{8, 1.0} // Reduction cost factor when tileSizeQ = 8
};

// The reduction cost emulated as a sequence length factor.
float const kernelReductionSeqLenFactor = 128.0f;

// The parameters for launching the kernel.
CtaLaunchParams ctaLaunchParams;
// The copy of the selectKernelParams, which makes sure it won't modify the original
// selectKernelParams when computing the number of CTAs.
SelectKernelParams selectKernelParamsCopy = selectKernelParams;
// Load the kernel.
auto [func, kernelMeta] = loadKernel(params, selectKernelParamsCopy);
// Compute numCtasX, numCtasY and numCtasZ.
computeCtaAndClusterConfig(ctaLaunchParams, params, kernelMeta, selectKernelParamsCopy);

// If there are no free SMs or tileSizeQ is already the smallest one, skip the heuristic
// selection.
if (ctaLaunchParams.mNumCtasX * ctaLaunchParams.mNumCtasY * ctaLaunchParams.mNumCtasZ * 2 >
params.mMultiProcessorCount ||
selectKernelParamsCopy.mTileSizeQ <= 8) {
// No need to select the kernel further.
return;
}

// Candidate tile sizes for tileSizeQ to explore.
int const candidateTileSizesQ[] = {128, 64, 32, 16, 8};

// The default tileSizeQ.
int defaultTileSizeQ = selectKernelParamsCopy.mTileSizeQ;
// The selected tileSizeQ.
int selectedTileSizeQ = selectKernelParamsCopy.mTileSizeQ;

// The minimum modeling kernel time.
float globalModelingKernelTime = FLT_MAX;
// Loop over each candidate tile size.
for (int tileSizeQ : candidateTileSizesQ) {
// Only consider candidates <= default tileSizeQ.
if (tileSizeQ > defaultTileSizeQ) {
continue;
}

// Compute the number of CTAs.
computeCtaAndClusterConfig(ctaLaunchParams, params, kernelMeta, selectKernelParamsCopy);

// Compute the seqLenPerCtaKv.
int32_t seqLenPerCtaKv =
flashinfer::ceil_div(flashinfer::ceil_div(params.mMaxSeqLenKv, kernelMeta.mStepKv),
ctaLaunchParams.mMaxNumCtasKv) *
kernelMeta.mStepKv;

// Compute the modeling kernel time = mainloop cost + reduction cost.
float modelingKernelTime = kernelMainloopCost[tileSizeQ] * seqLenPerCtaKv +
kernelReductionCost[tileSizeQ] * kernelReductionSeqLenFactor *
ctaLaunchParams.mMaxNumCtasKv;

// Compute the total number of CTAs.
int32_t numCtas =
ctaLaunchParams.mNumCtasX * ctaLaunchParams.mNumCtasY * ctaLaunchParams.mNumCtasZ;
// Compute the number of waves.
int32_t numWaves = flashinfer::ceil_div(numCtas, params.mMultiProcessorCount);
// Compute the total modeling kernel time.
modelingKernelTime *= numWaves;

// If this candidate has a lower time than the global minimum, update the global minimum.
if (modelingKernelTime < globalModelingKernelTime) {
globalModelingKernelTime = modelingKernelTime;
selectedTileSizeQ = tileSizeQ;
}
}

// Update the tileSizeQ.
selectKernelParams.mTileSizeQ = selectedTileSizeQ;
// Update the kernel type.
if (selectKernelParams.mTileSizeQ >= 64) {
selectKernelParams.mKernelType = FmhaKernelType::KeepsMmaAbForGeneration;
} else {
selectKernelParams.mKernelType = FmhaKernelType::SwapsMmaAbForGeneration;
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The heuristic for selecting tileSizeQ in selectTileSizeQForGqaGeneration appears to have a logical flaw. The kernelMeta and ctaLaunchParams are computed once before the loop that iterates through candidate tileSizeQ values. However, inside the loop, computeCtaAndClusterConfig is called again with the same stale kernelMeta, which means ctaLaunchParams is not updated based on the candidate tileSizeQ. This leads to an incorrect cost model evaluation, as the launch configuration (number of CTAs, waves, etc.) doesn't change with the tile size being evaluated.

To correctly evaluate each candidate tileSizeQ, the corresponding kernelMeta should be retrieved, and the launch parameters should be recomputed within the loop. This can be done efficiently by looking up the kernel metadata from mKernelMetaMap without fully reloading the kernel function in each iteration. I've provided a suggested refactoring of the function to address this.

  void selectTileSizeQForGqaGeneration(RunnerParams const& params,
                                       SelectKernelParams& selectKernelParams) const {
    // Define the per-tile mainloop cost model for different tileSizeQ choices.
    std::unordered_map<int, float> kernelMainloopCost = {
        {128, 2.2},  // Cost factor when tileSizeQ = 128
        {64, 1.68},  // Cost factor when tileSizeQ = 64
        {32, 1.48},  // Cost factor when tileSizeQ = 32
        {16, 1.2},   // Cost factor when tileSizeQ = 16
        {8, 1.0}     // Cost factor when tileSizeQ = 8
    };

    // Define the per-tile reduction cost model for different tileSizeQ choices.
    std::unordered_map<int, float> kernelReductionCost = {
        {128, 1.32},  // Reduction cost factor when tileSizeQ = 128
        {64, 1.2},    // Reduction cost factor when tileSizeQ = 64
        {32, 1.08},   // Reduction cost factor when tileSizeQ = 32
        {16, 1.03},   // Reduction cost factor when tileSizeQ = 16
        {8, 1.0}      // Reduction cost factor when tileSizeQ = 8
    };

    // The reduction cost emulated as a sequence length factor.
    float const kernelReductionSeqLenFactor = 128.0f;

    // The parameters for launching the kernel.
    CtaLaunchParams ctaLaunchParams;
    // The copy of the selectKernelParams, which makes sure it won't modify the original
    // selectKernelParams when computing the number of CTAs.
    SelectKernelParams selectKernelParamsCopy = selectKernelParams;

    // Candidate tile sizes for tileSizeQ to explore.
    int const candidateTileSizesQ[] = {128, 64, 32, 16, 8};

    // The default tileSizeQ.
    int defaultTileSizeQ = selectKernelParamsCopy.mTileSizeQ;
    // The selected tileSizeQ.
    int selectedTileSizeQ = selectKernelParamsCopy.mTileSizeQ;

    // The minimum modeling kernel time.
    float globalModelingKernelTime = FLT_MAX;
    // Loop over each candidate tile size.
    for (int tileSizeQ : candidateTileSizesQ) {
      // Only consider candidates <= default tileSizeQ.
      if (tileSizeQ > defaultTileSizeQ) {
        continue;
      }

      selectKernelParamsCopy.mTileSizeQ = tileSizeQ;
      if (tileSizeQ >= 64) {
        selectKernelParamsCopy.mKernelType = FmhaKernelType::KeepsMmaAbForGeneration;
      } else {
        selectKernelParamsCopy.mKernelType = FmhaKernelType::SwapsMmaAbForGeneration;
      }

      // Find kernel meta info without loading the kernel function
      auto [hashId, info] = hashFromRunnerParams(params, selectKernelParamsCopy);
      auto const findMetaIter = mKernelMetaMap.find(hashId);
      if (findMetaIter == mKernelMetaMap.end()) {
        continue;  // No kernel available for this tile size
      }
      auto const& kernelMeta = mKernelMeta[findMetaIter->second];

      // Compute the number of CTAs.
      computeCtaAndClusterConfig(ctaLaunchParams, params, kernelMeta, selectKernelParamsCopy);

      // If there are no free SMs, this tile size is not a good candidate.
      if (ctaLaunchParams.mNumCtasX * ctaLaunchParams.mNumCtasY * ctaLaunchParams.mNumCtasZ * 2 >
              params.mMultiProcessorCount &&
          tileSizeQ > 8) {  // allow smallest tile size to be selected even if it oversubscribes
        continue;
      }

      // Compute the seqLenPerCtaKv.
      int32_t seqLenPerCtaKv =
          flashinfer::ceil_div(flashinfer::ceil_div(params.mMaxSeqLenKv, kernelMeta.mStepKv),
                               ctaLaunchParams.mMaxNumCtasKv) *
          kernelMeta.mStepKv;

      // Compute the modeling kernel time = mainloop cost + reduction cost.
      float modelingKernelTime = kernelMainloopCost.at(tileSizeQ) * seqLenPerCtaKv +
                                 kernelReductionCost.at(tileSizeQ) * kernelReductionSeqLenFactor *
                                     ctaLaunchParams.mMaxNumCtasKv;

      // Compute the total number of CTAs.
      int32_t numCtas =
          ctaLaunchParams.mNumCtasX * ctaLaunchParams.mNumCtasY * ctaLaunchParams.mNumCtasZ;
      // Compute the number of waves.
      int32_t numWaves = flashinfer::ceil_div(numCtas, params.mMultiProcessorCount);
      // Compute the total modeling kernel time.
      modelingKernelTime *= numWaves;

      // If this candidate has a lower time than the global minimum, update the global minimum.
      if (modelingKernelTime < globalModelingKernelTime) {
        globalModelingKernelTime = modelingKernelTime;
        selectedTileSizeQ = tileSizeQ;
      }
    }

    // Update the tileSizeQ.
    selectKernelParams.mTileSizeQ = selectedTileSizeQ;
    // Update the kernel type.
    if (selectKernelParams.mTileSizeQ >= 64) {
      selectKernelParams.mKernelType = FmhaKernelType::KeepsMmaAbForGeneration;
    } else {
      selectKernelParams.mKernelType = FmhaKernelType::SwapsMmaAbForGeneration;
    }
  }

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It makes sense and I have fixed it.

}

int32_t ceilLog2(int32_t value) const {
return static_cast<int32_t>(std::ceil(std::log2(value)));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The ceilLog2 implementation uses floating-point math (std::log2), which can be less efficient and potentially have precision issues for integer inputs. A bitwise implementation would be more robust and performant, especially if this function is ever used in a performance-critical path. Consider replacing it with a portable integer-based implementation.

    if (value <= 0) return 0; // Match behavior of std::log2 for non-positive values
    int32_t result = 0;
    int32_t temp = value - 1;
    while (temp > 0) {
        temp >>= 1;
        result++;
    }
    return result;

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (2)
include/flashinfer/trtllm/fmha/fmhaKernels.cuh (2)

497-498: Remove extraneous semicolon.

There's a double semicolon on line 498 after the ceil_div call.

🔎 Proposed fix
-    int const maxNumCtasPerSeqKv = flashinfer::ceil_div(params.mMaxSeqLenKv, 256);
-    ;
+    int const maxNumCtasPerSeqKv = flashinfer::ceil_div(params.mMaxSeqLenKv, 256);

558-655: Cost model uses undocumented magic numbers.

The kernelMainloopCost and kernelReductionCost maps use empirical constants (2.2, 1.68, 1.48, etc.) without documentation on how they were derived. Consider adding a brief comment explaining these are benchmarked/profiled values or referencing the source.

📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 0ccf4e3 and c5b673a.

📒 Files selected for processing (6)
  • csrc/trtllm_fmha_kernel_launcher.cu
  • flashinfer/artifacts.py
  • include/flashinfer/trtllm/fmha/fmhaKernels.cuh
  • include/flashinfer/trtllm/fmha/fmhaRunnerParams.h
  • include/flashinfer/trtllm/fmha/kernelParams.h
  • tests/attention/test_trtllm_gen_attention.py
🧰 Additional context used
🧬 Code graph analysis (1)
include/flashinfer/trtllm/fmha/kernelParams.h (1)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/CommonUtils.h (1)
  • ceilDiv (42-44)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (15)
flashinfer/artifacts.py (1)

90-90: LGTM! Artifact path and checksum updated consistently.

The TRTLLM_GEN_FMHA artifact path and its corresponding checksum are updated together, which is correct. These changes align with the PR objectives to integrate optimized decode attention kernels, and the existing verify_cubin logic (line 224) will validate the checksum during artifact download.

Also applies to: 110-110

tests/attention/test_trtllm_gen_attention.py (1)

1282-1284: AI summary inconsistency: function name and parameter values are incorrect.

The AI summary states these changes are for test_trtllm_batch_decode_head_dim_256, but they are actually added to test_trtllm_batch_decode_long_sequence_length (line 1300). Additionally, the parameter interpretation is incorrect—the actual values are (batch_size=32, q_len_per_req=[4,8,16], page_size=16, num_kv_heads=2, head_grp_size=8), not what the summary describes.

The test additions themselves look good. They appropriately expand coverage for long-sequence scenarios with moderate batch sizes (32) and varying speculative decoding lengths (q_len_per_req of 4, 8, 16), which aligns well with the PR objectives to optimize for high-throughput workloads and speculative decoding.

include/flashinfer/trtllm/fmha/kernelParams.h (3)

282-301: LGTM – logic for grouping tokens and heads per CTA is clear.

The conditional handling for groupsTokensHeadsQ correctly computes numTokensPerCtaQ when grouping is enabled, with appropriate padding comments. The tuple return value now includes numTokensPerCtaQ for downstream use.


594-596: Helpful debugging guidance added.

The comment explaining that TMA descriptor errors may be caused by previous kernels and suggesting CUDA_LAUNCH_BLOCKING or cuda-gdb is useful for debugging.


801-803: Verify FastModDivInt32 handles mNumHeadsQPerKv = 1 safely.

Given the potential edge case with ceilLog2(1) mentioned earlier, ensure that when options.mNumHeadsQPerKv == 1, the FastModDivInt32 constructor doesn't produce invalid values.

csrc/trtllm_fmha_kernel_launcher.cu (1)

162-166: LGTM – mask type change with clear rationale.

The switch from Dense to Causal for generation is well-documented. The comment clarifies that kernel naming retains "dense" for performance reasons while the behavior is equivalent when each CTA processes a single tokenQ.

include/flashinfer/trtllm/fmha/fmhaRunnerParams.h (2)

351-361: New kernel selection parameters added.

The mNumTokensPerPage and mTileSizeQ fields extend the kernel selection API to support the new grouping and tile sizing logic. These integrate well with the updated hashID and kernel selection flow in fmhaKernels.cuh.


376-382: Constructor properly initializes new fields.

mNumTokensPerPage is propagated from input params, and mTileSizeQ/mTileSizeKv are initialized to 128, which is consistent with the largest supported tile size. The initialization order matches the declaration order.

include/flashinfer/trtllm/fmha/fmhaKernels.cuh (7)

54-62: LGTM – SM family/specific pair detection.

The isFamilySpecificSMPair helper correctly identifies when SM values are family/specific variants (e.g., kSM_100f with kSM_100 or kSM_103), enabling proper hash conflict resolution.


76-92: LGTM – CtaLaunchParams encapsulates launch configuration.

The new struct cleanly groups related launch parameters, improving code organization and reducing parameter passing overhead.


117-133: LGTM – Hash conflict resolution prefers specific SM.

The logic correctly handles hash conflicts between family and specific SM versions, preferring the specific version (e.g., kSM_100 or kSM_103 over kSM_100f).


399-402: LGTM – Factor of 2 for reduction overhead.

The comment clearly explains the factor of 2 is applied to balance the reduction overhead against mainloop benefits.


768-829: LGTM – On-demand kernel loading with caching.

The refactored loadKernel method properly caches modules and functions, handles shared memory configuration for large kernels (≥48KB), and includes helpful error messages.


334-345: LGTM – Generation kernel CTA handling for spec-decoding.

The logic correctly distinguishes between groupsTokensHeadsQ enabled/disabled scenarios, computing numCtasPerSeqQ appropriately for speculative decoding workloads.


583-590: Line 588 is not within the heuristic selection loop—it's the initial kernel load before the loop begins. The subsequent heuristic loop (lines 613+) only iterates over candidate tile sizes using the already-loaded kernelMeta; it does not call loadKernel again. If loadKernel fails due to a missing kernel, that's a legitimate error condition (kernel hash not found) that should be reported, not silently caught during heuristic probing.

@PerkzZheng PerkzZheng marked this pull request as draft December 24, 2025 15:12
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (2)
include/flashinfer/trtllm/fmha/kernelParams.h (2)

53-55: Consider replacing floating-point ceilLog2 with a bitwise implementation.

The current implementation uses std::log2, which is less efficient and may have precision issues for integer inputs. A bitwise approach would be more robust and performant.

🔎 Proposed bitwise implementation
  int32_t ceilLog2(int32_t value) const {
-   return static_cast<int32_t>(std::ceil(std::log2(value)));
+   if (value <= 1) return 0;
+   int32_t result = 0;
+   int32_t temp = value - 1;
+   while (temp > 0) {
+     temp >>= 1;
+     result++;
+   }
+   return result;
  }

Based on past review comments.


46-50: Critical: Edge case divisor == 1 leads to negative mShift causing undefined behavior.

When divisor = 1, ceilLog2(1) returns 0, so mShift = 0 - 1 = -1. Negative shift values cause undefined behavior in the multiplier calculation (uint64_t(1) << (32 + mShift) at line 49) and may break downstream usages of mShift.

🔎 Proposed fix
  FastModDivInt32(int32_t divisor) : mDivisor(divisor) {
+   if (divisor == 1) {
+     mShift = 0;
+     mMultiplier = 1;
+     return;
+   }
    mShift = ceilLog2(mDivisor) - 1;
    mMultiplier = static_cast<uint32_t>(
        flashinfer::ceil_div(uint64_t(1) << (32 + mShift), static_cast<uint64_t>(mDivisor)));
  }

Based on past review comments.

🧹 Nitpick comments (2)
include/flashinfer/trtllm/fmha/fmhaKernels.cuh (2)

559-665: The loop now correctly reloads kernel metadata for each candidate tileSizeQ.

The function addresses the prior review concern by:

  • Updating selectKernelParamsCopy.mTileSizeQ at line 618 for each candidate
  • Reloading the kernel with loadKernel at line 626 to get the correct kernelMeta
  • Recomputing CTA configuration at line 629 with the updated metadata

However, loadKernel (line 787) throws an exception if a kernel for a particular tileSizeQ doesn't exist. Consider wrapping the loadKernel call in a try-catch to gracefully skip unavailable candidates rather than aborting the entire selection.

🔎 Optional: graceful handling of missing kernels
    for (int tileSizeQ : candidateTileSizesQ) {
      // Only consider candidates <= default tileSizeQ.
      if (tileSizeQ > defaultTileSizeQ) {
        continue;
      }

      selectKernelParamsCopy.mTileSizeQ = tileSizeQ;
      if (tileSizeQ >= 64) {
        selectKernelParamsCopy.mKernelType = FmhaKernelType::KeepsMmaAbForGeneration;
      } else {
        selectKernelParamsCopy.mKernelType = FmhaKernelType::SwapsMmaAbForGeneration;
      }

-     // Load the kernel.
-     std::tie(func, kernelMeta) = loadKernel(params, selectKernelParamsCopy);
+     // Load the kernel. Skip if not available.
+     try {
+       std::tie(func, kernelMeta) = loadKernel(params, selectKernelParamsCopy);
+     } catch (...) {
+       continue;  // Skip candidates without available kernels
+     }

      // Compute the number of CTAs.
      computeCtaAndClusterConfig(ctaLaunchParams, params, kernelMeta, selectKernelParamsCopy);

496-498: Optional: Remove stray semicolon.

Line 498 has an extra semicolon (;;) after the maxNumCtasPerSeqKv calculation, which is harmless but reduces code cleanliness.

🔎 Proposed fix
    int const maxNumCtasPerSeqKv = flashinfer::ceil_div(params.mMaxSeqLenKv, 256);
-   ;
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between c5b673a and e7f9ba0.

📒 Files selected for processing (2)
  • include/flashinfer/trtllm/fmha/fmhaKernels.cuh
  • include/flashinfer/trtllm/fmha/kernelParams.h
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (16)
include/flashinfer/trtllm/fmha/kernelParams.h (6)

148-150: LGTM.

The new mInflateMax field is clearly documented and adds support for adjusting max value inflation during iterations.


173-174: Verify default initialization {1} given FastModDivInt32 edge case.

The default initialization FastModDivInt32 mNumHeadsQPerKvDivisor{1} triggers the divisor == 1 edge case flagged earlier (leading to mShift = -1). Ensure this default is intentional and that the critical fix for FastModDivInt32 constructor is applied to prevent undefined behavior.


205-207: LGTM.

The signature expansion to include groupsTokensHeadsQ aligns with the new per-CTA tokenization logic described in the PR summary.


296-296: LGTM.

The return tuple now includes numTokensPerCtaQ, consistent with the expanded per-CTA tokenization logic.


639-641: LGTM.

The call site correctly passes the new groupsTokensHeadsQ parameter and captures the expanded return tuple including numTokensPerCtaQ.


796-798: LGTM.

The new fields mNumHeadsQPerKvDivisor and mNumTokensPerCtaQ are correctly initialized from options and the computed numTokensPerCtaQ value.

include/flashinfer/trtllm/fmha/fmhaKernels.cuh (10)

54-62: LGTM.

The isFamilySpecificSMPair function correctly identifies when two SM values represent family/specific architecture pairs (e.g., SM_100f with SM_100 or SM_103), supporting graceful handling of Blackwell GPU variants.


76-93: LGTM.

The new CtaLaunchParams struct is a clean refactor that groups related kernel launch parameters, replacing multiple tuple returns with a more maintainable structure.


117-134: LGTM.

The hash conflict detection correctly allows only family/specific SM pairs to share a hash, and sensibly prefers specific SM versions (e.g., SM_100 over SM_100f) when both are present.


140-183: LGTM.

The expanded hashID signature and bit layout correctly incorporate the new kernel selection parameters (tileSizeQ, reuseSmemKForV, uses2CtaMma, sparseMla) with appropriate assertions and bit-packing.


511-556: LGTM.

The selectMlaGenerationKernel heuristic appropriately selects between low-latency (SwapsMmaAbForGeneration) and high-throughput (KeepsMmaAbForGeneration) kernels based on numHeadsQPerKv and GPU utilization, with clear logic for enabling 2-CTA MMA mode.


667-707: LGTM.

The selectGqGenerationKernel function uses clear threshold-based heuristics for selecting tileSizeQ and kernelType, and appropriately delegates to the cost-model-based selectTileSizeQForGqaGeneration when maxSeqLenQ > 1 for speculative decoding.


399-402: LGTM.

The factor-of-2 adjustment in maxNumCtasPerSeqKv is well-justified by the comment: it prevents splitting KV sequences so finely that reduction overhead exceeds mainloop speedup benefits.


477-484: LGTM.

The refactor to update CtaLaunchParams in place (lines 478–483) is clean and consistent with the new struct-based parameter passing introduced in this PR.


777-839: LGTM.

The loadKernel refactor implements clean on-demand kernel loading with two-level caching (modules and functions), appropriate shared memory configuration for large allocations (≥48KB), and clear error messages when kernels are not found.


204-314: LGTM.

The run() method correctly integrates the new CtaLaunchParams struct and on-demand loadKernel approach, replacing the previous tuple-based parameter passing with a cleaner, more maintainable flow.

Comment on lines 280 to 294
if (groupsHeadsQ) {
if (isSpecDecodingGenerationKernel(options.mKernelType)) {
FLASHINFER_CHECK((tileSizeQ % numGroupedHeads == 0), "internal error");
tileShapes = std::vector<uint32_t>{static_cast<uint32_t>(numEltsInClampedHeadDimQ),
static_cast<uint32_t>(numGroupedHeads), 1,
static_cast<uint32_t>(tileSizeQ / numGroupedHeads)};
if (groupsTokensHeadsQ) {
// Currently, it requires each CTA to process complete headsQ (i.e. numGroupedHeads) at a
// time, so it allows paddings in the end. Removing paddings needs re-organizing the Q
// tensor to [numTokensQ, numGroupedHeads, numHeads, headDimQ] and we might want to revisit
// this in the future.
numTokensPerCtaQ = static_cast<int32_t>(numTokensPerCtaQ / numGroupedHeads);
} else {
tileShapes = std::vector<uint32_t>{static_cast<uint32_t>(numEltsInClampedHeadDimQ),
static_cast<uint32_t>(tileSizeQ), 1, 1};
numGroupedHeads = tileSizeQ;
numTokensPerCtaQ = 1;
}
tileShapes = std::vector<uint32_t>{static_cast<uint32_t>(numEltsInClampedHeadDimQ),
static_cast<uint32_t>(numGroupedHeads), 1,
static_cast<uint32_t>(numTokensPerCtaQ)};
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

# Search for makeTmaShapeStrideQ definition and context
rg -nP 'makeTmaShapeStrideQ' --type=cpp -B5 -A20

Repository: flashinfer-ai/flashinfer

Length of output: 7590


🏁 Script executed:

# Search for tileSizeQ and numGroupedHeads initialization/assignment
rg -nP '\btileSizeQ\s*=' --type=cpp -B2 -A2 | head -50

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Look for the complete function containing this code
fd -e h | xargs rg -l 'groupsHeadsQ\|groupsTokensHeadsQ' | head -5

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Check for any assertions or validation around division
rg -nP 'numTokensPerCtaQ|numGroupedHeads' include/flashinfer/trtllm/fmha/kernelParams.h -B2 -A2

Repository: flashinfer-ai/flashinfer

Length of output: 2957


🏁 Script executed:

# Search for tileSizeQ/mTileSizeQ definition and constraints
rg -nP 'mTileSizeQ|tileSizeQ' --type=cpp -B3 -A3 | head -80

Repository: flashinfer-ai/flashinfer

Length of output: 5598


🏁 Script executed:

# Search for kernel metadata definition and mTileSizeQ initialization
fd -e h -e cpp | xargs rg -l 'mTileSizeQ' | head -5

Repository: flashinfer-ai/flashinfer

Length of output: 163


🏁 Script executed:

# Check if there are any constraints on tileSizeQ values in test files or kernel selection logic
rg -nP 'mTileSizeQ.*=' --type=cpp | head -30

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Search for kernel metadata setup and how mTileSizeQ gets assigned/constrained
rg -nP 'mGroupsTokensHeadsQ|groupsTokensHeadsQ' --type=cpp -B5 -A5 | head -80

Repository: flashinfer-ai/flashinfer

Length of output: 3587


🏁 Script executed:

# Look for any validation or constraints on tileSizeQ relative to mNumHeadsQPerKv
rg -nP 'numHeadsQPerKv|mNumHeadsQPerKv' --type=cpp | grep -i 'tile\|size' | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 593


🏁 Script executed:

# Check if there are tests or examples that show tileSizeQ and mNumHeadsQPerKv combinations
fd -e cu -e cpp | xargs rg -l 'mTileSizeQ.*=' | head -3

Repository: flashinfer-ai/flashinfer

Length of output: 89


🏁 Script executed:

# Search for where kernel metadata is created/selected and if there are constraints between tileSizeQ and mNumHeadsQPerKv
fd -e cpp -e cu | xargs rg -l 'mGroupsTokensHeadsQ.*=' | head -5

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Check kernel selection logic and metadata initialization
rg -nP 'KernelMeta|kernelMeta.*{' --type=cpp -A10 | head -100

Repository: flashinfer-ai/flashinfer

Length of output: 4231


🏁 Script executed:

# Look for any divmod, remainder, or modulo checks related to tileSizeQ or numGroupedHeads
rg -nP '%|divmod' include/flashinfer/trtllm/fmha/kernelParams.h -B3 -A3

Repository: flashinfer-ai/flashinfer

Length of output: 1465


Add a runtime assertion or comment to document the divisibility constraint at line 286.

At line 286, numTokensPerCtaQ is computed as tileSizeQ / numGroupedHeads and cast to int32_t, which silently truncates if the division is inexact. The comment acknowledges that padding is "currently required" and suggests revisiting this design. While the current implementation tolerates padding, there should be either:

  • An explicit FLASHINFER_CHECK(tileSizeQ % numGroupedHeads == 0, ...) to catch misconfigured kernel metadata, or
  • A clearer comment documenting that tileSizeQ must be divisible by numGroupedHeads when groupsTokensHeadsQ is true

This aligns with existing validation patterns in the codebase (e.g., line 225 for head counts, line 806 for sparse MLA top-k).

@PerkzZheng PerkzZheng marked this pull request as ready for review December 24, 2025 15:15
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
include/flashinfer/trtllm/fmha/kernelParams.h (1)

593-604: Out-of-bounds access when dim < 5 in error logging.

The error logging code accesses fixed indices [0-4] for shapes, tileShapes, and tileStrides, and [0-3] for stridesInBytes. However, dim can be 2–5 (per the check at line 555), so accessing index 4 when dim is 2, 3, or 4 is undefined behavior.

🔎 Proposed fix using a loop or conditional printing
-      std::cerr << "Shape: " << shapes[0] << " " << shapes[1] << " " << shapes[2] << " "
-                << shapes[3] << " " << shapes[4] << std::endl;
-      std::cerr << "Stride: " << stridesInBytes[0] << " " << stridesInBytes[1] << " "
-                << stridesInBytes[2] << " " << stridesInBytes[3] << std::endl;
-      std::cerr << "tileShapes: " << tileShapes[0] << " " << tileShapes[1] << " " << tileShapes[2]
-                << " " << tileShapes[3] << " " << tileShapes[4] << std::endl;
-      std::cerr << "tileStrides: " << tileStrides[0] << " " << tileStrides[1] << " "
-                << tileStrides[2] << " " << tileStrides[3] << " " << tileStrides[4] << std::endl;
+      std::cerr << "Shape:";
+      for (int32_t ii = 0; ii < dim; ++ii) std::cerr << " " << shapes[ii];
+      std::cerr << std::endl;
+      std::cerr << "Stride:";
+      for (int32_t ii = 0; ii < dim - 1; ++ii) std::cerr << " " << stridesInBytes[ii];
+      std::cerr << std::endl;
+      std::cerr << "tileShapes:";
+      for (int32_t ii = 0; ii < dim; ++ii) std::cerr << " " << tileShapes[ii];
+      std::cerr << std::endl;
+      std::cerr << "tileStrides:";
+      for (int32_t ii = 0; ii < dim; ++ii) std::cerr << " " << tileStrides[ii];
+      std::cerr << std::endl;
♻️ Duplicate comments (1)
include/flashinfer/trtllm/fmha/kernelParams.h (1)

280-294: Add a runtime assertion for the divisibility constraint.

The division at line 286 can silently truncate if tileSizeQ is not evenly divisible by numGroupedHeads. While the comment documents that padding is "currently required," adding an explicit FLASHINFER_CHECK would catch misconfigured kernel metadata early and align with validation patterns elsewhere in this file (e.g., line 225, line 806).

📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between e7f9ba0 and e2734cd.

📒 Files selected for processing (1)
  • include/flashinfer/trtllm/fmha/kernelParams.h
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.

Applied to files:

  • include/flashinfer/trtllm/fmha/kernelParams.h
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (7)
include/flashinfer/trtllm/fmha/kernelParams.h (7)

44-62: LGTM - FastModDivInt32 implementation correctly addresses prior feedback.

The edge case for divisor == 1 is now handled via std::max(ceilLog2(mDivisor) - 1, 0), and the code reuses flashinfer::ceil_div instead of a local duplicate. The mAdd field is correctly left at 0 per the CCCL fast modulo division algorithm for signed types.


148-150: LGTM - New inflation parameter is well-documented.

The mInflateMax field has a clear comment explaining its purpose. The field will be zero-initialized via memset in setKernelParams.


173-180: LGTM - New kernel parameters for token/head grouping.

The mNumHeadsQPerKvDivisor is correctly typed as FastModDivInt32{1} for fast modulo operations, and mNumTokensPerCtaQ supports the new grouping kernel feature.


205-207: LGTM - Clean API extension.

The new groupsTokensHeadsQ parameter cleanly extends the function signature to support the new token/head grouping mode.


589-591: Helpful debugging guidance added.

The updated comment clarifying that errors may originate from previous kernels and suggesting CUDA_LAUNCH_BLOCKING or cuda-gdb is a useful addition for debugging TMA initialization failures.


639-641: LGTM - Call site correctly updated.

The structured binding now captures the new numTokensPerCtaQ return value, and kernelMeta.mGroupsTokensHeadsQ is passed as the new parameter.


796-798: LGTM - Kernel parameters correctly assigned.

The mNumHeadsQPerKvDivisor is initialized from options.mNumHeadsQPerKv, and mNumTokensPerCtaQ is propagated from the makeTmaShapeStrideQ return value.

@yzh119
Copy link
Collaborator

yzh119 commented Dec 25, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !218 has been created, and the CI pipeline #40778397 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #40778397: 1/20 passed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants