Skip to content

Commit c1f8892

Browse files
authored
Add benchmark comparison tool for analyzing timing methods
Differential Revision: D90802828 Pull Request resolved: #791
1 parent 96de47e commit c1f8892

File tree

1 file changed

+367
-0
lines changed
  • benchmarks/timing_accuracy

1 file changed

+367
-0
lines changed

benchmarks/timing_accuracy/run.py

Lines changed: 367 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,367 @@
1+
import argparse
2+
import json
3+
import statistics
4+
import sys
5+
import time
6+
from dataclasses import dataclass, field
7+
from typing import Any, Callable, Dict, List
8+
9+
import torch
10+
import triton
11+
12+
13+
@dataclass
14+
class MethodStats:
15+
"""Statistics for a benchmarking method."""
16+
17+
method_name: str
18+
n_tests: int
19+
reps_per_test: int
20+
intra_test_medians: List[float] = field(default_factory=list)
21+
intra_test_stds: List[float] = field(default_factory=list)
22+
intra_test_cvs: List[float] = field(default_factory=list)
23+
intra_test_mins: List[float] = field(default_factory=list)
24+
intra_test_maxs: List[float] = field(default_factory=list)
25+
all_samples: List[List[float]] = field(default_factory=list)
26+
27+
@property
28+
def inter_test_median(self) -> float:
29+
return statistics.median(self.intra_test_medians) if self.intra_test_medians else 0.0
30+
31+
@property
32+
def inter_test_std(self) -> float:
33+
return (
34+
statistics.stdev(self.intra_test_medians)
35+
if len(self.intra_test_medians) > 1
36+
else 0.0
37+
)
38+
39+
@property
40+
def inter_test_cv(self) -> float:
41+
median = self.inter_test_median
42+
return self.inter_test_std / median if median > 0 else 0.0
43+
44+
@property
45+
def avg_intra_test_cv(self) -> float:
46+
return statistics.mean(self.intra_test_cvs) if self.intra_test_cvs else 0.0
47+
48+
@property
49+
def avg_intra_test_std(self) -> float:
50+
return statistics.mean(self.intra_test_stds) if self.intra_test_stds else 0.0
51+
52+
@property
53+
def inter_test_min(self) -> float:
54+
return min(self.intra_test_mins) if self.intra_test_mins else 0.0
55+
56+
def to_dict(self) -> Dict[str, Any]:
57+
return {
58+
"method_name": self.method_name,
59+
"n_tests": self.n_tests,
60+
"reps_per_test": self.reps_per_test,
61+
"intra_test": {
62+
"avg_median_ms": statistics.mean(self.intra_test_medians)
63+
if self.intra_test_medians
64+
else 0,
65+
"avg_std_ms": statistics.mean(self.intra_test_stds)
66+
if self.intra_test_stds
67+
else 0,
68+
"avg_cv": self.avg_intra_test_cv,
69+
"avg_min_ms": statistics.mean(self.intra_test_mins)
70+
if self.intra_test_mins
71+
else 0,
72+
"avg_max_ms": statistics.mean(self.intra_test_maxs)
73+
if self.intra_test_maxs
74+
else 0,
75+
},
76+
"inter_test": {
77+
"median_ms": self.inter_test_median,
78+
"std_ms": self.inter_test_std,
79+
"cv": self.inter_test_cv,
80+
},
81+
"intra_test_medians": self.intra_test_medians,
82+
"intra_test_stds": self.intra_test_stds,
83+
"intra_test_cvs": self.intra_test_cvs,
84+
"all_samples": self.all_samples,
85+
}
86+
87+
88+
def run_do_bench_standard(fn: Callable, warmup: int, rep: int) -> List[float]:
89+
return triton.runtime.driver.active.get_benchmarker()(
90+
fn, warmup=warmup, rep=rep, return_mode="all"
91+
)
92+
93+
94+
def run_do_bench_profiler(fn: Callable, warmup: int, rep: int) -> List[float]:
95+
from tritonbench.components.do_bench.run import _do_bench_profiler
96+
97+
return _do_bench_profiler(
98+
fn,
99+
warmup=warmup,
100+
rep=rep,
101+
return_mode="all",
102+
use_cudagraph=False,
103+
skip_cache_clearing=False,
104+
)
105+
106+
107+
def run_do_bench_cudagraph(fn: Callable, warmup: int, rep: int) -> List[float]:
108+
from tritonbench.components.do_bench.run import _do_bench_cudagraph_with_cache_clear
109+
110+
return _do_bench_cudagraph_with_cache_clear(
111+
fn, rep=rep, return_mode="all", skip_cache_clearing=False
112+
)
113+
114+
115+
def run_do_bench_entropy(fn: Callable, warmup: int, rep: int) -> List[float]:
116+
from tritonbench.components.do_bench.run import _do_bench_entropy
117+
118+
return _do_bench_entropy(fn, warmup=warmup, rep=rep, return_mode="all", repcnt=rep)
119+
120+
121+
def run_do_bench_profiler_cudagraph(fn: Callable, warmup: int, rep: int) -> List[float]:
122+
from tritonbench.components.do_bench.run import _do_bench_profiler
123+
124+
return _do_bench_profiler(
125+
fn,
126+
warmup=warmup,
127+
rep=rep,
128+
return_mode="all",
129+
use_cudagraph=True,
130+
skip_cache_clearing=False,
131+
)
132+
133+
134+
def run_do_bench_gpu_events(fn: Callable, warmup: int, rep: int) -> List[float]:
135+
from tritonbench.components.do_bench.gpu_events import do_bench_events
136+
137+
return do_bench_events(
138+
fn,
139+
warmup=warmup,
140+
rep=rep,
141+
return_mode="all",
142+
skip_cache_clearing=False,
143+
)
144+
145+
146+
BENCHMARK_METHODS = {
147+
"standard": ("triton do_bench (standard)", run_do_bench_standard),
148+
"profiler": ("profiler", run_do_bench_profiler),
149+
"cudagraph": ("CUDA graph", run_do_bench_cudagraph),
150+
"entropy": ("entropy-based", run_do_bench_entropy),
151+
"profiler_cudagraph": ("profiler + CUDA graph", run_do_bench_profiler_cudagraph),
152+
"gpu_events": ("GPU events", run_do_bench_gpu_events),
153+
}
154+
155+
156+
def benchmark_method(
157+
method_name: str,
158+
method_fn: Callable,
159+
kernel_fn: Callable,
160+
n_tests: int,
161+
warmup: int,
162+
rep: int,
163+
sleep_between_tests: float = 0.5,
164+
verbose: bool = True,
165+
) -> MethodStats:
166+
stats = MethodStats(method_name=method_name, n_tests=n_tests, reps_per_test=rep)
167+
168+
for test_idx in range(n_tests):
169+
if verbose:
170+
print(f" Test {test_idx + 1}/{n_tests}...", end=" ", flush=True)
171+
172+
if test_idx > 0 and sleep_between_tests > 0:
173+
time.sleep(sleep_between_tests)
174+
175+
try:
176+
samples = method_fn(kernel_fn, warmup=warmup, rep=rep)
177+
if not samples:
178+
print("WARNING: No samples returned!")
179+
continue
180+
181+
median = statistics.median(samples)
182+
mean = statistics.mean(samples)
183+
std = statistics.stdev(samples) if len(samples) > 1 else 0.0
184+
cv = std / median if median > 0 else 0.0
185+
186+
stats.intra_test_medians.append(median)
187+
stats.intra_test_stds.append(std)
188+
stats.intra_test_cvs.append(cv)
189+
stats.intra_test_mins.append(min(samples))
190+
stats.intra_test_maxs.append(max(samples))
191+
stats.all_samples.append(samples)
192+
193+
if verbose:
194+
print(f"median={median:.4f}ms, mean={mean:.4f}ms, std={std:.4f}ms, cv={cv:.4f}")
195+
196+
except Exception as e:
197+
print(f"ERROR: {e}")
198+
import traceback
199+
traceback.print_exc()
200+
201+
return stats
202+
203+
204+
def print_summary_table(results: Dict[str, MethodStats], operation_name: str):
205+
print("\n" + "=" * 120)
206+
print(f"SUMMARY: Latency Noise Comparison for '{operation_name}'")
207+
print("=" * 120)
208+
209+
header = f"{'Method':<25} | {'Min (ms)':<10} | {'Median (ms)':<12} | {'Intra-Std (ms)':<14} | {'Intra-CV':<10} | {'Inter-CV':<10} | {'Inter-Std (ms)':<14}"
210+
print(header)
211+
print("-" * 120)
212+
213+
for method_name, stats in sorted(results.items(), key=lambda x: x[1].inter_test_cv):
214+
print(
215+
f"{method_name:<25} | "
216+
f"{stats.inter_test_min:<10.4f} | "
217+
f"{stats.inter_test_median:<12.4f} | "
218+
f"{stats.avg_intra_test_std:<14.4f} | "
219+
f"{stats.avg_intra_test_cv:<10.4f} | "
220+
f"{stats.inter_test_cv:<10.4f} | "
221+
f"{stats.inter_test_std:<14.4f}"
222+
)
223+
224+
print("=" * 120)
225+
print("\nLegend: Intra-CV = noise within each run, Inter-CV = noise between runs. Lower = better.\n")
226+
227+
228+
def main():
229+
parser = argparse.ArgumentParser(
230+
description="Compare latency noise across benchmarking methods",
231+
allow_abbrev=False,
232+
)
233+
234+
parser.add_argument("--op", type=str, required=True, help="TritonBench operator")
235+
parser.add_argument("--only", type=str, default=None, help="Kernel implementation(s)")
236+
parser.add_argument("--input-id", type=str, default="0", help="Input config ID")
237+
parser.add_argument("--mode", choices=["fwd", "bwd", "fwd_bwd", "fwd_no_grad"], default="fwd")
238+
parser.add_argument("--precision", type=str, default="fp16")
239+
parser.add_argument("--n-tests", type=int, default=10, help="Benchmark runs per method")
240+
parser.add_argument("--reps-per-test", type=int, default=100, help="Reps per run")
241+
parser.add_argument("--warmup", type=int, default=25, help="Warmup (ms)")
242+
parser.add_argument("--sleep-between-tests", type=float, default=0.5)
243+
parser.add_argument(
244+
"--bench-methods",
245+
type=str,
246+
default="all",
247+
dest="methods",
248+
help=f"Methods: {','.join(BENCHMARK_METHODS.keys())},all",
249+
)
250+
parser.add_argument("--output", type=str, default=None, help="Output JSON path")
251+
parser.add_argument("--quiet", action="store_true")
252+
253+
args, extra_args = parser.parse_known_args()
254+
255+
if not torch.cuda.is_available():
256+
print("ERROR: CUDA is not available!")
257+
sys.exit(1)
258+
259+
device_name = torch.cuda.get_device_name()
260+
print(f"\nLoading operator: {args.op}")
261+
262+
# Use existing tritonbench infrastructure to load operator
263+
from tritonbench.utils.run_utils import load_operator_by_args
264+
265+
tb_arg_list = [
266+
"--op", args.op,
267+
"--mode", args.mode,
268+
"--precision", args.precision,
269+
"--device", "cuda",
270+
"--input-id", args.input_id,
271+
"--num-inputs", "1",
272+
"--test-only",
273+
]
274+
if args.only:
275+
tb_arg_list.extend(["--only", args.only])
276+
tb_arg_list.extend(extra_args)
277+
278+
opbench = load_operator_by_args(tb_arg_list)
279+
opbench.example_inputs = opbench.get_example_inputs()
280+
281+
if opbench.example_inputs is None:
282+
print(f"ERROR: No example inputs for operator '{args.op}'")
283+
sys.exit(1)
284+
285+
# Get the benchmark function
286+
if args.only:
287+
backend_name = args.only.split(",")[0]
288+
bench_fn_factory = getattr(opbench, backend_name, None)
289+
if bench_fn_factory is None:
290+
print(f"ERROR: Backend '{backend_name}' not found")
291+
sys.exit(1)
292+
else:
293+
from tritonbench.utils.triton_op import REGISTERED_BENCHMARKS
294+
registered = REGISTERED_BENCHMARKS.get(opbench.name, {})
295+
if not registered:
296+
print(f"ERROR: No benchmarks registered for '{args.op}'")
297+
sys.exit(1)
298+
backend_name = list(registered.keys())[0]
299+
bench_fn_factory = getattr(opbench, backend_name)
300+
301+
example_inputs = opbench.example_inputs
302+
if isinstance(example_inputs, dict):
303+
kernel_fn = bench_fn_factory(**example_inputs)
304+
else:
305+
kernel_fn = bench_fn_factory(*example_inputs)
306+
307+
operation_name = f"{args.op}:{backend_name} (input_id={args.input_id})"
308+
print(f"Device: {device_name}, Backend: {backend_name}, Tests: {args.n_tests}, Reps: {args.reps_per_test}\n")
309+
310+
# Determine methods to run
311+
if args.methods == "all":
312+
methods_to_run = list(BENCHMARK_METHODS.keys())
313+
else:
314+
methods_to_run = [m.strip() for m in args.methods.split(",")]
315+
for m in methods_to_run:
316+
if m not in BENCHMARK_METHODS:
317+
print(f"ERROR: Unknown method '{m}'. Available: {', '.join(BENCHMARK_METHODS.keys())}")
318+
sys.exit(1)
319+
320+
# Warmup
321+
print("GPU warmup...")
322+
for _ in range(10):
323+
kernel_fn()
324+
torch.cuda.synchronize()
325+
326+
# Run benchmarks
327+
results: Dict[str, MethodStats] = {}
328+
for method_key in methods_to_run:
329+
method_display_name, method_fn = BENCHMARK_METHODS[method_key]
330+
print(f"\n{'='*60}\nBenchmarking: {method_display_name}\n{'='*60}")
331+
332+
stats = benchmark_method(
333+
method_name=method_display_name,
334+
method_fn=method_fn,
335+
kernel_fn=kernel_fn,
336+
n_tests=args.n_tests,
337+
warmup=args.warmup,
338+
rep=args.reps_per_test,
339+
sleep_between_tests=args.sleep_between_tests,
340+
verbose=not args.quiet,
341+
)
342+
results[method_display_name] = stats
343+
344+
print_summary_table(results, operation_name)
345+
346+
if args.output:
347+
output_data = {
348+
"config": {
349+
"device": device_name,
350+
"operator": args.op,
351+
"backend": backend_name,
352+
"input_id": args.input_id,
353+
"mode": args.mode,
354+
"precision": args.precision,
355+
"n_tests": args.n_tests,
356+
"reps_per_test": args.reps_per_test,
357+
"warmup": args.warmup,
358+
},
359+
"results": {name: stats.to_dict() for name, stats in results.items()},
360+
}
361+
with open(args.output, "w") as f:
362+
json.dump(output_data, f, indent=2)
363+
print(f"Results saved to: {args.output}")
364+
365+
366+
if __name__ == "__main__":
367+
main()

0 commit comments

Comments
 (0)