[claude] Make MXFP8 cuda kernels ABI stable#3610
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3610
Note: Links to docs will display an error until the docs builds have been completed. ⏳ No Failures, 18 PendingAs of commit 0211669 with merge base 34322b5 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
2f3e96d to
7b28a33
Compare
There was a problem hiding this comment.
The best way to test this is by setting up a build where you can pass in -DTORCH_TARGET_VERSION and build the kernel + ensure it builds and runs. I'd recommend setting up a separate stable wheel for the migration to make testing easier as you slowly migrate kernels to the stable wheel
Edit: I just learned that you only have 3 files to migrate. In that case you do have the option of just biting the bullet and changing the setup.py in torchao to add this flag: https://github.com/Dao-AILab/flash-attention/pull/2155/files.
Note you want 20a0 and not 2090 though, for 2.10.
3784e4f to
2b13608
Compare
| #include <torch/headeronly/util/Exception.h> | ||
|
|
||
| // Declare the CUDA stream function that's behind #ifdef USE_CUDA in shim.h | ||
| extern "C" AOTITorchError aoti_torch_get_current_cuda_stream(int32_t device_index, void** ret_stream); |
There was a problem hiding this comment.
I don't think you need this, you should just pass -DUSE_CUDA as a flag in setup.py.
Though hmm @mikaylagawarecki should we automatically pass this in through CUDAExtension for better UX?
There was a problem hiding this comment.
This or -DUSE_CUDA would both work
tbh I'm not sure that silently defining USE_CUDA in CUDAExtension is good UX 🤔 (Where did USE_CUDA originate from? Is it a pytorch thing or a cuda thing? I'm wondering whether users might also define USE_CUDA themselves for separate purposes hmm, wdyt?)
| case torch::kFloat32: | ||
| DType get_input_dtype(const Tensor &t) { | ||
| auto scalar_type = t.scalar_type(); | ||
| if (scalar_type == torch::headeronly::ScalarType::Float) { |
There was a problem hiding this comment.
you should still be able to use a switch statement i believe, but no biggie
| const int64_t num_col_blocks = (cols + scale_dim_x - 1) / scale_dim_x; | ||
| output_rowwise = at::empty({rows, cols}, options_fp8); | ||
| scales_rowwise = at::empty({rows, num_col_blocks}, options_scale); | ||
| output_rowwise = torch::stable::empty( |
janeyx99
left a comment
There was a problem hiding this comment.
Yea, this looks close. Here's when I'd recommend just building the wheel and letting the error messages and test cases get you 100% there.
ca53f6c to
70a834b
Compare
70a834b to
476167e
Compare
| // Create column-major tensor by creating transposed shape and transposing | ||
| // We need shape {rows, cols} with strides {1, rows}, so create {cols, rows} and transpose | ||
| Tensor output_colwise_tmp = torch::stable::new_empty(input, {cols, rows}, torch::headeronly::ScalarType::Float8_e4m3fn); | ||
| output_colwise = torch::stable::transpose(output_colwise_tmp, 0, 1); |
There was a problem hiding this comment.
@janeyx99 @danielvegamyhre is this right? Seems like claude is doing this for all column major tensors
There was a problem hiding this comment.
looks right to me, it allocates empty shape (cols, rows) with strides (rows, 1) (row major) then transpose so we have (rows, cols) with strides (1, rows) (col major). since transpose is just a metadata change this shouldn't impact perf
There was a problem hiding this comment.
also as an aside (and a note to myself) i think the original code for the scales can be simplified, but will leave it for a future PR to avoid merge conflicts with this
| // Get raw pointers | ||
| const uint8_t* scales_ptr = reinterpret_cast<const uint8_t*>(scales_tensor.data_ptr()); | ||
| const int32_t* offsets_ptr = input_group_end_offsets.data_ptr<int32_t>(); | ||
| const int32_t* offsets_ptr = static_cast<const int32_t*>(input_group_end_offsets.data_ptr()); |
There was a problem hiding this comment.
do we need this change?
There was a problem hiding this comment.
You can use input_group_end_offsets.const_data_ptr<int32_t>() the stable Tensor.data_ptr() doesn't accept templates atm hence the issue :(
for future reference, there's also .mutable_data_ptr() in stable
| + [ | ||
| "-gencode=arch=compute_100,code=sm_100", | ||
| "-gencode=arch=compute_120,code=compute_120", | ||
| "-DUSE_CUDA", |
There was a problem hiding this comment.
@janeyx99 @mikaylagawarecki I needed to add this for aoti_torch_get_current_cuda_stream. Is this what you would recommend?
476167e to
ff1b866
Compare
janeyx99
left a comment
There was a problem hiding this comment.
Approval contingent on passing in the build flag and the newly built AO wheel being able to run tests correctly.
| f"-DPy_LIMITED_API={min_supported_cpython_hexcode}", | ||
| "-std=c++17", | ||
| "-O3", | ||
| "-DUSE_CUDA", |
There was a problem hiding this comment.
also pass in a -DTORCH_TARGET_VERSION=0x020a000000000000 for 2.10
ff1b866 to
4fa9a70
Compare
|
@andrewor14 i'll build and do some tests with this PR locally and follow up with the results |
|
@andrewor14 when I run unit tests I see runtime errors like this: Repro command:
@janeyx99 any thoughts on this? |
|
also @andrewor14 i just landed a test fix for an unrelated issue #3728, i would suggest rebasing on main now to get cleaner test signal |
4fa9a70 to
0841c49
Compare
|
Hmm I'm hitting a different error after rebasing and running the test: Error: Any ideas? Wonder if it's a shape error? We did add some transposes in this PR. @danielvegamyhre |
Oh boy, this is cuz our shim doesn't currently support the 44th ScalarType that is Float8_e8m0fnu --> is that needed? I'll patch something for 2.11 if so, but it would mean we can't migrate this til then :/ |
|
Marking this PR as draft so we don't accidentally merge it. Seems like this is blocked on Float8_e8m0fnu ABI stable support for now, probably until torch 2.11 |
0841c49 to
5edc119
Compare
| int64_t scale_rowwise_stride_dim0 = scales_rowwise.dim() >= 2 ? scales_rowwise.stride(0) : 0; | ||
| int64_t scale_rowwise_stride_dim1 = scales_rowwise.dim() >= 2 ? scales_rowwise.stride(1) : 0; | ||
| int64_t scale_colwise_stride_dim0 = scales_colwise.dim() >= 2 ? scales_colwise.stride(0) : 0; | ||
| int64_t scale_colwise_stride_dim1 = scales_colwise.dim() >= 2 ? scales_colwise.stride(1) : 0; |
There was a problem hiding this comment.
Hey @danielvegamyhre I had to add the dim check to pass tests for some reason. Does this look right? Here's what claude said:
The Problem:
In your commit, you changed from .strides()[i] to .stride(i) to use the stable ABI. However, the stable ABI's .stride(i) method validates the dimension index, while .strides()[i] accessed the array directly without
bounds checking.
When rowwise=False, scales_rowwise is created as a 1D tensor with shape {0}. When the code calls scales_rowwise.stride(1), it fails because dimension index 1 is out of range for a 1D tensor (valid indices are only -1 and
0).
The Fix (in torchao/csrc/cuda/mx_kernels/mxfp8_cuda.cu):
Guard the stride accesses with dimension checks:
// Get strides of scale ptrs (guard against 1D empty tensors when rowwise/colwise is false)
int64_t scale_rowwise_stride_dim0 = scales_rowwise.dim() >= 2 ? scales_rowwise.stride(0) : 0;
int64_t scale_rowwise_stride_dim1 = scales_rowwise.dim() >= 2 ? scales_rowwise.stride(1) : 0;
int64_t scale_colwise_stride_dim0 = scales_colwise.dim() >= 2 ? scales_colwise.stride(0) : 0;
int64_t scale_colwise_stride_dim1 = scales_colwise.dim() >= 2 ? scales_colwise.stride(1) : 0;
This was actually a latent bug in the original code - the unstable ABI's .strides()[i] silently returned garbage/undefined values for out-of-bounds access. The stable ABI's stricter bounds checking exposed it.
| "-O3", | ||
| "-DUSE_CUDA", | ||
| # define TORCH_TARGET_VERSION with min version 2.11 for ABI stable Float8_e8m0fnu | ||
| "-DTORCH_TARGET_VERSION=0x020b000000000000", |
There was a problem hiding this comment.
Note: updated the target version to 2.11 cause we need pytorch/pytorch#173669
There was a problem hiding this comment.
should we add a guard for torch_version_at_least("2.11.0") as well?
5edc119 to
c19db71
Compare
8558dff to
a4fca0c
Compare
Prompt:
```
I want you to make the mxfp8 kernels under torchao/csrc/cuda/mx_kernels ABI stable.
Read these for instructions:
pytorch/docs/source/notes/libtorch_stable_abi.md
cppdocs/_sources/stable.rst.txt
Use these files for an example:
Before: flash-attention/hopper/flash_api.cpp
After: flash-attention/hopper/flash_api_stable.cpp
Additional instructions:
Replace at::zeros with torch::stable::new_zeros
Replace at::empty with torch::stable::new_empty
Replace TORCH_CHECK with STD_TORCH_CHECK without changing the content of the check, just replace the function name
Replace c10::cuda::CUDAGuard with DeviceGuard
When calling tensor.stride(), no need to check for tensor.numel() > 0 beforehand
When using cudaStream_t, remember to include cuda_runtime.h
When calling aoti_torch_get_current_cuda_stream, get the device index from a tensor, not from the general torch::stable::accelerator::getCurrentDeviceIndex();
Don’t do these things:
Don’t use torch::stable::new_empty_strided since this API does not exist
Don’t define a dummy _C module that can be accessed from python
Don’t declare aoti_torch_get_current_cuda_stream, just include it from torch/csrc/inductor/aoti_torch/c/shim.h and add -DUSE_CUDA to both cxx and nvcc in setup.py
Don’t box kernels manually, just use TORCH_BOX
Don’t change switch statements into if else blocks
```
Follow-up prompts:
```
Revert the content changes of STD_TORCH_CHECK, just replace the function name
For the changes that involved tensor.data_ptr<some_dtype>() in the old code, revert the changes and replace it with tensor.const_data_ptr<some_dtype>()
```
Output:
```
● I've successfully made both files ABI stable. Here's a summary of the changes:
Changes to mxfp8_cuda.cu:
1. Updated includes: Replaced ATen and torch/extension.h with stable ABI headers (torch/csrc/stable/tensor.h, torch/csrc/inductor/aoti_torch/c/shim.h, torch/headeronly/core/ScalarType.h, torch/headeronly/util/Exception.h, torch/headeronly/util/shim_utils.h)
2. Added cuda_runtime.h for cudaStream_t
3. Changed tensor type: torch::Tensor → torch::stable::Tensor (with using declaration)
4. Changed scalar type enums: torch::kFloat32 → torch::headeronly::ScalarType::Float, etc.
5. Changed error checking: TORCH_CHECK → STD_TORCH_CHECK
6. Changed CUDA stream access: at::cuda::getCurrentCUDAStream() → aoti_torch_get_current_cuda_stream(input.get_device_index(), &stream_ptr) with TORCH_ERROR_CODE_CHECK
7. Fixed stride access: .strides()[i] → .stride(i) and .sizes()[i] → .size(i)
Changes to mxfp8_extension.cpp:
1. Updated includes: Replaced torch/library.h, ATen/ATen.h, ATen/cuda/CUDAContext.h, c10/cuda/CUDAGuard.h with stable headers
2. Changed tensor type: at::Tensor → torch::stable::Tensor
3. Changed DeviceGuard: c10::cuda::CUDAGuard → torch::stable::accelerator::DeviceGuard (with input.get_device_index() instead of input.device())
4. Changed error checking: TORCH_CHECK → STD_TORCH_CHECK
5. Changed tensor creation:
- at::empty → torch::stable::new_empty
- at::zeros → torch::stable::new_zeros
- For strided tensors: Created with transposed dimensions then used torch::stable::transpose to get column-major layout
6. Changed CUDA stream access: Used aoti_torch_get_current_cuda_stream with device index from tensor
7. Changed library registration: TORCH_LIBRARY_IMPL → STABLE_TORCH_LIBRARY_IMPL with TORCH_BOX wrapper
Changes to setup.py:
1. Added -DUSE_CUDA to both cxx and nvcc compiler arguments for the mxfp8 extension
```
a4fca0c to
0211669
Compare
Test Plan:
Prompt:
Follow-up prompts:
Output: