Skip to content

Commit 0d4b82a

Browse files
authored
Callable Types & Post Init Hooks (#226)
* Added support for simple `typing.Callable` types (WIP: advanced versions) * Added support for post init hooks that allow for validation on parameters defined within `@spock` decorated classes. Additionally, added some common validation check to utils (within, greater than, less than, etc.) * Updated unit tests to support Python 3.10 * Additional unit tests * linted
1 parent 2efe7d0 commit 0d4b82a

23 files changed

+695
-61
lines changed

.github/workflows/python-pytest-s3.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ jobs:
1414
runs-on: ubuntu-latest
1515
strategy:
1616
matrix:
17-
python-version: [3.6, 3.7, 3.8, 3.9]
17+
python-version: ["3.6", "3.7", "3.8", "3.9", "3.10"]
1818

1919
steps:
2020
- uses: actions/checkout@v2
@@ -26,7 +26,7 @@ jobs:
2626
- uses: actions/cache@v2
2727
with:
2828
path: ${{ env.pythonLocation }}
29-
key: ${{ env.pythonLocation }}-${{ hashFiles('setup.py') }}-${{ hashFiles('REQUIREMENTS.txt') }}-${{ hashFiles('./requirements/DEV_REQUIREMENTS.txt') }}-${{ hashFiles('./requirements/S3_REQUIREMENTS.txt') }}
29+
key: cache-v1-${{ env.pythonLocation }}-${{ hashFiles('setup.py') }}-${{ hashFiles('REQUIREMENTS.txt') }}-${{ hashFiles('./requirements/DEV_REQUIREMENTS.txt') }}-${{ hashFiles('./requirements/S3_REQUIREMENTS.txt') }}
3030

3131
- name: Install dependencies
3232
run: |

.github/workflows/python-pytest-tune.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ jobs:
1414
runs-on: ubuntu-latest
1515
strategy:
1616
matrix:
17-
python-version: [3.7, 3.8, 3.9]
17+
python-version: ["3.7", "3.8", "3.9", "3.10"]
1818

1919
steps:
2020
- uses: actions/checkout@v2
@@ -26,7 +26,7 @@ jobs:
2626
- uses: actions/cache@v2
2727
with:
2828
path: ${{ env.pythonLocation }}
29-
key: ${{ env.pythonLocation }}-${{ hashFiles('setup.py') }}-${{ hashFiles('REQUIREMENTS.txt') }}-${{ hashFiles('./requirements/DEV_REQUIREMENTS.txt') }}-${{ hashFiles('./requirements/S3_REQUIREMENTS.txt') }}-${{ hashFiles('./requirements/TUNE_REQUIREMENTS.txt') }}-${{ hashFiles('./requirements/TEST_EXTRAS_REQUIREMENTS_REQUIREMENTS.txt') }}
29+
key: cache-v1-${{ env.pythonLocation }}-${{ hashFiles('setup.py') }}-${{ hashFiles('REQUIREMENTS.txt') }}-${{ hashFiles('./requirements/DEV_REQUIREMENTS.txt') }}-${{ hashFiles('./requirements/S3_REQUIREMENTS.txt') }}-${{ hashFiles('./requirements/TUNE_REQUIREMENTS.txt') }}-${{ hashFiles('./requirements/TEST_EXTRAS_REQUIREMENTS_REQUIREMENTS.txt') }}
3030

3131
- name: Install dependencies
3232
run: |

.github/workflows/python-pytest.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ jobs:
1414
runs-on: ubuntu-latest
1515
strategy:
1616
matrix:
17-
python-version: [3.6, 3.7, 3.8, 3.9]
17+
python-version: ["3.6", "3.7", "3.8", "3.9", "3.10"]
1818

1919
steps:
2020
- uses: actions/checkout@v2
@@ -26,7 +26,7 @@ jobs:
2626
- uses: actions/cache@v2
2727
with:
2828
path: ${{ env.pythonLocation }}
29-
key: ${{ env.pythonLocation }}-${{ hashFiles('setup.py') }}-${{ hashFiles('REQUIREMENTS.txt') }}-${{ hashFiles('./requirements/DEV_REQUIREMENTS.txt') }}
29+
key: cache-v1-${{ env.pythonLocation }}-${{ hashFiles('setup.py') }}-${{ hashFiles('REQUIREMENTS.txt') }}-${{ hashFiles('./requirements/DEV_REQUIREMENTS.txt') }}
3030

3131
- name: Install dependencies
3232
run: |

README.md

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,16 @@
77
<p align="center">
88
<a href="https://opensource.org/licenses/Apache-2.0"><img src="https://img.shields.io/badge/License-Apache%202.0-9cf"/></a>
99
<a href="https://bestpractices.coreinfrastructure.org/projects/5551"><img src="https://bestpractices.coreinfrastructure.org/projects/5551/badge"/></a>
10+
<a><img src="https://github.com/fidelity/spock/workflows/pytest/badge.svg?branch=master"/></a>
11+
<a href="https://coveralls.io/github/fidelity/spock?branch=master"><img src="https://coveralls.io/repos/github/fidelity/spock/badge.svg?branch=master"/></a>
12+
<a><img src="https://github.com/fidelity/spock/workflows/docs/badge.svg"/></a>
13+
</p>
14+
15+
<p align="center">
1016
<a><img src="https://img.shields.io/badge/python-3.6+-informational.svg"/></a>
1117
<a href="https://github.com/psf/black"><img src="https://img.shields.io/badge/code%20style-black-000000.svg"/></a>
1218
<a href="https://badge.fury.io/py/spock-config"><img src="https://badge.fury.io/py/spock-config.svg"/></a>
13-
<a href="https://coveralls.io/github/fidelity/spock?branch=master"><img src="https://coveralls.io/repos/github/fidelity/spock/badge.svg?branch=master"/></a>
14-
<a><img src="https://github.com/fidelity/spock/workflows/pytest/badge.svg?branch=master"/></a>
15-
<a><img src="https://github.com/fidelity/spock/workflows/docs/badge.svg"/></a>
19+
<a href="https://pepy.tech/badge/spock-config"><img src="https://static.pepy.tech/personalized-badge/spock-config?period=total&units=international_system&left_color=grey&right_color=orange&left_text=Downloads"/></a>
1620
</p>
1721

1822
<h3 align="center">
@@ -97,6 +101,12 @@ See [Releases](https://github.com/fidelity/spock/releases) for more information.
97101

98102
<details>
99103

104+
#### March 11th, 2022
105+
* Added support for simple `typing.Callable` types (WIP: advanced versions)
106+
* Added support for post init hooks that allow for validation on parameters defined within `@spock` decorated classes.
107+
Additionally, added some common validation check to utils (within, greater than, less than, etc.)
108+
* Updated unit tests to support Python 3.10
109+
100110
#### January 26th, 2022
101111
* Added `evolve` support to the underlying `SpockBuilder` class. This provides functionality similar to the underlying
102112
attrs library ([attrs.evolve](https://www.attrs.org/en/stable/api.html#attrs.evolve)). `evolve()` creates a new
@@ -110,12 +120,6 @@ passed into `*args` within the main `SpockBuilder` API
110120
* Updated main API interface for better top-level imports (backwards compatible): `ConfigArgBuilder`->`SpockBuilder`
111121
* Added stubs to the underlying decorator that should help with type hinting in VSCode (pylance/pyright)
112122

113-
#### December 14, 2021
114-
* Refactored the backend to better handle nested dependencies (and for clarity)
115-
* Refactored the docs to use Docusaurus
116-
117-
#### August 17, 2021
118-
* Added hyper-parameter tuning backend support for Ax via Service API
119123
</details>
120124

121125
## Original Implementation

requirements/S3_REQUIREMENTS.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
boto3~=1.20
22
botocore~=1.24
3-
hurry.filesize==0.9
3+
hurry.filesize~=0.9
44
s3transfer~=0.5

spock/backend/builder.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
# SPDX-License-Identifier: Apache-2.0
55

66
"""Handles the building/saving of the configurations from the Spock config classes"""
7-
7+
import sys
8+
import typing
89
from abc import ABC, abstractmethod
910
from enum import EnumMeta
1011
from typing import List
@@ -17,7 +18,7 @@
1718
from spock.backend.spaces import BuilderSpace
1819
from spock.backend.wrappers import Spockspace
1920
from spock.graph import Graph
20-
from spock.utils import make_argument
21+
from spock.utils import _SpockVariadicGenericAlias, make_argument
2122

2223

2324
class BaseBuilder(ABC): # pylint: disable=too-few-public-methods
@@ -255,10 +256,11 @@ def _make_group_override_parser(parser, class_obj, class_name):
255256
)
256257
for val in class_obj.__attrs_attrs__:
257258
val_type = val.metadata["type"] if "type" in val.metadata else val.type
258-
# Check if the val type has __args__ -- this catches lists?
259+
# Check if the val type has __args__ -- this catches GenericAlias classes
259260
# TODO (ncilfone): Fix up this super super ugly logic
260261
if (
261-
hasattr(val_type, "__args__")
262+
not isinstance(val_type, _SpockVariadicGenericAlias)
263+
and hasattr(val_type, "__args__")
262264
and ((list(set(val_type.__args__))[0]).__module__ == class_name)
263265
and attr.has((list(set(val_type.__args__))[0]))
264266
):
@@ -274,6 +276,10 @@ def _make_group_override_parser(parser, class_obj, class_name):
274276
arg_name = f"--{str(attr_name)}.{val.name}"
275277
val_type = str
276278
group_parser = make_argument(arg_name, val_type, group_parser)
279+
# This catches callables -- need to be of type str which will be use in importlib
280+
elif isinstance(val.type, _SpockVariadicGenericAlias):
281+
arg_name = f"--{str(attr_name)}.{val.name}"
282+
group_parser = make_argument(arg_name, str, group_parser)
277283
else:
278284
arg_name = f"--{str(attr_name)}.{val.name}"
279285
group_parser = make_argument(arg_name, val_type, group_parser)

spock/backend/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,9 @@ def _process_class(cls, kw_only: bool, make_init: bool, dynamic: bool):
117117
auto_attribs=True,
118118
init=make_init,
119119
)
120+
# Copy over the post init function
121+
if hasattr(cls, "__post_hook__"):
122+
obj.__post_hook__ = cls.__post_hook__
120123
# For each class we dynamically create we need to register it within the system modules for pickle to work
121124
setattr(sys.modules["spock"].backend.config, obj.__name__, obj)
122125
# Swap the __doc__ string from cls to obj

spock/backend/field_handlers.py

Lines changed: 82 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,26 @@
55

66
"""Handles registering field attributes for spock classes -- deals with the recursive nature of dependencies"""
77

8+
import importlib
9+
import sys
810
from abc import ABC, abstractmethod
911
from enum import EnumMeta
1012
from typing import List, Type
1113

1214
from attr import NOTHING, Attribute
1315

14-
from spock.args import SpockArguments
1516
from spock.backend.spaces import AttributeSpace, BuilderSpace, ConfigSpace
16-
from spock.exceptions import _SpockInstantiationError, _SpockNotOptionalError
17-
from spock.utils import _check_iterable, _is_spock_instance, _is_spock_tune_instance
17+
from spock.exceptions import (
18+
_SpockInstantiationError,
19+
_SpockNotOptionalError,
20+
_SpockValueError,
21+
)
22+
from spock.utils import (
23+
_check_iterable,
24+
_is_spock_instance,
25+
_is_spock_tune_instance,
26+
_SpockVariadicGenericAlias,
27+
)
1828

1929

2030
class RegisterFieldTemplate(ABC):
@@ -318,6 +328,69 @@ def _handle_and_register_enum(
318328
builder_space.spock_space[enum_cls.__name__] = attr_space.field
319329

320330

331+
class RegisterCallableField(RegisterFieldTemplate):
332+
"""Class that registers callable types
333+
334+
Attributes:
335+
special_keys: dictionary to check special keys
336+
337+
"""
338+
339+
def __init__(self):
340+
"""Init call to RegisterSimpleField
341+
342+
Args:
343+
"""
344+
super(RegisterCallableField, self).__init__()
345+
346+
def handle_attribute_from_config(
347+
self, attr_space: AttributeSpace, builder_space: BuilderSpace
348+
):
349+
"""Handles setting a simple attribute when it is a spock class type
350+
351+
Args:
352+
attr_space: holds information about a single attribute that is mapped to a ConfigSpace
353+
builder_space: named_tuple containing the arguments and spock_space
354+
355+
Returns:
356+
"""
357+
# These are always going to be strings... cast just in case
358+
str_field = str(
359+
builder_space.arguments[attr_space.config_space.name][
360+
attr_space.attribute.name
361+
]
362+
)
363+
module, fn = str_field.rsplit(".", 1)
364+
try:
365+
call_ref = getattr(importlib.import_module(module), fn)
366+
attr_space.field = call_ref
367+
except Exception as e:
368+
raise _SpockValueError(
369+
f"Attempted to import module {module} and callable {fn} however it could not be found on the current "
370+
f"python path: {e}"
371+
)
372+
373+
def handle_optional_attribute_type(
374+
self, attr_space: AttributeSpace, builder_space: BuilderSpace
375+
):
376+
"""Not implemented for this type
377+
378+
Args:
379+
attr_space: holds information about a single attribute that is mapped to a ConfigSpace
380+
builder_space: named_tuple containing the arguments and spock_space
381+
382+
Raises:
383+
_SpockNotOptionalError
384+
385+
"""
386+
print("hi")
387+
raise _SpockNotOptionalError(
388+
f"Parameter `{attr_space.attribute.name}` within `{attr_space.config_space.name}` is of "
389+
f"type `{type(attr_space.attribute.type)}` which seems to be unsupported -- "
390+
f"are you missing an @spock decorator on a base python class?"
391+
)
392+
393+
321394
class RegisterSimpleField(RegisterFieldTemplate):
322395
"""Class that registers basic python types
323396
@@ -606,6 +679,9 @@ def recurse_generate(cls, spock_cls, builder_space: BuilderSpace):
606679
# References to tuner classes
607680
elif _is_spock_tune_instance(attribute.type):
608681
handler = RegisterTuneCls()
682+
# References to callables
683+
elif isinstance(attribute.type, _SpockVariadicGenericAlias):
684+
handler = RegisterCallableField()
609685
# Basic field
610686
else:
611687
handler = RegisterSimpleField()
@@ -617,6 +693,9 @@ def recurse_generate(cls, spock_cls, builder_space: BuilderSpace):
617693
# error on instantiation
618694
try:
619695
spock_instance = spock_cls(**fields)
696+
# If there is a __post_hook__ dunder method then call it
697+
if hasattr(spock_cls, "__post_hook__"):
698+
spock_instance.__post_hook__()
620699
except Exception as e:
621700
raise _SpockInstantiationError(
622701
f"Spock class `{spock_cls.__name__}` could not be instantiated -- attrs message: {e}"

spock/backend/saver.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -146,18 +146,31 @@ def _clean_output(self, out_dict):
146146
for idx, list_val in enumerate(val):
147147
tmp_dict = {}
148148
for inner_key, inner_val in list_val.items():
149-
tmp_dict = self._convert(tmp_dict, inner_val, inner_key)
149+
tmp_dict = self._convert_tuples_2_lists(
150+
tmp_dict, inner_val, inner_key
151+
)
150152
val[idx] = tmp_dict
151153
clean_inner_dict = val
152154
else:
153155
for inner_key, inner_val in val.items():
154-
clean_inner_dict = self._convert(
156+
clean_inner_dict = self._convert_tuples_2_lists(
155157
clean_inner_dict, inner_val, inner_key
156158
)
157159
clean_dict.update({key: clean_inner_dict})
158160
return clean_dict
159161

160-
def _convert(self, clean_inner_dict, inner_val, inner_key):
162+
def _convert_tuples_2_lists(self, clean_inner_dict, inner_val, inner_key):
163+
"""Convert tuples to lists
164+
165+
Args:
166+
clean_inner_dict: dictionary to update
167+
inner_val: current value
168+
inner_key: current key
169+
170+
Returns:
171+
updated dictionary where tuples are cast back to lists
172+
173+
"""
161174
# Convert tuples to lists so they get written correctly
162175
if isinstance(inner_val, tuple):
163176
clean_inner_dict.update(
@@ -277,6 +290,10 @@ def _recursively_handle_clean(
277290
if repeat_flag:
278291
clean_val = list(set(clean_val))[-1]
279292
out_dict.update({key: clean_val})
293+
# Catch any callables -- convert back to the str representation
294+
elif callable(val):
295+
call_2_str = f"{val.__module__}.{val.__name__}"
296+
out_dict.update({key: call_2_str})
280297
# If it's a spock class but has a parent then just use the class name to reference the values
281298
elif (val_name in all_cls) and parent_name is not None:
282299
out_dict.update({key: val_name})

0 commit comments

Comments
 (0)