33# See LICENSE.TXT
44# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
55
6- from dataclasses import dataclass
76import os
87import shutil
98import subprocess
1211from options import options
1312from utils .utils import download , run
1413from abc import ABC , abstractmethod
14+ from utils .unitrace import get_unitrace
15+ from utils .logger import log
1516
1617benchmark_tags = [
1718 BenchmarkTag ("SYCL" , "Benchmark uses SYCL runtime" ),
@@ -61,6 +62,12 @@ def enabled(self) -> bool:
6162 By default, it returns True, but can be overridden to disable a benchmark."""
6263 return True
6364
65+ def traceable (self ) -> bool :
66+ """Returns whether this benchmark should be traced by Unitrace.
67+ By default, it returns True, but can be overridden to disable tracing for a benchmark.
68+ """
69+ return True
70+
6471 @abstractmethod
6572 def setup (self ):
6673 pass
@@ -70,11 +77,12 @@ def teardown(self):
7077 pass
7178
7279 @abstractmethod
73- def run (self , env_vars ) -> list [Result ]:
80+ def run (self , env_vars , run_unitrace : bool = False ) -> list [Result ]:
7481 """Execute the benchmark with the given environment variables.
7582
7683 Args:
7784 env_vars: Environment variables to use when running the benchmark.
85+ run_unitrace: Whether to run benchmark under Unitrace.
7886
7987 Returns:
8088 A list of Result objects with the benchmark results.
@@ -97,7 +105,14 @@ def get_adapter_full_path():
97105 ), f"could not find adapter file { adapter_path } (and in similar lib paths)"
98106
99107 def run_bench (
100- self , command , env_vars , ld_library = [], add_sycl = True , use_stdout = True
108+ self ,
109+ command ,
110+ env_vars ,
111+ ld_library = [],
112+ add_sycl = True ,
113+ use_stdout = True ,
114+ run_unitrace = False ,
115+ extra_unitrace_opt = None ,
101116 ):
102117 env_vars = env_vars .copy ()
103118 if options .ur is not None :
@@ -110,13 +125,30 @@ def run_bench(
110125 ld_libraries = options .extra_ld_libraries .copy ()
111126 ld_libraries .extend (ld_library )
112127
113- result = run (
114- command = command ,
115- env_vars = env_vars ,
116- add_sycl = add_sycl ,
117- cwd = options .benchmark_cwd ,
118- ld_library = ld_libraries ,
119- )
128+ if self .traceable () and run_unitrace :
129+ if extra_unitrace_opt is None :
130+ extra_unitrace_opt = []
131+ unitrace_output , command = get_unitrace ().setup (
132+ self .name (), command , extra_unitrace_opt
133+ )
134+ log .debug (f"Unitrace output: { unitrace_output } " )
135+ log .debug (f"Unitrace command: { ' ' .join (command )} " )
136+
137+ try :
138+ result = run (
139+ command = command ,
140+ env_vars = env_vars ,
141+ add_sycl = add_sycl ,
142+ cwd = options .benchmark_cwd ,
143+ ld_library = ld_libraries ,
144+ )
145+ except subprocess .CalledProcessError :
146+ if run_unitrace :
147+ get_unitrace ().cleanup (options .benchmark_cwd , unitrace_output )
148+ raise
149+
150+ if self .traceable () and run_unitrace :
151+ get_unitrace ().handle_output (unitrace_output )
120152
121153 if use_stdout :
122154 return result .stdout .decode ()
0 commit comments