Skip to content

Commit 7863925

Browse files
committed
Inject state class into setUp and tearDown tasks
in the `runner.run()` method. This is necessary such that the setUp and tearDown tasks know the benchmark states. Namely, how many bench- marks are in the benchmark family and the index of the current. In a follow up implementation of a cache we will use the index and family lenght to compute a condition to empty the cache.
1 parent 65fc45b commit 7863925

File tree

4 files changed

+62
-12
lines changed

4 files changed

+62
-12
lines changed

src/nnbench/core.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,14 @@ def decorator(fn: Callable) -> list[Benchmark]:
178178
)
179179
names.add(name)
180180

181-
bm = Benchmark(fn, name=name, params=params, setUp=setUp, tearDown=tearDown, tags=tags)
181+
bm = Benchmark(
182+
fn,
183+
name=name,
184+
params=params,
185+
setUp=setUp,
186+
tearDown=tearDown,
187+
tags=tags,
188+
)
182189
benchmarks.append(bm)
183190
return benchmarks
184191

@@ -236,7 +243,14 @@ def decorator(fn: Callable) -> list[Benchmark]:
236243
)
237244
names.add(name)
238245

239-
bm = Benchmark(fn, name=name, params=params, setUp=setUp, tearDown=tearDown, tags=tags)
246+
bm = Benchmark(
247+
fn,
248+
name=name,
249+
params=params,
250+
setUp=setUp,
251+
tearDown=tearDown,
252+
tags=tags,
253+
)
240254
benchmarks.append(bm)
241255
return benchmarks
242256

src/nnbench/runner.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from __future__ import annotations
44

5+
import collections
56
import contextlib
67
import inspect
78
import logging
@@ -16,6 +17,7 @@
1617

1718
from nnbench.context import Context, ContextProvider
1819
from nnbench.types import Benchmark, BenchmarkRecord, Parameters
20+
from nnbench.types.types import State
1921
from nnbench.types.util import is_memo, is_memo_type
2022
from nnbench.util import import_file_as_module, ismodule
2123

@@ -247,6 +249,12 @@ def run(
247249
if not self.benchmarks:
248250
self.collect(path_or_module, tags)
249251

252+
family_sizes: dict[str, Any] = collections.defaultdict(int)
253+
for bm in self.benchmarks:
254+
family_sizes[bm.fn.__name__] += 1
255+
256+
family_indices: dict[str, Any] = collections.defaultdict(int)
257+
250258
if isinstance(context, Context):
251259
ctx = context
252260
else:
@@ -274,6 +282,14 @@ def _maybe_dememo(v, expected_type):
274282
return v
275283

276284
for benchmark in self.benchmarks:
285+
bm_family = benchmark.fn.__name__
286+
bm_state = State(
287+
name=benchmark.name or bm_family,
288+
family=bm_family,
289+
family_size=family_sizes[bm_family],
290+
family_index=family_indices[bm_family],
291+
)
292+
family_indices[bm_family] += 1
277293
bmtypes = dict(zip(benchmark.interface.names, benchmark.interface.types))
278294
bmparams = dict(zip(benchmark.interface.names, benchmark.interface.defaults))
279295
# TODO: Does this need a copy.deepcopy()?
@@ -291,14 +307,14 @@ def _maybe_dememo(v, expected_type):
291307
"parameters": bmparams,
292308
}
293309
try:
294-
benchmark.setUp(**bmparams)
310+
benchmark.setUp(bm_state, **bmparams)
295311
with timer(res):
296312
res["value"] = benchmark.fn(**bmparams)
297313
except Exception as e:
298314
res["error_occurred"] = True
299315
res["error_message"] = str(e)
300316
finally:
301-
benchmark.tearDown(**bmparams)
317+
benchmark.tearDown(bm_state, **bmparams)
302318
results.append(res)
303319

304320
return BenchmarkRecord(

src/nnbench/types/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .types import Benchmark, BenchmarkRecord, Memo, Parameters
1+
from .types import Benchmark, BenchmarkRecord, Memo, Parameters, State

src/nnbench/types/types.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,7 @@
66
import functools
77
import inspect
88
from dataclasses import dataclass, field
9-
from typing import (
10-
Any,
11-
Callable,
12-
Generic,
13-
Literal,
14-
TypeVar,
15-
)
9+
from typing import Any, Callable, Generic, Literal, TypeVar
1610

1711
from nnbench.context import Context
1812

@@ -101,6 +95,14 @@ def expand(cls, bms: list[dict[str, Any]]) -> BenchmarkRecord:
10195
# context data.
10296

10397

98+
@dataclass(frozen=True)
99+
class State:
100+
name: str
101+
family: str
102+
family_size: int
103+
family_index: int
104+
105+
104106
class Memo(Generic[T]):
105107
@functools.cache
106108
# TODO: Swap this out for a local type-wide memo cache.
@@ -170,6 +172,24 @@ def __post_init__(self):
170172
super().__setattr__("name", self.fn.__name__)
171173
super().__setattr__("interface", Interface.from_callable(self.fn, self.params))
172174

175+
original_setUp = self.setUp
176+
177+
def wrapped_setUp(state: State, /, *args: Any, **kwargs: Any) -> None:
178+
# TODO: setUp logic
179+
print("SetUp: ", state)
180+
original_setUp(*args, **kwargs)
181+
182+
super().__setattr__("setUp", wrapped_setUp)
183+
184+
original_tearDown = self.tearDown
185+
186+
def wrapped_tearDown(state: State, /, *args: Any, **kwargs: Any) -> None:
187+
# TODO: tearDown logic
188+
print("tearDown: ", state)
189+
original_tearDown(*args, **kwargs)
190+
191+
super().__setattr__("tearDown", wrapped_tearDown)
192+
173193

174194
@dataclass(frozen=True)
175195
class Interface:

0 commit comments

Comments
 (0)