-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add transform submodule, parameter compression transform (#124)
This is the better way of compressing parameters compared to directly in the benchmark runner, which steals responsibility of the transform that we just introduced. Refactors `nnbench.io.transform->nnbench.transforms`, the latter being its own submodule. This is useful to have when adding new builtin transforms, so that they do not have to go into a single file.
- Loading branch information
1 parent
aac4162
commit d09cd3c
Showing
5 changed files
with
51 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .base import ManyToManyTransform, ManyToOneTransform, OneToOneTransform |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
from typing import Any, Sequence | ||
|
||
from nnbench.transforms import ManyToManyTransform, OneToOneTransform | ||
from nnbench.types import BenchmarkRecord | ||
|
||
|
||
class CompressionMixin: | ||
def compress(self, params: dict[str, Any]) -> dict[str, Any]: | ||
containers = (tuple, list, set, frozenset) | ||
natives = (float, int, str, bool, bytes, complex) | ||
compressed: dict[str, Any] = {} | ||
|
||
def _compress_impl(val): | ||
if isinstance(val, natives): | ||
# save native types without modification... | ||
return val | ||
else: | ||
# ... or return the string repr. | ||
# TODO: Allow custom representations for types with formatters. | ||
return repr(val) | ||
|
||
for k, v in params.items(): | ||
if isinstance(v, containers): | ||
container_type = type(v) | ||
compressed[k] = container_type(_compress_impl(vv) for vv in v) | ||
elif isinstance(v, dict): | ||
compressed[k] = self.compress(v) | ||
else: | ||
compressed[k] = _compress_impl(v) | ||
|
||
return compressed | ||
|
||
|
||
class ParameterCompression1to1(OneToOneTransform, CompressionMixin): | ||
def apply(self, record: BenchmarkRecord) -> BenchmarkRecord: | ||
for bm in record.benchmarks: | ||
bm["params"] = self.compress(bm["params"]) | ||
|
||
return record | ||
|
||
|
||
class ParameterCompressionNtoN(ManyToManyTransform, CompressionMixin): | ||
def apply(self, record: Sequence[BenchmarkRecord]) -> Sequence[BenchmarkRecord]: | ||
for rec in record: | ||
for bm in rec.benchmarks: | ||
bm["params"] = self.compress(bm["params"]) | ||
|
||
return record |