Skip to content

Commit 9ff24b5

Browse files
author
Andrija Kolic
committed
[GR-64081] Add bench-suite execution context field
PullRequest: mx/1911
2 parents b15b40e + e459e69 commit 9ff24b5

File tree

2 files changed

+101
-29
lines changed

2 files changed

+101
-29
lines changed

src/mx/_impl/mx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18584,7 +18584,7 @@ def alarm_handler(signum, frame):
1858418584
_CACHE_DIR = get_env('MX_CACHE_DIR', join(dot_mx_dir(), 'cache'))
1858518585

1858618586
# The version must be updated for every PR (checked in CI) and the comment should reflect the PR's issue
18587-
version = VersionSpec("7.54.2") # GR-65332 improve mx benchmark docs
18587+
version = VersionSpec("7.54.3") # GR-64081: Supporting bench suite context classes for new 'graalos' benchmark suite
1858818588

1858918589
_mx_start_datetime = datetime.utcnow()
1859018590

src/mx/_impl/mx_benchmark.py

Lines changed: 100 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
"add_parser",
3838
"get_parser",
3939
"VmRegistry",
40+
"SingleBenchmarkExecutionContext",
4041
"BenchmarkSuite",
4142
"add_bm_suite",
4243
"bm_suite_valid_keys",
@@ -128,6 +129,7 @@
128129
from argparse import SUPPRESS
129130
from collections import OrderedDict, abc
130131
from typing import Callable, Sequence, Iterable, Optional, Dict, Any, List, Collection
132+
from dataclasses import dataclass
131133

132134
from .support.logging import log_deprecation
133135

@@ -448,6 +450,58 @@ def get_vms(self):
448450
return list(self._vms.values())
449451

450452

453+
@dataclass(frozen = True)
454+
class BenchmarkExecutionContext():
455+
"""
456+
Container class for the runtime context of a benchmark suite.
457+
* suite: The benchmark suite to which the currently running benchmarks belong to.
458+
* vm: The virtual machine on which the currently running benchmarks are executing on.
459+
* benchmarks: The names of the currently running benchmarks.
460+
* bmSuiteArgs: The arguments passed to the benchmark suite for the current benchmark run.
461+
"""
462+
suite: BenchmarkSuite
463+
virtual_machine: Vm
464+
benchmarks: List[str]
465+
bmSuiteArgs: List[str]
466+
467+
def __enter__(self):
468+
self.suite.push_execution_context(self)
469+
return self
470+
471+
def __exit__(self, exc_type, exc_value, exc_traceback):
472+
self.suite.pop_execution_context()
473+
474+
@dataclass(frozen = True)
475+
class SingleBenchmarkExecutionContext(BenchmarkExecutionContext):
476+
"""
477+
Container class for the runtime context of a benchmark suite that can only run a single benchmark.
478+
* benchmark: The name of the currently running benchmark.
479+
"""
480+
benchmark: str
481+
482+
def __init__(self, suite: BenchmarkSuite, vm: Vm, benchmarks: List[str], bmSuiteArgs: List[str]):
483+
super().__init__(suite, vm, benchmarks, bmSuiteArgs)
484+
self._enforce_single_benchmark(benchmarks, suite)
485+
# Assigning to the 'benchmark' field directly is not possible because of @dataclass(frozen = True).
486+
# If the 'frozen' attribute were to be set to false, we could assign directly here, but then assignments would
487+
# be allowed at any point.
488+
object.__setattr__(self, "benchmark", benchmarks[0])
489+
490+
def _enforce_single_benchmark(self, benchmarks: List[str], suite: BenchmarkSuite):
491+
"""
492+
Asserts that a single benchmark is requested to run.
493+
Raises an exception if none or multiple benchmarks are requested.
494+
"""
495+
if not isinstance(benchmarks, list):
496+
raise TypeError(f"{suite.__class__.__name__} expects to receive a list of benchmarks to run,"
497+
f" instead got an instance of {benchmarks.__class__.__name__}!"
498+
f" Please specify a single benchmark!")
499+
if len(benchmarks) != 1:
500+
raise ValueError(f"You have requested {benchmarks} to be run but {suite.__class__.__name__}"
501+
f" can only run a single benchmark at a time!"
502+
f" Please specify a single benchmark!")
503+
504+
451505
class BenchmarkSuite(object):
452506
"""
453507
A harness for a benchmark suite.
@@ -461,6 +515,7 @@ def __init__(self, *args, **kwargs):
461515
self._command_mapper_hooks = {}
462516
self._tracker = None
463517
self._currently_running_benchmark = None
518+
self._execution_context: List[BenchmarkExecutionContext] = []
464519

465520
def name(self):
466521
"""Returns the name of the suite to execute.
@@ -756,6 +811,21 @@ def expandBmSuiteArgs(self, benchmarks, bmSuiteArgs):
756811
"""
757812
return [bmSuiteArgs]
758813

814+
def new_execution_context(self, vm: Vm, benchmarks: List[str], bmSuiteArgs: List[str]) -> BenchmarkExecutionContext:
815+
return BenchmarkExecutionContext(self, vm, benchmarks, bmSuiteArgs)
816+
817+
def pop_execution_context(self) -> BenchmarkExecutionContext:
818+
return self._execution_context.pop()
819+
820+
def push_execution_context(self, context: BenchmarkExecutionContext):
821+
self._execution_context.append(context)
822+
823+
@property
824+
def execution_context(self) -> Optional[BenchmarkExecutionContext]:
825+
if len(self._execution_context) == 0:
826+
return None
827+
return self._execution_context[-1]
828+
759829

760830

761831
def add_bm_suite(suite, mxsuite=None):
@@ -1466,9 +1536,10 @@ class StdOutBenchmarkSuite(BenchmarkSuite):
14661536
5. Use the parse rules on the standard output to create data points.
14671537
"""
14681538
def run(self, benchmarks, bmSuiteArgs) -> DataPoints:
1469-
retcode, out, dims = self.runAndReturnStdOut(benchmarks, bmSuiteArgs)
1470-
datapoints = self.validateStdoutWithDimensions(out, benchmarks, bmSuiteArgs, retcode=retcode, dims=dims)
1471-
return datapoints
1539+
with self.new_execution_context(None, benchmarks, bmSuiteArgs):
1540+
retcode, out, dims = self.runAndReturnStdOut(benchmarks, bmSuiteArgs)
1541+
datapoints = self.validateStdoutWithDimensions(out, benchmarks, bmSuiteArgs, retcode=retcode, dims=dims)
1542+
return datapoints
14721543

14731544
def validateStdoutWithDimensions(
14741545
self, out, benchmarks, bmSuiteArgs, retcode=None, dims=None, extraRules=None) -> DataPoints:
@@ -1761,31 +1832,32 @@ def runAndReturnStdOut(self, benchmarks, bmSuiteArgs):
17611832
vm = self.get_vm_registry().get_vm_from_suite_args(bmSuiteArgs)
17621833
vm.extract_vm_info(self.vmArgs(bmSuiteArgs))
17631834
vm.command_mapper_hooks = [(name, func, self) for name, func in self._command_mapper_hooks.items()]
1764-
t = self._vmRun(vm, cwd, command, benchmarks, bmSuiteArgs)
1765-
if len(t) == 2:
1766-
ret_code, out = t
1767-
vm_dims = {}
1768-
else:
1769-
ret_code, out, vm_dims = t
1770-
host_vm = None
1771-
if isinstance(vm, GuestVm):
1772-
host_vm = vm.host_vm()
1773-
assert host_vm
1774-
dims = {
1775-
"vm": vm.name(),
1776-
"host-vm": host_vm.name() if host_vm else vm.name(),
1777-
"host-vm-config": self.host_vm_config_name(host_vm, vm),
1778-
"guest-vm": vm.name() if host_vm else "none",
1779-
"guest-vm-config": self.guest_vm_config_name(host_vm, vm),
1780-
}
1781-
for key, value in vm_dims.items():
1782-
if key in dims and value != dims[key]:
1783-
if value == 'none':
1784-
mx.warn(f"VM {vm.name()}:{vm.config_name()} ({vm.__class__.__name__}) tried overwriting {key}='{dims[key]}' with '{value}', keeping '{dims[key]}'")
1785-
continue
1786-
mx.warn(f"VM {vm.name()}:{vm.config_name()} ({vm.__class__.__name__}) is overwriting {key}='{dims[key]}' with '{value}'")
1787-
dims[key] = value
1788-
return ret_code, out, dims
1835+
with self.new_execution_context(vm, benchmarks, bmSuiteArgs):
1836+
t = self._vmRun(vm, cwd, command, benchmarks, bmSuiteArgs)
1837+
if len(t) == 2:
1838+
ret_code, out = t
1839+
vm_dims = {}
1840+
else:
1841+
ret_code, out, vm_dims = t
1842+
host_vm = None
1843+
if isinstance(vm, GuestVm):
1844+
host_vm = vm.host_vm()
1845+
assert host_vm
1846+
dims = {
1847+
"vm": vm.name(),
1848+
"host-vm": host_vm.name() if host_vm else vm.name(),
1849+
"host-vm-config": self.host_vm_config_name(host_vm, vm),
1850+
"guest-vm": vm.name() if host_vm else "none",
1851+
"guest-vm-config": self.guest_vm_config_name(host_vm, vm),
1852+
}
1853+
for key, value in vm_dims.items():
1854+
if key in dims and value != dims[key]:
1855+
if value == 'none':
1856+
mx.warn(f"VM {vm.name()}:{vm.config_name()} ({vm.__class__.__name__}) tried overwriting {key}='{dims[key]}' with '{value}', keeping '{dims[key]}'")
1857+
continue
1858+
mx.warn(f"VM {vm.name()}:{vm.config_name()} ({vm.__class__.__name__}) is overwriting {key}='{dims[key]}' with '{value}'")
1859+
dims[key] = value
1860+
return ret_code, out, dims
17891861

17901862
def host_vm_config_name(self, host_vm, vm):
17911863
return host_vm.config_name() if host_vm else vm.config_name()

0 commit comments

Comments
 (0)