99import warnings
1010from typing import Any , Callable , Iterable , Union , get_args , get_origin , overload
1111
12- from nnbench .types import Benchmark
12+ from nnbench .types import Benchmark , State
1313from nnbench .types .util import is_memo , is_memo_type
1414
1515
@@ -178,7 +178,9 @@ def decorator(fn: Callable) -> list[Benchmark]:
178178 )
179179 names .add (name )
180180
181- bm = Benchmark (fn , name = name , params = params , setUp = setUp , tearDown = tearDown , tags = tags )
181+ bm = Benchmark (
182+ fn , name = name , params = params , setUp = setUp , tearDown = tearDown , tags = tags
183+ )
182184 benchmarks .append (bm )
183185 return benchmarks
184186
@@ -224,7 +226,8 @@ def decorator(fn: Callable) -> list[Benchmark]:
224226 benchmarks = []
225227 names = set ()
226228 varnames = iterables .keys ()
227- for values in itertools .product (* iterables .values ()):
229+ cartesian_product = itertools .product (* iterables .values ())
230+ for idx , values in enumerate (cartesian_product ):
228231 params = dict (zip (varnames , values ))
229232 _check_against_interface (params , fn )
230233
@@ -235,8 +238,23 @@ def decorator(fn: Callable) -> list[Benchmark]:
235238 f"Perhaps you specified a parameter configuration twice?"
236239 )
237240 names .add (name )
241+ state = State (
242+ name = name ,
243+ function = fn ,
244+ family = fn .__name__ ,
245+ family_size = len (list (cartesian_product )),
246+ family_index = idx ,
247+ )
238248
239- bm = Benchmark (fn , name = name , params = params , setUp = setUp , tearDown = tearDown , tags = tags )
249+ bm = Benchmark (
250+ fn ,
251+ name = name ,
252+ params = params ,
253+ setUp = setUp ,
254+ tearDown = tearDown ,
255+ tags = tags ,
256+ state = state ,
257+ )
240258 benchmarks .append (bm )
241259 return benchmarks
242260
0 commit comments