Skip to content

[DEV] Support nested wrapper #70

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Jun 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 9 additions & 11 deletions .github/workflows/python-app.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ name: Python application

on:
push:
branches: [ "main" ]
branches: []
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those changes are suspicious

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change has been fixed in #71.
We use

   branches-ignore:
      - '**'

instead of

branches: []

to ignore all ci tests for now.

pull_request:
branches: [ "main" ]
branches: []

permissions:
contents: read
Expand All @@ -30,11 +30,11 @@ jobs:
with:
python-version: '3.10'

- name: Lint with pre-commit
run: |
cd triton_viz
pip install pre-commit
pre-commit run --all-files
# - name: Lint with pre-commit
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And why merge without making all tests passed?

Copy link
Collaborator Author

@mark14wu mark14wu Jun 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to focus on indirect load first and fix CI later.

As for CI problem, I figured out that once you set TRITON_INTERPRET=1, triton will skip driver checking.
See https://github.com/triton-lang/triton/blob/d141ab8b1bfa8e8dc703459412cd827b09f80b21/python/triton/_internal_testing.py#L30
However, triton-viz cannot work with TRITON_INTERPRET=1 and needs to be fixed.

I will do this after I finish indirect load.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, it has to be compatible with TRITON_INTERPRET=1

# run: |
# cd triton_viz
# pip install pre-commit
# pre-commit run --all-files

- name: Install Dependencies
if: steps.cache-pip.outputs.cache-hit != 'true'
Expand All @@ -44,9 +44,7 @@ jobs:

- name: Clone Triton and Install
run: |
git clone https://github.com/openai/triton.git
cd triton/python
pip install -e .
pip install triton==3.1.0

- name: Install Triton-Viz
run: |
Expand All @@ -56,4 +54,4 @@ jobs:
- name: Test with pytest
run: |
cd triton_viz
python -m pytest examples
python -m pytest tests
File renamed without changes.
6 changes: 4 additions & 2 deletions tests/test_autotune_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
import triton
import triton.language as tl

# Example import of the Trace decorator with a sanitizer client
# Adjust according to your actual project structure
import triton_viz
from triton_viz.clients import Sanitizer
from triton_viz import config as cfg


cfg.sanitizer_backend = "symexec"

@triton.autotune(
configs=[
Expand Down
3 changes: 3 additions & 0 deletions tests/test_print_traceback.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@

import triton_viz
from triton_viz.clients import Sanitizer
from triton_viz import config as cfg


cfg.sanitizer_backend = "symexec"

@triton.jit
def kernel_B(ptr, offset):
# a simple function that adds 1
Expand Down
43 changes: 43 additions & 0 deletions tests/test_trace_add_clients.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import triton
import triton.language as tl

import triton_viz
from triton_viz.clients import Sanitizer, Profiler, Tracer
from triton_viz import config as cfg


# Make sure sanitizer is on.
cfg.sanitizer_backend = "symexec"

def test_trace_decorator_add_clients():
"""
Test goal:
1. Apply @trace("sanitizer") and @trace("profiler") to add the Sanitizer and Profiler clients.
2. Apply @trace("tracer") to append a Tracer client.
3. Apply @trace(("sanitizer",)) with a duplicate Sanitizer, which should be
ignored by the de-duplication logic.

The final Trace object should contain exactly one instance each of
Sanitizer, Profiler, and Tracer (total = 3 clients).
"""
@triton_viz.trace("sanitizer")
@triton_viz.trace("profiler")
@triton_viz.trace("tracer")
@triton_viz.trace(Sanitizer(abort_on_error=True)) # Duplicate Sanitizer (should be ignored)
@triton.jit
def my_kernel(x_ptr, y_ptr, out_ptr, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(0)
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
tl.store(out_ptr + offs,
tl.load(x_ptr + offs) + tl.load(y_ptr + offs))

# Should be wrapped as a Trace object.
from triton_viz.core.trace import Trace
assert isinstance(my_kernel, Trace)

# Verify client de-duplication and addition logic
clients = my_kernel.client_manager.clients
assert len(clients) == 3
assert sum(isinstance(c, Sanitizer) for c in clients) == 1
assert sum(isinstance(c, Profiler) for c in clients) == 1
assert sum(isinstance(c, Tracer) for c in clients) == 1
1 change: 1 addition & 0 deletions tests/test_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def _decorator(fn):
# load sitecustomize.py
env = os.environ.copy()
env["PYTHONPATH"] = str(tmp_path) + os.pathsep + env.get("PYTHONPATH", "")
env["TRITON_SANITIZER_BACKEND"] = "symexec"

# run the dummy program using triton-sanitizer
cmd = ["triton-sanitizer", str(tmp_path / "dummy_program.py")]
Expand Down
50 changes: 37 additions & 13 deletions triton_viz/clients/sanitizer/sanitizer.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import sys, os, datetime, traceback
import traceback
from abc import ABC
from typing import Tuple, Callable, Optional, Type
import numpy as np
from anytree import Node, RenderTree
from z3 import Solver, Int, IntVal, If, Sum, And, Or, Not, sat, simplify
import triton
import triton.language as tl
from triton.runtime.interpreter import _get_np_dtype, TensorHandle
from triton.runtime.interpreter import TensorHandle

from ...core.client import Client
from ...core.data import (
Expand All @@ -17,7 +17,7 @@
CastImpl)
from ..utils import check_out_of_bounds_access, check_storage_contiguous, get_physical_addr_from_tensor_slice, check_inner_stride_equal_to_one
from .data import TracebackInfo, OutOfBoundsRecord, OutOfBoundsRecordBruteForce, OutOfBoundsRecordZ3
from ...core.config import sanitizer_backend
from ...core import config as cfg


def print_oob_record(oob_record: OutOfBoundsRecord, max_display=10):
Expand Down Expand Up @@ -939,13 +939,37 @@ def op_cast_impl_overrider(src, dst_type):
def finalize(self) -> list:
return []

class NullSanitizer:
"""
A do-nothing object returned when the sanitizer backend is 'off'.
Any attribute access raises an explicit error so misuse is obvious.
"""
def __getattr__(self, name):
raise RuntimeError(
"Sanitizer backend is off; no sanitizer functionality is available."
)

def Sanitizer(abort_on_error=False):
if sanitizer_backend == "brute_force":
return SanitizerBruteForce(abort_on_error)
elif sanitizer_backend == "symexec":
return SanitizerSymbolicExecution(abort_on_error)
elif sanitizer_backend == "off":
return None
else:
raise ValueError(f"Invalid TRITON_SANITIZER_BACKEND: {sanitizer_backend}")
class Sanitizer(ABC):
"""
Factory class that returns the concrete sanitizer implementation
based on the value of ``cfg.sanitizer_backend``.
"""
def __new__(cls, abort_on_error: bool = False):
backend = cfg.sanitizer_backend

if backend == "brute_force":
return SanitizerBruteForce(abort_on_error)

if backend == "symexec":
return SanitizerSymbolicExecution(abort_on_error)

if backend == "off":
return NullSanitizer()

raise ValueError(
f"Invalid TRITON_SANITIZER_BACKEND: {backend!r} "
)

Sanitizer.register(SanitizerBruteForce)
Sanitizer.register(SanitizerSymbolicExecution)
Sanitizer.register(NullSanitizer)
15 changes: 12 additions & 3 deletions triton_viz/core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from .data import Op, Launch
from .patch import patch_op, unpatch_op, op_list, patch_calls
from typing import Tuple, Callable, Type, Optional
from typing import Tuple, Callable, Type, Optional, List


class Client(ABC):
Expand Down Expand Up @@ -32,10 +32,19 @@ def finalize(self) -> list:


class ClientManager:
def __init__(self, clients: list[Client]):
self.clients = clients
def __init__(self, clients: Optional[List[Client]] = None):
self.clients = clients if clients is not None else []
self.launch = Launch()

def add_clients(self, new_clients: List[Client]) -> None:
for new_client in new_clients:
duplicate = any(
isinstance(existing_client, new_client.__class__)
for existing_client in self.clients
)
if not duplicate:
self.clients.append(new_client)

@contextmanager
def patch(self):
with patch_calls():
Expand Down
14 changes: 6 additions & 8 deletions triton_viz/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,7 @@ def __init__(self, name: str) -> None:
super().__init__(name)

# --- Sanitizer backend ---
env_backend = os.getenv("TRITON_SANITIZER_BACKEND", "")
if env_backend:
self.sanitizer_backend = env_backend # verify using setter
else:
raise ValueError(
f"TRITON_SANITIZER_BACKEND is not set!"
f"Available backends are: {AVAILABLE_SANITIZER_BACKENDS}"
)
self._sanitizer_backend = os.getenv("TRITON_SANITIZER_BACKEND", "") or None

# --- Grid execution progress flag ---
env_flag = os.getenv("REPORT_GRID_EXECUTION_PROGRESS", "0")
Expand All @@ -25,6 +18,11 @@ def __init__(self, name: str) -> None:
# ---------- sanitizer_backend ----------
@property
def sanitizer_backend(self) -> str:
if self._sanitizer_backend is None:
raise RuntimeError(
f"TRITON_SANITIZER_BACKEND is not set!"
f"Available backends are: {AVAILABLE_SANITIZER_BACKENDS}"
)
return self._sanitizer_backend

@sanitizer_backend.setter
Expand Down
10 changes: 5 additions & 5 deletions triton_viz/core/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Callable, Type, Dict
from tqdm import tqdm

from .config import report_grid_execution_progress, sanitizer_backend
from . import config as cfg
from .data import (
Op, RawLoad, Load, RawStore, Store,
UnaryOp, BinaryOp, TernaryOp, ProgramId,
Expand Down Expand Up @@ -136,14 +136,14 @@ def _grid_executor_call(self, *args_dev, **kwargs):
if kwargs.pop("warmup", False):
return
def run_grid_loops():
for x in tqdm(range(grid[0]), desc='Grid X', leave=False, disable=not report_grid_execution_progress):
for y in tqdm(range(grid[1]), desc='Grid Y', leave=False, disable=not (report_grid_execution_progress and grid[1] > 1)):
for z in tqdm(range(grid[2]), desc='Grid Z', leave=False, disable=not (report_grid_execution_progress and grid[2] > 1)):
for x in tqdm(range(grid[0]), desc='Grid X', leave=False, disable=not cfg.report_grid_execution_progress):
for y in tqdm(range(grid[1]), desc='Grid Y', leave=False, disable=not (cfg.report_grid_execution_progress and grid[1] > 1)):
for z in tqdm(range(grid[2]), desc='Grid Z', leave=False, disable=not (cfg.report_grid_execution_progress and grid[2] > 1)):
interpreter_builder.set_grid_idx(x, y, z)
client_manager.grid_idx_callback((x, y, z))
self.fn(**call_args)
# if symbolic execution, only do one iteration
if sanitizer_backend == "symexec":
if cfg.sanitizer_backend == "symexec":
return
# Removes not used reserved keywords from kwargs
# Triton doesn't support keyword-only, variable positional or variable keyword arguments
Expand Down
70 changes: 48 additions & 22 deletions triton_viz/core/trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import os
from typing import Tuple, Union

from .config import sanitizer_backend
from . import config as cfg
from ..clients import Sanitizer, Profiler, Tracer
from .client import ClientManager, Client
from .data import Launch
Expand All @@ -16,26 +16,33 @@

class Trace(KernelInterface):

def __init__(self, kernel: JITFunction, clients: Union[Tuple[Union[str, Client], ...], Union[str, Client]]) -> None:
@staticmethod
def _normalize_client(client: Union[str, Client]) -> Client:
if isinstance(client, str):
name = client.lower()
if name == "sanitizer":
return Sanitizer()
if name == "profiler":
return Profiler()
if name == "tracer":
return Tracer()
raise ValueError(f"Unknown client: {client}")
elif isinstance(client, Client):
return client
else:
raise TypeError(f"Expected str or Client, got {type(client)}")

def add_client(self, new_client: Union[Client, str]) -> None:
new_client_instance = self._normalize_client(new_client)
self.client_manager.add_clients([new_client_instance])

def __init__(self, kernel: JITFunction, client: Union[str, Client]) -> None:
assert isinstance(kernel, JITFunction), "Kernel must be a JITFunction"
self.interpreter_fn = InterpretedFunction(kernel.fn)
self.fn = kernel
self.arg_names = kernel.arg_names
init_clients: list[Client] = []
clients = (clients,) if not isinstance(clients, tuple) else clients
for client in clients:
if isinstance(client, str):
if client.lower() == "sanitizer":
init_clients.append(Sanitizer())
elif client.lower() == "profiler":
init_clients.append(Profiler())
elif client.lower() == "tracer":
init_clients.append(Tracer())
else:
raise ValueError(f"Unknown client: {client}")
else:
init_clients.append(client)
self.client_manager = ClientManager(init_clients)
self.client_manager = ClientManager()
self.add_client(client)

def run(self, *args, **kwargs):
with self.client_manager.patch():
Expand All @@ -52,17 +59,36 @@ def finalize(self):
launches.append(self.client_manager.launch)


def trace(clients: Union[Tuple[Union[str, Client], ...], Union[str, Client]] = ("sanitizer", "profiler")):
def trace(clients: Union[str, Client]):
"""
Create a trace object that can be used to run a kernel with instrumentation clients.

:param kernel: The kernel to run.
:param clients: A tuple of clients to run with the kernel.
:param client: A client to run with the kernel.
"""
def decorator(kernel: JITFunction) -> Trace:
if sanitizer_backend == "off":
if not clients:
raise ValueError("At least one client must be specified!")

if not isinstance(clients, (str, Client)):
raise TypeError(f"Expected str or Client, got {type(clients)}")

def decorator(kernel) -> Trace:
# When sanitizer is off, skip tracing and return the original kernel unchanged
if cfg.sanitizer_backend == "off":
return kernel
return Trace(kernel, clients)

# First-time wrapping
if isinstance(kernel, JITFunction):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you tested this with @triton.autotune?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. Let me write a unittest with autotune.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can consider running all tests using CI.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just ran test/test_autotune_add.py and it worked well.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's fix the bugs before merging

return Trace(kernel, clients)

# If the object is already a Trace, just append the new client(s)
if isinstance(kernel, Trace):
trace = kernel
trace.add_client(clients)
return trace

# If the object is neither a JITFunction nor Trace, raise an error
raise TypeError(f"Expected JITFunction, got {type(kernel)}")
return decorator


Expand Down
Loading