Skip to content

Commit

Permalink
Merge pull request #15 from aai-institute/remove-benchmark-params-slot
Browse files Browse the repository at this point in the history
Remove benchmark `params` slot, type core decorators
  • Loading branch information
nicholasjng authored Jan 23, 2024
2 parents d6ec79f + d9a8d6b commit bc5e76a
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 43 deletions.
126 changes: 90 additions & 36 deletions src/nnbench/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

from __future__ import annotations

from typing import Any, Callable, Iterable
from functools import partial, update_wrapper
from typing import Any, Callable, Iterable, overload

from nnbench.types import Benchmark

Expand All @@ -11,13 +12,41 @@ def NoOp(**kwargs: Any) -> None:
pass


# Overloads for the ``benchmark`` decorator.
# Case #1: Bare application without parentheses
# @nnbench.benchmark
# def foo() -> int:
# return 0
@overload
def benchmark(
func: None = None,
setUp: Callable[..., None] = NoOp,
tearDown: Callable[..., None] = NoOp,
tags: tuple[str, ...] = (),
) -> Callable[[Callable], Benchmark]:
...


# Case #2: Application with arguments
# @nnbench.benchmark(tags=("hello", "world"))
# def foo() -> int:
# return 0
@overload
def benchmark(
func: Callable[..., Any],
setUp: Callable[..., None] = NoOp,
tearDown: Callable[..., None] = NoOp,
tags: tuple[str, ...] = (),
) -> Benchmark:
...


def benchmark(
func: Callable[..., Any] | None = None,
params: dict[str, Any] | None = None,
setUp: Callable[..., None] = NoOp,
tearDown: Callable[..., None] = NoOp,
tags: tuple[str, ...] = (),
) -> Callable:
) -> Benchmark | Callable[[Callable], Benchmark]:
"""
Define a benchmark from a function.
Expand All @@ -28,43 +57,69 @@ def benchmark(
Parameters
----------
func: Callable[..., Any] | None
The function to benchmark.
params: dict[str, Any] | None
The parameters (or a subset thereof) defining the benchmark.
The function to benchmark. This slot only exists to allow application of the decorator
without parentheses, you should never fill it explicitly.
setUp: Callable[..., None]
A setup hook to run before each of the benchmarks.
A setup hook to run before the benchmark.
tearDown: Callable[..., None]
A teardown hook to run after each of the benchmarks.
A teardown hook to run after the benchmark.
tags: tuple[str, ...]
Additional tags to attach for bookkeeping and selective filtering during runs.
Returns
-------
Callable
A decorated callable yielding the benchmark.
Benchmark | Callable[[Callable], Benchmark]
The resulting benchmark (if no arguments were given), or a parametrized decorator
returning the benchmark.
"""

# TODO: The above return typing is incorrect
# (needs a func is None vs. func is not None overload)
def inner(fn: Callable) -> Benchmark:
name = fn.__name__
if params:
name += "_" + "_".join(f"{k}={v}" for k, v in params.items())
return Benchmark(fn, name=name, params=params, setUp=setUp, tearDown=tearDown, tags=tags)
def decorator(fun: Callable) -> Benchmark:
return Benchmark(fun, setUp=setUp, tearDown=tearDown, tags=tags)

if func:
return inner(func) # type: ignore
if func is not None:
return decorator(func)
else:
return inner
return decorator


# Overloads for the ``parametrize`` decorator.
# Case #1: Bare application without parentheses (rarely used)
# @nnbench.parametrize
# def foo() -> int:
# return 0
@overload
def parametrize(
func: None = None,
parameters: Iterable[dict] = (),
setUp: Callable[..., None] = NoOp,
tearDown: Callable[..., None] = NoOp,
tags: tuple[str, ...] = (),
) -> Callable[[Callable], list[Benchmark]]:
...


# Case #2: Application with arguments
# @nnbench.parametrize(parameters=..., tags=("hello", "world"))
# def foo() -> int:
# return 0
@overload
def parametrize(
func: Callable[..., Any],
parameters: Iterable[dict] = (),
setUp: Callable[..., None] = NoOp,
tearDown: Callable[..., None] = NoOp,
tags: tuple[str, ...] = (),
) -> list[Benchmark]:
...


def parametrize(
func: Callable[..., Any] | None = None,
parameters: Iterable[dict] | None = None,
parameters: Iterable[dict] = (),
setUp: Callable[..., None] = NoOp,
tearDown: Callable[..., None] = NoOp,
tags: tuple[str, ...] = (),
) -> Callable:
) -> list[Benchmark] | Callable[[Callable], list[Benchmark]]:
"""
Define a family of benchmarks over a function with varying parameters.
Expand All @@ -75,8 +130,9 @@ def parametrize(
Parameters
----------
func: Callable[..., Any] | None
The function to benchmark.
parameters: Iterable[dict] | None
The function to benchmark. This slot only exists to allow application of the decorator
without parentheses, you should never fill it explicitly.
parameters: Iterable[dict]
The different sets of parameters defining the benchmark family.
setUp: Callable[..., None]
A setup hook to run before each of the benchmarks.
Expand All @@ -87,23 +143,21 @@ def parametrize(
Returns
-------
Callable
A decorated callable yielding the benchmark family.
list[Benchmark] | Callable[[Callable], list[Benchmark]]
The resulting benchmark family (if no arguments were given), or a parametrized decorator
returning the benchmark family.
"""

# TODO: The above return typing is incorrect
# (needs a func is None vs. func is not None overload)
def inner(fn: Callable) -> list[Benchmark]:
def decorator(fn: Callable) -> list[Benchmark]:
benchmarks = []
for params in parameters:
name = fn.__name__
if params:
name += "_" + "_".join(f"{k}={v}" for k, v in params.items())
bm = Benchmark(fn, name=name, params=params, setUp=setUp, tearDown=tearDown, tags=tags)
name = fn.__name__ + "_" + "_".join(f"{k}={v}" for k, v in params.items())
wrapper = update_wrapper(partial(fn, **params), fn)
bm = Benchmark(wrapper, name=name, setUp=setUp, tearDown=tearDown, tags=tags)
benchmarks.append(bm)
return benchmarks

if func:
return inner(func) # type: ignore
if func is not None:
return decorator(func)
else:
return inner
return decorator
5 changes: 0 additions & 5 deletions src/nnbench/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,6 @@ class Benchmark:
name: str | None
A name to display for the given benchmark. If not given, will be constructed from the
function name and given parameters.
params: dict[str, Any]
Fixed parameters to pass to the benchmark.
setUp: Callable[..., None]
A setup hook run before the benchmark. Must take all members of `params` as inputs.
tearDown: Callable[..., None]
Expand All @@ -91,16 +89,13 @@ class Benchmark:

fn: Callable[..., Any]
name: str | None = field(default=None)
params: dict[str, Any] = field(repr=False, default_factory=dict)
setUp: Callable[..., None] = field(repr=False, default=NoOp)
tearDown: Callable[..., None] = field(repr=False, default=NoOp)
tags: tuple[str, ...] = field(repr=False, default=())

def __post_init__(self):
if not self.name:
name = self.fn.__name__
if self.params:
name += "_" + "_".join(f"{k}={v}" for k, v in self.params.items())

super().__setattr__("name", name)
# TODO: Parse interface using `inspect`, attach to the class
4 changes: 2 additions & 2 deletions tests/testproject/hello.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from nnbench import benchmark
import nnbench


@benchmark
@nnbench.benchmark
def double(x: int) -> int:
return x * 2

0 comments on commit bc5e76a

Please sign in to comment.