Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Transform an abstract base class, delete iapply method ... #118

Merged
merged 1 commit into from
Mar 18, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 39 additions & 32 deletions src/nnbench/io/transforms.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Metaclasses for defining transforms acting on benchmark records."""

from abc import ABC, abstractmethod
from typing import Sequence

from nnbench.types import BenchmarkRecord
Expand All @@ -8,19 +9,18 @@
class Transform:
"""The basic transform which every transform has to inherit from."""

pass


class OneToOneTransform(Transform):
invertible: bool = True
"""
Whether this transform is invertible,
i.e. records can be converted back and forth with no changes or data loss.
"""
pass


class OneToOneTransform(ABC, Transform):
@abstractmethod
def apply(self, record: BenchmarkRecord) -> BenchmarkRecord:
"""
Apply this transform to a benchmark record.
"""Apply this transform to a benchmark record.
Parameters
----------
Expand All @@ -34,8 +34,7 @@ def apply(self, record: BenchmarkRecord) -> BenchmarkRecord:
"""

def iapply(self, record: BenchmarkRecord) -> BenchmarkRecord:
"""
Apply the inverse of this transform.
"""Apply the inverse of this transform.
In general, applying the inverse on a record not previously transformed
may yield unexpected results.
Expand All @@ -49,25 +48,26 @@ def iapply(self, record: BenchmarkRecord) -> BenchmarkRecord:
-------
BenchmarkRecord
The inversely transformed benchmark record.
Raises
------
RuntimeError
If the `Transform.invertible` attribute is set to `False`.
"""
if not self.invertible:
raise RuntimeError(f"{self.__class__.__name__}() is marked as not invertible")
raise NotImplementedError


class ManyToOneTransform(Transform):
"""
A many-to-one transform reducing a collection of records to a single record.
"""A many-to-one transform reducing a collection of records to a single record.
This is useful for computing statistics on a collection of runs.
"""

invertible: bool = True
"""
Whether this transform is invertible,
i.e. records can be converted back and forth with no changes or data loss.
"""

@abstractmethod
def apply(self, record: Sequence[BenchmarkRecord]) -> BenchmarkRecord:
"""
Apply this transform to a benchmark record.
"""Apply this transform to a benchmark record.
Parameters
----------
Expand All @@ -82,8 +82,7 @@ def apply(self, record: Sequence[BenchmarkRecord]) -> BenchmarkRecord:
"""

def iapply(self, record: BenchmarkRecord) -> Sequence[BenchmarkRecord]:
"""
Apply the inverse of this transform.
"""Apply the inverse of this transform.
In general, applying the inverse on a record not previously transformed
may yield unexpected results.
Expand All @@ -97,31 +96,32 @@ def iapply(self, record: BenchmarkRecord) -> Sequence[BenchmarkRecord]:
-------
Sequence[BenchmarkRecord]
The inversely transformed benchmark record sequence.
Raises
------
RuntimeError
If the `Transform.invertible` attribute is set to `False`.
"""
# TODO: Does this even make sense? Can't hurt to allow it on paper, though.
if not self.invertible:
raise RuntimeError(f"{self.__class__.__name__}() is marked as not invertible")
raise NotImplementedError


class ManyToManyTransform(Transform):
"""
A many-to-many transform mapping an input record collection to an output collection.
"""A many-to-many transform mapping an input record collection to an output collection.
Use this to programmatically wrangle metadata or types in records, or to
convert parameters into database-ready representations.
"""

invertible: bool = True
"""
Whether this transform is invertible,
i.e. records can be converted back and forth with no changes or data loss.
"""
length_invariant: bool = True
"""
Whether this transform preserves the number of records, i.e. no records are dropped.
"""

@abstractmethod
def apply(self, record: Sequence[BenchmarkRecord]) -> Sequence[BenchmarkRecord]:
"""
Apply this transform to a benchmark record.
"""Apply this transform to a benchmark record.
Parameters
----------
Expand All @@ -135,8 +135,7 @@ def apply(self, record: Sequence[BenchmarkRecord]) -> Sequence[BenchmarkRecord]:
"""

def iapply(self, record: Sequence[BenchmarkRecord]) -> Sequence[BenchmarkRecord]:
"""
Apply the inverse of this transform.
"""Apply the inverse of this transform.
In general, applying the inverse on a record not previously transformed
may yield unexpected results.
Expand All @@ -150,4 +149,12 @@ def iapply(self, record: Sequence[BenchmarkRecord]) -> Sequence[BenchmarkRecord]
-------
Sequence[BenchmarkRecord]
The inversely transformed benchmark record sequence.
Raises
------
RuntimeError
If the `Transform.invertible` attribute is set to `False`.
"""
if not self.invertible:
raise RuntimeError(f"{self.__class__.__name__}() is marked as not invertible")
raise NotImplementedError