Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QST] How to Let __launch_bounds__ and setmaxnreg Work with Each Other? #2007

Open
Maximilianxu opened this issue Dec 23, 2024 · 2 comments
Open

Comments

@Maximilianxu
Copy link

Background

In a dense fp16 GeMM on H800, I have: tile size 192x128, 3 warpgroups with WG1&2 be the cooperative consumer warpgroups.
In detail, the producer wapgroup uses cutlass::arch::warpgroup_reg_dealloc<24>(); while the consumers use
cutlass::arch::warpgroup_reg_alloc<232>(); to set warpgroup-level reg count hint.

During compilation, the compiler shows ptxas info : Used 122 registers. And the kernel runs well.

What is your question?

Based on that, I add a __launch_bounds__(384, 1) hint for the kernel, the compiler shows ptxas info : Used 168 registers which is kind of expected.

However, after launching the kernel, it hangs at cutlass::arch::warpgroup_reg_alloc<232>(); and some warpgroups cannot proceed and the wgmma cannot be issued.

Another thing is that when I change the consumer reg count to cutlass::arch::warpgroup_reg_alloc<168>();, the kernel runs well. But if I increase this value, the kernel hangs.

The strange thing is that, I found that FA3 https://github.com/Dao-AILab/flash-attention/blob/0dfb28174333d9eefb7c1dd4292690a8458d1e89/hopper/flash_fwd_kernel.h#L28 also uses this method.

How to understand such behavior? Can we dump more info during compilation?

@thakkarV
Copy link
Collaborator

2322 + 24 != 1683

Change the 24 to 40 and it should work

@Maximilianxu
Copy link
Author

232_2 + 24 != 168_3

Change the 24 to 40 and it should work

I changed it to 232 * 2 + 40 = 168 * 3 configuration, but still hangs there.

The same phenomenon happened, only decreasing the consumer reg count to 168 works and even 176 will hang.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants