Skip to content

Commit 670a857

Browse files
committed
wip - State injection into setup and teardown
1 parent 65fc45b commit 670a857

File tree

3 files changed

+64
-5
lines changed

3 files changed

+64
-5
lines changed

src/nnbench/core.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import warnings
1010
from typing import Any, Callable, Iterable, Union, get_args, get_origin, overload
1111

12-
from nnbench.types import Benchmark
12+
from nnbench.types import Benchmark, State
1313
from nnbench.types.util import is_memo, is_memo_type
1414

1515

@@ -178,7 +178,9 @@ def decorator(fn: Callable) -> list[Benchmark]:
178178
)
179179
names.add(name)
180180

181-
bm = Benchmark(fn, name=name, params=params, setUp=setUp, tearDown=tearDown, tags=tags)
181+
bm = Benchmark(
182+
fn, name=name, params=params, setUp=setUp, tearDown=tearDown, tags=tags
183+
)
182184
benchmarks.append(bm)
183185
return benchmarks
184186

@@ -224,7 +226,8 @@ def decorator(fn: Callable) -> list[Benchmark]:
224226
benchmarks = []
225227
names = set()
226228
varnames = iterables.keys()
227-
for values in itertools.product(*iterables.values()):
229+
cartesian_product = itertools.product(*iterables.values())
230+
for idx, values in enumerate(cartesian_product):
228231
params = dict(zip(varnames, values))
229232
_check_against_interface(params, fn)
230233

@@ -235,8 +238,23 @@ def decorator(fn: Callable) -> list[Benchmark]:
235238
f"Perhaps you specified a parameter configuration twice?"
236239
)
237240
names.add(name)
241+
state = State(
242+
name=name,
243+
function=fn,
244+
family=fn.__name__,
245+
family_size=len(list(cartesian_product)),
246+
family_index=idx,
247+
)
238248

239-
bm = Benchmark(fn, name=name, params=params, setUp=setUp, tearDown=tearDown, tags=tags)
249+
bm = Benchmark(
250+
fn,
251+
name=name,
252+
params=params,
253+
setUp=setUp,
254+
tearDown=tearDown,
255+
tags=tags,
256+
state=state,
257+
)
240258
benchmarks.append(bm)
241259
return benchmarks
242260

src/nnbench/types/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .types import Benchmark, BenchmarkRecord, Memo, Parameters
1+
from .types import Benchmark, BenchmarkRecord, Memo, Parameters, State

src/nnbench/types/types.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
)
1616

1717
from nnbench.context import Context
18+
from nnbench import __version__
1819

1920
T = TypeVar("T")
2021
Variable = tuple[str, type, Any]
@@ -101,6 +102,16 @@ def expand(cls, bms: list[dict[str, Any]]) -> BenchmarkRecord:
101102
# context data.
102103

103104

105+
@dataclass(frozen=True)
106+
class State:
107+
name: str
108+
function: Callable
109+
family: str
110+
family_size: int
111+
family_index: int
112+
nnbench_version: str = __version__
113+
114+
104115
class Memo(Generic[T]):
105116
@functools.cache
106117
# TODO: Swap this out for a local type-wide memo cache.
@@ -164,11 +175,41 @@ class Benchmark:
164175
tearDown: Callable[..., None] = field(repr=False, default=NoOp)
165176
tags: tuple[str, ...] = field(repr=False, default=())
166177
interface: Interface = field(init=False, repr=False)
178+
state: State | None = field(default=None)
167179

168180
def __post_init__(self):
169181
if not self.name:
170182
super().__setattr__("name", self.fn.__name__)
171183
super().__setattr__("interface", Interface.from_callable(self.fn, self.params))
184+
if not self.state:
185+
super().__setattr__(
186+
"state",
187+
State(
188+
name=self.name or "",
189+
function=self.fn,
190+
family=self.fn.__name__,
191+
family_size=1,
192+
family_index=0,
193+
),
194+
)
195+
196+
original_setUp = self.setUp
197+
198+
def wrapped_setUp(*args, **kwargs):
199+
state = self.state
200+
# TODO: setUp and Teardown logic
201+
original_setUp(*args, **kwargs)
202+
203+
super().__setattr__("setUp", wrapped_setUp)
204+
205+
original_tearDown = self.tearDown
206+
207+
def wrapped_tearDown(*args, **kwargs):
208+
state = self.state
209+
# TODO: tearDown logic
210+
original_tearDown(*args, **kwargs)
211+
212+
super().__setattr__("tearDown", wrapped_tearDown)
172213

173214

174215
@dataclass(frozen=True)

0 commit comments

Comments
 (0)