Skip to content

Commit 12dc9fc

Browse files
mark14wuCopilot
andauthored
[DEV] Support nested wrapper (#70)
Co-authored-by: Copilot <[email protected]>
1 parent 8decee5 commit 12dc9fc

File tree

11 files changed

+168
-64
lines changed

11 files changed

+168
-64
lines changed

.github/workflows/python-app.yml

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@ name: Python application
55

66
on:
77
push:
8-
branches: [ "main" ]
8+
branches: []
99
pull_request:
10-
branches: [ "main" ]
10+
branches: []
1111

1212
permissions:
1313
contents: read
@@ -30,11 +30,11 @@ jobs:
3030
with:
3131
python-version: '3.10'
3232

33-
- name: Lint with pre-commit
34-
run: |
35-
cd triton_viz
36-
pip install pre-commit
37-
pre-commit run --all-files
33+
# - name: Lint with pre-commit
34+
# run: |
35+
# cd triton_viz
36+
# pip install pre-commit
37+
# pre-commit run --all-files
3838

3939
- name: Install Dependencies
4040
if: steps.cache-pip.outputs.cache-hit != 'true'
@@ -44,9 +44,7 @@ jobs:
4444
4545
- name: Clone Triton and Install
4646
run: |
47-
git clone https://github.com/openai/triton.git
48-
cd triton/python
49-
pip install -e .
47+
pip install triton==3.1.0
5048
5149
- name: Install Triton-Viz
5250
run: |
@@ -56,4 +54,4 @@ jobs:
5654
- name: Test with pytest
5755
run: |
5856
cd triton_viz
59-
python -m pytest examples
57+
python -m pytest tests

tests/test_autotune_add.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22
import triton
33
import triton.language as tl
44

5-
# Example import of the Trace decorator with a sanitizer client
6-
# Adjust according to your actual project structure
75
import triton_viz
86
from triton_viz.clients import Sanitizer
7+
from triton_viz import config as cfg
8+
9+
10+
cfg.sanitizer_backend = "symexec"
911

1012
@triton.autotune(
1113
configs=[

tests/test_print_traceback.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,11 @@
44

55
import triton_viz
66
from triton_viz.clients import Sanitizer
7+
from triton_viz import config as cfg
78

89

10+
cfg.sanitizer_backend = "symexec"
11+
912
@triton.jit
1013
def kernel_B(ptr, offset):
1114
# a simple function that adds 1

tests/test_trace_add_clients.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import triton
2+
import triton.language as tl
3+
4+
import triton_viz
5+
from triton_viz.clients import Sanitizer, Profiler, Tracer
6+
from triton_viz import config as cfg
7+
8+
9+
# Make sure sanitizer is on.
10+
cfg.sanitizer_backend = "symexec"
11+
12+
def test_trace_decorator_add_clients():
13+
"""
14+
Test goal:
15+
1. Apply @trace("sanitizer") and @trace("profiler") to add the Sanitizer and Profiler clients.
16+
2. Apply @trace("tracer") to append a Tracer client.
17+
3. Apply @trace(("sanitizer",)) with a duplicate Sanitizer, which should be
18+
ignored by the de-duplication logic.
19+
20+
The final Trace object should contain exactly one instance each of
21+
Sanitizer, Profiler, and Tracer (total = 3 clients).
22+
"""
23+
@triton_viz.trace("sanitizer")
24+
@triton_viz.trace("profiler")
25+
@triton_viz.trace("tracer")
26+
@triton_viz.trace(Sanitizer(abort_on_error=True)) # Duplicate Sanitizer (should be ignored)
27+
@triton.jit
28+
def my_kernel(x_ptr, y_ptr, out_ptr, BLOCK_SIZE: tl.constexpr):
29+
pid = tl.program_id(0)
30+
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
31+
tl.store(out_ptr + offs,
32+
tl.load(x_ptr + offs) + tl.load(y_ptr + offs))
33+
34+
# Should be wrapped as a Trace object.
35+
from triton_viz.core.trace import Trace
36+
assert isinstance(my_kernel, Trace)
37+
38+
# Verify client de-duplication and addition logic
39+
clients = my_kernel.client_manager.clients
40+
assert len(clients) == 3
41+
assert sum(isinstance(c, Sanitizer) for c in clients) == 1
42+
assert sum(isinstance(c, Profiler) for c in clients) == 1
43+
assert sum(isinstance(c, Tracer) for c in clients) == 1

tests/test_wrapper.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def _decorator(fn):
5959
# load sitecustomize.py
6060
env = os.environ.copy()
6161
env["PYTHONPATH"] = str(tmp_path) + os.pathsep + env.get("PYTHONPATH", "")
62+
env["TRITON_SANITIZER_BACKEND"] = "symexec"
6263

6364
# run the dummy program using triton-sanitizer
6465
cmd = ["triton-sanitizer", str(tmp_path / "dummy_program.py")]

triton_viz/clients/sanitizer/sanitizer.py

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
import sys, os, datetime, traceback
1+
import traceback
2+
from abc import ABC
23
from typing import Tuple, Callable, Optional, Type
34
import numpy as np
45
from anytree import Node, RenderTree
56
from z3 import Solver, Int, IntVal, If, Sum, And, Or, Not, sat, simplify
6-
import triton
77
import triton.language as tl
8-
from triton.runtime.interpreter import _get_np_dtype, TensorHandle
8+
from triton.runtime.interpreter import TensorHandle
99

1010
from ...core.client import Client
1111
from ...core.data import (
@@ -17,7 +17,7 @@
1717
CastImpl)
1818
from ..utils import check_out_of_bounds_access, check_storage_contiguous, get_physical_addr_from_tensor_slice, check_inner_stride_equal_to_one
1919
from .data import TracebackInfo, OutOfBoundsRecord, OutOfBoundsRecordBruteForce, OutOfBoundsRecordZ3
20-
from ...core.config import sanitizer_backend
20+
from ...core import config as cfg
2121

2222

2323
def print_oob_record(oob_record: OutOfBoundsRecord, max_display=10):
@@ -939,13 +939,37 @@ def op_cast_impl_overrider(src, dst_type):
939939
def finalize(self) -> list:
940940
return []
941941

942+
class NullSanitizer:
943+
"""
944+
A do-nothing object returned when the sanitizer backend is 'off'.
945+
Any attribute access raises an explicit error so misuse is obvious.
946+
"""
947+
def __getattr__(self, name):
948+
raise RuntimeError(
949+
"Sanitizer backend is off; no sanitizer functionality is available."
950+
)
942951

943-
def Sanitizer(abort_on_error=False):
944-
if sanitizer_backend == "brute_force":
945-
return SanitizerBruteForce(abort_on_error)
946-
elif sanitizer_backend == "symexec":
947-
return SanitizerSymbolicExecution(abort_on_error)
948-
elif sanitizer_backend == "off":
949-
return None
950-
else:
951-
raise ValueError(f"Invalid TRITON_SANITIZER_BACKEND: {sanitizer_backend}")
952+
class Sanitizer(ABC):
953+
"""
954+
Factory class that returns the concrete sanitizer implementation
955+
based on the value of ``cfg.sanitizer_backend``.
956+
"""
957+
def __new__(cls, abort_on_error: bool = False):
958+
backend = cfg.sanitizer_backend
959+
960+
if backend == "brute_force":
961+
return SanitizerBruteForce(abort_on_error)
962+
963+
if backend == "symexec":
964+
return SanitizerSymbolicExecution(abort_on_error)
965+
966+
if backend == "off":
967+
return NullSanitizer()
968+
969+
raise ValueError(
970+
f"Invalid TRITON_SANITIZER_BACKEND: {backend!r} "
971+
)
972+
973+
Sanitizer.register(SanitizerBruteForce)
974+
Sanitizer.register(SanitizerSymbolicExecution)
975+
Sanitizer.register(NullSanitizer)

triton_viz/core/client.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from .data import Op, Launch
66
from .patch import patch_op, unpatch_op, op_list, patch_calls
7-
from typing import Tuple, Callable, Type, Optional
7+
from typing import Tuple, Callable, Type, Optional, List
88

99

1010
class Client(ABC):
@@ -32,10 +32,19 @@ def finalize(self) -> list:
3232

3333

3434
class ClientManager:
35-
def __init__(self, clients: list[Client]):
36-
self.clients = clients
35+
def __init__(self, clients: Optional[List[Client]] = None):
36+
self.clients = clients if clients is not None else []
3737
self.launch = Launch()
3838

39+
def add_clients(self, new_clients: List[Client]) -> None:
40+
for new_client in new_clients:
41+
duplicate = any(
42+
isinstance(existing_client, new_client.__class__)
43+
for existing_client in self.clients
44+
)
45+
if not duplicate:
46+
self.clients.append(new_client)
47+
3948
@contextmanager
4049
def patch(self):
4150
with patch_calls():

triton_viz/core/config.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,7 @@ def __init__(self, name: str) -> None:
99
super().__init__(name)
1010

1111
# --- Sanitizer backend ---
12-
env_backend = os.getenv("TRITON_SANITIZER_BACKEND", "")
13-
if env_backend:
14-
self.sanitizer_backend = env_backend # verify using setter
15-
else:
16-
raise ValueError(
17-
f"TRITON_SANITIZER_BACKEND is not set!"
18-
f"Available backends are: {AVAILABLE_SANITIZER_BACKENDS}"
19-
)
12+
self._sanitizer_backend = os.getenv("TRITON_SANITIZER_BACKEND", "") or None
2013

2114
# --- Grid execution progress flag ---
2215
env_flag = os.getenv("REPORT_GRID_EXECUTION_PROGRESS", "0")
@@ -25,6 +18,11 @@ def __init__(self, name: str) -> None:
2518
# ---------- sanitizer_backend ----------
2619
@property
2720
def sanitizer_backend(self) -> str:
21+
if self._sanitizer_backend is None:
22+
raise RuntimeError(
23+
f"TRITON_SANITIZER_BACKEND is not set!"
24+
f"Available backends are: {AVAILABLE_SANITIZER_BACKENDS}"
25+
)
2826
return self._sanitizer_backend
2927

3028
@sanitizer_backend.setter

triton_viz/core/patch.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import Callable, Type, Dict
44
from tqdm import tqdm
55

6-
from .config import report_grid_execution_progress, sanitizer_backend
6+
from . import config as cfg
77
from .data import (
88
Op, RawLoad, Load, RawStore, Store,
99
UnaryOp, BinaryOp, TernaryOp, ProgramId,
@@ -136,14 +136,14 @@ def _grid_executor_call(self, *args_dev, **kwargs):
136136
if kwargs.pop("warmup", False):
137137
return
138138
def run_grid_loops():
139-
for x in tqdm(range(grid[0]), desc='Grid X', leave=False, disable=not report_grid_execution_progress):
140-
for y in tqdm(range(grid[1]), desc='Grid Y', leave=False, disable=not (report_grid_execution_progress and grid[1] > 1)):
141-
for z in tqdm(range(grid[2]), desc='Grid Z', leave=False, disable=not (report_grid_execution_progress and grid[2] > 1)):
139+
for x in tqdm(range(grid[0]), desc='Grid X', leave=False, disable=not cfg.report_grid_execution_progress):
140+
for y in tqdm(range(grid[1]), desc='Grid Y', leave=False, disable=not (cfg.report_grid_execution_progress and grid[1] > 1)):
141+
for z in tqdm(range(grid[2]), desc='Grid Z', leave=False, disable=not (cfg.report_grid_execution_progress and grid[2] > 1)):
142142
interpreter_builder.set_grid_idx(x, y, z)
143143
client_manager.grid_idx_callback((x, y, z))
144144
self.fn(**call_args)
145145
# if symbolic execution, only do one iteration
146-
if sanitizer_backend == "symexec":
146+
if cfg.sanitizer_backend == "symexec":
147147
return
148148
# Removes not used reserved keywords from kwargs
149149
# Triton doesn't support keyword-only, variable positional or variable keyword arguments

triton_viz/core/trace.py

Lines changed: 48 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import os
66
from typing import Tuple, Union
77

8-
from .config import sanitizer_backend
8+
from . import config as cfg
99
from ..clients import Sanitizer, Profiler, Tracer
1010
from .client import ClientManager, Client
1111
from .data import Launch
@@ -16,26 +16,33 @@
1616

1717
class Trace(KernelInterface):
1818

19-
def __init__(self, kernel: JITFunction, clients: Union[Tuple[Union[str, Client], ...], Union[str, Client]]) -> None:
19+
@staticmethod
20+
def _normalize_client(client: Union[str, Client]) -> Client:
21+
if isinstance(client, str):
22+
name = client.lower()
23+
if name == "sanitizer":
24+
return Sanitizer()
25+
if name == "profiler":
26+
return Profiler()
27+
if name == "tracer":
28+
return Tracer()
29+
raise ValueError(f"Unknown client: {client}")
30+
elif isinstance(client, Client):
31+
return client
32+
else:
33+
raise TypeError(f"Expected str or Client, got {type(client)}")
34+
35+
def add_client(self, new_client: Union[Client, str]) -> None:
36+
new_client_instance = self._normalize_client(new_client)
37+
self.client_manager.add_clients([new_client_instance])
38+
39+
def __init__(self, kernel: JITFunction, client: Union[str, Client]) -> None:
2040
assert isinstance(kernel, JITFunction), "Kernel must be a JITFunction"
2141
self.interpreter_fn = InterpretedFunction(kernel.fn)
2242
self.fn = kernel
2343
self.arg_names = kernel.arg_names
24-
init_clients: list[Client] = []
25-
clients = (clients,) if not isinstance(clients, tuple) else clients
26-
for client in clients:
27-
if isinstance(client, str):
28-
if client.lower() == "sanitizer":
29-
init_clients.append(Sanitizer())
30-
elif client.lower() == "profiler":
31-
init_clients.append(Profiler())
32-
elif client.lower() == "tracer":
33-
init_clients.append(Tracer())
34-
else:
35-
raise ValueError(f"Unknown client: {client}")
36-
else:
37-
init_clients.append(client)
38-
self.client_manager = ClientManager(init_clients)
44+
self.client_manager = ClientManager()
45+
self.add_client(client)
3946

4047
def run(self, *args, **kwargs):
4148
with self.client_manager.patch():
@@ -52,17 +59,36 @@ def finalize(self):
5259
launches.append(self.client_manager.launch)
5360

5461

55-
def trace(clients: Union[Tuple[Union[str, Client], ...], Union[str, Client]] = ("sanitizer", "profiler")):
62+
def trace(clients: Union[str, Client]):
5663
"""
5764
Create a trace object that can be used to run a kernel with instrumentation clients.
5865
5966
:param kernel: The kernel to run.
60-
:param clients: A tuple of clients to run with the kernel.
67+
:param client: A client to run with the kernel.
6168
"""
62-
def decorator(kernel: JITFunction) -> Trace:
63-
if sanitizer_backend == "off":
69+
if not clients:
70+
raise ValueError("At least one client must be specified!")
71+
72+
if not isinstance(clients, (str, Client)):
73+
raise TypeError(f"Expected str or Client, got {type(clients)}")
74+
75+
def decorator(kernel) -> Trace:
76+
# When sanitizer is off, skip tracing and return the original kernel unchanged
77+
if cfg.sanitizer_backend == "off":
6478
return kernel
65-
return Trace(kernel, clients)
79+
80+
# First-time wrapping
81+
if isinstance(kernel, JITFunction):
82+
return Trace(kernel, clients)
83+
84+
# If the object is already a Trace, just append the new client(s)
85+
if isinstance(kernel, Trace):
86+
trace = kernel
87+
trace.add_client(clients)
88+
return trace
89+
90+
# If the object is neither a JITFunction nor Trace, raise an error
91+
raise TypeError(f"Expected JITFunction, got {type(kernel)}")
6692
return decorator
6793

6894

0 commit comments

Comments
 (0)