-
Notifications
You must be signed in to change notification settings - Fork 621
[WIP] Refactor: simplify torch -> cute-dsl boilerplate and enable tvm-ffi for cute-dsl kernels #2279
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
📝 WalkthroughWalkthroughThis change transitions two RMSNorm FP4 quantization kernels from pointer-based to tensor-based TVM-FFI invocation. The kernel interfaces are refactored to accept tensors directly instead of raw pointers, compilation paths are updated with symbolic tensors for dynamic shapes, and runtime flows now pass tensors through TVM-FFI instead of marshalling pointers via CUDA driver APIs. Changes
Sequence Diagram(s)sequenceDiagram
participant Host as Host Code
participant TVM as TVM-FFI
participant Compile as Kernel Compiler
participant CUDA as CUDA Kernel
rect rgb(200, 220, 255)
Note over Host,CUDA: Compilation Phase
Host->>Host: Create symbolic M tensor<br/>and fake tensor fixtures
Host->>Compile: Invoke kernel compiler<br/>with symbolic tensors
Compile->>TVM: Register fake tensors<br/>for TVM compilation
Compile->>CUDA: Compile kernel<br/>with symbolic shapes
end
rect rgb(200, 255, 220)
Note over Host,CUDA: Runtime Invocation Phase
Host->>Host: Prepare input tensors<br/>(mX, mW, mY, mS, mR)
Host->>Host: Flatten/contiguate scale<br/>tensor if swizzled layout
Host->>TVM: Pass tensors via<br/>TVM-FFI interface
TVM->>CUDA: Launch kernel with<br/>tensor-based inputs
CUDA->>CUDA: Execute computation<br/>on tensor data
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 inconclusive)
✅ Passed checks (2 passed)
✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @yzh119, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request refactors the existing cute-dsl kernels, specifically for RMSNorm FP4 quantization, to leverage TVM-FFI. This integration aims to simplify the interaction between PyTorch tensors and the underlying CUDA kernels by allowing direct tensor passing, thereby reducing boilerplate code and improving the overall efficiency and developer experience when working with CUTLASS's cute-dsl. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request refactors the cute-dsl kernels to leverage tvm-ffi, which is a significant improvement. By enabling tvm-ffi, the code is simplified by allowing torch.Tensor objects to be passed directly to the kernels, removing the boilerplate for manual pointer creation and management. The changes in add_rmsnorm_fp4quant.py and rmsnorm_fp4quant.py are consistent and correctly use cute.runtime.make_fake_compact_tensor with symbolic dimensions for compilation. My review includes a couple of suggestions to correct misleading comments for better code clarity. Overall, this is a great change that improves maintainability.
| # Scale factor tensor layout depends on swizzle mode | ||
| if is_sf_swizzled_layout: | ||
| # For swizzled mode, use 1D layout - the swizzle pattern is computed in kernel | ||
| # Size is: num_m_tiles * num_k_tiles * 512, which is independent of M |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment incorrectly states that the swizzled size is independent of M. The number of M-tiles (num_m_tiles) is derived from M (the batch dimension), so the total swizzled size is dependent on M. The implementation correctly uses a symbolic integer for this dynamic size, but the comment is misleading and should be corrected for clarity.
| # Size is: num_m_tiles * num_k_tiles * 512, which is independent of M | |
| # Size is `num_m_tiles * num_k_tiles * 512`, which depends on the `M` dimension. |
| # Scale factor tensor layout depends on swizzle mode | ||
| if is_sf_swizzled_layout: | ||
| # For swizzled mode, use 1D layout - the swizzle pattern is computed in kernel | ||
| # Size is: num_m_tiles * num_k_tiles * 512, which is independent of M |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment is misleading. The swizzled size is dependent on M because num_m_tiles is calculated based on M. While the code correctly uses a symbolic size, the comment should be updated to reflect this dependency to avoid confusion.
| # Size is: num_m_tiles * num_k_tiles * 512, which is independent of M | |
| # Size is `num_m_tiles * num_k_tiles * 512`, which depends on the `M` dimension. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
flashinfer/cute_dsl/rmsnorm_fp4quant.py (1)
1658-1661: Outdated section header.The section header on line 1659 says "Pointer-based Compilation" but the code now uses tensor-based TVM-FFI compilation. This should be updated for consistency.
Suggested fix
# ============================================================================= -# PyTorch API Functions - Streamlined with Pointer-based Compilation +# PyTorch API Functions - Streamlined with TVM-FFI Tensor Compilation # =============================================================================
🧹 Nitpick comments (2)
flashinfer/cute_dsl/rmsnorm_fp4quant.py (1)
1706-1713: Minor: Misleading comment about M-independence.The comment states the swizzled size "is independent of M", but
num_m_tiles = ceil(M / 128), so the size actually depends on M. The implementation using a separate symbolic variable is correct, but the comment is confusing.Suggested fix
if is_sf_swizzled_layout: # For swizzled mode, use 1D layout - the swizzle pattern is computed in kernel - # Size is: num_m_tiles * num_k_tiles * 512, which is independent of M - # Use a separate symbolic variable for this size + # Size is: num_m_tiles * num_k_tiles * 512 + # Use a separate symbolic variable since this has different shape semantics sym_swizzled_size = cute.sym_int()flashinfer/cute_dsl/add_rmsnorm_fp4quant.py (1)
39-671: Consider extracting shared intrinsics and utilities to a common module.Both
rmsnorm_fp4quant.pyandadd_rmsnorm_fp4quant.pyshare substantial duplicate code (~800+ lines):
- PTX intrinsics (
set_block_rank,store_shared_remote,ld_global_v4_u32, etc.)- Half2/BFloat2 SIMD intrinsics
- FP8/UE8M0 conversion intrinsics
- Reduction utilities (
warp_reduce,block_reduce,cluster_reduce)get_sm_versionfunctionExtracting these to a shared module (e.g.,
flashinfer/cute_dsl/intrinsics.py) would reduce maintenance burden and ensure consistency.
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
flashinfer/cute_dsl/add_rmsnorm_fp4quant.pyflashinfer/cute_dsl/rmsnorm_fp4quant.py
🧰 Additional context used
📓 Path-based instructions (1)
flashinfer/**/*.py
📄 CodeRabbit inference engine (CLAUDE.md)
flashinfer/**/*.py: Use@functools.cachedecorator on Python API functions to implement module-level caching and avoid recompilation
Use@flashinfer_apidecorator for debugging API calls, enable viaFLASHINFER_LOGLEVELenvironment variable (0=off, 1=basic, 3=detailed, 5=with stats)
Files:
flashinfer/cute_dsl/add_rmsnorm_fp4quant.pyflashinfer/cute_dsl/rmsnorm_fp4quant.py
🧠 Learnings (2)
📓 Common learnings
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to csrc/**/*.cu : Framework bindings and PyTorch tensor handling should be implemented in `csrc/` via TVM-FFI, not in `include/` headers
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to csrc/**/*.cu : Framework bindings and PyTorch tensor handling should be implemented in `csrc/` via TVM-FFI, not in `include/` headers
Applied to files:
flashinfer/cute_dsl/add_rmsnorm_fp4quant.py
🧬 Code graph analysis (2)
flashinfer/cute_dsl/add_rmsnorm_fp4quant.py (2)
csrc/tvm_ffi_utils.h (1)
Tensor(304-306)include/flashinfer/trtllm/fused_moe/runner.h (1)
hidden_size(265-265)
flashinfer/cute_dsl/rmsnorm_fp4quant.py (2)
csrc/tvm_ffi_utils.h (1)
Tensor(304-306)include/flashinfer/trtllm/fused_moe/runner.h (1)
hidden_size(265-265)
🔇 Additional comments (5)
flashinfer/cute_dsl/rmsnorm_fp4quant.py (3)
1001-1018: LGTM - Kernel interface refactored to tensor-based API.The signature change from pointer-based to tensor-based inputs aligns with the TVM-FFI refactoring objective. The docstrings accurately describe the expected tensor shapes and layouts.
1739-1760: LGTM - Runtime tensor API correctly handles tensor passing.The
tensor_apiclosure appropriately handles the scale tensor layout (flatten for swizzled, contiguous for non-swizzled). The callerrmsnorm_fp4quantensures input tensors are contiguous before passing them.
1726-1737: LGTM - TVM-FFI compilation setup.The compilation correctly uses a fake stream with
use_tvm_ffi_env_stream=Trueto capture the environment stream at runtime, and enables TVM-FFI with the--enable-tvm-ffioption.flashinfer/cute_dsl/add_rmsnorm_fp4quant.py (2)
2103-2171: LGTM - TVM-FFI compilation setup consistent with rmsnorm_fp4quant.py.The fake tensor creation and compilation approach is consistent with the pattern in
rmsnorm_fp4quant.py. The implementation correctly uses symbolic dimensions and TVM-FFI options.Note: The same minor comment about "independent of M" at line 2141 applies here as well.
2173-2196: LGTM - Runtime tensor API correctly handles tensor passing.The
tensor_apiclosure appropriately handles the scale tensor layout and matches the pattern fromrmsnorm_fp4quant.py.
| """Host function to launch the kernel. | ||
| Takes tensors directly via TVM-FFI. | ||
| - mX: Input tensor, shape (M, H), row-major | ||
| - mR: Residual tensor (will be updated in-place), shape (M, H), row-major | ||
| - mW: Weight tensor, shape (H,) | ||
| - mY: Output FP4 tensor, shape (M, H // 2), row-major (packed) | ||
| - mS: Scale factor tensor, shape depends on swizzle mode | ||
| """ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Docstring claims in-place update of mR, but kernel doesn't write back.
The docstring states "mR: Residual tensor (will be updated in-place)", but the kernel computes h = x + r and stores it only in shared memory (sH) for normalization. The residual tensor mR is never modified. This could mislead callers expecting the residual to be updated.
Suggested fix
- - mR: Residual tensor (will be updated in-place), shape (M, H), row-major
+ - mR: Residual tensor, shape (M, H), row-major📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| """Host function to launch the kernel. | |
| Takes tensors directly via TVM-FFI. | |
| - mX: Input tensor, shape (M, H), row-major | |
| - mR: Residual tensor (will be updated in-place), shape (M, H), row-major | |
| - mW: Weight tensor, shape (H,) | |
| - mY: Output FP4 tensor, shape (M, H // 2), row-major (packed) | |
| - mS: Scale factor tensor, shape depends on swizzle mode | |
| """ | |
| """Host function to launch the kernel. | |
| Takes tensors directly via TVM-FFI. | |
| - mX: Input tensor, shape (M, H), row-major | |
| - mR: Residual tensor, shape (M, H), row-major | |
| - mW: Weight tensor, shape (H,) | |
| - mY: Output FP4 tensor, shape (M, H // 2), row-major (packed) | |
| - mS: Scale factor tensor, shape depends on swizzle mode | |
| """ |
📌 Description
cute-dsl adds support of compiling with tvm-ffi since 4.3 release https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/cute_dsl_general/compile_with_tvm_ffi.html, which allows user to pass torch tensors directly with negligible dlpack conversion cost, without the need of manually creating cute tensors from cute pointer.
In this PR we refactored the existing cute-dsl kernels to enable tvm-ffi and simplify the torch -> cute-dsl boilerplate.
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
Release Notes
✏️ Tip: You can customize this high-level summary in your review settings.