Skip to content

Conversation

@yzh119
Copy link
Collaborator

@yzh119 yzh119 commented Jan 1, 2026

📌 Description

cute-dsl adds support of compiling with tvm-ffi since 4.3 release https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/cute_dsl_general/compile_with_tvm_ffi.html, which allows user to pass torch tensors directly with negligible dlpack conversion cost, without the need of manually creating cute tensors from cute pointer.

In this PR we refactored the existing cute-dsl kernels to enable tvm-ffi and simplify the torch -> cute-dsl boilerplate.

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

Release Notes

  • Refactor
    • Updated RMSNorm and AddRMSNorm FP4 quantization kernel interfaces to use tensor-based inputs instead of pointer-based parameters, streamlining the API and improving memory handling in the optimization layer.

✏️ Tip: You can customize this high-level summary in your review settings.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 1, 2026

📝 Walkthrough

Walkthrough

This change transitions two RMSNorm FP4 quantization kernels from pointer-based to tensor-based TVM-FFI invocation. The kernel interfaces are refactored to accept tensors directly instead of raw pointers, compilation paths are updated with symbolic tensors for dynamic shapes, and runtime flows now pass tensors through TVM-FFI instead of marshalling pointers via CUDA driver APIs.

Changes

Cohort / File(s) Summary
Kernel Interface Migration
flashinfer/cute_dsl/add_rmsnorm_fp4quant.py, flashinfer/cute_dsl/rmsnorm_fp4quant.py
Replaced pointer-based kernel signatures (x_ptr, w_ptr, y_ptr, s_ptr, r_ptr) with tensor-based parameters (mX, mW, mY, mS, mR). Updated host __call__ methods to accept and pass tensors directly via TVM-FFI. Removed raw pointer marshalling and CUDA driver stream handling.
Compilation Path
flashinfer/cute_dsl/add_rmsnorm_fp4quant.py, flashinfer/cute_dsl/rmsnorm_fp4quant.py
Introduced symbolic M-size tensors and fake tensor fixtures for TVM-FFI compilation. Removed prior pointer-generation logic. Added swizzle-aware scale tensor layout handling at compile-time.
Runtime Tensor Handling
flashinfer/cute_dsl/add_rmsnorm_fp4quant.py, flashinfer/cute_dsl/rmsnorm_fp4quant.py
Updated tensor_api path to flatten/contiguate scale tensors when using swizzled output layouts. Removed manual pointer offset and layout construction in host code.
Documentation & Comments
flashinfer/cute_dsl/add_rmsnorm_fp4quant.py, flashinfer/cute_dsl/rmsnorm_fp4quant.py
Updated docstrings and inline comments to reflect tensor-based inputs, TVM-FFI tensor passing, and dynamic shape behavior.

Sequence Diagram(s)

sequenceDiagram
    participant Host as Host Code
    participant TVM as TVM-FFI
    participant Compile as Kernel Compiler
    participant CUDA as CUDA Kernel
    
    rect rgb(200, 220, 255)
    Note over Host,CUDA: Compilation Phase
    Host->>Host: Create symbolic M tensor<br/>and fake tensor fixtures
    Host->>Compile: Invoke kernel compiler<br/>with symbolic tensors
    Compile->>TVM: Register fake tensors<br/>for TVM compilation
    Compile->>CUDA: Compile kernel<br/>with symbolic shapes
    end
    
    rect rgb(200, 255, 220)
    Note over Host,CUDA: Runtime Invocation Phase
    Host->>Host: Prepare input tensors<br/>(mX, mW, mY, mS, mR)
    Host->>Host: Flatten/contiguate scale<br/>tensor if swizzled layout
    Host->>TVM: Pass tensors via<br/>TVM-FFI interface
    TVM->>CUDA: Launch kernel with<br/>tensor-based inputs
    CUDA->>CUDA: Execute computation<br/>on tensor data
    end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Suggested reviewers

  • jiahanc
  • kahyunnam
  • jimmyzho
  • nvmbreughe
  • aleozlx

Poem

🐰 Pointers fade to tensors bright,
TVM-FFI shines so right!
Kernels dance with shapes dynamic,
Swizzle-wise and most organic,
Code flows clean, no marshalls needed—
Architecture's call has been heeded! 🎉

Pre-merge checks and finishing touches

❌ Failed checks (1 inconclusive)
Check name Status Explanation Resolution
Description check ❓ Inconclusive The description section is well-completed with clear context about tvm-ffi capability and PR objectives. However, checklist items (pre-commit checks and tests) remain unchecked, indicating incompleteness. Check off completed pre-commit and test items, or explicitly note if any are not applicable. This ensures full transparency about PR readiness.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately summarizes the main change: refactoring cute-dsl kernels to simplify boilerplate and enable tvm-ffi integration for direct tensor passing.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.
✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @yzh119, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request refactors the existing cute-dsl kernels, specifically for RMSNorm FP4 quantization, to leverage TVM-FFI. This integration aims to simplify the interaction between PyTorch tensors and the underlying CUDA kernels by allowing direct tensor passing, thereby reducing boilerplate code and improving the overall efficiency and developer experience when working with CUTLASS's cute-dsl.

Highlights

  • TVM-FFI Integration: The core change enables TVM-FFI (Foreign Function Interface) for cute-dsl kernels, allowing direct passing of PyTorch tensors to CUDA kernels without manual pointer conversions, significantly simplifying the boilerplate.
  • Simplified Kernel Interface: The __call__ methods within the AddRMSNormFP4QuantKernel and RMSNormFP4QuantKernel classes have been refactored to directly accept cute.Tensor objects, eliminating the need for internal cute.make_tensor calls from raw pointers.
  • Reduced Boilerplate: Removed cuda.bindings.driver, cutlass.torch, and the custom make_ptr utility, streamlining the codebase and reducing dependencies related to manual pointer handling.
  • Dynamic Tensor Compilation: The kernel compilation process now uses cute.runtime.make_fake_compact_tensor with symbolic sizes (cute.sym_int()) and cute.runtime.make_fake_stream with use_tvm_ffi_env_stream=True, enabling more flexible and efficient compilation for dynamic tensor shapes via TVM-FFI.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the cute-dsl kernels to leverage tvm-ffi, which is a significant improvement. By enabling tvm-ffi, the code is simplified by allowing torch.Tensor objects to be passed directly to the kernels, removing the boilerplate for manual pointer creation and management. The changes in add_rmsnorm_fp4quant.py and rmsnorm_fp4quant.py are consistent and correctly use cute.runtime.make_fake_compact_tensor with symbolic dimensions for compilation. My review includes a couple of suggestions to correct misleading comments for better code clarity. Overall, this is a great change that improves maintainability.

# Scale factor tensor layout depends on swizzle mode
if is_sf_swizzled_layout:
# For swizzled mode, use 1D layout - the swizzle pattern is computed in kernel
# Size is: num_m_tiles * num_k_tiles * 512, which is independent of M
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The comment incorrectly states that the swizzled size is independent of M. The number of M-tiles (num_m_tiles) is derived from M (the batch dimension), so the total swizzled size is dependent on M. The implementation correctly uses a symbolic integer for this dynamic size, but the comment is misleading and should be corrected for clarity.

Suggested change
# Size is: num_m_tiles * num_k_tiles * 512, which is independent of M
# Size is `num_m_tiles * num_k_tiles * 512`, which depends on the `M` dimension.

# Scale factor tensor layout depends on swizzle mode
if is_sf_swizzled_layout:
# For swizzled mode, use 1D layout - the swizzle pattern is computed in kernel
# Size is: num_m_tiles * num_k_tiles * 512, which is independent of M
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This comment is misleading. The swizzled size is dependent on M because num_m_tiles is calculated based on M. While the code correctly uses a symbolic size, the comment should be updated to reflect this dependency to avoid confusion.

Suggested change
# Size is: num_m_tiles * num_k_tiles * 512, which is independent of M
# Size is `num_m_tiles * num_k_tiles * 512`, which depends on the `M` dimension.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
flashinfer/cute_dsl/rmsnorm_fp4quant.py (1)

1658-1661: Outdated section header.

The section header on line 1659 says "Pointer-based Compilation" but the code now uses tensor-based TVM-FFI compilation. This should be updated for consistency.

Suggested fix
 # =============================================================================
-# PyTorch API Functions - Streamlined with Pointer-based Compilation
+# PyTorch API Functions - Streamlined with TVM-FFI Tensor Compilation
 # =============================================================================
🧹 Nitpick comments (2)
flashinfer/cute_dsl/rmsnorm_fp4quant.py (1)

1706-1713: Minor: Misleading comment about M-independence.

The comment states the swizzled size "is independent of M", but num_m_tiles = ceil(M / 128), so the size actually depends on M. The implementation using a separate symbolic variable is correct, but the comment is confusing.

Suggested fix
     if is_sf_swizzled_layout:
         # For swizzled mode, use 1D layout - the swizzle pattern is computed in kernel
-        # Size is: num_m_tiles * num_k_tiles * 512, which is independent of M
-        # Use a separate symbolic variable for this size
+        # Size is: num_m_tiles * num_k_tiles * 512
+        # Use a separate symbolic variable since this has different shape semantics
         sym_swizzled_size = cute.sym_int()
flashinfer/cute_dsl/add_rmsnorm_fp4quant.py (1)

39-671: Consider extracting shared intrinsics and utilities to a common module.

Both rmsnorm_fp4quant.py and add_rmsnorm_fp4quant.py share substantial duplicate code (~800+ lines):

  • PTX intrinsics (set_block_rank, store_shared_remote, ld_global_v4_u32, etc.)
  • Half2/BFloat2 SIMD intrinsics
  • FP8/UE8M0 conversion intrinsics
  • Reduction utilities (warp_reduce, block_reduce, cluster_reduce)
  • get_sm_version function

Extracting these to a shared module (e.g., flashinfer/cute_dsl/intrinsics.py) would reduce maintenance burden and ensure consistency.

📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 6f1624c and db694aa.

📒 Files selected for processing (2)
  • flashinfer/cute_dsl/add_rmsnorm_fp4quant.py
  • flashinfer/cute_dsl/rmsnorm_fp4quant.py
🧰 Additional context used
📓 Path-based instructions (1)
flashinfer/**/*.py

📄 CodeRabbit inference engine (CLAUDE.md)

flashinfer/**/*.py: Use @functools.cache decorator on Python API functions to implement module-level caching and avoid recompilation
Use @flashinfer_api decorator for debugging API calls, enable via FLASHINFER_LOGLEVEL environment variable (0=off, 1=basic, 3=detailed, 5=with stats)

Files:

  • flashinfer/cute_dsl/add_rmsnorm_fp4quant.py
  • flashinfer/cute_dsl/rmsnorm_fp4quant.py
🧠 Learnings (2)
📓 Common learnings
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to csrc/**/*.cu : Framework bindings and PyTorch tensor handling should be implemented in `csrc/` via TVM-FFI, not in `include/` headers
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to csrc/**/*.cu : Framework bindings and PyTorch tensor handling should be implemented in `csrc/` via TVM-FFI, not in `include/` headers

Applied to files:

  • flashinfer/cute_dsl/add_rmsnorm_fp4quant.py
🧬 Code graph analysis (2)
flashinfer/cute_dsl/add_rmsnorm_fp4quant.py (2)
csrc/tvm_ffi_utils.h (1)
  • Tensor (304-306)
include/flashinfer/trtllm/fused_moe/runner.h (1)
  • hidden_size (265-265)
flashinfer/cute_dsl/rmsnorm_fp4quant.py (2)
csrc/tvm_ffi_utils.h (1)
  • Tensor (304-306)
include/flashinfer/trtllm/fused_moe/runner.h (1)
  • hidden_size (265-265)
🔇 Additional comments (5)
flashinfer/cute_dsl/rmsnorm_fp4quant.py (3)

1001-1018: LGTM - Kernel interface refactored to tensor-based API.

The signature change from pointer-based to tensor-based inputs aligns with the TVM-FFI refactoring objective. The docstrings accurately describe the expected tensor shapes and layouts.


1739-1760: LGTM - Runtime tensor API correctly handles tensor passing.

The tensor_api closure appropriately handles the scale tensor layout (flatten for swizzled, contiguous for non-swizzled). The caller rmsnorm_fp4quant ensures input tensors are contiguous before passing them.


1726-1737: LGTM - TVM-FFI compilation setup.

The compilation correctly uses a fake stream with use_tvm_ffi_env_stream=True to capture the environment stream at runtime, and enables TVM-FFI with the --enable-tvm-ffi option.

flashinfer/cute_dsl/add_rmsnorm_fp4quant.py (2)

2103-2171: LGTM - TVM-FFI compilation setup consistent with rmsnorm_fp4quant.py.

The fake tensor creation and compilation approach is consistent with the pattern in rmsnorm_fp4quant.py. The implementation correctly uses symbolic dimensions and TVM-FFI options.

Note: The same minor comment about "independent of M" at line 2141 applies here as well.


2173-2196: LGTM - Runtime tensor API correctly handles tensor passing.

The tensor_api closure appropriately handles the scale tensor layout and matches the pattern from rmsnorm_fp4quant.py.

Comment on lines +1017 to +1025
"""Host function to launch the kernel.
Takes tensors directly via TVM-FFI.
- mX: Input tensor, shape (M, H), row-major
- mR: Residual tensor (will be updated in-place), shape (M, H), row-major
- mW: Weight tensor, shape (H,)
- mY: Output FP4 tensor, shape (M, H // 2), row-major (packed)
- mS: Scale factor tensor, shape depends on swizzle mode
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Docstring claims in-place update of mR, but kernel doesn't write back.

The docstring states "mR: Residual tensor (will be updated in-place)", but the kernel computes h = x + r and stores it only in shared memory (sH) for normalization. The residual tensor mR is never modified. This could mislead callers expecting the residual to be updated.

Suggested fix
-        - mR: Residual tensor (will be updated in-place), shape (M, H), row-major
+        - mR: Residual tensor, shape (M, H), row-major
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
"""Host function to launch the kernel.
Takes tensors directly via TVM-FFI.
- mX: Input tensor, shape (M, H), row-major
- mR: Residual tensor (will be updated in-place), shape (M, H), row-major
- mW: Weight tensor, shape (H,)
- mY: Output FP4 tensor, shape (M, H // 2), row-major (packed)
- mS: Scale factor tensor, shape depends on swizzle mode
"""
"""Host function to launch the kernel.
Takes tensors directly via TVM-FFI.
- mX: Input tensor, shape (M, H), row-major
- mR: Residual tensor, shape (M, H), row-major
- mW: Weight tensor, shape (H,)
- mY: Output FP4 tensor, shape (M, H // 2), row-major (packed)
- mS: Scale factor tensor, shape depends on swizzle mode
"""

@bkryu
Copy link
Collaborator

bkryu commented Jan 2, 2026

Thank @yzh119, the previous torch -> cute-dsl overhead was a bit of a pain and this addresses the issue well. This PR would be helpful in preparing future cute-dsl kernels' APIs to correctly use tvm-ffi

Once the merge conflict with #2260 has been resolved, I can review.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants