Skip to content

Commit 46c0ba3

Browse files
committed
Inject state class into setUp and tearDown tasks
of the Benchmark class. This is necessary such that the setUp and tearDown tasks know the benchmark states. Namely, how many bench- marks are in the benchmark family and the index of the current. In a follow up implementation of a cache we will use the index and family lenght to compute a condition to empty the cache.
1 parent 65fc45b commit 46c0ba3

File tree

3 files changed

+64
-5
lines changed

3 files changed

+64
-5
lines changed

src/nnbench/core.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import warnings
1010
from typing import Any, Callable, Iterable, Union, get_args, get_origin, overload
1111

12-
from nnbench.types import Benchmark
12+
from nnbench.types import Benchmark, State
1313
from nnbench.types.util import is_memo, is_memo_type
1414

1515

@@ -178,7 +178,9 @@ def decorator(fn: Callable) -> list[Benchmark]:
178178
)
179179
names.add(name)
180180

181-
bm = Benchmark(fn, name=name, params=params, setUp=setUp, tearDown=tearDown, tags=tags)
181+
bm = Benchmark(
182+
fn, name=name, params=params, setUp=setUp, tearDown=tearDown, tags=tags
183+
)
182184
benchmarks.append(bm)
183185
return benchmarks
184186

@@ -224,7 +226,8 @@ def decorator(fn: Callable) -> list[Benchmark]:
224226
benchmarks = []
225227
names = set()
226228
varnames = iterables.keys()
227-
for values in itertools.product(*iterables.values()):
229+
cartesian_product = itertools.product(*iterables.values())
230+
for idx, values in enumerate(cartesian_product):
228231
params = dict(zip(varnames, values))
229232
_check_against_interface(params, fn)
230233

@@ -235,8 +238,23 @@ def decorator(fn: Callable) -> list[Benchmark]:
235238
f"Perhaps you specified a parameter configuration twice?"
236239
)
237240
names.add(name)
241+
state = State(
242+
name=name,
243+
function=fn,
244+
family=fn.__name__,
245+
family_size=len(list(cartesian_product)),
246+
family_index=idx,
247+
)
238248

239-
bm = Benchmark(fn, name=name, params=params, setUp=setUp, tearDown=tearDown, tags=tags)
249+
bm = Benchmark(
250+
fn,
251+
name=name,
252+
params=params,
253+
setUp=setUp,
254+
tearDown=tearDown,
255+
tags=tags,
256+
state=state,
257+
)
240258
benchmarks.append(bm)
241259
return benchmarks
242260

src/nnbench/types/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .types import Benchmark, BenchmarkRecord, Memo, Parameters
1+
from .types import Benchmark, BenchmarkRecord, Memo, Parameters, State

src/nnbench/types/types.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
)
1616

1717
from nnbench.context import Context
18+
from nnbench import __version__
1819

1920
T = TypeVar("T")
2021
Variable = tuple[str, type, Any]
@@ -101,6 +102,16 @@ def expand(cls, bms: list[dict[str, Any]]) -> BenchmarkRecord:
101102
# context data.
102103

103104

105+
@dataclass(frozen=True)
106+
class State:
107+
name: str
108+
function: Callable
109+
family: str
110+
family_size: int
111+
family_index: int
112+
nnbench_version: str = __version__
113+
114+
104115
class Memo(Generic[T]):
105116
@functools.cache
106117
# TODO: Swap this out for a local type-wide memo cache.
@@ -164,11 +175,41 @@ class Benchmark:
164175
tearDown: Callable[..., None] = field(repr=False, default=NoOp)
165176
tags: tuple[str, ...] = field(repr=False, default=())
166177
interface: Interface = field(init=False, repr=False)
178+
state: State | None = field(default=None)
167179

168180
def __post_init__(self):
169181
if not self.name:
170182
super().__setattr__("name", self.fn.__name__)
171183
super().__setattr__("interface", Interface.from_callable(self.fn, self.params))
184+
if not self.state:
185+
super().__setattr__(
186+
"state",
187+
State(
188+
name=self.name or "",
189+
function=self.fn,
190+
family=self.fn.__name__,
191+
family_size=1,
192+
family_index=0,
193+
),
194+
)
195+
196+
original_setUp = self.setUp
197+
198+
def wrapped_setUp(*args, **kwargs):
199+
state = self.state
200+
# TODO: setUp and Teardown logic
201+
original_setUp(*args, **kwargs)
202+
203+
super().__setattr__("setUp", wrapped_setUp)
204+
205+
original_tearDown = self.tearDown
206+
207+
def wrapped_tearDown(*args, **kwargs):
208+
state = self.state
209+
# TODO: tearDown logic
210+
original_tearDown(*args, **kwargs)
211+
212+
super().__setattr__("tearDown", wrapped_tearDown)
172213

173214

174215
@dataclass(frozen=True)

0 commit comments

Comments
 (0)