Skip to content

Conversation

@czhu-cohere
Copy link
Contributor

@czhu-cohere czhu-cohere commented Nov 28, 2025

Purpose

As title; the benefit of W4A8 is it can use fp8 tensor cores while still maintaining the low memory footprint of W4A16 (with negligible quality loss). In addition there is no Machete-like impl in vLLM for W4A16 grouped gemm so the compute gains should be even larger compared to the current Marlin kernels.

The CUTLASS kernel implementation follows example 69 which uses a LUT-based method for fast INT4 -> FP8 conversion. Similarly to W4A8 dense, we also add per-channel/per-token epilogue.

We have uploaded a W4A8 quantized variant of Qwen3-30B-A3B as an e2e sanity check.

C++ changes

csrc/quantization/cutlass_w4a8/w4a8_utils.cu

  • refactor int4 reordering/shuffling utilities to a common file that can be shared between dense and grouped gemm w4a8.

csrc/quantization/cutlass_w4a8/get_group_starts.cuh

  • compute the pointers for each expert for grouped gemm. Most of the logic is the same as the w8a8 version, but w4a8 needs to account for slightly different input types, packed weights, and additional group-wise scales.

csrc/quantization/cutlass_w4a8/w4a8_grouped_mm_entry.cu

  • main cutlass kernel implementation and dispatch.
  • when encoding/shuffling the weight matrix (encode_and_reorder_int4b), we construct the layout object and serialize it to a torch tensor so that we can pass it into the grouped gemm at runtime. This is to avoid having to reconstruct the layout itself at runtime, which would incur significant overhead when the number of experts is large.
    • the static_assert and layout_width should guarantee that the layout can be serialized to the expected torch tensor dtype/size

csrc/quantization/w8a8/cutlass/moe/moe_data.cu

  • W4A8 moe can re-use a lot of the helper functions/utilities of W8A8. However, the logic for get_cutlass_moe_mm_problem_sizes is coupled with SwapAB, so I added an argument to allow the user to explicitly specify SwapAB is true/false (for RS GEMM it is always true, since the argument to be dequantized - B - needs to be in the LHS)

Python changes

vllm/model_executor/layers/fused_moe/config.py
vllm/model_executor/layers/fused_moe/modular_kernel.py

  • I think the previous code assumes that group-wise and channel-wise scales are exclusive, which is not the case here. So I added a field to save the channel scales, in the case where original scales like w1_scale are used for the group-wise scales.

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py

  • define weight loading and pre-processing. Some notable differences to FP8 cutlass moe:
    • uses different quant methods FusedMoeWeightScaleSupported.GROUP.value and FusedMoeWeightScaleSupported.CHANNEL.value to load group and channel scales respectively
    • stricter requirement on shapes due to limitation of cutlass:reorder_tensor explained above (In practice that means for small MoEs like Qwen 30B you may not be able to do TP2)
    • s_strides1/2 which store strides for the group scales are stored as shape [num_experts, 2] and dtype int64 since that is what the kernel expects
    • b_strides1/2 is returned by the reordering op and saved to pass in at runtime

vllm/model_executor/layers/fused_moe/cutlass_moe.py

  • The structure is similar to FP8 CUTLASS moe; main differences are 1) the extra arguments for scales, 2) different strides for each input, 3) enforcing SwapAB is true always.

Limitations

Have not implemented/checked compatibility with the different EP options other than default.

Test Plan

kernel correctness test - tests/kernels/quantization/test_cutlass_w4a8_moe.py
e2e eval - lm_eval gsm8k, compare qwen3-30b-a3b w4a16 and w4a8 variants

Test Result

tests/kernels/quantization/test_cutlass_w4a8_moe.py - pass
lm_eval

czhu-cohere/Qwen3-30B-A3B-quantized.w4a8
|  Tasks  |Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|---------|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k_cot|      3|flexible-extract|     8|exact_match|↑  |0.9212|±  |0.0074|
|         |       |strict-match    |     8|exact_match|↑  |0.9007|±  |0.0082|

RedHatAI/Qwen3-30B-A3B-quantized.w4a16
|  Tasks  |Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|---------|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k_cot|      3|flexible-extract|     8|exact_match|↑  |0.9227|±  |0.0074|
|         |       |strict-match    |     8|exact_match|↑  |0.9052|±  |0.0081|

TODO

I am currently working on benchmark scripts and sweeping the <schedule, tileshape, clustershape> to get good performance for different problem sizes. But the core logic/interface should be ready for review


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Signed-off-by: czhu-cohere <[email protected]>
Signed-off-by: czhu-cohere <[email protected]>
Signed-off-by: czhu-cohere <[email protected]>
Signed-off-by: czhu-cohere <[email protected]>
Signed-off-by: czhu-cohere <[email protected]>
Signed-off-by: czhu-cohere <[email protected]>
Signed-off-by: czhu-cohere <[email protected]>
Signed-off-by: czhu-cohere <[email protected]>
Signed-off-by: czhu-cohere <[email protected]>
Signed-off-by: czhu-cohere <[email protected]>
Signed-off-by: czhu-cohere <[email protected]>
Signed-off-by: czhu-cohere <[email protected]>
Signed-off-by: czhu-cohere <[email protected]>
Signed-off-by: czhu-cohere <[email protected]>
Signed-off-by: czhu-cohere <[email protected]>
Signed-off-by: czhu-cohere <[email protected]>
Signed-off-by: czhu-cohere <[email protected]>
Signed-off-by: czhu-cohere <[email protected]>
Signed-off-by: czhu-cohere <[email protected]>
Signed-off-by: czhu-cohere <[email protected]>
Signed-off-by: czhu-cohere <[email protected]>
Signed-off-by: czhu-cohere <[email protected]>
Signed-off-by: czhu-cohere <[email protected]>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for W4A8 Grouped GEMM on Hopper GPUs, which is a significant feature for running quantized Mixture-of-Experts models efficiently. The changes span across C++ CUDA kernels, Python bindings, and integration into the model execution layers. The implementation looks solid, with new tests for the functionality. I've identified a couple of critical issues related to data type checks that could cause runtime failures for supported configurations. Addressing these will improve the robustness of the new kernel.

@czhu-cohere czhu-cohere force-pushed the czhu/w4a8-moe branch 4 times, most recently from 5529fc7 to ba25cff Compare November 28, 2025 18:10
Signed-off-by: czhu-cohere <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build new-model Requests to new models nvidia

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

1 participant