Skip to content

Commit d68305f

Browse files
committed
Inject state class into setUp and tearDown tasks
of the Benchmark class. This is necessary such that the setUp and tearDown tasks know the benchmark states. Namely, how many bench- marks are in the benchmark family and the index of the current. In a follow up implementation of a cache we will use the index and family lenght to compute a condition to empty the cache.
1 parent 65fc45b commit d68305f

File tree

3 files changed

+78
-6
lines changed

3 files changed

+78
-6
lines changed

src/nnbench/core.py

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import warnings
1010
from typing import Any, Callable, Iterable, Union, get_args, get_origin, overload
1111

12-
from nnbench.types import Benchmark
12+
from nnbench.types import Benchmark, State
1313
from nnbench.types.util import is_memo, is_memo_type
1414

1515

@@ -167,7 +167,7 @@ def parametrize(
167167
def decorator(fn: Callable) -> list[Benchmark]:
168168
benchmarks = []
169169
names = set()
170-
for params in parameters:
170+
for idx, params in enumerate(parameters):
171171
_check_against_interface(params, fn)
172172

173173
name = namegen(fn, **params)
@@ -177,8 +177,23 @@ def decorator(fn: Callable) -> list[Benchmark]:
177177
f"Perhaps you specified a parameter configuration twice?"
178178
)
179179
names.add(name)
180+
state = State(
181+
name=name,
182+
function=fn,
183+
family=fn.__name__,
184+
family_size=len(list(parameters)),
185+
family_index=idx,
186+
)
180187

181-
bm = Benchmark(fn, name=name, params=params, setUp=setUp, tearDown=tearDown, tags=tags)
188+
bm = Benchmark(
189+
fn,
190+
name=name,
191+
params=params,
192+
setUp=setUp,
193+
tearDown=tearDown,
194+
tags=tags,
195+
state=state,
196+
)
182197
benchmarks.append(bm)
183198
return benchmarks
184199

@@ -224,7 +239,8 @@ def decorator(fn: Callable) -> list[Benchmark]:
224239
benchmarks = []
225240
names = set()
226241
varnames = iterables.keys()
227-
for values in itertools.product(*iterables.values()):
242+
cartesian_product = list(itertools.product(*iterables.values()))
243+
for idx, values in enumerate(cartesian_product):
228244
params = dict(zip(varnames, values))
229245
_check_against_interface(params, fn)
230246

@@ -235,8 +251,23 @@ def decorator(fn: Callable) -> list[Benchmark]:
235251
f"Perhaps you specified a parameter configuration twice?"
236252
)
237253
names.add(name)
254+
state = State(
255+
name=name,
256+
function=fn,
257+
family=fn.__name__,
258+
family_size=len(cartesian_product),
259+
family_index=idx,
260+
)
238261

239-
bm = Benchmark(fn, name=name, params=params, setUp=setUp, tearDown=tearDown, tags=tags)
262+
bm = Benchmark(
263+
fn,
264+
name=name,
265+
params=params,
266+
setUp=setUp,
267+
tearDown=tearDown,
268+
tags=tags,
269+
state=state,
270+
)
240271
benchmarks.append(bm)
241272
return benchmarks
242273

src/nnbench/types/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .types import Benchmark, BenchmarkRecord, Memo, Parameters
1+
from .types import Benchmark, BenchmarkRecord, Memo, Parameters, State

src/nnbench/types/types.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
TypeVar,
1515
)
1616

17+
from nnbench import __version__
1718
from nnbench.context import Context
1819

1920
T = TypeVar("T")
@@ -101,6 +102,16 @@ def expand(cls, bms: list[dict[str, Any]]) -> BenchmarkRecord:
101102
# context data.
102103

103104

105+
@dataclass(frozen=True)
106+
class State:
107+
name: str
108+
function: Callable[..., Any]
109+
family: str
110+
family_size: int
111+
family_index: int
112+
nnbench_version: str = __version__
113+
114+
104115
class Memo(Generic[T]):
105116
@functools.cache
106117
# TODO: Swap this out for a local type-wide memo cache.
@@ -164,11 +175,41 @@ class Benchmark:
164175
tearDown: Callable[..., None] = field(repr=False, default=NoOp)
165176
tags: tuple[str, ...] = field(repr=False, default=())
166177
interface: Interface = field(init=False, repr=False)
178+
state: State | None = field(default=None)
167179

168180
def __post_init__(self):
169181
if not self.name:
170182
super().__setattr__("name", self.fn.__name__)
171183
super().__setattr__("interface", Interface.from_callable(self.fn, self.params))
184+
if self.state is None:
185+
super().__setattr__(
186+
"state",
187+
State(
188+
name=self.name or "",
189+
function=self.fn,
190+
family=self.fn.__name__,
191+
family_size=1,
192+
family_index=0,
193+
),
194+
)
195+
196+
original_setUp = self.setUp
197+
198+
def wrapped_setUp(*args: Any, **kwargs: Any) -> None:
199+
state = self.state
200+
# TODO: setUp logic
201+
original_setUp(*args, **kwargs)
202+
203+
super().__setattr__("setUp", wrapped_setUp)
204+
205+
original_tearDown = self.tearDown
206+
207+
def wrapped_tearDown(*args: Any, **kwargs: Any) -> None:
208+
state = self.state
209+
# TODO: tearDown logic
210+
original_tearDown(*args, **kwargs)
211+
212+
super().__setattr__("tearDown", wrapped_tearDown)
172213

173214

174215
@dataclass(frozen=True)

0 commit comments

Comments
 (0)