-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Open
Labels
Description
I have a simple cutedsl kernel which calls cute.arch.warpgroup_reg_alloc(200). I found setmaxnreg.inc.sync.aligned.u32 200; in the generated PTX, but I cannot find USETMAXREG.TRY_ALLOC.CTAPOOL in the generate SASS code.
However, If I run the example examples/python/CuTeDSL/blackwell/fmha.py, I can find USETMAXREG.TRY_ALLOC.CTAPOOL in SASS.
I'm on GB200, cuda 13. Commandline:
CUTE_DSL_KEEP_PTX=1 CUTE_DSL_KEEP_CUBIN=1 python3 hello.py
Code:
import cutlass
import cutlass.cute as cute
import torch
from cutlass.cute.runtime import from_dlpack
from cuda.bindings.driver import CUstream
@cute.kernel
def kernel(
t_in: cute.Tensor,
t_out: cute.Tensor,
t_atomic: cute.Tensor,
NUM_ELEMS_PER_THREAD: cutlass.Constexpr,
NUM_ELEMS_PER_BLOCK: cutlass.Constexpr,
):
cute.arch.warpgroup_reg_alloc(200)
m, n = t_in.shape
tidx, tidy, tidz = cute.arch.thread_idx()
bidx, bidy, bidz = cute.arch.block_idx()
idx = tidx * NUM_ELEMS_PER_THREAD + bidx * NUM_ELEMS_PER_BLOCK
t_in_1d = cute.make_tensor(t_in.iterator + idx, cute.make_layout(NUM_ELEMS_PER_THREAD))
t_out_1d = cute.make_tensor(t_out.iterator + idx, cute.make_layout(NUM_ELEMS_PER_THREAD))
for i in range(NUM_ELEMS_PER_THREAD):
t_out_1d[i] = t_in_1d[i]
@cute.jit
def hello_world(
t_in: cute.Tensor, t_out: cute.Tensor, t_atomic: cute.Tensor, stream: CUstream = None
):
print("Hello World from host!")
NUM_ELEMS_PER_BLOCK = 512
NUM_ELEMS_PER_THREAD = 4
m, n = t_in.shape
# Launch kernel
kernel(t_in, t_out, t_atomic, NUM_ELEMS_PER_THREAD, NUM_ELEMS_PER_BLOCK).launch(
grid=(cute.ceil_div(m * n, NUM_ELEMS_PER_BLOCK), 1, 1), # Single thread block
block=(
NUM_ELEMS_PER_BLOCK // NUM_ELEMS_PER_THREAD,
1,
1,
), # One warp (32 threads) per thread block
stream=stream,
)
M, N = 2, 512
t_in = torch.randn(M, N, device="cuda", dtype=torch.float16)
t_out = torch.empty(M, N, device="cuda", dtype=torch.float16)
t_atomic = torch.zeros(1, device="cuda", dtype=torch.int32)
t_in_ = from_dlpack(t_in, assumed_align=16)
t_out_ = from_dlpack(t_out, assumed_align=16)
t_atomic_ = from_dlpack(t_atomic)
print("Compiling...")
s = torch.cuda.Stream() # Create a new stream.
hello_world_compiled = cute.compile(
hello_world, t_in_, t_out_, t_atomic_, stream=CUstream(s.cuda_stream)
)
# Run the pre-compiled version
print("Running compiled version...")
with torch.cuda.stream(s):
hello_world_compiled(t_in_, t_out_, t_atomic_, stream=CUstream(s.cuda_stream))
s.synchronize()