-
Notifications
You must be signed in to change notification settings - Fork 14
Implement persistent kernels #238
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
jansel
added a commit
that referenced
this pull request
Jul 3, 2025
Enabled with `config["pid_type"]="persistent_blocked"` or `"persistent_interleaved"`. This also refactors much of the program id handling. stack-info: PR: #238, branch: jansel/stack/77
Example output: ============================================================
PERSISTENT BLOCKED KERNEL
============================================================
from __future__ import annotations
import torch
import helion
import triton
import triton.language as tl
@triton.jit
def _simple_add_kernel(x, y, result, x_size_0, x_size_1, result_stride_0, result_stride_1, x_stride_0, x_stride_1, y_stride_0, y_stride_1, _NUM_SM: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
total_pids = tl.cdiv(x_size_0, _BLOCK_SIZE_0) * tl.cdiv(x_size_1, _BLOCK_SIZE_1)
block_size = tl.cdiv(total_pids, _NUM_SM)
start_pid = tl.program_id(0) * block_size
end_pid = tl.minimum(start_pid + block_size, total_pids)
for virtual_pid in tl.range(start_pid, end_pid):
num_blocks_0 = tl.cdiv(x_size_0, _BLOCK_SIZE_0)
pid_0 = virtual_pid % num_blocks_0
pid_1 = virtual_pid // num_blocks_0
offset_0 = pid_0 * _BLOCK_SIZE_0
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
mask_0 = indices_0 < x_size_0
offset_1 = pid_1 * _BLOCK_SIZE_1
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
mask_1 = indices_1 < x_size_1
load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
load_1 = tl.load(y + (indices_0[:, None] * y_stride_0 + indices_1[None, :] * y_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
v_0 = load + load_1
tl.store(result + (indices_0[:, None] * result_stride_0 + indices_1[None, :] * result_stride_1), v_0, mask_0[:, None] & mask_1[None, :])
def simple_add(x: torch.Tensor, y: torch.Tensor):
result = x.new_empty(x.size())
_NUM_SM = helion.runtime.get_num_sm(torch.device(type='cuda', index=0))
_BLOCK_SIZE_0 = 32
_BLOCK_SIZE_1 = 16
_simple_add_kernel[_NUM_SM,](x, y, result, x.size(0), x.size(1), result.stride(0), result.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), _NUM_SM, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
return result
def _simple_add_make_precompiler(x: torch.Tensor, y: torch.Tensor):
result = x.new_empty(x.size())
_NUM_SM = helion.runtime.get_num_sm(torch.device(type='cuda', index=0))
_BLOCK_SIZE_0 = 32
_BLOCK_SIZE_1 = 16
from helion.runtime.precompile_shim import make_precompiler
return make_precompiler(_simple_add_kernel)(x, y, result, x.size(0), x.size(1), result.stride(0), result.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), _NUM_SM, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
============================================================
PERSISTENT INTERLEAVED KERNEL
============================================================
from __future__ import annotations
import torch
import helion
import triton
import triton.language as tl
@triton.jit
def _simple_add_kernel(x, y, result, x_size_0, x_size_1, result_stride_0, result_stride_1, x_stride_0, x_stride_1, y_stride_0, y_stride_1, _NUM_SM: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
total_pids = tl.cdiv(x_size_0, _BLOCK_SIZE_0) * tl.cdiv(x_size_1, _BLOCK_SIZE_1)
for virtual_pid in tl.range(tl.program_id(0), total_pids, _NUM_SM):
num_blocks_0 = tl.cdiv(x_size_0, _BLOCK_SIZE_0)
pid_0 = virtual_pid % num_blocks_0
pid_1 = virtual_pid // num_blocks_0
offset_0 = pid_0 * _BLOCK_SIZE_0
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
mask_0 = indices_0 < x_size_0
offset_1 = pid_1 * _BLOCK_SIZE_1
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
mask_1 = indices_1 < x_size_1
load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
load_1 = tl.load(y + (indices_0[:, None] * y_stride_0 + indices_1[None, :] * y_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
v_0 = load + load_1
tl.store(result + (indices_0[:, None] * result_stride_0 + indices_1[None, :] * result_stride_1), v_0, mask_0[:, None] & mask_1[None, :])
def simple_add(x: torch.Tensor, y: torch.Tensor):
result = x.new_empty(x.size())
_NUM_SM = helion.runtime.get_num_sm(torch.device(type='cuda', index=0))
_BLOCK_SIZE_0 = 32
_BLOCK_SIZE_1 = 16
_simple_add_kernel[_NUM_SM,](x, y, result, x.size(0), x.size(1), result.stride(0), result.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), _NUM_SM, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
return result
def _simple_add_make_precompiler(x: torch.Tensor, y: torch.Tensor):
result = x.new_empty(x.size())
_NUM_SM = helion.runtime.get_num_sm(torch.device(type='cuda', index=0))
_BLOCK_SIZE_0 = 32
_BLOCK_SIZE_1 = 16
from helion.runtime.precompile_shim import make_precompiler
return make_precompiler(_simple_add_kernel)(x, y, result, x.size(0), x.size(1), result.stride(0), result.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), _NUM_SM, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
============================================================
PERSISTENT BLOCKED WITH FOR EACH PROGRAM ID (Multi-loop)
============================================================
from __future__ import annotations
import torch
import helion
import triton
import triton.language as tl
@triton.jit
def _multi_loop_example_kernel(x, result1, y, result2, x_size_0, x_size_1, y_size_0, y_size_1, result1_stride_0, result1_stride_1, result2_stride_0, result2_stride_1, x_stride_0, x_stride_1, y_stride_0, y_stride_1, _NUM_SM: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr):
total_pids = tl.cdiv(x_size_0, _BLOCK_SIZE_0) * tl.cdiv(x_size_1, _BLOCK_SIZE_1) + tl.cdiv(y_size_0, _BLOCK_SIZE_2) * tl.cdiv(y_size_1, _BLOCK_SIZE_3)
block_size = tl.cdiv(total_pids, _NUM_SM)
start_pid = tl.program_id(0) * block_size
end_pid = tl.minimum(start_pid + block_size, total_pids)
for virtual_pid in tl.range(start_pid, end_pid):
pid_shared = virtual_pid
if pid_shared < tl.cdiv(x_size_0, _BLOCK_SIZE_0) * tl.cdiv(x_size_1, _BLOCK_SIZE_1):
num_blocks_0 = tl.cdiv(x_size_0, _BLOCK_SIZE_0)
pid_0 = pid_shared % num_blocks_0
pid_1 = pid_shared // num_blocks_0
offset_0 = pid_0 * _BLOCK_SIZE_0
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
mask_0 = indices_0 < x_size_0
offset_1 = pid_1 * _BLOCK_SIZE_1
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
mask_1 = indices_1 < x_size_1
load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
v_0 = 2.0
v_1 = load * v_0
tl.store(result1 + (indices_0[:, None] * result1_stride_0 + indices_1[None, :] * result1_stride_1), v_1, mask_0[:, None] & mask_1[None, :])
else:
pid_shared -= tl.cdiv(x_size_0, _BLOCK_SIZE_0) * tl.cdiv(x_size_1, _BLOCK_SIZE_1)
num_blocks_1 = tl.cdiv(y_size_0, _BLOCK_SIZE_2)
pid_2 = pid_shared % num_blocks_1
pid_3 = pid_shared // num_blocks_1
offset_2 = pid_2 * _BLOCK_SIZE_2
indices_2 = (offset_2 + tl.arange(0, _BLOCK_SIZE_2)).to(tl.int32)
mask_2 = indices_2 < y_size_0
offset_3 = pid_3 * _BLOCK_SIZE_3
indices_3 = (offset_3 + tl.arange(0, _BLOCK_SIZE_3)).to(tl.int32)
mask_3 = indices_3 < y_size_1
load_1 = tl.load(y + (indices_2[:, None] * y_stride_0 + indices_3[None, :] * y_stride_1), mask_2[:, None] & mask_3[None, :], other=0)
v_2 = 3.0
v_3 = load_1 * v_2
tl.store(result2 + (indices_2[:, None] * result2_stride_0 + indices_3[None, :] * result2_stride_1), v_3, mask_2[:, None] & mask_3[None, :])
def multi_loop_example(x: torch.Tensor, y: torch.Tensor):
result1 = x.new_empty(x.size())
result2 = y.new_empty(y.size())
_NUM_SM = helion.runtime.get_num_sm(torch.device(type='cuda', index=0))
_BLOCK_SIZE_0 = 32
_BLOCK_SIZE_1 = 16
_BLOCK_SIZE_2 = 32
_BLOCK_SIZE_3 = 16
_multi_loop_example_kernel[_NUM_SM,](x, result1, y, result2, x.size(0), x.size(1), y.size(0), y.size(1), result1.stride(0), result1.stride(1), result2.stride(0), result2.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), _NUM_SM, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, num_warps=4, num_stages=3)
return (result1, result2)
def _multi_loop_example_make_precompiler(x: torch.Tensor, y: torch.Tensor):
result1 = x.new_empty(x.size())
result2 = y.new_empty(y.size())
_NUM_SM = helion.runtime.get_num_sm(torch.device(type='cuda', index=0))
_BLOCK_SIZE_0 = 32
_BLOCK_SIZE_1 = 16
_BLOCK_SIZE_2 = 32
_BLOCK_SIZE_3 = 16
from helion.runtime.precompile_shim import make_precompiler
return make_precompiler(_multi_loop_example_kernel)(x, result1, y, result2, x.size(0), x.size(1), y.size(0), y.size(1), result1.stride(0), result1.stride(1), result2.stride(0), result2.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), _NUM_SM, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, num_warps=4, num_stages=3)
============================================================
PERSISTENT INTERLEAVED WITH FOR EACH PROGRAM ID (Multi-loop)
============================================================
from __future__ import annotations
import torch
import helion
import triton
import triton.language as tl
@triton.jit
def _multi_loop_example_kernel(x, result1, y, result2, x_size_0, x_size_1, y_size_0, y_size_1, result1_stride_0, result1_stride_1, result2_stride_0, result2_stride_1, x_stride_0, x_stride_1, y_stride_0, y_stride_1, _NUM_SM: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr):
total_pids = tl.cdiv(x_size_0, _BLOCK_SIZE_0) * tl.cdiv(x_size_1, _BLOCK_SIZE_1) + tl.cdiv(y_size_0, _BLOCK_SIZE_2) * tl.cdiv(y_size_1, _BLOCK_SIZE_3)
for virtual_pid in tl.range(tl.program_id(0), total_pids, _NUM_SM):
pid_shared = virtual_pid
if pid_shared < tl.cdiv(x_size_0, _BLOCK_SIZE_0) * tl.cdiv(x_size_1, _BLOCK_SIZE_1):
num_blocks_0 = tl.cdiv(x_size_0, _BLOCK_SIZE_0)
pid_0 = pid_shared % num_blocks_0
pid_1 = pid_shared // num_blocks_0
offset_0 = pid_0 * _BLOCK_SIZE_0
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
mask_0 = indices_0 < x_size_0
offset_1 = pid_1 * _BLOCK_SIZE_1
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
mask_1 = indices_1 < x_size_1
load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
v_0 = 2.0
v_1 = load * v_0
tl.store(result1 + (indices_0[:, None] * result1_stride_0 + indices_1[None, :] * result1_stride_1), v_1, mask_0[:, None] & mask_1[None, :])
else:
pid_shared -= tl.cdiv(x_size_0, _BLOCK_SIZE_0) * tl.cdiv(x_size_1, _BLOCK_SIZE_1)
num_blocks_1 = tl.cdiv(y_size_0, _BLOCK_SIZE_2)
pid_2 = pid_shared % num_blocks_1
pid_3 = pid_shared // num_blocks_1
offset_2 = pid_2 * _BLOCK_SIZE_2
indices_2 = (offset_2 + tl.arange(0, _BLOCK_SIZE_2)).to(tl.int32)
mask_2 = indices_2 < y_size_0
offset_3 = pid_3 * _BLOCK_SIZE_3
indices_3 = (offset_3 + tl.arange(0, _BLOCK_SIZE_3)).to(tl.int32)
mask_3 = indices_3 < y_size_1
load_1 = tl.load(y + (indices_2[:, None] * y_stride_0 + indices_3[None, :] * y_stride_1), mask_2[:, None] & mask_3[None, :], other=0)
v_2 = 3.0
v_3 = load_1 * v_2
tl.store(result2 + (indices_2[:, None] * result2_stride_0 + indices_3[None, :] * result2_stride_1), v_3, mask_2[:, None] & mask_3[None, :])
def multi_loop_example(x: torch.Tensor, y: torch.Tensor):
result1 = x.new_empty(x.size())
result2 = y.new_empty(y.size())
_NUM_SM = helion.runtime.get_num_sm(torch.device(type='cuda', index=0))
_BLOCK_SIZE_0 = 32
_BLOCK_SIZE_1 = 16
_BLOCK_SIZE_2 = 32
_BLOCK_SIZE_3 = 16
_multi_loop_example_kernel[_NUM_SM,](x, result1, y, result2, x.size(0), x.size(1), y.size(0), y.size(1), result1.stride(0), result1.stride(1), result2.stride(0), result2.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), _NUM_SM, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, num_warps=4, num_stages=3)
return (result1, result2)
def _multi_loop_example_make_precompiler(x: torch.Tensor, y: torch.Tensor):
result1 = x.new_empty(x.size())
result2 = y.new_empty(y.size())
_NUM_SM = helion.runtime.get_num_sm(torch.device(type='cuda', index=0))
_BLOCK_SIZE_0 = 32
_BLOCK_SIZE_1 = 16
_BLOCK_SIZE_2 = 32
_BLOCK_SIZE_3 = 16
from helion.runtime.precompile_shim import make_precompiler
return make_precompiler(_multi_loop_example_kernel)(x, result1, y, result2, x.size(0), x.size(1), y.size(0), y.size(1), result1.stride(0), result1.stride(1), result2.stride(0), result2.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), _NUM_SM, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, num_warps=4, num_stages=3)
============================================================
PERSISTENT BLOCKED WITH L2 GROUPING
============================================================
from __future__ import annotations
import torch
import helion
import triton
import triton.language as tl
@triton.jit
def _simple_add_kernel(x, y, result, x_size_0, x_size_1, result_stride_0, result_stride_1, x_stride_0, x_stride_1, y_stride_0, y_stride_1, _NUM_SM: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
total_pids = tl.cdiv(x_size_0, _BLOCK_SIZE_0) * tl.cdiv(x_size_1, _BLOCK_SIZE_1)
block_size = tl.cdiv(total_pids, _NUM_SM)
start_pid = tl.program_id(0) * block_size
end_pid = tl.minimum(start_pid + block_size, total_pids)
for virtual_pid in tl.range(start_pid, end_pid):
num_pid_m = tl.cdiv(x_size_0, _BLOCK_SIZE_0)
num_pid_n = tl.cdiv(x_size_1, _BLOCK_SIZE_1)
num_pid_in_group = 8 * num_pid_n
group_id = virtual_pid // num_pid_in_group
first_pid_m = group_id * 8
group_size_m = min(num_pid_m - first_pid_m, 8)
pid_0 = first_pid_m + virtual_pid % num_pid_in_group % group_size_m
pid_1 = virtual_pid % num_pid_in_group // group_size_m
offset_0 = pid_0 * _BLOCK_SIZE_0
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
mask_0 = indices_0 < x_size_0
offset_1 = pid_1 * _BLOCK_SIZE_1
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
mask_1 = indices_1 < x_size_1
load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
load_1 = tl.load(y + (indices_0[:, None] * y_stride_0 + indices_1[None, :] * y_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
v_0 = load + load_1
tl.store(result + (indices_0[:, None] * result_stride_0 + indices_1[None, :] * result_stride_1), v_0, mask_0[:, None] & mask_1[None, :])
def simple_add(x: torch.Tensor, y: torch.Tensor):
result = x.new_empty(x.size())
_NUM_SM = helion.runtime.get_num_sm(torch.device(type='cuda', index=0))
_BLOCK_SIZE_0 = 32
_BLOCK_SIZE_1 = 16
_simple_add_kernel[_NUM_SM,](x, y, result, x.size(0), x.size(1), result.stride(0), result.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), _NUM_SM, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
return result
def _simple_add_make_precompiler(x: torch.Tensor, y: torch.Tensor):
result = x.new_empty(x.size())
_NUM_SM = helion.runtime.get_num_sm(torch.device(type='cuda', index=0))
_BLOCK_SIZE_0 = 32
_BLOCK_SIZE_1 = 16
from helion.runtime.precompile_shim import make_precompiler
return make_precompiler(_simple_add_kernel)(x, y, result, x.size(0), x.size(1), result.stride(0), result.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), _NUM_SM, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3) |
Enabled with `config["pid_type"]="persistent_blocked"` or `"persistent_interleaved"`. This also refactors much of the program id handling. stack-info: PR: #238, branch: jansel/stack/77
Merged
This was referenced Jul 5, 2025
Open
yf225
approved these changes
Jul 7, 2025
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Stacked PRs (oldest at bottom):
Implement persistent kernels
Enabled with
config["pid_type"]="persistent_blocked"
or"persistent_interleaved"
.This also refactors much of the program id handling.