Skip to content

Commit

Permalink
Make Transform an abstract base class, delete iapply method ...
Browse files Browse the repository at this point in the history
...based on the `invertible` attribute.

The iapply() itself is not abstract since it is strictly speaking not
needed for a transform.
  • Loading branch information
nicholasjng committed Mar 18, 2024
1 parent cdf4884 commit 65c54e5
Showing 1 changed file with 39 additions and 32 deletions.
71 changes: 39 additions & 32 deletions src/nnbench/io/transforms.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Metaclasses for defining transforms acting on benchmark records."""

from abc import ABC, abstractmethod
from typing import Sequence

from nnbench.types import BenchmarkRecord
Expand All @@ -8,19 +9,18 @@
class Transform:
"""The basic transform which every transform has to inherit from."""

pass


class OneToOneTransform(Transform):
invertible: bool = True
"""
Whether this transform is invertible,
i.e. records can be converted back and forth with no changes or data loss.
"""
pass


class OneToOneTransform(ABC, Transform):
@abstractmethod
def apply(self, record: BenchmarkRecord) -> BenchmarkRecord:
"""
Apply this transform to a benchmark record.
"""Apply this transform to a benchmark record.
Parameters
----------
Expand All @@ -34,8 +34,7 @@ def apply(self, record: BenchmarkRecord) -> BenchmarkRecord:
"""

def iapply(self, record: BenchmarkRecord) -> BenchmarkRecord:
"""
Apply the inverse of this transform.
"""Apply the inverse of this transform.
In general, applying the inverse on a record not previously transformed
may yield unexpected results.
Expand All @@ -49,25 +48,26 @@ def iapply(self, record: BenchmarkRecord) -> BenchmarkRecord:
-------
BenchmarkRecord
The inversely transformed benchmark record.
Raises
------
RuntimeError
If the `Transform.invertible` attribute is set to `False`.
"""
if not self.invertible:
raise RuntimeError(f"{self.__class__.__name__}() is marked as not invertible")
raise NotImplementedError


class ManyToOneTransform(Transform):
"""
A many-to-one transform reducing a collection of records to a single record.
"""A many-to-one transform reducing a collection of records to a single record.
This is useful for computing statistics on a collection of runs.
"""

invertible: bool = True
"""
Whether this transform is invertible,
i.e. records can be converted back and forth with no changes or data loss.
"""

@abstractmethod
def apply(self, record: Sequence[BenchmarkRecord]) -> BenchmarkRecord:
"""
Apply this transform to a benchmark record.
"""Apply this transform to a benchmark record.
Parameters
----------
Expand All @@ -82,8 +82,7 @@ def apply(self, record: Sequence[BenchmarkRecord]) -> BenchmarkRecord:
"""

def iapply(self, record: BenchmarkRecord) -> Sequence[BenchmarkRecord]:
"""
Apply the inverse of this transform.
"""Apply the inverse of this transform.
In general, applying the inverse on a record not previously transformed
may yield unexpected results.
Expand All @@ -97,31 +96,32 @@ def iapply(self, record: BenchmarkRecord) -> Sequence[BenchmarkRecord]:
-------
Sequence[BenchmarkRecord]
The inversely transformed benchmark record sequence.
Raises
------
RuntimeError
If the `Transform.invertible` attribute is set to `False`.
"""
# TODO: Does this even make sense? Can't hurt to allow it on paper, though.
if not self.invertible:
raise RuntimeError(f"{self.__class__.__name__}() is marked as not invertible")
raise NotImplementedError


class ManyToManyTransform(Transform):
"""
A many-to-many transform mapping an input record collection to an output collection.
"""A many-to-many transform mapping an input record collection to an output collection.
Use this to programmatically wrangle metadata or types in records, or to
convert parameters into database-ready representations.
"""

invertible: bool = True
"""
Whether this transform is invertible,
i.e. records can be converted back and forth with no changes or data loss.
"""
length_invariant: bool = True
"""
Whether this transform preserves the number of records, i.e. no records are dropped.
"""

@abstractmethod
def apply(self, record: Sequence[BenchmarkRecord]) -> Sequence[BenchmarkRecord]:
"""
Apply this transform to a benchmark record.
"""Apply this transform to a benchmark record.
Parameters
----------
Expand All @@ -135,8 +135,7 @@ def apply(self, record: Sequence[BenchmarkRecord]) -> Sequence[BenchmarkRecord]:
"""

def iapply(self, record: Sequence[BenchmarkRecord]) -> Sequence[BenchmarkRecord]:
"""
Apply the inverse of this transform.
"""Apply the inverse of this transform.
In general, applying the inverse on a record not previously transformed
may yield unexpected results.
Expand All @@ -150,4 +149,12 @@ def iapply(self, record: Sequence[BenchmarkRecord]) -> Sequence[BenchmarkRecord]
-------
Sequence[BenchmarkRecord]
The inversely transformed benchmark record sequence.
Raises
------
RuntimeError
If the `Transform.invertible` attribute is set to `False`.
"""
if not self.invertible:
raise RuntimeError(f"{self.__class__.__name__}() is marked as not invertible")
raise NotImplementedError

0 comments on commit 65c54e5

Please sign in to comment.