-
Notifications
You must be signed in to change notification settings - Fork 621
feat: Add FP8/NVFP4 quant fusion for MNNVL Allreduce #2263
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
📝 WalkthroughWalkthroughThis pull request introduces quantization support to TRT-LLM MNNVL AllReduce fusion operations. Changes include a new Changes
Estimated Code Review Effort🎯 4 (Complex) | ⏱️ ~50 minutes Possibly Related PRs
Suggested Reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 inconclusive)
✅ Passed checks (2 passed)
✨ Finishing touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @timlee0212, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the MNNVL Allreduce functionality by integrating FP8 and NVFP4 quantization fusion. This allows for more memory-efficient and potentially faster distributed computations, especially when combined with RMSNorm. The changes involve updating core data structures, introducing a new set of quantization utility functions, and modifying existing CUDA kernels and dispatch mechanisms to support these new quantization types across all five unified allreduce interface fusion patterns. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds FP8/NVFP4 quantization fusion for MNNVL Allreduce. The changes are extensive and introduce new quantization logic. My review focuses on correctness and maintainability. I've identified a critical bug in the kernel dispatch logic that needs to be fixed. Additionally, there is significant code duplication that should be addressed by refactoring the shared code into a common header file to improve maintainability. The author has already noted these duplications with TODO comments, and I've formalized them as review comments.
| LAUNCH_ALLREDUCE_KERNEL(WORLD_SIZE, true, QuantType::kFP4); \ | ||
| break; \ | ||
| case QuantType::kNone: \ | ||
| LAUNCH_ALLREDUCE_KERNEL(WORLD_SIZE, false, QuantType::kNone); \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| // TODO: These code are shared with trtllm_allreduce_fusion.cuh, and moe_allreduce_fusion; Should we | ||
| // move them to a shared header? | ||
| namespace quant { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As pointed out in the TODO, this large block of code containing the quant namespace is duplicated across multiple files (trtllm_allreduce_fusion.cuh, moe_allreduce_fusion.cuh). This creates a significant maintenance burden. Please refactor this into a shared utility header (e.g., include/flashinfer/comm/quant_utils.cuh) and include it where needed. This will greatly improve code maintainability.
| // TODO: Same; This enum defination is duplicated | ||
| enum class QuantType : int { | ||
| kNone = 0, | ||
| kFP8 = 1, | ||
| kFP4 = 2, | ||
| }; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
🧹 Nitpick comments (2)
include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh (2)
40-45: Address the TODO: Consolidate duplicated enum definition.The TODO comment indicates this
QuantTypeenum is duplicated elsewhere. Consider moving it to a shared header to maintain a single source of truth and avoid potential divergence.Would you like me to help identify the other location(s) where this enum is defined and suggest a consolidation approach?
1745-1745: Minor: Inconsistent namingkELTS_SIZEvskELT_SIZEused elsewhere.This constant is named
kELTS_SIZEbut similar constants in other kernels (e.g., line 1178 inoneshotAllreduceFusionKernel) usekELT_SIZE. Consider using consistent naming across kernels for maintainability.
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (5)
include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh (5)
47-72: LGTM!The struct expansion is well-organized with sensible defaults. The explicit initialization of pointer fields to
nullptrand enum fields to their "off" states is good practice for optional fusion parameters.
922-959: LGTM!The SF offset calculation handles both batched and non-batched scenarios correctly. The architecture guard for SM100+ and graceful nullptr return for unsupported layouts are appropriate.
1162-1176: LGTM!The kernel template extension is well-designed. The
static_assertcorrectly enforces that quantization requires RMSNorm fusion, preventing invalid usage patterns at compile time.
1857-1927: LGTM!The two-shot dispatch macros correctly extend the existing pattern to support quantization. The switch-based dispatch on
QuantTypeand the macro cleanup with#undefare well-structured.
1140-1146: The reinterpret_cast at line 1141 is safe and correctly casts between compatible memory layouts.PackedVec<float4, half>(16 bytes: a union offloat4andhalf[8]) andvec_t<half, 8>(16 bytes: containingint4 data[1]) both occupy identical 128-bit memory with no padding, making direct reinterpretation valid.
| template <typename T> | ||
| __device__ inline T cuda_abs(T val) { | ||
| assert(false); | ||
| return {}; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
assert(false) is ineffective in release builds and returns uninitialized value.
The generic cuda_abs template uses assert(false) which is stripped in release builds. The function then returns an uninitialized/default-constructed value, leading to silent incorrect behavior. Consider using a static_assert with a dependent false expression instead.
🔎 Proposed fix
template <typename T>
__device__ inline T cuda_abs(T val) {
- assert(false);
- return {};
+ static_assert(sizeof(T) == 0, "cuda_abs not specialized for this type");
+ return {}; // Never reached, but needed for compilation
}| #define DISPATCH_ALLREDUCE_KERNEL(WORLD_SIZE) \ | ||
| if (params.rmsNormFusion) { \ | ||
| switch (params.quantType) { \ | ||
| case QuantType::kFP8: \ | ||
| LAUNCH_ALLREDUCE_KERNEL(WORLD_SIZE, true, QuantType::kFP8); \ | ||
| break; \ | ||
| case QuantType::kFP4: \ | ||
| LAUNCH_ALLREDUCE_KERNEL(WORLD_SIZE, true, QuantType::kFP4); \ | ||
| break; \ | ||
| case QuantType::kNone: \ | ||
| LAUNCH_ALLREDUCE_KERNEL(WORLD_SIZE, false, QuantType::kNone); \ | ||
| break; \ | ||
| default: \ | ||
| FLASHINFER_ERROR("Unsupported quant type! Got " + \ | ||
| std::to_string(static_cast<int>(params.quantType))); \ | ||
| return cudaErrorInvalidValue; \ | ||
| } \ | ||
| } else { \ | ||
| LAUNCH_ALLREDUCE_KERNEL(WORLD_SIZE, false, QuantType::kNone); \ | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bug: RMSNorm fusion disabled when quantType=kNone despite rmsNormFusion=true.
When params.rmsNormFusion is true but params.quantType is kNone, the dispatch incorrectly launches the kernel with RMSNORM=false (line 1384). This breaks the RMSNorm-only fusion pattern (pattern #2 in the unified interface).
🔎 Proposed fix
#define DISPATCH_ALLREDUCE_KERNEL(WORLD_SIZE) \
if (params.rmsNormFusion) { \
switch (params.quantType) { \
case QuantType::kFP8: \
LAUNCH_ALLREDUCE_KERNEL(WORLD_SIZE, true, QuantType::kFP8); \
break; \
case QuantType::kFP4: \
LAUNCH_ALLREDUCE_KERNEL(WORLD_SIZE, true, QuantType::kFP4); \
break; \
case QuantType::kNone: \
- LAUNCH_ALLREDUCE_KERNEL(WORLD_SIZE, false, QuantType::kNone); \
+ LAUNCH_ALLREDUCE_KERNEL(WORLD_SIZE, true, QuantType::kNone); \
break; \
default: \
FLASHINFER_ERROR("Unsupported quant type! Got " + \
std::to_string(static_cast<int>(params.quantType))); \
return cudaErrorInvalidValue; \
} \
} else { \
LAUNCH_ALLREDUCE_KERNEL(WORLD_SIZE, false, QuantType::kNone); \
}| if constexpr (QType == QuantType::kFP8) { | ||
| quant::quant_fp8<T, float4, kELTS_PER_LOAD>(r_out, quantOut, outputScale, threadOffset); | ||
| } else if constexpr (QType == QuantType::kFP4) { | ||
| quant::quant_nvfp4<T, float4, kELTS_PER_LOAD>(r_out, quantOut, scalingFactorOut, | ||
| outputScale, token, dim, | ||
| token_offset / kELTS_PER_LOAD, sfLayout); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bug: Incorrect offset passed to quant_fp8 - uses threadOffset instead of global offset.
The quant_fp8 call at line 1737 passes threadOffset (which is just threadIdx.x) as the output offset. This should be blockLoadOffset + threadLoadOffset to match the global position, similar to the outputNorm write at line 1734 and the quant_nvfp4 call which correctly computes its offset.
🔎 Proposed fix
if (outputNorm != nullptr) {
*reinterpret_cast<float4*>(&outputNorm[blockLoadOffset + threadLoadOffset]) = r_out.packed;
}
if constexpr (QType == QuantType::kFP8) {
- quant::quant_fp8<T, float4, kELTS_PER_LOAD>(r_out, quantOut, outputScale, threadOffset);
+ quant::quant_fp8<T, float4, kELTS_PER_LOAD>(r_out, quantOut, outputScale, blockLoadOffset + threadLoadOffset);
} else if constexpr (QType == QuantType::kFP4) {
quant::quant_nvfp4<T, float4, kELTS_PER_LOAD>(r_out, quantOut, scalingFactorOut,
outputScale, token, dim,
token_offset / kELTS_PER_LOAD, sfLayout);
}📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if constexpr (QType == QuantType::kFP8) { | |
| quant::quant_fp8<T, float4, kELTS_PER_LOAD>(r_out, quantOut, outputScale, threadOffset); | |
| } else if constexpr (QType == QuantType::kFP4) { | |
| quant::quant_nvfp4<T, float4, kELTS_PER_LOAD>(r_out, quantOut, scalingFactorOut, | |
| outputScale, token, dim, | |
| token_offset / kELTS_PER_LOAD, sfLayout); | |
| if constexpr (QType == QuantType::kFP8) { | |
| quant::quant_fp8<T, float4, kELTS_PER_LOAD>(r_out, quantOut, outputScale, blockLoadOffset + threadLoadOffset); | |
| } else if constexpr (QType == QuantType::kFP4) { | |
| quant::quant_nvfp4<T, float4, kELTS_PER_LOAD>(r_out, quantOut, scalingFactorOut, | |
| outputScale, token, dim, | |
| token_offset / kELTS_PER_LOAD, sfLayout); | |
| } |
🤖 Prompt for AI Agents
In include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh around lines 1736 to 1741,
the quant_fp8 call incorrectly passes threadOffset (threadIdx.x) as the output
offset instead of the global output offset; change the argument to use the
combined global offset (blockLoadOffset + threadLoadOffset) so quant_fp8 writes
to the same global position as outputNorm and quant_nvfp4, ensuring the offset
calculation mirrors the other calls.
|
Hi @timlee0212 is this PR ready? I noticed that you marked it as draft. |
No it's still WIP. Convert it to draft. |
📌 Description
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
✏️ Tip: You can customize this high-level summary in your review settings.