132132from typing import Callable , Sequence , Iterable , Optional , Dict , Any , List , Collection
133133from dataclasses import dataclass
134134
135+ from .mx_util import Stage , MapperHook , FunctionHookAdapter
135136from .support .logging import log_deprecation
136137
137138from . import mx
@@ -513,7 +514,7 @@ def __init__(self, *args, **kwargs):
513514 super (BenchmarkSuite , self ).__init__ (* args , ** kwargs )
514515 self ._desired_version = None
515516 self ._suite_dimensions : DataPoint = {}
516- self ._command_mapper_hooks = {}
517+ self ._command_mapper_hooks : Dict [ str , MapperHook ] = {}
517518 self ._tracker = None
518519 self ._currently_running_benchmark = None
519520 self ._execution_context : List [BenchmarkExecutionContext ] = []
@@ -586,23 +587,27 @@ def currently_running_benchmark(self):
586587 """
587588 return self ._currently_running_benchmark
588589
589- def register_command_mapper_hook (self , name , func ):
590- """Registers a function that takes as input the benchmark suite object and the command to execute and returns
591- a modified command line.
590+ def register_command_mapper_hook (self , name : str , hook : Callable | MapperHook ):
591+ """Registers a command modification hook with stage awareness.
592592
593593 :param string name: Unique name of the hook.
594- :param function func:
594+ :param hook: MapperHook instance or callable
595595 :return: None
596596 """
597- self ._command_mapper_hooks [name ] = func
597+ if isinstance (hook , MapperHook ):
598+ self ._command_mapper_hooks [name ] = hook
599+ elif callable (hook ):
600+ mx .warn (f"Registering a Callable as a command mapper hook '{ name } ' for the benchmark suite. Please use the MapperHook interface instead of Callable as it will be deprecated in the future!" )
601+ self ._command_mapper_hooks [name ] = FunctionHookAdapter (hook )
602+ else :
603+ raise ValueError (f"Hook must be MapperHook or callable, got { type (hook )} " )
604+
598605
599606 def register_tracker (self , name , tracker_type ):
600607 tracker = tracker_type (self )
601608 self ._tracker = tracker
609+ hook = tracker .get_hook ()
602610
603- def hook (cmd , suite ):
604- assert suite == self
605- return tracker .map_command (cmd )
606611 self .register_command_mapper_hook (name , hook )
607612
608613 def version (self ):
@@ -1764,9 +1769,10 @@ def vmArgs(self, bmSuiteArgs):
17641769
17651770 def func (cmd , bmSuite , prefix_command = prefix_command ):
17661771 return prefix_command + cmd
1772+
17671773 if self ._command_mapper_hooks and any (hook_name != profiler for hook_name in self ._command_mapper_hooks ):
17681774 mx .abort (f"Profiler '{ profiler } ' conflicts with trackers '{ ', ' .join ([hook_name for hook_name in self ._command_mapper_hooks if hook_name != profiler ])} '\n Use --tracker none to disable all trackers" )
1769- self ._command_mapper_hooks = { profiler : func }
1775+ self .register_command_mapper_hook ( profiler , FunctionHookAdapter ( func ))
17701776 return args
17711777
17721778 def parserNames (self ):
@@ -2010,7 +2016,7 @@ def command_mapper_hooks(self, hooks):
20102016 """
20112017 Registers a list of `hooks` (given as a tuple 'name', 'func', 'suite') to manipulate the command line before its
20122018 execution.
2013- :param list[tuple] hooks: the list of hooks given as tuples of names and functions
2019+ :param list[tuple] hooks: the list of hooks given as tuples of names and MapperHook instances
20142020 """
20152021 self ._command_mapper_hooks = hooks
20162022
@@ -3078,6 +3084,14 @@ def disable_tracker():
30783084 global _use_tracker
30793085 _use_tracker = False
30803086
3087+ class DefaultTrackerHook (MapperHook ):
3088+ def __init__ (self , tracker : Tracker ):
3089+ self .tracker = tracker
3090+
3091+ def hook (self , cmd , suite = None ):
3092+ assert suite == self .tracker .bmSuite
3093+ return self .tracker .map_command (cmd )
3094+
30813095class Tracker (object ):
30823096 def __init__ (self , bmSuite ):
30833097 self .bmSuite = bmSuite
@@ -3088,6 +3102,9 @@ def map_command(self, cmd):
30883102 def get_rules (self , bmSuiteArgs ):
30893103 raise NotImplementedError ()
30903104
3105+ def get_hook (self ) -> MapperHook :
3106+ return DefaultTrackerHook (self )
3107+
30913108class RssTracker (Tracker ):
30923109 def __init__ (self , bmSuite ):
30933110 super ().__init__ (bmSuite )
@@ -3461,6 +3478,10 @@ def __init__(self, bmSuite):
34613478 # GR-65536
34623479 atexit .register (self .cleanup )
34633480
3481+ def get_hook (self ) -> MapperHook :
3482+ """Returns an energy tracking hook"""
3483+ return EnergyTrackerHook (self )
3484+
34643485 @property
34653486 def baseline_power (self ):
34663487 """Caches the average baseline power value"""
@@ -3559,6 +3580,15 @@ def get_rules(self, bmSuiteArgs):
35593580
35603581 return rules
35613582
3583+ class EnergyTrackerHook (DefaultTrackerHook ):
3584+ """Hook for energy consumption tracking."""
3585+ def __init__ (self , tracker : EnergyConsumptionTracker ):
3586+ assert isinstance (tracker , EnergyConsumptionTracker )
3587+ super ().__init__ (tracker )
3588+
3589+ def should_apply (self , stage : Optional [Stage ]) -> bool :
3590+ return stage .is_final () if stage else False
3591+
35623592_available_trackers = {
35633593 "rss" : RssTracker ,
35643594 "psrecord" : PsrecordTracker ,
0 commit comments