Skip to content

Commit

Permalink
Wire context values through to the runner in command-line mode
Browse files Browse the repository at this point in the history
Also give the option to disable typechecks, which is not as useful right
now because currently, benchmarks rely on defaults only in command-line
mode, so parameters are never passed, and defaults are not typechecked.
  • Loading branch information
nicholasjng committed Nov 15, 2024
1 parent fb3c487 commit f857672
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 5 deletions.
12 changes: 9 additions & 3 deletions src/nnbench/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,13 @@ def main() -> int:
"--output-file",
metavar="<file>",
dest="outfile",
help="File or stream to write results to.",
help="File or stream to write results to, defaults to stdout.",
default=sys.stdout,
)
parser.add_argument(
"--typecheck",
action=argparse.BooleanOptionalAction,
default=True,
help="Whether or not to strictly check types of benchmark inputs.",
)

Expand All @@ -54,12 +55,17 @@ def main() -> int:
k, v = val.split("=")
except ValueError:
raise ValueError("context values need to be of the form <key>=<value>")
# TODO: Support builtin providers in the runner
context[k] = v

record = BenchmarkRunner().run(args.benchmarks, tags=tuple(args.tags))
record = BenchmarkRunner(typecheck=args.typecheck).run(
args.benchmarks,
tags=tuple(args.tags),
context=[lambda: context],
)

outfile = args.outfile
if args.outfile == sys.stdout:
if outfile == sys.stdout:
reporter = BenchmarkReporter()
reporter.display(record)
else:
Expand Down
4 changes: 2 additions & 2 deletions src/nnbench/reporter/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,15 +153,15 @@ def get_extension(f: str | os.PathLike[str] | IO) -> str:
Given a file path or file-like object, returns file extension
(can be the empty string, if the file has no extension).
"""
if isinstance(f, str | bytes | os.PathLike):
if isinstance(f, str | os.PathLike):
return Path(f).suffix
else:
return Path(f.name).suffix


def get_protocol(url: str | os.PathLike[str]) -> str:
url = str(url)
parts = re.split(r"(\:\:|\://)", url, maxsplit=1)
parts = re.split(r"(::|://)", url, maxsplit=1)
if len(parts) > 1:
return parts[0]
return "file"
Expand Down
2 changes: 2 additions & 0 deletions src/nnbench/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,8 @@ def _issubtype(t1: type, t2: type) -> bool:
msng, *_ = missing
raise ValueError(f"missing value for required parameter {msng!r}")

# TODO(n.junge): This doesn't pick up mistyped defaults
# (admittedly, that's likely user error)
for k, v in params.items():
if k not in allvars:
warnings.warn(
Expand Down

0 comments on commit f857672

Please sign in to comment.