Skip to content

Conversation

@timlee0212
Copy link
Contributor

@timlee0212 timlee0212 commented Dec 24, 2025

📌 Description

  • Add FP8/NVFP4 quant fusion to MNNVL Allreduce
  • Support all 5 fusion patterns defined in the unified allreduce interface.

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features
    • Added FP8 and FP4 quantization type support for all-reduce fusion operations.
    • Extended all-reduce fusion kernel with quantization output capabilities.
    • Enhanced RMSNorm kernel with quantization fusion support.
    • Added quantization utility functions for type conversions between FP8, FP4, FP16, and BF16 formats.

✏️ Tip: You can customize this high-level summary in your review settings.

@timlee0212 timlee0212 marked this pull request as draft December 24, 2025 05:09
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Dec 24, 2025

📝 Walkthrough

Walkthrough

This pull request introduces quantization support to TRT-LLM MNNVL AllReduce fusion operations. Changes include a new QuantType enum, expanded AllReduceFusionParams with quantization fields, template-based kernel extensions for FP8/FP4 quantization paths, RMSNorm fusion variants, and comprehensive quantization utility functions for type conversions and scale factor handling.

Changes

Cohort / File(s) Summary
Quantization Infrastructure
include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh
Added QuantType enum (kNone, kFP8, kFP4); introduced quant::details and quant::maths namespaces with type casts, conversions (cuda_cast, abs, max, reciprocal) between FP8/FP4/FP16/BF16; implemented fp32_vec_to_e2m1, cvt_warp_fp16_to_fp4, quant_fp8, and quant_nvfp4 functions; added cvt_quant_to_fp4_get_sf_out_offset for scale factor offset computation.
Data Structures & Parameters
include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh
Updated AllReduceFusionParams struct with quantization support: added outputScale, sfLayout, quantType, quantOut, scalingFactorOut, residualOut, output pointers; added residualIn, gamma, epsilon, rmsNormFusion, launchWithPdl fields with appropriate defaults.
Kernel Signatures & Dispatch Logic
include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh
Extended oneshotAllreduceFusionKernel with template parameter QuantType QType (default kNone) and quantization output pointers; updated LAUNCH_ALLREDUCE_KERNEL and DISPATCH_ALLREDUCE_KERNEL macros to route based on quantization type; updated oneshotAllreduceFusionDispatch and twoshotAllreduceFusionDispatch to support quantization-enabled fusion paths.
RMSNorm Fusion Kernels
include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh
Introduced rmsNormLamport_fusion<T, QuantType, LoadsPerThread> template extending previous rmsNormLamport with quantization support; accepts quantOut, scalingFactorOut, and sfLayout parameters; includes logic to dispatch to appropriate quantization kernels (FP8/FP4).
Internal Refactoring & Headers
include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh
Replaced T_IN/T_OUT type usage with unified T where applicable; adjusted load/store paths for PackedType handling; added new header includes (fp4_layout.cuh, vec_dtypes.cuh); added utility aliases (PackedVec, toFloat); guarded CUDA architecture requirements for FP8/FP4 features.

Estimated Code Review Effort

🎯 4 (Complex) | ⏱️ ~50 minutes

Possibly Related PRs

Suggested Reviewers

  • cyx-6
  • nvmbreughe
  • djmmoss
  • yzh119

Poem

🐰 A hoppy refactor, quantization takes flight!
FP8 and FP4 pack data oh so tight,
RMSNorm fusion dances in the GPU light,
AllReduce kernels scale to the height,
Templates and macros make everything right! 🚀

Pre-merge checks and finishing touches

❌ Failed checks (1 inconclusive)
Check name Status Explanation Resolution
Description check ❓ Inconclusive The description addresses the main objective and mentions support for all five fusion patterns, but lacks critical details on testing, related issues, and pre-commit validation status. Complete the checklist by marking items as done/not applicable, link any related issues, and add test details or notes on test coverage for the new quantization fusion features.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and concisely describes the main change: adding FP8/NVFP4 quantization fusion support to MNNVL Allreduce, which directly matches the primary focus of the changeset.
Docstring Coverage ✅ Passed No functions found in the changed files to evaluate docstring coverage. Skipping docstring coverage check.
✨ Finishing touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @timlee0212, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the MNNVL Allreduce functionality by integrating FP8 and NVFP4 quantization fusion. This allows for more memory-efficient and potentially faster distributed computations, especially when combined with RMSNorm. The changes involve updating core data structures, introducing a new set of quantization utility functions, and modifying existing CUDA kernels and dispatch mechanisms to support these new quantization types across all five unified allreduce interface fusion patterns.

Highlights

  • FP8/NVFP4 Quantization Fusion: Introduced support for FP8 and NVFP4 quantization directly into the MNNVL Allreduce operation, enabling more efficient data handling.
  • Unified Interface Support: The implementation now supports all five fusion patterns defined in the unified allreduce interface, enhancing versatility.
  • Quantization Parameters Extension: Extended the AllReduceFusionParams struct to include new fields for managing quantization outputs, scaling factors, and layout configurations.
  • Dedicated Quantization Functions: Added a new quant namespace with specialized CUDA device functions for performing FP8 and NVFP4 conversions, scaling, and other related mathematical operations.
  • Kernel Integration: Modified existing Allreduce and RMSNorm kernels (oneshotAllreduceFusionKernel, rmsNormLamport_fusion) to seamlessly incorporate the new quantization logic and parameters.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds FP8/NVFP4 quantization fusion for MNNVL Allreduce. The changes are extensive and introduce new quantization logic. My review focuses on correctness and maintainability. I've identified a critical bug in the kernel dispatch logic that needs to be fixed. Additionally, there is significant code duplication that should be addressed by refactoring the shared code into a common header file to improve maintainability. The author has already noted these duplications with TODO comments, and I've formalized them as review comments.

LAUNCH_ALLREDUCE_KERNEL(WORLD_SIZE, true, QuantType::kFP4); \
break; \
case QuantType::kNone: \
LAUNCH_ALLREDUCE_KERNEL(WORLD_SIZE, false, QuantType::kNone); \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There appears to be a bug here. When params.rmsNormFusion is true and params.quantType is QuantType::kNone, the kernel is launched with RMSNORM=false. It should be launched with RMSNORM=true to correctly apply the RMS normalization.

        LAUNCH_ALLREDUCE_KERNEL(WORLD_SIZE, true, QuantType::kNone);           \

Comment on lines +520 to +522
// TODO: These code are shared with trtllm_allreduce_fusion.cuh, and moe_allreduce_fusion; Should we
// move them to a shared header?
namespace quant {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

As pointed out in the TODO, this large block of code containing the quant namespace is duplicated across multiple files (trtllm_allreduce_fusion.cuh, moe_allreduce_fusion.cuh). This creates a significant maintenance burden. Please refactor this into a shared utility header (e.g., include/flashinfer/comm/quant_utils.cuh) and include it where needed. This will greatly improve code maintainability.

Comment on lines +40 to +45
// TODO: Same; This enum defination is duplicated
enum class QuantType : int {
kNone = 0,
kFP8 = 1,
kFP4 = 2,
};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

As noted in the TODO, this QuantType enum is duplicated. It also exists in include/flashinfer/comm/trtllm_allreduce_fusion.cuh. To improve maintainability and prevent inconsistencies, please define this enum in a shared header file and include it where needed.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (2)
include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh (2)

40-45: Address the TODO: Consolidate duplicated enum definition.

The TODO comment indicates this QuantType enum is duplicated elsewhere. Consider moving it to a shared header to maintain a single source of truth and avoid potential divergence.

Would you like me to help identify the other location(s) where this enum is defined and suggest a consolidation approach?


1745-1745: Minor: Inconsistent naming kELTS_SIZE vs kELT_SIZE used elsewhere.

This constant is named kELTS_SIZE but similar constants in other kernels (e.g., line 1178 in oneshotAllreduceFusionKernel) use kELT_SIZE. Consider using consistent naming across kernels for maintainability.

📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 25de38e and 4a4d184.

📒 Files selected for processing (1)
  • include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (5)
include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh (5)

47-72: LGTM!

The struct expansion is well-organized with sensible defaults. The explicit initialization of pointer fields to nullptr and enum fields to their "off" states is good practice for optional fusion parameters.


922-959: LGTM!

The SF offset calculation handles both batched and non-batched scenarios correctly. The architecture guard for SM100+ and graceful nullptr return for unsupported layouts are appropriate.


1162-1176: LGTM!

The kernel template extension is well-designed. The static_assert correctly enforces that quantization requires RMSNorm fusion, preventing invalid usage patterns at compile time.


1857-1927: LGTM!

The two-shot dispatch macros correctly extend the existing pattern to support quantization. The switch-based dispatch on QuantType and the macro cleanup with #undef are well-structured.


1140-1146: The reinterpret_cast at line 1141 is safe and correctly casts between compatible memory layouts. PackedVec<float4, half> (16 bytes: a union of float4 and half[8]) and vec_t<half, 8> (16 bytes: containing int4 data[1]) both occupy identical 128-bit memory with no padding, making direct reinterpretation valid.

Comment on lines +784 to +788
template <typename T>
__device__ inline T cuda_abs(T val) {
assert(false);
return {};
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

assert(false) is ineffective in release builds and returns uninitialized value.

The generic cuda_abs template uses assert(false) which is stripped in release builds. The function then returns an uninitialized/default-constructed value, leading to silent incorrect behavior. Consider using a static_assert with a dependent false expression instead.

🔎 Proposed fix
 template <typename T>
 __device__ inline T cuda_abs(T val) {
-  assert(false);
-  return {};
+  static_assert(sizeof(T) == 0, "cuda_abs not specialized for this type");
+  return {};  // Never reached, but needed for compilation
 }

Comment on lines +1374 to 1393
#define DISPATCH_ALLREDUCE_KERNEL(WORLD_SIZE) \
if (params.rmsNormFusion) { \
switch (params.quantType) { \
case QuantType::kFP8: \
LAUNCH_ALLREDUCE_KERNEL(WORLD_SIZE, true, QuantType::kFP8); \
break; \
case QuantType::kFP4: \
LAUNCH_ALLREDUCE_KERNEL(WORLD_SIZE, true, QuantType::kFP4); \
break; \
case QuantType::kNone: \
LAUNCH_ALLREDUCE_KERNEL(WORLD_SIZE, false, QuantType::kNone); \
break; \
default: \
FLASHINFER_ERROR("Unsupported quant type! Got " + \
std::to_string(static_cast<int>(params.quantType))); \
return cudaErrorInvalidValue; \
} \
} else { \
LAUNCH_ALLREDUCE_KERNEL(WORLD_SIZE, false, QuantType::kNone); \
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Bug: RMSNorm fusion disabled when quantType=kNone despite rmsNormFusion=true.

When params.rmsNormFusion is true but params.quantType is kNone, the dispatch incorrectly launches the kernel with RMSNORM=false (line 1384). This breaks the RMSNorm-only fusion pattern (pattern #2 in the unified interface).

🔎 Proposed fix
 #define DISPATCH_ALLREDUCE_KERNEL(WORLD_SIZE)                                 \
   if (params.rmsNormFusion) {                                                 \
     switch (params.quantType) {                                               \
       case QuantType::kFP8:                                                   \
         LAUNCH_ALLREDUCE_KERNEL(WORLD_SIZE, true, QuantType::kFP8);           \
         break;                                                                \
       case QuantType::kFP4:                                                   \
         LAUNCH_ALLREDUCE_KERNEL(WORLD_SIZE, true, QuantType::kFP4);           \
         break;                                                                \
       case QuantType::kNone:                                                  \
-        LAUNCH_ALLREDUCE_KERNEL(WORLD_SIZE, false, QuantType::kNone);         \
+        LAUNCH_ALLREDUCE_KERNEL(WORLD_SIZE, true, QuantType::kNone);          \
         break;                                                                \
       default:                                                                \
         FLASHINFER_ERROR("Unsupported quant type! Got " +                     \
                          std::to_string(static_cast<int>(params.quantType))); \
         return cudaErrorInvalidValue;                                         \
     }                                                                         \
   } else {                                                                    \
     LAUNCH_ALLREDUCE_KERNEL(WORLD_SIZE, false, QuantType::kNone);             \
   }

Comment on lines +1736 to +1741
if constexpr (QType == QuantType::kFP8) {
quant::quant_fp8<T, float4, kELTS_PER_LOAD>(r_out, quantOut, outputScale, threadOffset);
} else if constexpr (QType == QuantType::kFP4) {
quant::quant_nvfp4<T, float4, kELTS_PER_LOAD>(r_out, quantOut, scalingFactorOut,
outputScale, token, dim,
token_offset / kELTS_PER_LOAD, sfLayout);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Bug: Incorrect offset passed to quant_fp8 - uses threadOffset instead of global offset.

The quant_fp8 call at line 1737 passes threadOffset (which is just threadIdx.x) as the output offset. This should be blockLoadOffset + threadLoadOffset to match the global position, similar to the outputNorm write at line 1734 and the quant_nvfp4 call which correctly computes its offset.

🔎 Proposed fix
       if (outputNorm != nullptr) {
         *reinterpret_cast<float4*>(&outputNorm[blockLoadOffset + threadLoadOffset]) = r_out.packed;
       }
       if constexpr (QType == QuantType::kFP8) {
-        quant::quant_fp8<T, float4, kELTS_PER_LOAD>(r_out, quantOut, outputScale, threadOffset);
+        quant::quant_fp8<T, float4, kELTS_PER_LOAD>(r_out, quantOut, outputScale, blockLoadOffset + threadLoadOffset);
       } else if constexpr (QType == QuantType::kFP4) {
         quant::quant_nvfp4<T, float4, kELTS_PER_LOAD>(r_out, quantOut, scalingFactorOut,
                                                       outputScale, token, dim,
                                                       token_offset / kELTS_PER_LOAD, sfLayout);
       }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if constexpr (QType == QuantType::kFP8) {
quant::quant_fp8<T, float4, kELTS_PER_LOAD>(r_out, quantOut, outputScale, threadOffset);
} else if constexpr (QType == QuantType::kFP4) {
quant::quant_nvfp4<T, float4, kELTS_PER_LOAD>(r_out, quantOut, scalingFactorOut,
outputScale, token, dim,
token_offset / kELTS_PER_LOAD, sfLayout);
if constexpr (QType == QuantType::kFP8) {
quant::quant_fp8<T, float4, kELTS_PER_LOAD>(r_out, quantOut, outputScale, blockLoadOffset + threadLoadOffset);
} else if constexpr (QType == QuantType::kFP4) {
quant::quant_nvfp4<T, float4, kELTS_PER_LOAD>(r_out, quantOut, scalingFactorOut,
outputScale, token, dim,
token_offset / kELTS_PER_LOAD, sfLayout);
}
🤖 Prompt for AI Agents
In include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh around lines 1736 to 1741,
the quant_fp8 call incorrectly passes threadOffset (threadIdx.x) as the output
offset instead of the global output offset; change the argument to use the
combined global offset (blockLoadOffset + threadLoadOffset) so quant_fp8 writes
to the same global position as outputNorm and quant_nvfp4, ensuring the offset
calculation mirrors the other calls.

@yzh119 yzh119 marked this pull request as ready for review December 24, 2025 06:24
@yzh119
Copy link
Collaborator

yzh119 commented Dec 24, 2025

Hi @timlee0212 is this PR ready? I noticed that you marked it as draft.

@timlee0212
Copy link
Contributor Author

Hi @timlee0212 is this PR ready? I noticed that you marked it as draft.

No it's still WIP. Convert it to draft.

@timlee0212 timlee0212 marked this pull request as draft December 31, 2025 06:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants