-
Notifications
You must be signed in to change notification settings - Fork 395
[claude] Make MXFP8 cuda kernels ABI stable #3610
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3610
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit 2b13608 with merge base b78bd40 ( BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
2f3e96d to
7b28a33
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The best way to test this is by setting up a build where you can pass in -DTORCH_TARGET_VERSION and build the kernel + ensure it builds and runs. I'd recommend setting up a separate stable wheel for the migration to make testing easier as you slowly migrate kernels to the stable wheel
Edit: I just learned that you only have 3 files to migrate. In that case you do have the option of just biting the bullet and changing the setup.py in torchao to add this flag: https://github.com/Dao-AILab/flash-attention/pull/2155/files.
Note you want 20a0 and not 2090 though, for 2.10.
7b28a33 to
3784e4f
Compare
Prompt: ``` Make these two files ABI stable: torchao/csrc/cuda/mx_kernels/mxfp8_cuda.cu torchao/csrc/cuda/mx_kernels/mxfp8_extension.cpp Read these for instructions: pytorch/docs/source/notes/libtorch_stable_abi.md cppdocs/_sources/stable.rst.txt Use these files for an example: Before: flash-attention/hopper/flash_api.cpp After: flash-attention/hopper/flash_api_stable.cpp Additional instructions: Replace at::empty with torch::stable::empty Replace at::zeros with torch::stable::new_zeros No need to touch function calls that are already stable, like tensor.sizes() When calling tensor.stride(), no need to check for tensor.numel() > 0 beforehand When replacing TORCH_CHECK with STD_TORCH_CHECK, don’t touch the content of the check, just replace the function name Don’t make a PyModule No need to box kernels manually, just use TORCH_BOX Minimize unnecessary/unrelated style changes ``` Claude output: ``` ● Both files have been updated to be ABI stable. Here's a summary of the changes: mxfp8_cuda.cu: - Replaced #include <ATen/cuda/CUDAContext.h> and #include <torch/extension.h> with stable headers (torch/csrc/stable/tensor.h, torch/csrc/inductor/aoti_torch/c/shim.h, torch/headeronly/core/ScalarType.h, torch/headeronly/util/Exception.h) - Changed torch::Tensor to torch::stable::Tensor (via using declaration) - Replaced torch::kFloat32, etc. with torch::headeronly::ScalarType::Float, etc. - Replaced TORCH_CHECK with STD_TORCH_CHECK - Replaced at::cuda::getCurrentCUDAStream() with the stable C shim API (aoti_torch_get_current_cuda_stream) mxfp8_extension.cpp: - Replaced all ATen includes with stable headers - Changed at::Tensor to torch::stable::Tensor - Replaced at::empty with torch::stable::empty - Replaced at::zeros with torch::stable::new_zeros - Replaced at::empty_strided patterns with torch::stable::empty + torch::stable::transpose to achieve the same strided layout - Replaced TORCH_CHECK with STD_TORCH_CHECK - Replaced c10::cuda::CUDAGuard with torch::stable::accelerator::DeviceGuard - Replaced TORCH_LIBRARY_IMPL with STABLE_TORCH_LIBRARY_IMPL - Wrapped kernel implementations with TORCH_BOX - Replaced at::cuda::getCurrentCUDAStream() with stable C shim API ```
3784e4f to
2b13608
Compare
| #include <torch/headeronly/util/Exception.h> | ||
|
|
||
| // Declare the CUDA stream function that's behind #ifdef USE_CUDA in shim.h | ||
| extern "C" AOTITorchError aoti_torch_get_current_cuda_stream(int32_t device_index, void** ret_stream); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think you need this, you should just pass -DUSE_CUDA as a flag in setup.py.
Though hmm @mikaylagawarecki should we automatically pass this in through CUDAExtension for better UX?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This or -DUSE_CUDA would both work
tbh I'm not sure that silently defining USE_CUDA in CUDAExtension is good UX 🤔 (Where did USE_CUDA originate from? Is it a pytorch thing or a cuda thing? I'm wondering whether users might also define USE_CUDA themselves for separate purposes hmm, wdyt?)
| case torch::kFloat32: | ||
| DType get_input_dtype(const Tensor &t) { | ||
| auto scalar_type = t.scalar_type(); | ||
| if (scalar_type == torch::headeronly::ScalarType::Float) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you should still be able to use a switch statement i believe, but no biggie
| const int64_t num_col_blocks = (cols + scale_dim_x - 1) / scale_dim_x; | ||
| output_rowwise = at::empty({rows, cols}, options_fp8); | ||
| scales_rowwise = at::empty({rows, num_col_blocks}, options_scale); | ||
| output_rowwise = torch::stable::empty( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
new_empty?
here and below
janeyx99
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea, this looks close. Here's when I'd recommend just building the wheel and letting the error messages and test cases get you 100% there.
Prompt:
Claude output: