Skip to content

[claude] Make MXFP8 cuda kernels ABI stable#3610

Merged
andrewor14 merged 1 commit intomainfrom
mxfp8-kernel-abi-stable
Mar 23, 2026
Merged

[claude] Make MXFP8 cuda kernels ABI stable#3610
andrewor14 merged 1 commit intomainfrom
mxfp8-kernel-abi-stable

Conversation

@andrewor14
Copy link
Contributor

@andrewor14 andrewor14 commented Jan 9, 2026

Test Plan:

pytest test/prototype/mx_formats/test_kernels.py -v -s -k cuda
pytest test/prototype/moe_training/test_kernels.py -v -s -k cuda_mx_block
pytest test/prototype/moe_training/test_training.py -v -s

Prompt:

Make these two files ABI stable:
torchao/csrc/cuda/mx_kernels/mxfp8_cuda.cu
torchao/csrc/cuda/mx_kernels/mxfp8_extension.cpp

Read these for instructions:
pytorch/docs/source/notes/libtorch_stable_abi.md
cppdocs/_sources/stable.rst.txt

Use these files for an example:
Before: flash-attention/hopper/flash_api.cpp
After: flash-attention/hopper/flash_api_stable.cpp

Additional instructions:
Replace at::zeros with torch::stable::new_zeros
Replace at::empty with torch::stable::new_empty
Replace TORCH_CHECK with STD_TORCH_CHECK without changing the content of the check, just replace the function name
Replace c10::cuda::CUDAGuard with DeviceGuard
When calling tensor.stride(), no need to check for tensor.numel() > 0 beforehand
When using cudaStream_t, remember to include cuda_runtime.h
When calling aoti_torch_get_current_cuda_stream, get the device index from a tensor, not from the general torch::stable::accelerator::getCurrentDeviceIndex();

Don’t do these things:
Don’t use torch::stable::new_empty_strided since this API does not exist
Don’t define a dummy _C module that can be accessed from python
Don’t declare aoti_torch_get_current_cuda_stream, just include it from torch/csrc/inductor/aoti_torch/c/shim.h and add -DUSE_CUDA to both cxx and nvcc in setup.py
Don’t box kernels manually, just use TORCH_BOX
Don’t change switch statements into if else blocks

Follow-up prompts:

Revert the content changes of STD_TORCH_CHECK, just replace the function name

For the changes that involved tensor.data_ptr<some_dtype>() in the old code, revert the changes and replace it with tensor.const_data_ptr<some_dtype>()

Output:

● I've successfully made both files ABI stable. Here's a summary of the changes:

  Changes to mxfp8_cuda.cu:

  1. Updated includes: Replaced ATen and torch/extension.h with stable ABI headers (torch/csrc/stable/tensor.h, torch/csrc/inductor/aoti_torch/c/shim.h, torch/headeronly/core/ScalarType.h, torch/headeronly/util/Exception.h, torch/headeronly/util/shim_utils.h)
  2. Added cuda_runtime.h for cudaStream_t
  3. Changed tensor type: torch::Tensor → torch::stable::Tensor (with using declaration)
  4. Changed scalar type enums: torch::kFloat32 → torch::headeronly::ScalarType::Float, etc.
  5. Changed error checking: TORCH_CHECK → STD_TORCH_CHECK
  6. Changed CUDA stream access: at::cuda::getCurrentCUDAStream() → aoti_torch_get_current_cuda_stream(input.get_device_index(), &stream_ptr) with TORCH_ERROR_CODE_CHECK
  7. Fixed stride access: .strides()[i] → .stride(i) and .sizes()[i] → .size(i)

  Changes to mxfp8_extension.cpp:

  1. Updated includes: Replaced torch/library.h, ATen/ATen.h, ATen/cuda/CUDAContext.h, c10/cuda/CUDAGuard.h with stable headers
  2. Changed tensor type: at::Tensor → torch::stable::Tensor
  3. Changed DeviceGuard: c10::cuda::CUDAGuard → torch::stable::accelerator::DeviceGuard (with input.get_device_index() instead of input.device())
  4. Changed error checking: TORCH_CHECK → STD_TORCH_CHECK
  5. Changed tensor creation:
    - at::empty → torch::stable::new_empty
    - at::zeros → torch::stable::new_zeros
    - For strided tensors: Created with transposed dimensions then used torch::stable::transpose to get column-major layout
  6. Changed CUDA stream access: Used aoti_torch_get_current_cuda_stream with device index from tensor
  7. Changed library registration: TORCH_LIBRARY_IMPL → STABLE_TORCH_LIBRARY_IMPL with TORCH_BOX wrapper

  Changes to setup.py:

  1. Added -DUSE_CUDA to both cxx and nvcc compiler arguments for the mxfp8 extension

@pytorch-bot
Copy link

pytorch-bot bot commented Jan 9, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3610

Note: Links to docs will display an error until the docs builds have been completed.

⏳ No Failures, 18 Pending

As of commit 0211669 with merge base 34322b5 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jan 9, 2026
@andrewor14 andrewor14 requested a review from janeyx99 January 9, 2026 16:26
@andrewor14 andrewor14 changed the title [draft] Make MXFP8 cuda kernel ABI stable [claude] Make MXFP8 cuda kernel ABI stable Jan 9, 2026
@andrewor14 andrewor14 force-pushed the mxfp8-kernel-abi-stable branch from 2f3e96d to 7b28a33 Compare January 9, 2026 16:26
@andrewor14 andrewor14 added the topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories) label Jan 9, 2026
Copy link
Contributor

@janeyx99 janeyx99 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The best way to test this is by setting up a build where you can pass in -DTORCH_TARGET_VERSION and build the kernel + ensure it builds and runs. I'd recommend setting up a separate stable wheel for the migration to make testing easier as you slowly migrate kernels to the stable wheel

Edit: I just learned that you only have 3 files to migrate. In that case you do have the option of just biting the bullet and changing the setup.py in torchao to add this flag: https://github.com/Dao-AILab/flash-attention/pull/2155/files.

Note you want 20a0 and not 2090 though, for 2.10.

@andrewor14 andrewor14 marked this pull request as draft January 9, 2026 16:31
@andrewor14 andrewor14 force-pushed the mxfp8-kernel-abi-stable branch 2 times, most recently from 3784e4f to 2b13608 Compare January 9, 2026 20:29
@andrewor14 andrewor14 changed the title [claude] Make MXFP8 cuda kernel ABI stable [claude] Make MXFP8 cuda kernels ABI stable Jan 9, 2026
@andrewor14 andrewor14 requested a review from janeyx99 January 9, 2026 20:36
#include <torch/headeronly/util/Exception.h>

// Declare the CUDA stream function that's behind #ifdef USE_CUDA in shim.h
extern "C" AOTITorchError aoti_torch_get_current_cuda_stream(int32_t device_index, void** ret_stream);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think you need this, you should just pass -DUSE_CUDA as a flag in setup.py.

Though hmm @mikaylagawarecki should we automatically pass this in through CUDAExtension for better UX?

Copy link

@mikaylagawarecki mikaylagawarecki Jan 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This or -DUSE_CUDA would both work

tbh I'm not sure that silently defining USE_CUDA in CUDAExtension is good UX 🤔 (Where did USE_CUDA originate from? Is it a pytorch thing or a cuda thing? I'm wondering whether users might also define USE_CUDA themselves for separate purposes hmm, wdyt?)

case torch::kFloat32:
DType get_input_dtype(const Tensor &t) {
auto scalar_type = t.scalar_type();
if (scalar_type == torch::headeronly::ScalarType::Float) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you should still be able to use a switch statement i believe, but no biggie

const int64_t num_col_blocks = (cols + scale_dim_x - 1) / scale_dim_x;
output_rowwise = at::empty({rows, cols}, options_fp8);
scales_rowwise = at::empty({rows, num_col_blocks}, options_scale);
output_rowwise = torch::stable::empty(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

new_empty?

here and below

Copy link
Contributor

@janeyx99 janeyx99 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, this looks close. Here's when I'd recommend just building the wheel and letting the error messages and test cases get you 100% there.

@andrewor14 andrewor14 force-pushed the mxfp8-kernel-abi-stable branch 3 times, most recently from ca53f6c to 70a834b Compare January 12, 2026 23:45
// Create column-major tensor by creating transposed shape and transposing
// We need shape {rows, cols} with strides {1, rows}, so create {cols, rows} and transpose
Tensor output_colwise_tmp = torch::stable::new_empty(input, {cols, rows}, torch::headeronly::ScalarType::Float8_e4m3fn);
output_colwise = torch::stable::transpose(output_colwise_tmp, 0, 1);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@janeyx99 @danielvegamyhre is this right? Seems like claude is doing this for all column major tensors

Copy link
Contributor

@danielvegamyhre danielvegamyhre Jan 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks right to me, it allocates empty shape (cols, rows) with strides (rows, 1) (row major) then transpose so we have (rows, cols) with strides (1, rows) (col major). since transpose is just a metadata change this shouldn't impact perf

Copy link
Contributor

@danielvegamyhre danielvegamyhre Jan 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also as an aside (and a note to myself) i think the original code for the scales can be simplified, but will leave it for a future PR to avoid merge conflicts with this

// Get raw pointers
const uint8_t* scales_ptr = reinterpret_cast<const uint8_t*>(scales_tensor.data_ptr());
const int32_t* offsets_ptr = input_group_end_offsets.data_ptr<int32_t>();
const int32_t* offsets_ptr = static_cast<const int32_t*>(input_group_end_offsets.data_ptr());
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need this change?

Copy link

@mikaylagawarecki mikaylagawarecki Jan 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use input_group_end_offsets.const_data_ptr<int32_t>() the stable Tensor.data_ptr() doesn't accept templates atm hence the issue :(

for future reference, there's also .mutable_data_ptr() in stable

+ [
"-gencode=arch=compute_100,code=sm_100",
"-gencode=arch=compute_120,code=compute_120",
"-DUSE_CUDA",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@janeyx99 @mikaylagawarecki I needed to add this for aoti_torch_get_current_cuda_stream. Is this what you would recommend?

@andrewor14 andrewor14 marked this pull request as ready for review January 15, 2026 22:19
@andrewor14 andrewor14 force-pushed the mxfp8-kernel-abi-stable branch from 476167e to ff1b866 Compare January 16, 2026 18:58
Copy link
Contributor

@janeyx99 janeyx99 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approval contingent on passing in the build flag and the newly built AO wheel being able to run tests correctly.

f"-DPy_LIMITED_API={min_supported_cpython_hexcode}",
"-std=c++17",
"-O3",
"-DUSE_CUDA",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also pass in a -DTORCH_TARGET_VERSION=0x020a000000000000 for 2.10

@andrewor14 andrewor14 force-pushed the mxfp8-kernel-abi-stable branch from ff1b866 to 4fa9a70 Compare January 26, 2026 20:58
@danielvegamyhre
Copy link
Contributor

@andrewor14 i'll build and do some tests with this PR locally and follow up with the results

@danielvegamyhre
Copy link
Contributor

danielvegamyhre commented Jan 26, 2026

@andrewor14 when I run unit tests I see runtime errors like this:

RuntimeError: call, /home/dev/.conda/envs/torch/lib/python3.12/site-packages/torch/include/torch/csrc/stable/stableivalue_conversions.h:530, Not yet supported ScalarType 44, please file an issue describing your use case.

Repro command:

  • pytest test/prototype/moe_training/test_training.py -v -s

@janeyx99 any thoughts on this?

@danielvegamyhre
Copy link
Contributor

danielvegamyhre commented Jan 27, 2026

also @andrewor14 i just landed a test fix for an unrelated issue #3728, i would suggest rebasing on main now to get cleaner test signal

@andrewor14 andrewor14 force-pushed the mxfp8-kernel-abi-stable branch from 4fa9a70 to 0841c49 Compare January 27, 2026 15:28
@andrewor14
Copy link
Contributor Author

andrewor14 commented Jan 27, 2026

Hmm I'm hitting a different error after rebasing and running the test:

pytest test/prototype/moe_training/test_training.py -v -s 

Error:

FAILED test/prototype/moe_training/test_training.py::test_moe_training[recipe_config1-False-target_fqns0] -

RuntimeError: CUDA error: CUBLAS_STATUS_INVALID_VALUE when calling
`cublasGemmEx( handle, opa, opb, m, n, k, &falpha, a, CUDA_R_16BF, lda,
b, CUDA_R_16BF, ldb, &fbeta, c, std::is_same_v<C_Dtype, float...

Any ideas? Wonder if it's a shape error? We did add some transposes in this PR. @danielvegamyhre

@janeyx99
Copy link
Contributor

RuntimeError: call, /home/dev/.conda/envs/torch/lib/python3.12/site-packages/torch/include/torch/csrc/stable/stableivalue_conversions.h:530, Not yet supported ScalarType 44, please file an issue describing your use case.

Repro command:

  • pytest test/prototype/moe_training/test_training.py -v -s

@janeyx99 any thoughts on this?

Oh boy, this is cuz our shim doesn't currently support the 44th ScalarType that is Float8_e8m0fnu --> is that needed? I'll patch something for 2.11 if so, but it would mean we can't migrate this til then :/

@andrewor14 andrewor14 marked this pull request as draft January 28, 2026 16:54
@andrewor14
Copy link
Contributor Author

Marking this PR as draft so we don't accidentally merge it. Seems like this is blocked on Float8_e8m0fnu ABI stable support for now, probably until torch 2.11

@andrewor14 andrewor14 force-pushed the mxfp8-kernel-abi-stable branch from 0841c49 to 5edc119 Compare February 6, 2026 20:11
int64_t scale_rowwise_stride_dim0 = scales_rowwise.dim() >= 2 ? scales_rowwise.stride(0) : 0;
int64_t scale_rowwise_stride_dim1 = scales_rowwise.dim() >= 2 ? scales_rowwise.stride(1) : 0;
int64_t scale_colwise_stride_dim0 = scales_colwise.dim() >= 2 ? scales_colwise.stride(0) : 0;
int64_t scale_colwise_stride_dim1 = scales_colwise.dim() >= 2 ? scales_colwise.stride(1) : 0;
Copy link
Contributor Author

@andrewor14 andrewor14 Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @danielvegamyhre I had to add the dim check to pass tests for some reason. Does this look right? Here's what claude said:

  The Problem:                                                                                                                                                                                                                
                                                                                                                                                                                                                              
  In your commit, you changed from .strides()[i] to .stride(i) to use the stable ABI. However, the stable ABI's .stride(i) method validates the dimension index, while .strides()[i] accessed the array directly without      
  bounds checking.                                                                                                                                                                                                            
                                                                                                                                                                                                                              
  When rowwise=False, scales_rowwise is created as a 1D tensor with shape {0}. When the code calls scales_rowwise.stride(1), it fails because dimension index 1 is out of range for a 1D tensor (valid indices are only -1 and
   0).                                                                                                                                                                                                                        
                                                                                                                                                                                                                              
  The Fix (in torchao/csrc/cuda/mx_kernels/mxfp8_cuda.cu):                                                                                                                                                                    
                                                                                                                                                                                                                              
  Guard the stride accesses with dimension checks:                                                                                                                                                                            
                                                                                                                                                                                                                              
  // Get strides of scale ptrs (guard against 1D empty tensors when rowwise/colwise is false)                                                                                                                                 
  int64_t scale_rowwise_stride_dim0 = scales_rowwise.dim() >= 2 ? scales_rowwise.stride(0) : 0;                                                                                                                               
  int64_t scale_rowwise_stride_dim1 = scales_rowwise.dim() >= 2 ? scales_rowwise.stride(1) : 0;                                                                                                                               
  int64_t scale_colwise_stride_dim0 = scales_colwise.dim() >= 2 ? scales_colwise.stride(0) : 0;                                                                                                                               
  int64_t scale_colwise_stride_dim1 = scales_colwise.dim() >= 2 ? scales_colwise.stride(1) : 0;                                                                                                                               
                                                                                                                                                                                                                              
  This was actually a latent bug in the original code - the unstable ABI's .strides()[i] silently returned garbage/undefined values for out-of-bounds access. The stable ABI's stricter bounds checking exposed it. 

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me

"-O3",
"-DUSE_CUDA",
# define TORCH_TARGET_VERSION with min version 2.11 for ABI stable Float8_e8m0fnu
"-DTORCH_TARGET_VERSION=0x020b000000000000",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: updated the target version to 2.11 cause we need pytorch/pytorch#173669

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we add a guard for torch_version_at_least("2.11.0") as well?

@andrewor14 andrewor14 force-pushed the mxfp8-kernel-abi-stable branch from 5edc119 to c19db71 Compare March 19, 2026 22:00
@andrewor14 andrewor14 added the module: not user facing Use this tag if you don't want this PR to show up in release notes label Mar 19, 2026
@andrewor14 andrewor14 marked this pull request as ready for review March 19, 2026 22:01
@andrewor14 andrewor14 force-pushed the mxfp8-kernel-abi-stable branch 2 times, most recently from 8558dff to a4fca0c Compare March 23, 2026 21:20
Prompt:
```
I want you to make the mxfp8 kernels under torchao/csrc/cuda/mx_kernels ABI stable.

Read these for instructions:
pytorch/docs/source/notes/libtorch_stable_abi.md
cppdocs/_sources/stable.rst.txt

Use these files for an example:
Before: flash-attention/hopper/flash_api.cpp
After: flash-attention/hopper/flash_api_stable.cpp

Additional instructions:
Replace at::zeros with torch::stable::new_zeros
Replace at::empty with torch::stable::new_empty
Replace TORCH_CHECK with STD_TORCH_CHECK without changing the content of the check, just replace the function name
Replace c10::cuda::CUDAGuard with DeviceGuard
When calling tensor.stride(), no need to check for tensor.numel() > 0 beforehand
When using cudaStream_t, remember to include cuda_runtime.h
When calling aoti_torch_get_current_cuda_stream, get the device index from a tensor, not from the general torch::stable::accelerator::getCurrentDeviceIndex();

Don’t do these things:
Don’t use torch::stable::new_empty_strided since this API does not exist
Don’t define a dummy _C module that can be accessed from python
Don’t declare aoti_torch_get_current_cuda_stream, just include it from torch/csrc/inductor/aoti_torch/c/shim.h and add -DUSE_CUDA to both cxx and nvcc in setup.py
Don’t box kernels manually, just use TORCH_BOX
Don’t change switch statements into if else blocks
```

Follow-up prompts:
```
Revert the content changes of STD_TORCH_CHECK, just replace the function name

For the changes that involved tensor.data_ptr<some_dtype>() in the old code, revert the changes and replace it with tensor.const_data_ptr<some_dtype>()
```

Output:
```
● I've successfully made both files ABI stable. Here's a summary of the changes:

  Changes to mxfp8_cuda.cu:

  1. Updated includes: Replaced ATen and torch/extension.h with stable ABI headers (torch/csrc/stable/tensor.h, torch/csrc/inductor/aoti_torch/c/shim.h, torch/headeronly/core/ScalarType.h, torch/headeronly/util/Exception.h, torch/headeronly/util/shim_utils.h)
  2. Added cuda_runtime.h for cudaStream_t
  3. Changed tensor type: torch::Tensor → torch::stable::Tensor (with using declaration)
  4. Changed scalar type enums: torch::kFloat32 → torch::headeronly::ScalarType::Float, etc.
  5. Changed error checking: TORCH_CHECK → STD_TORCH_CHECK
  6. Changed CUDA stream access: at::cuda::getCurrentCUDAStream() → aoti_torch_get_current_cuda_stream(input.get_device_index(), &stream_ptr) with TORCH_ERROR_CODE_CHECK
  7. Fixed stride access: .strides()[i] → .stride(i) and .sizes()[i] → .size(i)

  Changes to mxfp8_extension.cpp:

  1. Updated includes: Replaced torch/library.h, ATen/ATen.h, ATen/cuda/CUDAContext.h, c10/cuda/CUDAGuard.h with stable headers
  2. Changed tensor type: at::Tensor → torch::stable::Tensor
  3. Changed DeviceGuard: c10::cuda::CUDAGuard → torch::stable::accelerator::DeviceGuard (with input.get_device_index() instead of input.device())
  4. Changed error checking: TORCH_CHECK → STD_TORCH_CHECK
  5. Changed tensor creation:
    - at::empty → torch::stable::new_empty
    - at::zeros → torch::stable::new_zeros
    - For strided tensors: Created with transposed dimensions then used torch::stable::transpose to get column-major layout
  6. Changed CUDA stream access: Used aoti_torch_get_current_cuda_stream with device index from tensor
  7. Changed library registration: TORCH_LIBRARY_IMPL → STABLE_TORCH_LIBRARY_IMPL with TORCH_BOX wrapper

  Changes to setup.py:

  1. Added -DUSE_CUDA to both cxx and nvcc compiler arguments for the mxfp8 extension
```
@andrewor14 andrewor14 force-pushed the mxfp8-kernel-abi-stable branch from a4fca0c to 0211669 Compare March 23, 2026 21:21
@andrewor14 andrewor14 merged commit aa9be68 into main Mar 23, 2026
27 of 33 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: not user facing Use this tag if you don't want this PR to show up in release notes topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories)

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants