Skip to content

Implement persistent kernels #238

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 8, 2025
Merged

Implement persistent kernels #238

merged 1 commit into from
Jul 8, 2025

Conversation

jansel
Copy link
Contributor

@jansel jansel commented Jul 3, 2025

Stacked PRs (oldest at bottom):


Implement persistent kernels

Enabled with config["pid_type"]="persistent_blocked" or
"persistent_interleaved".

This also refactors much of the program id handling.

jansel added a commit that referenced this pull request Jul 3, 2025
Enabled with `config["pid_type"]="persistent_blocked"` or
`"persistent_interleaved"`.

This also refactors much of the program id handling.

stack-info: PR: #238, branch: jansel/stack/77
@jansel jansel force-pushed the jansel/stack/77 branch from 687ab9b to 0afa941 Compare July 3, 2025 20:47
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jul 3, 2025
@jansel
Copy link
Contributor Author

jansel commented Jul 3, 2025

Example output:

============================================================
PERSISTENT BLOCKED KERNEL
============================================================
from __future__ import annotations

import torch
import helion
import triton
import triton.language as tl

@triton.jit
def _simple_add_kernel(x, y, result, x_size_0, x_size_1, result_stride_0, result_stride_1, x_stride_0, x_stride_1, y_stride_0, y_stride_1, _NUM_SM: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
    total_pids = tl.cdiv(x_size_0, _BLOCK_SIZE_0) * tl.cdiv(x_size_1, _BLOCK_SIZE_1)
    block_size = tl.cdiv(total_pids, _NUM_SM)
    start_pid = tl.program_id(0) * block_size
    end_pid = tl.minimum(start_pid + block_size, total_pids)
    for virtual_pid in tl.range(start_pid, end_pid):
        num_blocks_0 = tl.cdiv(x_size_0, _BLOCK_SIZE_0)
        pid_0 = virtual_pid % num_blocks_0
        pid_1 = virtual_pid // num_blocks_0
        offset_0 = pid_0 * _BLOCK_SIZE_0
        indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
        mask_0 = indices_0 < x_size_0
        offset_1 = pid_1 * _BLOCK_SIZE_1
        indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
        mask_1 = indices_1 < x_size_1
        load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
        load_1 = tl.load(y + (indices_0[:, None] * y_stride_0 + indices_1[None, :] * y_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
        v_0 = load + load_1
        tl.store(result + (indices_0[:, None] * result_stride_0 + indices_1[None, :] * result_stride_1), v_0, mask_0[:, None] & mask_1[None, :])

def simple_add(x: torch.Tensor, y: torch.Tensor):
    result = x.new_empty(x.size())
    _NUM_SM = helion.runtime.get_num_sm(torch.device(type='cuda', index=0))
    _BLOCK_SIZE_0 = 32
    _BLOCK_SIZE_1 = 16
    _simple_add_kernel[_NUM_SM,](x, y, result, x.size(0), x.size(1), result.stride(0), result.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), _NUM_SM, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
    return result

def _simple_add_make_precompiler(x: torch.Tensor, y: torch.Tensor):
    result = x.new_empty(x.size())
    _NUM_SM = helion.runtime.get_num_sm(torch.device(type='cuda', index=0))
    _BLOCK_SIZE_0 = 32
    _BLOCK_SIZE_1 = 16
    from helion.runtime.precompile_shim import make_precompiler
    return make_precompiler(_simple_add_kernel)(x, y, result, x.size(0), x.size(1), result.stride(0), result.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), _NUM_SM, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)

============================================================
PERSISTENT INTERLEAVED KERNEL
============================================================
from __future__ import annotations

import torch
import helion
import triton
import triton.language as tl

@triton.jit
def _simple_add_kernel(x, y, result, x_size_0, x_size_1, result_stride_0, result_stride_1, x_stride_0, x_stride_1, y_stride_0, y_stride_1, _NUM_SM: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
    total_pids = tl.cdiv(x_size_0, _BLOCK_SIZE_0) * tl.cdiv(x_size_1, _BLOCK_SIZE_1)
    for virtual_pid in tl.range(tl.program_id(0), total_pids, _NUM_SM):
        num_blocks_0 = tl.cdiv(x_size_0, _BLOCK_SIZE_0)
        pid_0 = virtual_pid % num_blocks_0
        pid_1 = virtual_pid // num_blocks_0
        offset_0 = pid_0 * _BLOCK_SIZE_0
        indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
        mask_0 = indices_0 < x_size_0
        offset_1 = pid_1 * _BLOCK_SIZE_1
        indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
        mask_1 = indices_1 < x_size_1
        load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
        load_1 = tl.load(y + (indices_0[:, None] * y_stride_0 + indices_1[None, :] * y_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
        v_0 = load + load_1
        tl.store(result + (indices_0[:, None] * result_stride_0 + indices_1[None, :] * result_stride_1), v_0, mask_0[:, None] & mask_1[None, :])

def simple_add(x: torch.Tensor, y: torch.Tensor):
    result = x.new_empty(x.size())
    _NUM_SM = helion.runtime.get_num_sm(torch.device(type='cuda', index=0))
    _BLOCK_SIZE_0 = 32
    _BLOCK_SIZE_1 = 16
    _simple_add_kernel[_NUM_SM,](x, y, result, x.size(0), x.size(1), result.stride(0), result.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), _NUM_SM, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
    return result

def _simple_add_make_precompiler(x: torch.Tensor, y: torch.Tensor):
    result = x.new_empty(x.size())
    _NUM_SM = helion.runtime.get_num_sm(torch.device(type='cuda', index=0))
    _BLOCK_SIZE_0 = 32
    _BLOCK_SIZE_1 = 16
    from helion.runtime.precompile_shim import make_precompiler
    return make_precompiler(_simple_add_kernel)(x, y, result, x.size(0), x.size(1), result.stride(0), result.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), _NUM_SM, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)

============================================================
PERSISTENT BLOCKED WITH FOR EACH PROGRAM ID (Multi-loop)
============================================================
from __future__ import annotations

import torch
import helion
import triton
import triton.language as tl

@triton.jit
def _multi_loop_example_kernel(x, result1, y, result2, x_size_0, x_size_1, y_size_0, y_size_1, result1_stride_0, result1_stride_1, result2_stride_0, result2_stride_1, x_stride_0, x_stride_1, y_stride_0, y_stride_1, _NUM_SM: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr):
    total_pids = tl.cdiv(x_size_0, _BLOCK_SIZE_0) * tl.cdiv(x_size_1, _BLOCK_SIZE_1) + tl.cdiv(y_size_0, _BLOCK_SIZE_2) * tl.cdiv(y_size_1, _BLOCK_SIZE_3)
    block_size = tl.cdiv(total_pids, _NUM_SM)
    start_pid = tl.program_id(0) * block_size
    end_pid = tl.minimum(start_pid + block_size, total_pids)
    for virtual_pid in tl.range(start_pid, end_pid):
        pid_shared = virtual_pid
        if pid_shared < tl.cdiv(x_size_0, _BLOCK_SIZE_0) * tl.cdiv(x_size_1, _BLOCK_SIZE_1):
            num_blocks_0 = tl.cdiv(x_size_0, _BLOCK_SIZE_0)
            pid_0 = pid_shared % num_blocks_0
            pid_1 = pid_shared // num_blocks_0
            offset_0 = pid_0 * _BLOCK_SIZE_0
            indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
            mask_0 = indices_0 < x_size_0
            offset_1 = pid_1 * _BLOCK_SIZE_1
            indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
            mask_1 = indices_1 < x_size_1
            load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
            v_0 = 2.0
            v_1 = load * v_0
            tl.store(result1 + (indices_0[:, None] * result1_stride_0 + indices_1[None, :] * result1_stride_1), v_1, mask_0[:, None] & mask_1[None, :])
        else:
            pid_shared -= tl.cdiv(x_size_0, _BLOCK_SIZE_0) * tl.cdiv(x_size_1, _BLOCK_SIZE_1)
            num_blocks_1 = tl.cdiv(y_size_0, _BLOCK_SIZE_2)
            pid_2 = pid_shared % num_blocks_1
            pid_3 = pid_shared // num_blocks_1
            offset_2 = pid_2 * _BLOCK_SIZE_2
            indices_2 = (offset_2 + tl.arange(0, _BLOCK_SIZE_2)).to(tl.int32)
            mask_2 = indices_2 < y_size_0
            offset_3 = pid_3 * _BLOCK_SIZE_3
            indices_3 = (offset_3 + tl.arange(0, _BLOCK_SIZE_3)).to(tl.int32)
            mask_3 = indices_3 < y_size_1
            load_1 = tl.load(y + (indices_2[:, None] * y_stride_0 + indices_3[None, :] * y_stride_1), mask_2[:, None] & mask_3[None, :], other=0)
            v_2 = 3.0
            v_3 = load_1 * v_2
            tl.store(result2 + (indices_2[:, None] * result2_stride_0 + indices_3[None, :] * result2_stride_1), v_3, mask_2[:, None] & mask_3[None, :])

def multi_loop_example(x: torch.Tensor, y: torch.Tensor):
    result1 = x.new_empty(x.size())
    result2 = y.new_empty(y.size())
    _NUM_SM = helion.runtime.get_num_sm(torch.device(type='cuda', index=0))
    _BLOCK_SIZE_0 = 32
    _BLOCK_SIZE_1 = 16
    _BLOCK_SIZE_2 = 32
    _BLOCK_SIZE_3 = 16
    _multi_loop_example_kernel[_NUM_SM,](x, result1, y, result2, x.size(0), x.size(1), y.size(0), y.size(1), result1.stride(0), result1.stride(1), result2.stride(0), result2.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), _NUM_SM, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, num_warps=4, num_stages=3)
    return (result1, result2)

def _multi_loop_example_make_precompiler(x: torch.Tensor, y: torch.Tensor):
    result1 = x.new_empty(x.size())
    result2 = y.new_empty(y.size())
    _NUM_SM = helion.runtime.get_num_sm(torch.device(type='cuda', index=0))
    _BLOCK_SIZE_0 = 32
    _BLOCK_SIZE_1 = 16
    _BLOCK_SIZE_2 = 32
    _BLOCK_SIZE_3 = 16
    from helion.runtime.precompile_shim import make_precompiler
    return make_precompiler(_multi_loop_example_kernel)(x, result1, y, result2, x.size(0), x.size(1), y.size(0), y.size(1), result1.stride(0), result1.stride(1), result2.stride(0), result2.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), _NUM_SM, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, num_warps=4, num_stages=3)

============================================================
PERSISTENT INTERLEAVED WITH FOR EACH PROGRAM ID (Multi-loop)
============================================================
from __future__ import annotations

import torch
import helion
import triton
import triton.language as tl

@triton.jit
def _multi_loop_example_kernel(x, result1, y, result2, x_size_0, x_size_1, y_size_0, y_size_1, result1_stride_0, result1_stride_1, result2_stride_0, result2_stride_1, x_stride_0, x_stride_1, y_stride_0, y_stride_1, _NUM_SM: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr):
    total_pids = tl.cdiv(x_size_0, _BLOCK_SIZE_0) * tl.cdiv(x_size_1, _BLOCK_SIZE_1) + tl.cdiv(y_size_0, _BLOCK_SIZE_2) * tl.cdiv(y_size_1, _BLOCK_SIZE_3)
    for virtual_pid in tl.range(tl.program_id(0), total_pids, _NUM_SM):
        pid_shared = virtual_pid
        if pid_shared < tl.cdiv(x_size_0, _BLOCK_SIZE_0) * tl.cdiv(x_size_1, _BLOCK_SIZE_1):
            num_blocks_0 = tl.cdiv(x_size_0, _BLOCK_SIZE_0)
            pid_0 = pid_shared % num_blocks_0
            pid_1 = pid_shared // num_blocks_0
            offset_0 = pid_0 * _BLOCK_SIZE_0
            indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
            mask_0 = indices_0 < x_size_0
            offset_1 = pid_1 * _BLOCK_SIZE_1
            indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
            mask_1 = indices_1 < x_size_1
            load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
            v_0 = 2.0
            v_1 = load * v_0
            tl.store(result1 + (indices_0[:, None] * result1_stride_0 + indices_1[None, :] * result1_stride_1), v_1, mask_0[:, None] & mask_1[None, :])
        else:
            pid_shared -= tl.cdiv(x_size_0, _BLOCK_SIZE_0) * tl.cdiv(x_size_1, _BLOCK_SIZE_1)
            num_blocks_1 = tl.cdiv(y_size_0, _BLOCK_SIZE_2)
            pid_2 = pid_shared % num_blocks_1
            pid_3 = pid_shared // num_blocks_1
            offset_2 = pid_2 * _BLOCK_SIZE_2
            indices_2 = (offset_2 + tl.arange(0, _BLOCK_SIZE_2)).to(tl.int32)
            mask_2 = indices_2 < y_size_0
            offset_3 = pid_3 * _BLOCK_SIZE_3
            indices_3 = (offset_3 + tl.arange(0, _BLOCK_SIZE_3)).to(tl.int32)
            mask_3 = indices_3 < y_size_1
            load_1 = tl.load(y + (indices_2[:, None] * y_stride_0 + indices_3[None, :] * y_stride_1), mask_2[:, None] & mask_3[None, :], other=0)
            v_2 = 3.0
            v_3 = load_1 * v_2
            tl.store(result2 + (indices_2[:, None] * result2_stride_0 + indices_3[None, :] * result2_stride_1), v_3, mask_2[:, None] & mask_3[None, :])

def multi_loop_example(x: torch.Tensor, y: torch.Tensor):
    result1 = x.new_empty(x.size())
    result2 = y.new_empty(y.size())
    _NUM_SM = helion.runtime.get_num_sm(torch.device(type='cuda', index=0))
    _BLOCK_SIZE_0 = 32
    _BLOCK_SIZE_1 = 16
    _BLOCK_SIZE_2 = 32
    _BLOCK_SIZE_3 = 16
    _multi_loop_example_kernel[_NUM_SM,](x, result1, y, result2, x.size(0), x.size(1), y.size(0), y.size(1), result1.stride(0), result1.stride(1), result2.stride(0), result2.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), _NUM_SM, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, num_warps=4, num_stages=3)
    return (result1, result2)

def _multi_loop_example_make_precompiler(x: torch.Tensor, y: torch.Tensor):
    result1 = x.new_empty(x.size())
    result2 = y.new_empty(y.size())
    _NUM_SM = helion.runtime.get_num_sm(torch.device(type='cuda', index=0))
    _BLOCK_SIZE_0 = 32
    _BLOCK_SIZE_1 = 16
    _BLOCK_SIZE_2 = 32
    _BLOCK_SIZE_3 = 16
    from helion.runtime.precompile_shim import make_precompiler
    return make_precompiler(_multi_loop_example_kernel)(x, result1, y, result2, x.size(0), x.size(1), y.size(0), y.size(1), result1.stride(0), result1.stride(1), result2.stride(0), result2.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), _NUM_SM, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, num_warps=4, num_stages=3)

============================================================
PERSISTENT BLOCKED WITH L2 GROUPING
============================================================
from __future__ import annotations

import torch
import helion
import triton
import triton.language as tl

@triton.jit
def _simple_add_kernel(x, y, result, x_size_0, x_size_1, result_stride_0, result_stride_1, x_stride_0, x_stride_1, y_stride_0, y_stride_1, _NUM_SM: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
    total_pids = tl.cdiv(x_size_0, _BLOCK_SIZE_0) * tl.cdiv(x_size_1, _BLOCK_SIZE_1)
    block_size = tl.cdiv(total_pids, _NUM_SM)
    start_pid = tl.program_id(0) * block_size
    end_pid = tl.minimum(start_pid + block_size, total_pids)
    for virtual_pid in tl.range(start_pid, end_pid):
        num_pid_m = tl.cdiv(x_size_0, _BLOCK_SIZE_0)
        num_pid_n = tl.cdiv(x_size_1, _BLOCK_SIZE_1)
        num_pid_in_group = 8 * num_pid_n
        group_id = virtual_pid // num_pid_in_group
        first_pid_m = group_id * 8
        group_size_m = min(num_pid_m - first_pid_m, 8)
        pid_0 = first_pid_m + virtual_pid % num_pid_in_group % group_size_m
        pid_1 = virtual_pid % num_pid_in_group // group_size_m
        offset_0 = pid_0 * _BLOCK_SIZE_0
        indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
        mask_0 = indices_0 < x_size_0
        offset_1 = pid_1 * _BLOCK_SIZE_1
        indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
        mask_1 = indices_1 < x_size_1
        load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
        load_1 = tl.load(y + (indices_0[:, None] * y_stride_0 + indices_1[None, :] * y_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
        v_0 = load + load_1
        tl.store(result + (indices_0[:, None] * result_stride_0 + indices_1[None, :] * result_stride_1), v_0, mask_0[:, None] & mask_1[None, :])

def simple_add(x: torch.Tensor, y: torch.Tensor):
    result = x.new_empty(x.size())
    _NUM_SM = helion.runtime.get_num_sm(torch.device(type='cuda', index=0))
    _BLOCK_SIZE_0 = 32
    _BLOCK_SIZE_1 = 16
    _simple_add_kernel[_NUM_SM,](x, y, result, x.size(0), x.size(1), result.stride(0), result.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), _NUM_SM, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
    return result

def _simple_add_make_precompiler(x: torch.Tensor, y: torch.Tensor):
    result = x.new_empty(x.size())
    _NUM_SM = helion.runtime.get_num_sm(torch.device(type='cuda', index=0))
    _BLOCK_SIZE_0 = 32
    _BLOCK_SIZE_1 = 16
    from helion.runtime.precompile_shim import make_precompiler
    return make_precompiler(_simple_add_kernel)(x, y, result, x.size(0), x.size(1), result.stride(0), result.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), _NUM_SM, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)

Enabled with `config["pid_type"]="persistent_blocked"` or
`"persistent_interleaved"`.

This also refactors much of the program id handling.

stack-info: PR: #238, branch: jansel/stack/77
@jansel jansel merged commit 5c8e35b into main Jul 8, 2025
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants