-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
State injection into setup and teardown #127
Conversation
670a857
to
46c0ba3
Compare
59db52c
to
9271c50
Compare
src/nnbench/core.py
Outdated
@@ -224,7 +239,8 @@ def decorator(fn: Callable) -> list[Benchmark]: | |||
benchmarks = [] | |||
names = set() | |||
varnames = iterables.keys() | |||
for values in itertools.product(*iterables.values()): | |||
cartesian_product = list(itertools.product(*iterables.values())) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
list
here forces execution of the generator. I think that is no problem as the number of benchmarks will not be too computationally extensive. Additionally, we nee the number of values in the cartesian product as well as the index anyway, so we might as well unravel the generator here instead of in the following for loop.
src/nnbench/types/types.py
Outdated
if self.state is None: | ||
super().__setattr__( | ||
"state", | ||
State( | ||
name=self.name or "", | ||
function=self.fn, | ||
family=self.fn.__name__, | ||
family_size=1, | ||
family_index=0, | ||
), | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it sufficiently clear that this is the default if the Benchmark state is not set? that is the case for standalone benchmarks (which do not come out of parametrize
or product
). Therefore family_size=1
and family_index=0
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the state should not be tied to the benchmark at all. Right now, the dependency is Benchmark -> State
(as in, the state is a member of the benchmark struct), and State even has a reference to the benchmark function.
What I am looking for is more an on-the-fly constructed POD struct (i.e. holding data only, to be discussed) in runner.run()
, mapping a benchmark to its corresponding state irrespective of its origin.
You can infer name, family name, and index from the benchmark data no problem, and you just need to bookkeep a small map of family name -> index.
But as for the family size, you're right that this will be needed before running the benchmarks. The single benchmark case is easy (size 1), with the parametrized ones it's a little harder.
Maybe the @parametrize/@product
decorators should return a BenchmarkFamily
object, which holds its size as a property?
But for now, just go through the self.benchmarks
list and populate the family sizes like so:
family_sizes = collections.defaultdict(int)
family_idxs = collections.defaultdict(int)
for bm in self.benchmarks:
family_sizes[bm.fn.__name__] += 1
...
# later, in execution loop:
state = State(..., family_size=family_sizes[bm.fn.__name__], family_idx=family_idxs[bm.fn.__name__])
# advance index after state dispatch.
family_idxs[bm.fn.__name__] += 1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! We're almost there.
@@ -178,7 +178,14 @@ def decorator(fn: Callable) -> list[Benchmark]: | |||
) | |||
names.add(name) | |||
|
|||
bm = Benchmark(fn, name=name, params=params, setUp=setUp, tearDown=tearDown, tags=tags) | |||
bm = Benchmark( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did ruff change this formatting?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that was my pyright autoformatter.
Do you prefer it in a single line?
src/nnbench/runner.py
Outdated
@@ -274,6 +282,14 @@ def _maybe_dememo(v, expected_type): | |||
return v | |||
|
|||
for benchmark in self.benchmarks: | |||
bm_family = benchmark.fn.__name__ | |||
bm_state = State( | |||
name=benchmark.name or bm_family, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
benchmark.name
should never be None (or ""), so you can leave out the or branch.
src/nnbench/types/types.py
Outdated
original_setUp = self.setUp | ||
|
||
def wrapped_setUp(state: State, /, *args: Any, **kwargs: Any) -> None: | ||
# TODO: setUp logic | ||
print("SetUp: ", state) | ||
original_setUp(*args, **kwargs) | ||
|
||
super().__setattr__("setUp", wrapped_setUp) | ||
|
||
original_tearDown = self.tearDown | ||
|
||
def wrapped_tearDown(state: State, /, *args: Any, **kwargs: Any) -> None: | ||
# TODO: tearDown logic | ||
print("tearDown: ", state) | ||
original_tearDown(*args, **kwargs) | ||
|
||
super().__setattr__("tearDown", wrapped_tearDown) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, no - I meant we should change the setUp/tearDown calling convention, as in, require callables of signature (state, params) => Any
, not wrap the callable itself.
(We will need to adjust the docs on setUp and tearDown tasks afterwards.)
NB: The NoOp
from that module must then become def NoOp(state, params): pass
.
I guess it is nicer to pass the dict instead of the kwargs. But this requires wrapping it into a form where the user cannot inject any parameters, e.g. types.MappingProxy
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work!
class CallbackProtocol(Protocol): | ||
def __call__(self, state: State, params: Mapping[str, Any]) -> None: ... | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would just go with Callable[[State, Mapping[str, Any]], None]
, unless I'm missing something?
in the `runner.run()` method. This is necessary such that the setUp and tearDown tasks know the benchmark states. Namely, how many bench- marks are in the benchmark family and the index of the current. In a follow up implementation of a cache we will use the index and family length to compute a condition to empty the cache.
Closes #125
Trailing ToDos (actual work with the
State
Object in setup and teardown tasks will be addressed in the PR for #126 .