Skip to content

[QST] cute.arch.warpgroup_reg_alloc in PTX, but not in SASS #2927

@ipiszy-x

Description

@ipiszy-x

I have a simple cutedsl kernel which calls cute.arch.warpgroup_reg_alloc(200). I found setmaxnreg.inc.sync.aligned.u32 200; in the generated PTX, but I cannot find USETMAXREG.TRY_ALLOC.CTAPOOL in the generate SASS code.

However, If I run the example examples/python/CuTeDSL/blackwell/fmha.py, I can find USETMAXREG.TRY_ALLOC.CTAPOOL in SASS.

I'm on GB200, cuda 13. Commandline:

CUTE_DSL_KEEP_PTX=1 CUTE_DSL_KEEP_CUBIN=1 python3 hello.py 

Code:

import cutlass
import cutlass.cute as cute
import torch

from cutlass.cute.runtime import from_dlpack
from cuda.bindings.driver import CUstream


@cute.kernel
def kernel(
    t_in: cute.Tensor,
    t_out: cute.Tensor,
    t_atomic: cute.Tensor,
    NUM_ELEMS_PER_THREAD: cutlass.Constexpr,
    NUM_ELEMS_PER_BLOCK: cutlass.Constexpr,
):
    cute.arch.warpgroup_reg_alloc(200)

    m, n = t_in.shape
    tidx, tidy, tidz = cute.arch.thread_idx()
    bidx, bidy, bidz = cute.arch.block_idx()
    idx = tidx * NUM_ELEMS_PER_THREAD + bidx * NUM_ELEMS_PER_BLOCK
    t_in_1d = cute.make_tensor(t_in.iterator + idx, cute.make_layout(NUM_ELEMS_PER_THREAD))
    t_out_1d = cute.make_tensor(t_out.iterator + idx, cute.make_layout(NUM_ELEMS_PER_THREAD))

    for i in range(NUM_ELEMS_PER_THREAD):
        t_out_1d[i] = t_in_1d[i]

@cute.jit
def hello_world(
    t_in: cute.Tensor, t_out: cute.Tensor, t_atomic: cute.Tensor, stream: CUstream = None
):
    print("Hello World from host!")

    NUM_ELEMS_PER_BLOCK = 512
    NUM_ELEMS_PER_THREAD = 4

    m, n = t_in.shape

    # Launch kernel
    kernel(t_in, t_out, t_atomic, NUM_ELEMS_PER_THREAD, NUM_ELEMS_PER_BLOCK).launch(
        grid=(cute.ceil_div(m * n, NUM_ELEMS_PER_BLOCK), 1, 1),  # Single thread block
        block=(
            NUM_ELEMS_PER_BLOCK // NUM_ELEMS_PER_THREAD,
            1,
            1,
        ),  # One warp (32 threads) per thread block
        stream=stream,
    )


M, N = 2, 512
t_in = torch.randn(M, N, device="cuda", dtype=torch.float16)
t_out = torch.empty(M, N, device="cuda", dtype=torch.float16)
t_atomic = torch.zeros(1, device="cuda", dtype=torch.int32)

t_in_ = from_dlpack(t_in, assumed_align=16)
t_out_ = from_dlpack(t_out, assumed_align=16)
t_atomic_ = from_dlpack(t_atomic)

print("Compiling...")
s = torch.cuda.Stream()  # Create a new stream.
hello_world_compiled = cute.compile(
    hello_world, t_in_, t_out_, t_atomic_, stream=CUstream(s.cuda_stream)
)

# Run the pre-compiled version
print("Running compiled version...")
with torch.cuda.stream(s):
    hello_world_compiled(t_in_, t_out_, t_atomic_, stream=CUstream(s.cuda_stream))
    s.synchronize()

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions