Skip to content

Commit 20e9fe8

Browse files
committed
Inject state class into setUp and tearDown tasks
in the `runner.run()` method. This is necessary such that the setUp and tearDown tasks know the benchmark states. Namely, how many bench- marks are in the benchmark family and the index of the current. In a follow up implementation of a cache we will use the index and family length to compute a condition to empty the cache.
1 parent d5e033f commit 20e9fe8

File tree

4 files changed

+54
-21
lines changed

4 files changed

+54
-21
lines changed

src/nnbench/core.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from typing import Any, Callable, Iterable, Union, get_args, get_origin, overload
1111

1212
from nnbench.types import Benchmark
13+
from nnbench.types.types import NoOp
1314
from nnbench.types.util import is_memo, is_memo_type
1415

1516

@@ -52,10 +53,6 @@ def _default_namegen(fn: Callable, **kwargs: Any) -> str:
5253
return fn.__name__ + "_" + "_".join(f"{k}={v}" for k, v in kwargs.items())
5354

5455

55-
def NoOp(**kwargs: Any) -> None:
56-
pass
57-
58-
5956
# Overloads for the ``benchmark`` decorator.
6057
# Case #1: Bare application without parentheses
6158
# @nnbench.benchmark
@@ -178,7 +175,14 @@ def decorator(fn: Callable) -> list[Benchmark]:
178175
)
179176
names.add(name)
180177

181-
bm = Benchmark(fn, name=name, params=params, setUp=setUp, tearDown=tearDown, tags=tags)
178+
bm = Benchmark(
179+
fn,
180+
name=name,
181+
params=params,
182+
setUp=setUp,
183+
tearDown=tearDown,
184+
tags=tags,
185+
)
182186
benchmarks.append(bm)
183187
return benchmarks
184188

@@ -236,7 +240,14 @@ def decorator(fn: Callable) -> list[Benchmark]:
236240
)
237241
names.add(name)
238242

239-
bm = Benchmark(fn, name=name, params=params, setUp=setUp, tearDown=tearDown, tags=tags)
243+
bm = Benchmark(
244+
fn,
245+
name=name,
246+
params=params,
247+
setUp=setUp,
248+
tearDown=tearDown,
249+
tags=tags,
250+
)
240251
benchmarks.append(bm)
241252
return benchmarks
242253

src/nnbench/runner.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from __future__ import annotations
44

5+
import collections
56
import contextlib
67
import inspect
78
import logging
@@ -15,7 +16,7 @@
1516
from typing import Any, Callable, Generator, Sequence, get_origin
1617

1718
from nnbench.context import Context, ContextProvider
18-
from nnbench.types import Benchmark, BenchmarkRecord, Parameters
19+
from nnbench.types import Benchmark, BenchmarkRecord, Parameters, State
1920
from nnbench.types.util import is_memo, is_memo_type
2021
from nnbench.util import import_file_as_module, ismodule
2122

@@ -247,6 +248,9 @@ def run(
247248
if not self.benchmarks:
248249
self.collect(path_or_module, tags)
249250

251+
family_sizes: dict[str, Any] = collections.defaultdict(int)
252+
family_indices: dict[str, Any] = collections.defaultdict(int)
253+
250254
if isinstance(context, Context):
251255
ctx = context
252256
else:
@@ -259,6 +263,9 @@ def run(
259263
warnings.warn(f"No benchmarks found in path/module {str(path_or_module)!r}.")
260264
return BenchmarkRecord(context=ctx, benchmarks=[])
261265

266+
for bm in self.benchmarks:
267+
family_sizes[bm.fn.__name__] += 1
268+
262269
if isinstance(params, Parameters):
263270
dparams = asdict(params)
264271
else:
@@ -274,6 +281,14 @@ def _maybe_dememo(v, expected_type):
274281
return v
275282

276283
for benchmark in self.benchmarks:
284+
bm_family = benchmark.fn.__name__
285+
state = State(
286+
name=benchmark.name,
287+
family=bm_family,
288+
family_size=family_sizes[bm_family],
289+
family_index=family_indices[bm_family],
290+
)
291+
family_indices[bm_family] += 1
277292
bmtypes = dict(zip(benchmark.interface.names, benchmark.interface.types))
278293
bmparams = dict(zip(benchmark.interface.names, benchmark.interface.defaults))
279294
# TODO: Does this need a copy.deepcopy()?
@@ -291,14 +306,14 @@ def _maybe_dememo(v, expected_type):
291306
"parameters": bmparams,
292307
}
293308
try:
294-
benchmark.setUp(**bmparams)
309+
benchmark.setUp(state, bmparams)
295310
with timer(res):
296311
res["value"] = benchmark.fn(**bmparams)
297312
except Exception as e:
298313
res["error_occurred"] = True
299314
res["error_message"] = str(e)
300315
finally:
301-
benchmark.tearDown(**bmparams)
316+
benchmark.tearDown(state, bmparams)
302317
results.append(res)
303318

304319
return BenchmarkRecord(

src/nnbench/types/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .types import Benchmark, BenchmarkRecord, Memo, Parameters
1+
from .types import Benchmark, BenchmarkRecord, Memo, Parameters, State

src/nnbench/types/types.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,24 +6,23 @@
66
import functools
77
import inspect
88
from dataclasses import dataclass, field
9-
from typing import (
10-
Any,
11-
Callable,
12-
Generic,
13-
Literal,
14-
TypeVar,
15-
)
9+
from types import MappingProxyType
10+
from typing import Any, Callable, Generic, Literal, Mapping, Protocol, TypeVar
1611

1712
from nnbench.context import Context
1813

1914
T = TypeVar("T")
2015
Variable = tuple[str, type, Any]
2116

2217

23-
def NoOp(**kwargs: Any) -> None:
18+
def NoOp(state: State, params: Mapping[str, Any] = MappingProxyType({})) -> None:
2419
pass
2520

2621

22+
class CallbackProtocol(Protocol):
23+
def __call__(self, state: State, params: Mapping[str, Any]) -> None: ...
24+
25+
2726
@dataclass(frozen=True)
2827
class BenchmarkRecord:
2928
context: Context
@@ -101,6 +100,14 @@ def expand(cls, bms: list[dict[str, Any]]) -> BenchmarkRecord:
101100
# context data.
102101

103102

103+
@dataclass(frozen=True)
104+
class State:
105+
name: str
106+
family: str
107+
family_size: int
108+
family_index: int
109+
110+
104111
class Memo(Generic[T]):
105112
@functools.cache
106113
# TODO: Swap this out for a local type-wide memo cache.
@@ -158,10 +165,10 @@ class Benchmark:
158165
"""
159166

160167
fn: Callable[..., Any]
161-
name: str | None = field(default=None)
168+
name: str = field(default="")
162169
params: dict[str, Any] = field(default_factory=dict)
163-
setUp: Callable[..., None] = field(repr=False, default=NoOp)
164-
tearDown: Callable[..., None] = field(repr=False, default=NoOp)
170+
setUp: Callable[[State, Mapping[str, Any]], None] = field(repr=False, default=NoOp)
171+
tearDown: Callable[[State, Mapping[str, Any]], None] = field(repr=False, default=NoOp)
165172
tags: tuple[str, ...] = field(repr=False, default=())
166173
interface: Interface = field(init=False, repr=False)
167174

0 commit comments

Comments
 (0)