-
-
Notifications
You must be signed in to change notification settings - Fork 11.6k
[WIP][Kernel]Support W4A8 Grouped GEMM on Hopper #29691
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: czhu-cohere <[email protected]>
Signed-off-by: czhu-cohere <[email protected]>
Signed-off-by: czhu-cohere <[email protected]>
Signed-off-by: czhu-cohere <[email protected]>
Signed-off-by: czhu-cohere <[email protected]>
Signed-off-by: czhu-cohere <[email protected]>
Signed-off-by: czhu-cohere <[email protected]>
Signed-off-by: czhu-cohere <[email protected]>
Signed-off-by: czhu-cohere <[email protected]>
Signed-off-by: czhu-cohere <[email protected]>
Signed-off-by: czhu-cohere <[email protected]>
Signed-off-by: czhu-cohere <[email protected]>
Signed-off-by: czhu-cohere <[email protected]>
Signed-off-by: czhu-cohere <[email protected]>
Signed-off-by: czhu-cohere <[email protected]>
Signed-off-by: czhu-cohere <[email protected]>
Signed-off-by: czhu-cohere <[email protected]>
Signed-off-by: czhu-cohere <[email protected]>
Signed-off-by: czhu-cohere <[email protected]>
Signed-off-by: czhu-cohere <[email protected]>
Signed-off-by: czhu-cohere <[email protected]>
Signed-off-by: czhu-cohere <[email protected]>
Signed-off-by: czhu-cohere <[email protected]>
Signed-off-by: czhu-cohere <[email protected]>
Signed-off-by: czhu-cohere <[email protected]>
Signed-off-by: czhu-cohere <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds support for W4A8 Grouped GEMM on Hopper GPUs, which is a significant feature for running quantized Mixture-of-Experts models efficiently. The changes span across C++ CUDA kernels, Python bindings, and integration into the model execution layers. The implementation looks solid, with new tests for the functionality. I've identified a couple of critical issues related to data type checks that could cause runtime failures for supported configurations. Addressing these will improve the robustness of the new kernel.
5529fc7 to
ba25cff
Compare
Signed-off-by: czhu-cohere <[email protected]>
ba25cff to
e7cf2d3
Compare
Purpose
As title; the benefit of W4A8 is it can use fp8 tensor cores while still maintaining the low memory footprint of W4A16 (with negligible quality loss). In addition there is no Machete-like impl in vLLM for W4A16 grouped gemm so the compute gains should be even larger compared to the current Marlin kernels.
The CUTLASS kernel implementation follows example 69 which uses a LUT-based method for fast INT4 -> FP8 conversion. Similarly to W4A8 dense, we also add per-channel/per-token epilogue.
We have uploaded a W4A8 quantized variant of Qwen3-30B-A3B as an e2e sanity check.
C++ changes
csrc/quantization/cutlass_w4a8/w4a8_utils.cucsrc/quantization/cutlass_w4a8/get_group_starts.cuhcsrc/quantization/cutlass_w4a8/w4a8_grouped_mm_entry.cuencode_and_reorder_int4b), we construct the layout object and serialize it to a torch tensor so that we can pass it into the grouped gemm at runtime. This is to avoid having to reconstruct the layout itself at runtime, which would incur significant overhead when the number of experts is large.static_assertandlayout_widthshould guarantee that the layout can be serialized to the expected torch tensor dtype/sizecsrc/quantization/w8a8/cutlass/moe/moe_data.cuget_cutlass_moe_mm_problem_sizesis coupled withSwapAB, so I added an argument to allow the user to explicitly specify SwapAB is true/false (forRSGEMM it is always true, since the argument to be dequantized - B - needs to be in the LHS)Python changes
vllm/model_executor/layers/fused_moe/config.pyvllm/model_executor/layers/fused_moe/modular_kernel.pyw1_scaleare used for the group-wise scales.vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.pyFusedMoeWeightScaleSupported.GROUP.valueandFusedMoeWeightScaleSupported.CHANNEL.valueto load group and channel scales respectivelycutlass:reorder_tensorexplained above (In practice that means for small MoEs like Qwen 30B you may not be able to do TP2)s_strides1/2which store strides for the group scales are stored as shape[num_experts, 2]and dtypeint64since that is what the kernel expectsb_strides1/2is returned by the reordering op and saved to pass in at runtimevllm/model_executor/layers/fused_moe/cutlass_moe.pySwapABis true always.Limitations
Have not implemented/checked compatibility with the different EP options other than default.
Test Plan
kernel correctness test -
tests/kernels/quantization/test_cutlass_w4a8_moe.pye2e eval - lm_eval gsm8k, compare qwen3-30b-a3b w4a16 and w4a8 variants
Test Result
tests/kernels/quantization/test_cutlass_w4a8_moe.py- passlm_eval
TODO
I am currently working on benchmark scripts and sweeping the <schedule, tileshape, clustershape> to get good performance for different problem sizes. But the core logic/interface should be ready for review
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.