Skip to content

Conversation

@andrewor14
Copy link
Contributor

@andrewor14 andrewor14 commented Jan 9, 2026

Prompt:

Make these two files ABI stable:
torchao/csrc/cuda/mx_kernels/mxfp8_cuda.cu
torchao/csrc/cuda/mx_kernels/mxfp8_extension.cpp

Read these for instructions:
pytorch/docs/source/notes/libtorch_stable_abi.md
cppdocs/_sources/stable.rst.txt

Use these files for an example:
Before: flash-attention/hopper/flash_api.cpp
After: flash-attention/hopper/flash_api_stable.cpp

Additional instructions:
Replace at::empty with torch::stable::empty
Replace at::zeros with torch::stable::new_zeros
No need to touch function calls that are already stable, like tensor.sizes()
When calling tensor.stride(), no need to check for tensor.numel() > 0 beforehand
When replacing TORCH_CHECK with STD_TORCH_CHECK, don’t touch the content of the check, just replace the function name
Don’t make a PyModule
No need to box kernels manually, just use TORCH_BOX
Minimize unnecessary/unrelated style changes

Claude output:

● Both files have been updated to be ABI stable. Here's a summary of the changes:

  mxfp8_cuda.cu:
  - Replaced #include <ATen/cuda/CUDAContext.h> and #include <torch/extension.h> with stable headers (torch/csrc/stable/tensor.h, torch/csrc/inductor/aoti_torch/c/shim.h, torch/headeronly/core/ScalarType.h, torch/headeronly/util/Exception.h)
  - Changed torch::Tensor to torch::stable::Tensor (via using declaration)
  - Replaced torch::kFloat32, etc. with torch::headeronly::ScalarType::Float, etc.
  - Replaced TORCH_CHECK with STD_TORCH_CHECK
  - Replaced at::cuda::getCurrentCUDAStream() with the stable C shim API (aoti_torch_get_current_cuda_stream)

  mxfp8_extension.cpp:
  - Replaced all ATen includes with stable headers
  - Changed at::Tensor to torch::stable::Tensor
  - Replaced at::empty with torch::stable::empty
  - Replaced at::zeros with torch::stable::new_zeros
  - Replaced at::empty_strided patterns with torch::stable::empty + torch::stable::transpose to achieve the same strided layout
  - Replaced TORCH_CHECK with STD_TORCH_CHECK
  - Replaced c10::cuda::CUDAGuard with torch::stable::accelerator::DeviceGuard
  - Replaced TORCH_LIBRARY_IMPL with STABLE_TORCH_LIBRARY_IMPL
  - Wrapped kernel implementations with TORCH_BOX
  - Replaced at::cuda::getCurrentCUDAStream() with stable C shim API

@pytorch-bot
Copy link

pytorch-bot bot commented Jan 9, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3610

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 2b13608 with merge base b78bd40 (image):

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jan 9, 2026
@andrewor14 andrewor14 requested a review from janeyx99 January 9, 2026 16:26
@andrewor14 andrewor14 changed the title [draft] Make MXFP8 cuda kernel ABI stable [claude] Make MXFP8 cuda kernel ABI stable Jan 9, 2026
@andrewor14 andrewor14 force-pushed the mxfp8-kernel-abi-stable branch from 2f3e96d to 7b28a33 Compare January 9, 2026 16:26
@andrewor14 andrewor14 added the topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories) label Jan 9, 2026
Copy link
Contributor

@janeyx99 janeyx99 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The best way to test this is by setting up a build where you can pass in -DTORCH_TARGET_VERSION and build the kernel + ensure it builds and runs. I'd recommend setting up a separate stable wheel for the migration to make testing easier as you slowly migrate kernels to the stable wheel

Edit: I just learned that you only have 3 files to migrate. In that case you do have the option of just biting the bullet and changing the setup.py in torchao to add this flag: https://github.com/Dao-AILab/flash-attention/pull/2155/files.

Note you want 20a0 and not 2090 though, for 2.10.

@andrewor14 andrewor14 marked this pull request as draft January 9, 2026 16:31
@andrewor14 andrewor14 force-pushed the mxfp8-kernel-abi-stable branch from 7b28a33 to 3784e4f Compare January 9, 2026 18:57
Prompt:
```
Make these two files ABI stable:
torchao/csrc/cuda/mx_kernels/mxfp8_cuda.cu
torchao/csrc/cuda/mx_kernels/mxfp8_extension.cpp

Read these for instructions:
pytorch/docs/source/notes/libtorch_stable_abi.md
cppdocs/_sources/stable.rst.txt

Use these files for an example:
Before: flash-attention/hopper/flash_api.cpp
After: flash-attention/hopper/flash_api_stable.cpp

Additional instructions:
Replace at::empty with torch::stable::empty
Replace at::zeros with torch::stable::new_zeros
No need to touch function calls that are already stable, like tensor.sizes()
When calling tensor.stride(), no need to check for tensor.numel() > 0 beforehand
When replacing TORCH_CHECK with STD_TORCH_CHECK, don’t touch the content of the check, just replace the function name
Don’t make a PyModule
No need to box kernels manually, just use TORCH_BOX
Minimize unnecessary/unrelated style changes
```

Claude output:
```
● Both files have been updated to be ABI stable. Here's a summary of the changes:

  mxfp8_cuda.cu:
  - Replaced #include <ATen/cuda/CUDAContext.h> and #include <torch/extension.h> with stable headers (torch/csrc/stable/tensor.h, torch/csrc/inductor/aoti_torch/c/shim.h, torch/headeronly/core/ScalarType.h, torch/headeronly/util/Exception.h)
  - Changed torch::Tensor to torch::stable::Tensor (via using declaration)
  - Replaced torch::kFloat32, etc. with torch::headeronly::ScalarType::Float, etc.
  - Replaced TORCH_CHECK with STD_TORCH_CHECK
  - Replaced at::cuda::getCurrentCUDAStream() with the stable C shim API (aoti_torch_get_current_cuda_stream)

  mxfp8_extension.cpp:
  - Replaced all ATen includes with stable headers
  - Changed at::Tensor to torch::stable::Tensor
  - Replaced at::empty with torch::stable::empty
  - Replaced at::zeros with torch::stable::new_zeros
  - Replaced at::empty_strided patterns with torch::stable::empty + torch::stable::transpose to achieve the same strided layout
  - Replaced TORCH_CHECK with STD_TORCH_CHECK
  - Replaced c10::cuda::CUDAGuard with torch::stable::accelerator::DeviceGuard
  - Replaced TORCH_LIBRARY_IMPL with STABLE_TORCH_LIBRARY_IMPL
  - Wrapped kernel implementations with TORCH_BOX
  - Replaced at::cuda::getCurrentCUDAStream() with stable C shim API
```
@andrewor14 andrewor14 force-pushed the mxfp8-kernel-abi-stable branch from 3784e4f to 2b13608 Compare January 9, 2026 20:29
@andrewor14 andrewor14 changed the title [claude] Make MXFP8 cuda kernel ABI stable [claude] Make MXFP8 cuda kernels ABI stable Jan 9, 2026
@andrewor14 andrewor14 requested a review from janeyx99 January 9, 2026 20:36
#include <torch/headeronly/util/Exception.h>

// Declare the CUDA stream function that's behind #ifdef USE_CUDA in shim.h
extern "C" AOTITorchError aoti_torch_get_current_cuda_stream(int32_t device_index, void** ret_stream);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think you need this, you should just pass -DUSE_CUDA as a flag in setup.py.

Though hmm @mikaylagawarecki should we automatically pass this in through CUDAExtension for better UX?

Copy link

@mikaylagawarecki mikaylagawarecki Jan 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This or -DUSE_CUDA would both work

tbh I'm not sure that silently defining USE_CUDA in CUDAExtension is good UX 🤔 (Where did USE_CUDA originate from? Is it a pytorch thing or a cuda thing? I'm wondering whether users might also define USE_CUDA themselves for separate purposes hmm, wdyt?)

case torch::kFloat32:
DType get_input_dtype(const Tensor &t) {
auto scalar_type = t.scalar_type();
if (scalar_type == torch::headeronly::ScalarType::Float) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you should still be able to use a switch statement i believe, but no biggie

const int64_t num_col_blocks = (cols + scale_dim_x - 1) / scale_dim_x;
output_rowwise = at::empty({rows, cols}, options_fp8);
scales_rowwise = at::empty({rows, num_col_blocks}, options_scale);
output_rowwise = torch::stable::empty(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

new_empty?

here and below

Copy link
Contributor

@janeyx99 janeyx99 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, this looks close. Here's when I'd recommend just building the wheel and letting the error messages and test cases get you 100% there.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories)

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants