-
Notifications
You must be signed in to change notification settings - Fork 621
[TRTLLM-Gen Fmha] add optimized trtllm-gen decode kernels for high throughput + speculative decoding #2265
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. 📝 WalkthroughWalkthroughRefactors FMHA kernel selection/launch to on-demand cubin loading and new CtaLaunchParams, extends kernel hashing/selection surface and TMA token grouping logic, switches generation path maskType to Causal, updates artifact checksum/path, and adds head_dim=256 batch=32 tests. Changes
sequenceDiagram
autonumber
participant Host as Host (Runner)
participant Selector as SelectKernelParams
participant Loader as KernelLoader (module cache)
participant CUDA as CUDA Driver
participant GPU as GPU Kernel
Host->>Selector: prepare RunnerParams & selection params
Note right of Selector `#bfe9e0`: decide tileSizeQ, scheduler,\nmaskType, kernelType, numTokensPerPage
Selector->>Loader: request kernel (hash from params)
alt cached
Loader-->>Selector: return CUfunction + KernelMeta (cached)
else not cached
Loader->>CUDA: load module/cubin
CUDA-->>Loader: CUfunction + KernelMeta
Loader-->>Selector: return CUfunction + KernelMeta
end
Host->>CUDA: configure grid/cluster using CtaLaunchParams
CUDA->>GPU: launch kernel (CUfunction) with smem from KernelMeta
GPU-->>Host: execution complete / results
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related issues
Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches✅ Passed checks (3 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
📜 Recent review detailsConfiguration used: defaults Review profile: CHILL Plan: Pro 📒 Files selected for processing (1)
🧰 Additional context used🧠 Learnings (1)📚 Learning: 2025-11-12T03:35:17.583ZApplied to files:
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
🔇 Additional comments (7)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @PerkzZheng, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the performance of decode attention kernels, particularly for high-throughput scenarios involving large batch sizes and speculative decoding. By introducing optimized kernels and a refined dynamic kernel selection strategy, the changes aim to improve efficiency and reduce latency in generation-phase operations. The update also includes adjustments to mask handling, expanded kernel parameterization, and new test cases to ensure robustness and validate the reported speedups. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces significant optimizations for speculative decoding in TRT-LLM's FMHA kernels by adding support for grouping tokens and heads. The changes are extensive, involving a major refactoring of the kernel selection logic, hash calculation, and kernel parameter structures. The introduction of a more modular, heuristic-based kernel selection mechanism is a notable improvement. My review focuses on a potential bug in the new performance heuristic and a minor improvement in a math utility function.
| void selectTileSizeQForGqaGeneration(RunnerParams const& params, | ||
| SelectKernelParams& selectKernelParams) const { | ||
| // Define the per-tile mainloop cost model for different tileSizeQ choices. | ||
| std::unordered_map<int, float> kernelMainloopCost = { | ||
| {128, 2.2}, // Cost factor when tileSizeQ = 128 | ||
| {64, 1.68}, // Cost factor when tileSizeQ = 64 | ||
| {32, 1.48}, // Cost factor when tileSizeQ = 32 | ||
| {16, 1.2}, // Cost factor when tileSizeQ = 16 | ||
| {8, 1.0} // Cost factor when tileSizeQ = 8 | ||
| }; | ||
|
|
||
| // Define the per-tile reduction cost model for different tileSizeQ choices. | ||
| std::unordered_map<int, float> kernelReductionCost = { | ||
| {128, 1.32}, // Reduction cost factor when tileSizeQ = 128 | ||
| {64, 1.2}, // Reduction cost factor when tileSizeQ = 64 | ||
| {32, 1.08}, // Reduction cost factor when tileSizeQ = 32 | ||
| {16, 1.03}, // Reduction cost factor when tileSizeQ = 16 | ||
| {8, 1.0} // Reduction cost factor when tileSizeQ = 8 | ||
| }; | ||
|
|
||
| // The reduction cost emulated as a sequence length factor. | ||
| float const kernelReductionSeqLenFactor = 128.0f; | ||
|
|
||
| // The parameters for launching the kernel. | ||
| CtaLaunchParams ctaLaunchParams; | ||
| // The copy of the selectKernelParams, which makes sure it won't modify the original | ||
| // selectKernelParams when computing the number of CTAs. | ||
| SelectKernelParams selectKernelParamsCopy = selectKernelParams; | ||
| // Load the kernel. | ||
| auto [func, kernelMeta] = loadKernel(params, selectKernelParamsCopy); | ||
| // Compute numCtasX, numCtasY and numCtasZ. | ||
| computeCtaAndClusterConfig(ctaLaunchParams, params, kernelMeta, selectKernelParamsCopy); | ||
|
|
||
| // If there are no free SMs or tileSizeQ is already the smallest one, skip the heuristic | ||
| // selection. | ||
| if (ctaLaunchParams.mNumCtasX * ctaLaunchParams.mNumCtasY * ctaLaunchParams.mNumCtasZ * 2 > | ||
| params.mMultiProcessorCount || | ||
| selectKernelParamsCopy.mTileSizeQ <= 8) { | ||
| // No need to select the kernel further. | ||
| return; | ||
| } | ||
|
|
||
| // Candidate tile sizes for tileSizeQ to explore. | ||
| int const candidateTileSizesQ[] = {128, 64, 32, 16, 8}; | ||
|
|
||
| // The default tileSizeQ. | ||
| int defaultTileSizeQ = selectKernelParamsCopy.mTileSizeQ; | ||
| // The selected tileSizeQ. | ||
| int selectedTileSizeQ = selectKernelParamsCopy.mTileSizeQ; | ||
|
|
||
| // The minimum modeling kernel time. | ||
| float globalModelingKernelTime = FLT_MAX; | ||
| // Loop over each candidate tile size. | ||
| for (int tileSizeQ : candidateTileSizesQ) { | ||
| // Only consider candidates <= default tileSizeQ. | ||
| if (tileSizeQ > defaultTileSizeQ) { | ||
| continue; | ||
| } | ||
|
|
||
| // Compute the number of CTAs. | ||
| computeCtaAndClusterConfig(ctaLaunchParams, params, kernelMeta, selectKernelParamsCopy); | ||
|
|
||
| // Compute the seqLenPerCtaKv. | ||
| int32_t seqLenPerCtaKv = | ||
| flashinfer::ceil_div(flashinfer::ceil_div(params.mMaxSeqLenKv, kernelMeta.mStepKv), | ||
| ctaLaunchParams.mMaxNumCtasKv) * | ||
| kernelMeta.mStepKv; | ||
|
|
||
| // Compute the modeling kernel time = mainloop cost + reduction cost. | ||
| float modelingKernelTime = kernelMainloopCost[tileSizeQ] * seqLenPerCtaKv + | ||
| kernelReductionCost[tileSizeQ] * kernelReductionSeqLenFactor * | ||
| ctaLaunchParams.mMaxNumCtasKv; | ||
|
|
||
| // Compute the total number of CTAs. | ||
| int32_t numCtas = | ||
| ctaLaunchParams.mNumCtasX * ctaLaunchParams.mNumCtasY * ctaLaunchParams.mNumCtasZ; | ||
| // Compute the number of waves. | ||
| int32_t numWaves = flashinfer::ceil_div(numCtas, params.mMultiProcessorCount); | ||
| // Compute the total modeling kernel time. | ||
| modelingKernelTime *= numWaves; | ||
|
|
||
| // If this candidate has a lower time than the global minimum, update the global minimum. | ||
| if (modelingKernelTime < globalModelingKernelTime) { | ||
| globalModelingKernelTime = modelingKernelTime; | ||
| selectedTileSizeQ = tileSizeQ; | ||
| } | ||
| } | ||
|
|
||
| // Update the tileSizeQ. | ||
| selectKernelParams.mTileSizeQ = selectedTileSizeQ; | ||
| // Update the kernel type. | ||
| if (selectKernelParams.mTileSizeQ >= 64) { | ||
| selectKernelParams.mKernelType = FmhaKernelType::KeepsMmaAbForGeneration; | ||
| } else { | ||
| selectKernelParams.mKernelType = FmhaKernelType::SwapsMmaAbForGeneration; | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The heuristic for selecting tileSizeQ in selectTileSizeQForGqaGeneration appears to have a logical flaw. The kernelMeta and ctaLaunchParams are computed once before the loop that iterates through candidate tileSizeQ values. However, inside the loop, computeCtaAndClusterConfig is called again with the same stale kernelMeta, which means ctaLaunchParams is not updated based on the candidate tileSizeQ. This leads to an incorrect cost model evaluation, as the launch configuration (number of CTAs, waves, etc.) doesn't change with the tile size being evaluated.
To correctly evaluate each candidate tileSizeQ, the corresponding kernelMeta should be retrieved, and the launch parameters should be recomputed within the loop. This can be done efficiently by looking up the kernel metadata from mKernelMetaMap without fully reloading the kernel function in each iteration. I've provided a suggested refactoring of the function to address this.
void selectTileSizeQForGqaGeneration(RunnerParams const& params,
SelectKernelParams& selectKernelParams) const {
// Define the per-tile mainloop cost model for different tileSizeQ choices.
std::unordered_map<int, float> kernelMainloopCost = {
{128, 2.2}, // Cost factor when tileSizeQ = 128
{64, 1.68}, // Cost factor when tileSizeQ = 64
{32, 1.48}, // Cost factor when tileSizeQ = 32
{16, 1.2}, // Cost factor when tileSizeQ = 16
{8, 1.0} // Cost factor when tileSizeQ = 8
};
// Define the per-tile reduction cost model for different tileSizeQ choices.
std::unordered_map<int, float> kernelReductionCost = {
{128, 1.32}, // Reduction cost factor when tileSizeQ = 128
{64, 1.2}, // Reduction cost factor when tileSizeQ = 64
{32, 1.08}, // Reduction cost factor when tileSizeQ = 32
{16, 1.03}, // Reduction cost factor when tileSizeQ = 16
{8, 1.0} // Reduction cost factor when tileSizeQ = 8
};
// The reduction cost emulated as a sequence length factor.
float const kernelReductionSeqLenFactor = 128.0f;
// The parameters for launching the kernel.
CtaLaunchParams ctaLaunchParams;
// The copy of the selectKernelParams, which makes sure it won't modify the original
// selectKernelParams when computing the number of CTAs.
SelectKernelParams selectKernelParamsCopy = selectKernelParams;
// Candidate tile sizes for tileSizeQ to explore.
int const candidateTileSizesQ[] = {128, 64, 32, 16, 8};
// The default tileSizeQ.
int defaultTileSizeQ = selectKernelParamsCopy.mTileSizeQ;
// The selected tileSizeQ.
int selectedTileSizeQ = selectKernelParamsCopy.mTileSizeQ;
// The minimum modeling kernel time.
float globalModelingKernelTime = FLT_MAX;
// Loop over each candidate tile size.
for (int tileSizeQ : candidateTileSizesQ) {
// Only consider candidates <= default tileSizeQ.
if (tileSizeQ > defaultTileSizeQ) {
continue;
}
selectKernelParamsCopy.mTileSizeQ = tileSizeQ;
if (tileSizeQ >= 64) {
selectKernelParamsCopy.mKernelType = FmhaKernelType::KeepsMmaAbForGeneration;
} else {
selectKernelParamsCopy.mKernelType = FmhaKernelType::SwapsMmaAbForGeneration;
}
// Find kernel meta info without loading the kernel function
auto [hashId, info] = hashFromRunnerParams(params, selectKernelParamsCopy);
auto const findMetaIter = mKernelMetaMap.find(hashId);
if (findMetaIter == mKernelMetaMap.end()) {
continue; // No kernel available for this tile size
}
auto const& kernelMeta = mKernelMeta[findMetaIter->second];
// Compute the number of CTAs.
computeCtaAndClusterConfig(ctaLaunchParams, params, kernelMeta, selectKernelParamsCopy);
// If there are no free SMs, this tile size is not a good candidate.
if (ctaLaunchParams.mNumCtasX * ctaLaunchParams.mNumCtasY * ctaLaunchParams.mNumCtasZ * 2 >
params.mMultiProcessorCount &&
tileSizeQ > 8) { // allow smallest tile size to be selected even if it oversubscribes
continue;
}
// Compute the seqLenPerCtaKv.
int32_t seqLenPerCtaKv =
flashinfer::ceil_div(flashinfer::ceil_div(params.mMaxSeqLenKv, kernelMeta.mStepKv),
ctaLaunchParams.mMaxNumCtasKv) *
kernelMeta.mStepKv;
// Compute the modeling kernel time = mainloop cost + reduction cost.
float modelingKernelTime = kernelMainloopCost.at(tileSizeQ) * seqLenPerCtaKv +
kernelReductionCost.at(tileSizeQ) * kernelReductionSeqLenFactor *
ctaLaunchParams.mMaxNumCtasKv;
// Compute the total number of CTAs.
int32_t numCtas =
ctaLaunchParams.mNumCtasX * ctaLaunchParams.mNumCtasY * ctaLaunchParams.mNumCtasZ;
// Compute the number of waves.
int32_t numWaves = flashinfer::ceil_div(numCtas, params.mMultiProcessorCount);
// Compute the total modeling kernel time.
modelingKernelTime *= numWaves;
// If this candidate has a lower time than the global minimum, update the global minimum.
if (modelingKernelTime < globalModelingKernelTime) {
globalModelingKernelTime = modelingKernelTime;
selectedTileSizeQ = tileSizeQ;
}
}
// Update the tileSizeQ.
selectKernelParams.mTileSizeQ = selectedTileSizeQ;
// Update the kernel type.
if (selectKernelParams.mTileSizeQ >= 64) {
selectKernelParams.mKernelType = FmhaKernelType::KeepsMmaAbForGeneration;
} else {
selectKernelParams.mKernelType = FmhaKernelType::SwapsMmaAbForGeneration;
}
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It makes sense and I have fixed it.
| } | ||
|
|
||
| int32_t ceilLog2(int32_t value) const { | ||
| return static_cast<int32_t>(std::ceil(std::log2(value))); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The ceilLog2 implementation uses floating-point math (std::log2), which can be less efficient and potentially have precision issues for integer inputs. A bitwise implementation would be more robust and performant, especially if this function is ever used in a performance-critical path. Consider replacing it with a portable integer-based implementation.
if (value <= 0) return 0; // Match behavior of std::log2 for non-positive values
int32_t result = 0;
int32_t temp = value - 1;
while (temp > 0) {
temp >>= 1;
result++;
}
return result;There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
🧹 Nitpick comments (2)
include/flashinfer/trtllm/fmha/fmhaKernels.cuh (2)
497-498: Remove extraneous semicolon.There's a double semicolon on line 498 after the
ceil_divcall.🔎 Proposed fix
- int const maxNumCtasPerSeqKv = flashinfer::ceil_div(params.mMaxSeqLenKv, 256); - ; + int const maxNumCtasPerSeqKv = flashinfer::ceil_div(params.mMaxSeqLenKv, 256);
558-655: Cost model uses undocumented magic numbers.The
kernelMainloopCostandkernelReductionCostmaps use empirical constants (2.2, 1.68, 1.48, etc.) without documentation on how they were derived. Consider adding a brief comment explaining these are benchmarked/profiled values or referencing the source.
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
csrc/trtllm_fmha_kernel_launcher.cuflashinfer/artifacts.pyinclude/flashinfer/trtllm/fmha/fmhaKernels.cuhinclude/flashinfer/trtllm/fmha/fmhaRunnerParams.hinclude/flashinfer/trtllm/fmha/kernelParams.htests/attention/test_trtllm_gen_attention.py
🧰 Additional context used
🧬 Code graph analysis (1)
include/flashinfer/trtllm/fmha/kernelParams.h (1)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/CommonUtils.h (1)
ceilDiv(42-44)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (15)
flashinfer/artifacts.py (1)
90-90: LGTM! Artifact path and checksum updated consistently.The TRTLLM_GEN_FMHA artifact path and its corresponding checksum are updated together, which is correct. These changes align with the PR objectives to integrate optimized decode attention kernels, and the existing
verify_cubinlogic (line 224) will validate the checksum during artifact download.Also applies to: 110-110
tests/attention/test_trtllm_gen_attention.py (1)
1282-1284: AI summary inconsistency: function name and parameter values are incorrect.The AI summary states these changes are for
test_trtllm_batch_decode_head_dim_256, but they are actually added totest_trtllm_batch_decode_long_sequence_length(line 1300). Additionally, the parameter interpretation is incorrect—the actual values are(batch_size=32, q_len_per_req=[4,8,16], page_size=16, num_kv_heads=2, head_grp_size=8), not what the summary describes.The test additions themselves look good. They appropriately expand coverage for long-sequence scenarios with moderate batch sizes (32) and varying speculative decoding lengths (q_len_per_req of 4, 8, 16), which aligns well with the PR objectives to optimize for high-throughput workloads and speculative decoding.
include/flashinfer/trtllm/fmha/kernelParams.h (3)
282-301: LGTM – logic for grouping tokens and heads per CTA is clear.The conditional handling for
groupsTokensHeadsQcorrectly computesnumTokensPerCtaQwhen grouping is enabled, with appropriate padding comments. The tuple return value now includesnumTokensPerCtaQfor downstream use.
594-596: Helpful debugging guidance added.The comment explaining that TMA descriptor errors may be caused by previous kernels and suggesting
CUDA_LAUNCH_BLOCKINGorcuda-gdbis useful for debugging.
801-803: VerifyFastModDivInt32handlesmNumHeadsQPerKv = 1safely.Given the potential edge case with
ceilLog2(1)mentioned earlier, ensure that whenoptions.mNumHeadsQPerKv == 1, theFastModDivInt32constructor doesn't produce invalid values.csrc/trtllm_fmha_kernel_launcher.cu (1)
162-166: LGTM – mask type change with clear rationale.The switch from
DensetoCausalfor generation is well-documented. The comment clarifies that kernel naming retains "dense" for performance reasons while the behavior is equivalent when each CTA processes a single tokenQ.include/flashinfer/trtllm/fmha/fmhaRunnerParams.h (2)
351-361: New kernel selection parameters added.The
mNumTokensPerPageandmTileSizeQfields extend the kernel selection API to support the new grouping and tile sizing logic. These integrate well with the updatedhashIDand kernel selection flow infmhaKernels.cuh.
376-382: Constructor properly initializes new fields.
mNumTokensPerPageis propagated from input params, andmTileSizeQ/mTileSizeKvare initialized to 128, which is consistent with the largest supported tile size. The initialization order matches the declaration order.include/flashinfer/trtllm/fmha/fmhaKernels.cuh (7)
54-62: LGTM – SM family/specific pair detection.The
isFamilySpecificSMPairhelper correctly identifies when SM values are family/specific variants (e.g.,kSM_100fwithkSM_100orkSM_103), enabling proper hash conflict resolution.
76-92: LGTM –CtaLaunchParamsencapsulates launch configuration.The new struct cleanly groups related launch parameters, improving code organization and reducing parameter passing overhead.
117-133: LGTM – Hash conflict resolution prefers specific SM.The logic correctly handles hash conflicts between family and specific SM versions, preferring the specific version (e.g.,
kSM_100orkSM_103overkSM_100f).
399-402: LGTM – Factor of 2 for reduction overhead.The comment clearly explains the factor of 2 is applied to balance the reduction overhead against mainloop benefits.
768-829: LGTM – On-demand kernel loading with caching.The refactored
loadKernelmethod properly caches modules and functions, handles shared memory configuration for large kernels (≥48KB), and includes helpful error messages.
334-345: LGTM – Generation kernel CTA handling for spec-decoding.The logic correctly distinguishes between
groupsTokensHeadsQenabled/disabled scenarios, computingnumCtasPerSeqQappropriately for speculative decoding workloads.
583-590: Line 588 is not within the heuristic selection loop—it's the initial kernel load before the loop begins. The subsequent heuristic loop (lines 613+) only iterates over candidate tile sizes using the already-loaded kernelMeta; it does not callloadKernelagain. IfloadKernelfails due to a missing kernel, that's a legitimate error condition (kernel hash not found) that should be reported, not silently caught during heuristic probing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (2)
include/flashinfer/trtllm/fmha/kernelParams.h (2)
53-55: Consider replacing floating-pointceilLog2with a bitwise implementation.The current implementation uses
std::log2, which is less efficient and may have precision issues for integer inputs. A bitwise approach would be more robust and performant.🔎 Proposed bitwise implementation
int32_t ceilLog2(int32_t value) const { - return static_cast<int32_t>(std::ceil(std::log2(value))); + if (value <= 1) return 0; + int32_t result = 0; + int32_t temp = value - 1; + while (temp > 0) { + temp >>= 1; + result++; + } + return result; }Based on past review comments.
46-50: Critical: Edge casedivisor == 1leads to negativemShiftcausing undefined behavior.When
divisor = 1,ceilLog2(1)returns 0, somShift = 0 - 1 = -1. Negative shift values cause undefined behavior in the multiplier calculation (uint64_t(1) << (32 + mShift)at line 49) and may break downstream usages ofmShift.🔎 Proposed fix
FastModDivInt32(int32_t divisor) : mDivisor(divisor) { + if (divisor == 1) { + mShift = 0; + mMultiplier = 1; + return; + } mShift = ceilLog2(mDivisor) - 1; mMultiplier = static_cast<uint32_t>( flashinfer::ceil_div(uint64_t(1) << (32 + mShift), static_cast<uint64_t>(mDivisor))); }Based on past review comments.
🧹 Nitpick comments (2)
include/flashinfer/trtllm/fmha/fmhaKernels.cuh (2)
559-665: The loop now correctly reloads kernel metadata for each candidatetileSizeQ.The function addresses the prior review concern by:
- Updating
selectKernelParamsCopy.mTileSizeQat line 618 for each candidate- Reloading the kernel with
loadKernelat line 626 to get the correctkernelMeta- Recomputing CTA configuration at line 629 with the updated metadata
However,
loadKernel(line 787) throws an exception if a kernel for a particulartileSizeQdoesn't exist. Consider wrapping theloadKernelcall in a try-catch to gracefully skip unavailable candidates rather than aborting the entire selection.🔎 Optional: graceful handling of missing kernels
for (int tileSizeQ : candidateTileSizesQ) { // Only consider candidates <= default tileSizeQ. if (tileSizeQ > defaultTileSizeQ) { continue; } selectKernelParamsCopy.mTileSizeQ = tileSizeQ; if (tileSizeQ >= 64) { selectKernelParamsCopy.mKernelType = FmhaKernelType::KeepsMmaAbForGeneration; } else { selectKernelParamsCopy.mKernelType = FmhaKernelType::SwapsMmaAbForGeneration; } - // Load the kernel. - std::tie(func, kernelMeta) = loadKernel(params, selectKernelParamsCopy); + // Load the kernel. Skip if not available. + try { + std::tie(func, kernelMeta) = loadKernel(params, selectKernelParamsCopy); + } catch (...) { + continue; // Skip candidates without available kernels + } // Compute the number of CTAs. computeCtaAndClusterConfig(ctaLaunchParams, params, kernelMeta, selectKernelParamsCopy);
496-498: Optional: Remove stray semicolon.Line 498 has an extra semicolon (
;;) after themaxNumCtasPerSeqKvcalculation, which is harmless but reduces code cleanliness.🔎 Proposed fix
int const maxNumCtasPerSeqKv = flashinfer::ceil_div(params.mMaxSeqLenKv, 256); - ;
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
include/flashinfer/trtllm/fmha/fmhaKernels.cuhinclude/flashinfer/trtllm/fmha/kernelParams.h
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (16)
include/flashinfer/trtllm/fmha/kernelParams.h (6)
148-150: LGTM.The new
mInflateMaxfield is clearly documented and adds support for adjusting max value inflation during iterations.
173-174: Verify default initialization{1}givenFastModDivInt32edge case.The default initialization
FastModDivInt32 mNumHeadsQPerKvDivisor{1}triggers thedivisor == 1edge case flagged earlier (leading tomShift = -1). Ensure this default is intentional and that the critical fix forFastModDivInt32constructor is applied to prevent undefined behavior.
205-207: LGTM.The signature expansion to include
groupsTokensHeadsQaligns with the new per-CTA tokenization logic described in the PR summary.
296-296: LGTM.The return tuple now includes
numTokensPerCtaQ, consistent with the expanded per-CTA tokenization logic.
639-641: LGTM.The call site correctly passes the new
groupsTokensHeadsQparameter and captures the expanded return tuple includingnumTokensPerCtaQ.
796-798: LGTM.The new fields
mNumHeadsQPerKvDivisorandmNumTokensPerCtaQare correctly initialized fromoptionsand the computednumTokensPerCtaQvalue.include/flashinfer/trtllm/fmha/fmhaKernels.cuh (10)
54-62: LGTM.The
isFamilySpecificSMPairfunction correctly identifies when two SM values represent family/specific architecture pairs (e.g., SM_100f with SM_100 or SM_103), supporting graceful handling of Blackwell GPU variants.
76-93: LGTM.The new
CtaLaunchParamsstruct is a clean refactor that groups related kernel launch parameters, replacing multiple tuple returns with a more maintainable structure.
117-134: LGTM.The hash conflict detection correctly allows only family/specific SM pairs to share a hash, and sensibly prefers specific SM versions (e.g., SM_100 over SM_100f) when both are present.
140-183: LGTM.The expanded
hashIDsignature and bit layout correctly incorporate the new kernel selection parameters (tileSizeQ,reuseSmemKForV,uses2CtaMma,sparseMla) with appropriate assertions and bit-packing.
511-556: LGTM.The
selectMlaGenerationKernelheuristic appropriately selects between low-latency (SwapsMmaAbForGeneration) and high-throughput (KeepsMmaAbForGeneration) kernels based onnumHeadsQPerKvand GPU utilization, with clear logic for enabling 2-CTA MMA mode.
667-707: LGTM.The
selectGqGenerationKernelfunction uses clear threshold-based heuristics for selectingtileSizeQandkernelType, and appropriately delegates to the cost-model-basedselectTileSizeQForGqaGenerationwhenmaxSeqLenQ > 1for speculative decoding.
399-402: LGTM.The factor-of-2 adjustment in
maxNumCtasPerSeqKvis well-justified by the comment: it prevents splitting KV sequences so finely that reduction overhead exceeds mainloop speedup benefits.
477-484: LGTM.The refactor to update
CtaLaunchParamsin place (lines 478–483) is clean and consistent with the new struct-based parameter passing introduced in this PR.
777-839: LGTM.The
loadKernelrefactor implements clean on-demand kernel loading with two-level caching (modules and functions), appropriate shared memory configuration for large allocations (≥48KB), and clear error messages when kernels are not found.
204-314: LGTM.The
run()method correctly integrates the newCtaLaunchParamsstruct and on-demandloadKernelapproach, replacing the previous tuple-based parameter passing with a cleaner, more maintainable flow.
| if (groupsHeadsQ) { | ||
| if (isSpecDecodingGenerationKernel(options.mKernelType)) { | ||
| FLASHINFER_CHECK((tileSizeQ % numGroupedHeads == 0), "internal error"); | ||
| tileShapes = std::vector<uint32_t>{static_cast<uint32_t>(numEltsInClampedHeadDimQ), | ||
| static_cast<uint32_t>(numGroupedHeads), 1, | ||
| static_cast<uint32_t>(tileSizeQ / numGroupedHeads)}; | ||
| if (groupsTokensHeadsQ) { | ||
| // Currently, it requires each CTA to process complete headsQ (i.e. numGroupedHeads) at a | ||
| // time, so it allows paddings in the end. Removing paddings needs re-organizing the Q | ||
| // tensor to [numTokensQ, numGroupedHeads, numHeads, headDimQ] and we might want to revisit | ||
| // this in the future. | ||
| numTokensPerCtaQ = static_cast<int32_t>(numTokensPerCtaQ / numGroupedHeads); | ||
| } else { | ||
| tileShapes = std::vector<uint32_t>{static_cast<uint32_t>(numEltsInClampedHeadDimQ), | ||
| static_cast<uint32_t>(tileSizeQ), 1, 1}; | ||
| numGroupedHeads = tileSizeQ; | ||
| numTokensPerCtaQ = 1; | ||
| } | ||
| tileShapes = std::vector<uint32_t>{static_cast<uint32_t>(numEltsInClampedHeadDimQ), | ||
| static_cast<uint32_t>(numGroupedHeads), 1, | ||
| static_cast<uint32_t>(numTokensPerCtaQ)}; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
# Search for makeTmaShapeStrideQ definition and context
rg -nP 'makeTmaShapeStrideQ' --type=cpp -B5 -A20Repository: flashinfer-ai/flashinfer
Length of output: 7590
🏁 Script executed:
# Search for tileSizeQ and numGroupedHeads initialization/assignment
rg -nP '\btileSizeQ\s*=' --type=cpp -B2 -A2 | head -50Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Look for the complete function containing this code
fd -e h | xargs rg -l 'groupsHeadsQ\|groupsTokensHeadsQ' | head -5Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Check for any assertions or validation around division
rg -nP 'numTokensPerCtaQ|numGroupedHeads' include/flashinfer/trtllm/fmha/kernelParams.h -B2 -A2Repository: flashinfer-ai/flashinfer
Length of output: 2957
🏁 Script executed:
# Search for tileSizeQ/mTileSizeQ definition and constraints
rg -nP 'mTileSizeQ|tileSizeQ' --type=cpp -B3 -A3 | head -80Repository: flashinfer-ai/flashinfer
Length of output: 5598
🏁 Script executed:
# Search for kernel metadata definition and mTileSizeQ initialization
fd -e h -e cpp | xargs rg -l 'mTileSizeQ' | head -5Repository: flashinfer-ai/flashinfer
Length of output: 163
🏁 Script executed:
# Check if there are any constraints on tileSizeQ values in test files or kernel selection logic
rg -nP 'mTileSizeQ.*=' --type=cpp | head -30Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Search for kernel metadata setup and how mTileSizeQ gets assigned/constrained
rg -nP 'mGroupsTokensHeadsQ|groupsTokensHeadsQ' --type=cpp -B5 -A5 | head -80Repository: flashinfer-ai/flashinfer
Length of output: 3587
🏁 Script executed:
# Look for any validation or constraints on tileSizeQ relative to mNumHeadsQPerKv
rg -nP 'numHeadsQPerKv|mNumHeadsQPerKv' --type=cpp | grep -i 'tile\|size' | head -20Repository: flashinfer-ai/flashinfer
Length of output: 593
🏁 Script executed:
# Check if there are tests or examples that show tileSizeQ and mNumHeadsQPerKv combinations
fd -e cu -e cpp | xargs rg -l 'mTileSizeQ.*=' | head -3Repository: flashinfer-ai/flashinfer
Length of output: 89
🏁 Script executed:
# Search for where kernel metadata is created/selected and if there are constraints between tileSizeQ and mNumHeadsQPerKv
fd -e cpp -e cu | xargs rg -l 'mGroupsTokensHeadsQ.*=' | head -5Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Check kernel selection logic and metadata initialization
rg -nP 'KernelMeta|kernelMeta.*{' --type=cpp -A10 | head -100Repository: flashinfer-ai/flashinfer
Length of output: 4231
🏁 Script executed:
# Look for any divmod, remainder, or modulo checks related to tileSizeQ or numGroupedHeads
rg -nP '%|divmod' include/flashinfer/trtllm/fmha/kernelParams.h -B3 -A3Repository: flashinfer-ai/flashinfer
Length of output: 1465
Add a runtime assertion or comment to document the divisibility constraint at line 286.
At line 286, numTokensPerCtaQ is computed as tileSizeQ / numGroupedHeads and cast to int32_t, which silently truncates if the division is inexact. The comment acknowledges that padding is "currently required" and suggests revisiting this design. While the current implementation tolerates padding, there should be either:
- An explicit
FLASHINFER_CHECK(tileSizeQ % numGroupedHeads == 0, ...)to catch misconfigured kernel metadata, or - A clearer comment documenting that
tileSizeQmust be divisible bynumGroupedHeadswhengroupsTokensHeadsQis true
This aligns with existing validation patterns in the codebase (e.g., line 225 for head counts, line 806 for sparse MLA top-k).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
include/flashinfer/trtllm/fmha/kernelParams.h (1)
593-604: Out-of-bounds access whendim < 5in error logging.The error logging code accesses fixed indices
[0-4]forshapes,tileShapes, andtileStrides, and[0-3]forstridesInBytes. However,dimcan be 2–5 (per the check at line 555), so accessing index 4 whendimis 2, 3, or 4 is undefined behavior.🔎 Proposed fix using a loop or conditional printing
- std::cerr << "Shape: " << shapes[0] << " " << shapes[1] << " " << shapes[2] << " " - << shapes[3] << " " << shapes[4] << std::endl; - std::cerr << "Stride: " << stridesInBytes[0] << " " << stridesInBytes[1] << " " - << stridesInBytes[2] << " " << stridesInBytes[3] << std::endl; - std::cerr << "tileShapes: " << tileShapes[0] << " " << tileShapes[1] << " " << tileShapes[2] - << " " << tileShapes[3] << " " << tileShapes[4] << std::endl; - std::cerr << "tileStrides: " << tileStrides[0] << " " << tileStrides[1] << " " - << tileStrides[2] << " " << tileStrides[3] << " " << tileStrides[4] << std::endl; + std::cerr << "Shape:"; + for (int32_t ii = 0; ii < dim; ++ii) std::cerr << " " << shapes[ii]; + std::cerr << std::endl; + std::cerr << "Stride:"; + for (int32_t ii = 0; ii < dim - 1; ++ii) std::cerr << " " << stridesInBytes[ii]; + std::cerr << std::endl; + std::cerr << "tileShapes:"; + for (int32_t ii = 0; ii < dim; ++ii) std::cerr << " " << tileShapes[ii]; + std::cerr << std::endl; + std::cerr << "tileStrides:"; + for (int32_t ii = 0; ii < dim; ++ii) std::cerr << " " << tileStrides[ii]; + std::cerr << std::endl;
♻️ Duplicate comments (1)
include/flashinfer/trtllm/fmha/kernelParams.h (1)
280-294: Add a runtime assertion for the divisibility constraint.The division at line 286 can silently truncate if
tileSizeQis not evenly divisible bynumGroupedHeads. While the comment documents that padding is "currently required," adding an explicitFLASHINFER_CHECKwould catch misconfigured kernel metadata early and align with validation patterns elsewhere in this file (e.g., line 225, line 806).
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
include/flashinfer/trtllm/fmha/kernelParams.h
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
Applied to files:
include/flashinfer/trtllm/fmha/kernelParams.h
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (7)
include/flashinfer/trtllm/fmha/kernelParams.h (7)
44-62: LGTM - FastModDivInt32 implementation correctly addresses prior feedback.The edge case for
divisor == 1is now handled viastd::max(ceilLog2(mDivisor) - 1, 0), and the code reusesflashinfer::ceil_divinstead of a local duplicate. ThemAddfield is correctly left at 0 per the CCCL fast modulo division algorithm for signed types.
148-150: LGTM - New inflation parameter is well-documented.The
mInflateMaxfield has a clear comment explaining its purpose. The field will be zero-initialized viamemsetinsetKernelParams.
173-180: LGTM - New kernel parameters for token/head grouping.The
mNumHeadsQPerKvDivisoris correctly typed asFastModDivInt32{1}for fast modulo operations, andmNumTokensPerCtaQsupports the new grouping kernel feature.
205-207: LGTM - Clean API extension.The new
groupsTokensHeadsQparameter cleanly extends the function signature to support the new token/head grouping mode.
589-591: Helpful debugging guidance added.The updated comment clarifying that errors may originate from previous kernels and suggesting
CUDA_LAUNCH_BLOCKINGorcuda-gdbis a useful addition for debugging TMA initialization failures.
639-641: LGTM - Call site correctly updated.The structured binding now captures the new
numTokensPerCtaQreturn value, andkernelMeta.mGroupsTokensHeadsQis passed as the new parameter.
796-798: LGTM - Kernel parameters correctly assigned.The
mNumHeadsQPerKvDivisoris initialized fromoptions.mNumHeadsQPerKv, andmNumTokensPerCtaQis propagated from themakeTmaShapeStrideQreturn value.
|
/bot run |
|
[FAILED] Pipeline #40778397: 1/20 passed |
📌 Description
This MR adds the optimized decode attention kernels for high throughput (large batch size) + speculative decoding (seqlen_q > 1).
See below for speedups (collected by
benchmarks/flashinfer_benchmark.py). The seqlenKv is 16K for all cases.🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
Bug Fixes
Performance / Refactor
Chores
Tests
✏️ Tip: You can customize this high-level summary in your review settings.