-
-
Notifications
You must be signed in to change notification settings - Fork 11.6k
SM120 / NVFP4: add device guard and runtime SM dispatch to cutlass_scaled_fp4_mm #29711
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: releases/v0.11.2
Are you sure you want to change the base?
SM120 / NVFP4: add device guard and runtime SM dispatch to cutlass_scaled_fp4_mm #29711
Conversation
Fix cutlass_scaled_fp4_mm: ensure correct CUDA device guard and runtime SM-based kernel dispatch Signed-off-by: Hendrik Holtmann <[email protected]>
Enhance cutlass_scaled_fp4_mm with device checks (SM100)
|
Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces runtime SM dispatch for cutlass_scaled_fp4_mm, which is a solid improvement to support multiple GPU architectures and fixes issues on newer hardware like the 5090 series. The change from compile-time to runtime dispatch also corrects a latent critical bug where the previous implementation would attempt to return a value from a void function. The new logic is more robust and the improved error message is a good addition. I have one suggestion to refactor the dispatch logic to improve its long-term maintainability and make it less error-prone when adding support for future architectures.
| #if defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100 | ||
| if (sm >= 100 && sm < 120) { | ||
| cutlass_scaled_fp4_mm_sm100a(D, A, B, A_sf, B_sf, alpha); | ||
| return; | ||
| } | ||
| #endif | ||
| TORCH_CHECK_NOT_IMPLEMENTED(false, | ||
| "No compiled nvfp4 mm kernel, vLLM should " | ||
| "be compiled using CUDA 12.8 and target " | ||
| "compute capability 100 or above."); | ||
|
|
||
| #if defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120 | ||
| if (sm >= 120) { | ||
| cutlass_scaled_fp4_mm_sm120a(D, A, B, A_sf, B_sf, alpha); | ||
| return; | ||
| } | ||
| #endif |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For better maintainability and extensibility, it's recommended to order the runtime dispatch checks from the newest supported SM version to the oldest. This allows simplifying the conditions by removing the upper bounds, as the return statements will prevent fall-through. This makes it easier and less error-prone to add support for future SM architectures.
#if defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120
if (sm >= 120) {
cutlass_scaled_fp4_mm_sm120a(D, A, B, A_sf, B_sf, alpha);
return;
}
#endif
#if defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100
if (sm >= 100) {
cutlass_scaled_fp4_mm_sm100a(D, A, B, A_sf, B_sf, alpha);
return;
}
#endif
fix(cuda): add device guard and runtime SM dispatch to cutlass_scaled_fp4_mm
Purpose
Currently, the fp4 scaled_mm function doesn't work for the 5090 GPU, resulting in a RuntimeError: Internal Error. See #21274 and #22783 for more informatio
Test Plan
pytest -v -s tests/kernels/quantization/test_nvfp4_scaled_mm.py
Test Result
All passed.