Skip to content

Commit 57b4f02

Browse files
authored
[DEV] Support TRITON_INTERPRET=1 (#73)
1 parent c14d958 commit 57b4f02

File tree

6 files changed

+40
-18
lines changed

6 files changed

+40
-18
lines changed

.github/workflows/python-app.yml

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@ name: Python application
55

66
on:
77
push:
8-
branches-ignore:
9-
- '**'
8+
branches:
9+
- main
10+
- keren/v2.0
1011
pull_request:
11-
branches-ignore:
12-
- '**'
12+
branches:
13+
- main
14+
- keren/v2.0
1315

1416
permissions:
1517
contents: read
@@ -21,6 +23,8 @@ concurrency:
2123
jobs:
2224
build:
2325
runs-on: ubuntu-latest
26+
env:
27+
TRITON_INTERPRET: "1"
2428

2529
steps:
2630
- uses: actions/checkout@v3
@@ -44,7 +48,7 @@ jobs:
4448
pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121
4549
pip uninstall pytorch-triton -y
4650
47-
- name: Clone Triton and Install
51+
- name: Install Triton
4852
run: |
4953
pip install triton==3.1.0
5054

tests/test_autotune_add.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import pytest
12
import torch
23
import triton
34
import triton.language as tl
@@ -7,8 +8,12 @@
78
from triton_viz import config as cfg
89

910

10-
cfg.sanitizer_backend = "symexec"
11+
try:
12+
torch.cuda.current_device()
13+
except:
14+
pytest.skip("This test requires a CUDA-enabled environment.", allow_module_level=True)
1115

16+
cfg.sanitizer_backend = "symexec"
1217

1318
@triton.autotune(
1419
configs=[
@@ -39,8 +44,8 @@ def test_autotune_add_inrange():
3944
This test uses n_elements = 128, matching the size of the input tensors.
4045
It should NOT cause any out-of-bound access.
4146
"""
42-
x = torch.randn(128, device="cuda")
43-
y = torch.randn(128, device="cuda")
47+
x = torch.randn(128)
48+
y = torch.randn(128)
4449
out = torch.empty_like(x)
4550

4651
# The kernel launch uses n_elements=128, aligned with the tensor size.
@@ -55,8 +60,8 @@ def test_autotune_add_out_of_bound():
5560
This test deliberately sets n_elements = 256, exceeding the actual buffer size (128).
5661
It will likely cause out-of-bound reads/writes, which may trigger errors or warnings.
5762
"""
58-
x = torch.randn(128, device="cuda")
59-
y = torch.randn(128, device="cuda")
63+
x = torch.randn(128)
64+
y = torch.randn(128)
6065
out = torch.empty_like(x)
6166

6267
# The kernel launch uses n_elements=256, exceeding the valid tensor size.

tests/test_config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
import pytest
1+
import pytest, os
2+
os.environ["TRITON_SANITIZER_BACKEND"] = "off"
23
import triton_viz.core.config as cfg
34

45

tests/test_print_traceback.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def kernel_A(ptr, n):
2727

2828

2929
def test_print_nested_functions():
30-
x = torch.arange(4, device="cuda", dtype=torch.float32)
30+
x = torch.arange(4, dtype=torch.float32)
3131
print("Input:", x)
3232

3333
# We'll launch a grid bigger than x.numel() to force a out-of-bounds error

tests/test_wrapper.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def _decorator(fn):
5959
env = os.environ.copy()
6060
env["PYTHONPATH"] = str(tmp_path) + os.pathsep + env.get("PYTHONPATH", "")
6161
env["TRITON_SANITIZER_BACKEND"] = "symexec"
62+
env["TRITON_INTERPRET"] = "1"
6263

6364
# run the dummy program using triton-sanitizer
6465
cmd = ["triton-sanitizer", str(tmp_path / "dummy_program.py")]

triton_viz/core/trace.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,20 @@ def add_client(self, new_client: Union[Client, str]) -> None:
3434
new_client_instance = self._normalize_client(new_client)
3535
self.client_manager.add_clients([new_client_instance])
3636

37-
def __init__(self, kernel: JITFunction, client: Union[str, Client]) -> None:
38-
assert isinstance(kernel, JITFunction), "Kernel must be a JITFunction"
39-
self.interpreter_fn = InterpretedFunction(kernel.fn)
37+
def __init__(
38+
self,
39+
kernel: Union[JITFunction, InterpretedFunction],
40+
client: Union[str, Client],
41+
) -> None:
4042
self.fn = kernel
43+
if isinstance(kernel, InterpretedFunction):
44+
self.interpreter_fn = kernel
45+
elif isinstance(kernel, JITFunction):
46+
self.interpreter_fn = InterpretedFunction(kernel.fn)
47+
else:
48+
raise TypeError(
49+
f"Kernel must be JITFunction or InterpretedFunction, got {type(kernel)}"
50+
)
4151
self.arg_names = kernel.arg_names
4252
self.client_manager = ClientManager()
4353
self.add_client(client)
@@ -76,7 +86,7 @@ def decorator(kernel) -> Trace:
7686
return kernel
7787

7888
# First-time wrapping
79-
if isinstance(kernel, JITFunction):
89+
if isinstance(kernel, (JITFunction, InterpretedFunction)):
8090
return Trace(kernel, clients)
8191

8292
# If the object is already a Trace, just append the new client(s)
@@ -85,8 +95,9 @@ def decorator(kernel) -> Trace:
8595
trace.add_client(clients)
8696
return trace
8797

88-
# If the object is neither a JITFunction nor Trace, raise an error
89-
raise TypeError(f"Expected JITFunction, got {type(kernel)}")
98+
raise TypeError(
99+
f"Expected JITFunction, InterpretedFunction or Trace, got {type(kernel)}"
100+
)
90101

91102
return decorator
92103

0 commit comments

Comments
 (0)