Skip to content

Conversation

@hholtmann
Copy link

@hholtmann hholtmann commented Nov 29, 2025

fix(cuda): add device guard and runtime SM dispatch to cutlass_scaled_fp4_mm

Purpose

Currently, the fp4 scaled_mm function doesn't work for the 5090 GPU, resulting in a RuntimeError: Internal Error. See #21274 and #22783 for more informatio

Test Plan

pytest -v -s tests/kernels/quantization/test_nvfp4_scaled_mm.py

Test Result

All passed.

Fix cutlass_scaled_fp4_mm: ensure correct CUDA device guard and runtime SM-based kernel dispatch

Signed-off-by: Hendrik Holtmann <[email protected]>
Enhance cutlass_scaled_fp4_mm with device checks (SM100)
@chatgpt-codex-connector
Copy link

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces runtime SM dispatch for cutlass_scaled_fp4_mm, which is a solid improvement to support multiple GPU architectures and fixes issues on newer hardware like the 5090 series. The change from compile-time to runtime dispatch also corrects a latent critical bug where the previous implementation would attempt to return a value from a void function. The new logic is more robust and the improved error message is a good addition. I have one suggestion to refactor the dispatch logic to improve its long-term maintainability and make it less error-prone when adding support for future architectures.

Comment on lines +48 to +60
#if defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100
if (sm >= 100 && sm < 120) {
cutlass_scaled_fp4_mm_sm100a(D, A, B, A_sf, B_sf, alpha);
return;
}
#endif
TORCH_CHECK_NOT_IMPLEMENTED(false,
"No compiled nvfp4 mm kernel, vLLM should "
"be compiled using CUDA 12.8 and target "
"compute capability 100 or above.");

#if defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120
if (sm >= 120) {
cutlass_scaled_fp4_mm_sm120a(D, A, B, A_sf, B_sf, alpha);
return;
}
#endif
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

For better maintainability and extensibility, it's recommended to order the runtime dispatch checks from the newest supported SM version to the oldest. This allows simplifying the conditions by removing the upper bounds, as the return statements will prevent fall-through. This makes it easier and less error-prone to add support for future SM architectures.

#if defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120
  if (sm >= 120) {
    cutlass_scaled_fp4_mm_sm120a(D, A, B, A_sf, B_sf, alpha);
    return;
  }
#endif

#if defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100
  if (sm >= 100) {
    cutlass_scaled_fp4_mm_sm100a(D, A, B, A_sf, B_sf, alpha);
    return;
  }
#endif

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

1 participant