Skip to content

Commit 1fbaca2

Browse files
[GR-64688] Always apply command mapper hooks (trackers) for the image stages
PullRequest: mx/1931
2 parents 7533e1f + 720a254 commit 1fbaca2

File tree

4 files changed

+143
-15
lines changed

4 files changed

+143
-15
lines changed

src/mx/_impl/mx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13906,7 +13906,7 @@ def apply_command_mapper_hooks(command, hooks):
1390613906
for hook in reversed(hooks):
1390713907
hook_name, hook_func, suite = hook[:3]
1390813908
logv(f"Applying command mapper hook '{hook_name}'")
13909-
new_cmd = hook_func(new_cmd, suite)
13909+
new_cmd = hook_func.hook(new_cmd, suite)
1391013910
logv(f"New command: {new_cmd}")
1391113911
else:
1391213912
log("Skipping command mapper hooks as they were disabled explicitly.")

src/mx/_impl/mx_benchmark.py

Lines changed: 41 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@
132132
from typing import Callable, Sequence, Iterable, Optional, Dict, Any, List, Collection
133133
from dataclasses import dataclass
134134

135+
from .mx_util import Stage, MapperHook, FunctionHookAdapter
135136
from .support.logging import log_deprecation
136137

137138
from . import mx
@@ -513,7 +514,7 @@ def __init__(self, *args, **kwargs):
513514
super(BenchmarkSuite, self).__init__(*args, **kwargs)
514515
self._desired_version = None
515516
self._suite_dimensions: DataPoint = {}
516-
self._command_mapper_hooks = {}
517+
self._command_mapper_hooks: Dict[str, MapperHook] = {}
517518
self._tracker = None
518519
self._currently_running_benchmark = None
519520
self._execution_context: List[BenchmarkExecutionContext] = []
@@ -586,23 +587,27 @@ def currently_running_benchmark(self):
586587
"""
587588
return self._currently_running_benchmark
588589

589-
def register_command_mapper_hook(self, name, func):
590-
"""Registers a function that takes as input the benchmark suite object and the command to execute and returns
591-
a modified command line.
590+
def register_command_mapper_hook(self, name: str, hook: Callable | MapperHook):
591+
"""Registers a command modification hook with stage awareness.
592592
593593
:param string name: Unique name of the hook.
594-
:param function func:
594+
:param hook: MapperHook instance or callable
595595
:return: None
596596
"""
597-
self._command_mapper_hooks[name] = func
597+
if isinstance(hook, MapperHook):
598+
self._command_mapper_hooks[name] = hook
599+
elif callable(hook):
600+
mx.warn(f"Registering a Callable as a command mapper hook '{name}' for the benchmark suite. Please use the MapperHook interface instead of Callable as it will be deprecated in the future!")
601+
self._command_mapper_hooks[name] = FunctionHookAdapter(hook)
602+
else:
603+
raise ValueError(f"Hook must be MapperHook or callable, got {type(hook)}")
604+
598605

599606
def register_tracker(self, name, tracker_type):
600607
tracker = tracker_type(self)
601608
self._tracker = tracker
609+
hook = tracker.get_hook()
602610

603-
def hook(cmd, suite):
604-
assert suite == self
605-
return tracker.map_command(cmd)
606611
self.register_command_mapper_hook(name, hook)
607612

608613
def version(self):
@@ -1764,9 +1769,10 @@ def vmArgs(self, bmSuiteArgs):
17641769

17651770
def func(cmd, bmSuite, prefix_command=prefix_command):
17661771
return prefix_command + cmd
1772+
17671773
if self._command_mapper_hooks and any(hook_name != profiler for hook_name in self._command_mapper_hooks):
17681774
mx.abort(f"Profiler '{profiler}' conflicts with trackers '{', '.join([hook_name for hook_name in self._command_mapper_hooks if hook_name != profiler])}'\nUse --tracker none to disable all trackers")
1769-
self._command_mapper_hooks = {profiler: func}
1775+
self.register_command_mapper_hook(profiler, FunctionHookAdapter(func))
17701776
return args
17711777

17721778
def parserNames(self):
@@ -2010,7 +2016,7 @@ def command_mapper_hooks(self, hooks):
20102016
"""
20112017
Registers a list of `hooks` (given as a tuple 'name', 'func', 'suite') to manipulate the command line before its
20122018
execution.
2013-
:param list[tuple] hooks: the list of hooks given as tuples of names and functions
2019+
:param list[tuple] hooks: the list of hooks given as tuples of names and MapperHook instances
20142020
"""
20152021
self._command_mapper_hooks = hooks
20162022

@@ -3078,6 +3084,14 @@ def disable_tracker():
30783084
global _use_tracker
30793085
_use_tracker = False
30803086

3087+
class DefaultTrackerHook(MapperHook):
3088+
def __init__(self, tracker: Tracker):
3089+
self.tracker = tracker
3090+
3091+
def hook(self, cmd, suite=None):
3092+
assert suite == self.tracker.bmSuite
3093+
return self.tracker.map_command(cmd)
3094+
30813095
class Tracker(object):
30823096
def __init__(self, bmSuite):
30833097
self.bmSuite = bmSuite
@@ -3088,6 +3102,9 @@ def map_command(self, cmd):
30883102
def get_rules(self, bmSuiteArgs):
30893103
raise NotImplementedError()
30903104

3105+
def get_hook(self) -> MapperHook:
3106+
return DefaultTrackerHook(self)
3107+
30913108
class RssTracker(Tracker):
30923109
def __init__(self, bmSuite):
30933110
super().__init__(bmSuite)
@@ -3461,6 +3478,10 @@ def __init__(self, bmSuite):
34613478
# GR-65536
34623479
atexit.register(self.cleanup)
34633480

3481+
def get_hook(self) -> MapperHook:
3482+
"""Returns an energy tracking hook"""
3483+
return EnergyTrackerHook(self)
3484+
34643485
@property
34653486
def baseline_power(self):
34663487
"""Caches the average baseline power value"""
@@ -3559,6 +3580,15 @@ def get_rules(self, bmSuiteArgs):
35593580

35603581
return rules
35613582

3583+
class EnergyTrackerHook(DefaultTrackerHook):
3584+
"""Hook for energy consumption tracking."""
3585+
def __init__(self, tracker: EnergyConsumptionTracker):
3586+
assert isinstance(tracker, EnergyConsumptionTracker)
3587+
super().__init__(tracker)
3588+
3589+
def should_apply(self, stage: Optional[Stage]) -> bool:
3590+
return stage.is_final() if stage else False
3591+
35623592
_available_trackers = {
35633593
"rss": RssTracker,
35643594
"psrecord": PsrecordTracker,

src/mx/_impl/mx_util.py

Lines changed: 100 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#
2525
# ----------------------------------------------------------------------------------------------------
2626
#
27+
from __future__ import annotations
2728

2829
#
2930
# Utility functions for use by mx and mx extensions.
@@ -37,14 +38,21 @@
3738
"get_file_extension",
3839
"ensure_dirname_exists",
3940
"ensure_dir_exists",
40-
"SafeFileCreation"
41+
"SafeFileCreation",
42+
"Stage",
43+
"StageName",
44+
"Layer",
45+
"MapperHook",
46+
"FunctionHookAdapter"
4147
]
4248

4349
import os.path
4450
import errno
4551
import sys
4652
import tempfile
47-
from typing import Optional
53+
from dataclasses import dataclass
54+
from enum import Enum
55+
from typing import Optional, List, Callable
4856
from os.path import dirname, exists, join, isdir, basename
4957

5058
min_required_python_version = (3, 8)
@@ -215,3 +223,93 @@ def _create_tmp_files(tmp_dir, num):
215223
pid = os.getpid()
216224
with open(sfc.tmpPath, 'w') as out:
217225
print(f"file {i} created by process {pid}", file=out)
226+
227+
228+
@dataclass(frozen = True)
229+
class Stage:
230+
stage_name: StageName
231+
layer_info: Layer = None
232+
233+
@staticmethod
234+
def from_string(s: str) -> Stage:
235+
return Stage(StageName(s))
236+
237+
def is_image(self) -> bool:
238+
return self.stage_name.is_image()
239+
240+
def is_instrument(self) -> bool:
241+
return self.stage_name.is_instrument()
242+
243+
def is_agent(self) -> bool:
244+
return self.stage_name.is_agent()
245+
246+
def is_final(self) -> bool:
247+
return self.stage_name.is_final()
248+
249+
def is_layered(self) -> bool:
250+
return self.layer_info is not None
251+
252+
def is_requested(self, request: str):
253+
"""Whether the 'request' is equal to either the full name of the stage or it's name without layer info."""
254+
return str(self) == request or str(self.stage_name) == request
255+
256+
def __str__(self):
257+
if not self.is_layered():
258+
return str(self.stage_name)
259+
return f"{self.stage_name}-{self.layer_info}"
260+
261+
class StageName(Enum):
262+
AGENT = "agent"
263+
INSTRUMENT_IMAGE = "instrument-image"
264+
INSTRUMENT_RUN = "instrument-run"
265+
IMAGE = "image"
266+
RUN = "run"
267+
268+
def __str__(self):
269+
return self.value
270+
271+
def is_image(self) -> bool:
272+
"""Whether this is an image stage (a stage that performs an image build)"""
273+
return self in [StageName.INSTRUMENT_IMAGE, StageName.IMAGE]
274+
275+
def is_instrument(self) -> bool:
276+
"""Whether this is an image stage (a stage that performs an image build)"""
277+
return self in [StageName.INSTRUMENT_IMAGE, StageName.INSTRUMENT_RUN]
278+
279+
def is_agent(self) -> bool:
280+
return self == StageName.AGENT
281+
282+
def is_final(self) -> bool:
283+
return self in [StageName.IMAGE, StageName.RUN]
284+
285+
@dataclass(frozen = True)
286+
class Layer:
287+
index: int
288+
is_shared_library: bool
289+
290+
def __str__(self):
291+
return f"layer{self.index}"
292+
293+
294+
class MapperHook:
295+
"""Base class for all command mapper hooks."""
296+
def hook(self, cmd: List[str], suite=None) -> List[str]:
297+
raise NotImplementedError()
298+
299+
def should_apply(self, stage: Optional[Stage]) -> bool:
300+
"""Determines if this hook should be applied for the given stage. By default, hooks shouldn't apply to image stages.
301+
Args:
302+
stage: For NativeImage suites, the current stage
303+
For JIT suites, None
304+
Returns:
305+
bool: True if the hook should be applied
306+
"""
307+
return not stage.is_image() if stage else True
308+
309+
class FunctionHookAdapter(MapperHook):
310+
"""Adapter for function hooks."""
311+
def __init__(self, func: Callable[[List[str]], List[str]]):
312+
self.func = func
313+
314+
def hook(self, cmd: List[str], suite=None) -> List[str]:
315+
return self.func(cmd, suite)

src/mx/mx_version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
# The version must be updated for every PR (checked in CI) and the comment should reflect the PR's issue
2-
version = "7.58.5" # GR-66866: Suppress compile warnings in shaded jar projects in Truffle for Intellij.
2+
version = "7.58.6" # GR-64688 Always apply command mapper hooks (trackers) for the image stages

0 commit comments

Comments
 (0)