Skip to content

Commit 5ef5e0c

Browse files
authored
Add support for inherent signature. (#6)
1 parent 85adb07 commit 5ef5e0c

File tree

3 files changed

+166
-14
lines changed

3 files changed

+166
-14
lines changed

tests/test_registry.py

+100
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import pytest
44
from utilsd.config import ClassConfig, Registry, RegistryConfig, SubclassConfig, configclass
55
from utilsd.config.type_def import TypeDef
6+
from utilsd.config.exception import ValidationError
67
from tests.assets.import_class import BaseBar
78

89

@@ -84,6 +85,11 @@ def test_registry_config():
8485
assert config.m.a == 1
8586
assert config.m.type() == Converter1
8687
assert config.m.build().a == 1
88+
89+
# test overwrite during build() call
90+
assert config.m.build(a=2).a == 2
91+
with pytest.raises(RuntimeError):
92+
config.m.build(c=3)
8793
assert isinstance(config.m.build(), Converter1)
8894

8995

@@ -145,3 +151,97 @@ def test_registry():
145151

146152
Converters.register_module(module=Converter2)
147153
assert len(Converters) == 2
154+
155+
156+
class TestInhReg(metaclass=Registry, name='TestInh'):
157+
pass
158+
159+
160+
@TestInhReg.register_module()
161+
class InhBase():
162+
def __init__(self, a: int, b: int, **kwargs):
163+
self.a = a
164+
self.b = b
165+
self.uncollected = kwargs
166+
167+
168+
@TestInhReg.register_module(inherit=True)
169+
class InhChild1(InhBase):
170+
def __init__(self, c: int, d: int, **kwargs):
171+
super().__init__(**kwargs)
172+
self.c = c
173+
self.d = d
174+
175+
176+
@TestInhReg.register_module()
177+
class InhChild2(InhBase):
178+
def __init__(self, c: int, d: int, **kwargs):
179+
super().__init__(**kwargs)
180+
self.c = c
181+
self.d = d
182+
183+
184+
@configclass
185+
class InhRegCfg:
186+
m: RegistryConfig[TestInhReg]
187+
188+
189+
def test_superclass_registry():
190+
assert len(TestInhReg) == 3
191+
assert "InhChild1" in TestInhReg
192+
assert "InhChildNotDefined" not in TestInhReg
193+
194+
# when inherit is set True, the module will look back to its super class for areas
195+
config = TypeDef.load(
196+
InhRegCfg, dict(m=dict(type="InhChild1", a=1, b=2, c=3, d=4))
197+
)
198+
child1 = config.m.build()
199+
assert child1.a == 1
200+
assert child1.c == 3
201+
assert config.m.build(e=5).uncollected['e'] == 5
202+
203+
with pytest.raises(ValidationError):
204+
# when inherit is set False (default), all the keys must be specifed in the param list of __init__
205+
config = TypeDef.load(
206+
InhRegCfg, dict(m=dict(type="InhChild2", a=1, b=2, c=3, d=4))
207+
)
208+
209+
with pytest.raises(ValidationError):
210+
# when inherit is set True, use of positional arguments is banned to remove possible confusions
211+
@TestInhReg.register_module(inherit=True)
212+
class InhChildwithPosArgs(InhBase):
213+
def __init__(self, c: int, d: int, *args, **kwargs):
214+
super().__init__(**kwargs)
215+
self.c = c
216+
self.d = d
217+
218+
config = TypeDef.load(
219+
InhRegCfg, dict(m=dict(type="InhChildwithPosArgs", a=1, b=2, c=3, d=4))
220+
)
221+
222+
TestInhReg.unregister_module("InhChildwithPosArgs")
223+
224+
# won't raise TypeError with *args when inherit is set False (default) TODO: should this use also be banned?
225+
@TestInhReg.register_module()
226+
class InhChildwithPosArgs(InhBase):
227+
def __init__(self, a: int, b: int, c: int, d: int, *args, **kwargs):
228+
super().__init__(a=a, b=b)
229+
self.c = c
230+
self.d = d
231+
self.uncollected_args = args
232+
233+
config = TypeDef.load(
234+
InhRegCfg, dict(m=dict(type="InhChildwithPosArgs", a=1, b=2, c=3, d=4))
235+
)
236+
childwithPosArgs = config.m.build()
237+
assert childwithPosArgs.uncollected_args == ()
238+
239+
with pytest.raises(ValidationError):
240+
@TestInhReg.register_module()
241+
class InhChildwithOverwriteArgs(InhBase):
242+
def __init__(self, a: int, b: float):
243+
super().__init__(a=a, b=int(b))
244+
245+
config = TypeDef.load(
246+
InhRegCfg, dict(m=dict(type="InhChildwithOverwriteArgs", a=1, b=2.5))
247+
)

utilsd/config/registry.py

+64-11
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def __new__(cls, clsname, bases, attrs, name=None):
3333
cls = super().__new__(cls, clsname, bases, attrs)
3434
cls._name = name
3535
cls._module_dict = {}
36+
cls._inherit_dict = {} # track whether a module should expand its superclass init parameters when specified with **kwargs
3637
return cls
3738

3839
@property
@@ -47,7 +48,7 @@ def __len__(cls):
4748
return len(cls._module_dict)
4849

4950
def __contains__(cls, key):
50-
return cls.get(key) is not None
51+
return key in cls._module_dict
5152

5253
def __repr__(cls):
5354
format_str = cls.__name__ + f'(name={cls._name}, items={cls._module_dict})'
@@ -57,14 +58,19 @@ def get(cls, key):
5758
if key in cls._module_dict:
5859
return cls._module_dict[key]
5960
raise KeyError(f'{key} not found in {cls}')
60-
61+
62+
def get_module_with_inherit(cls, key):
63+
if key in cls._module_dict:
64+
return cls._module_dict[key], cls._inherit_dict[key]
65+
raise KeyError(f'{key} not found in {cls}')
66+
6167
def inverse_get(cls, value):
6268
keys = [k for k, v in cls._module_dict.items() if v == value]
6369
if len(keys) != 1:
6470
raise ValueError(f'{value} needs to appear exactly once in {cls}')
6571
return keys[0]
6672

67-
def _register_module(cls, module_class, module_name=None, force=False):
73+
def _register_module(cls, module_class, module_name=None, force=False, *, inherit=False):
6874
if not inspect.isclass(module_class):
6975
raise TypeError(f'module must be a class, but got {type(module_class)}')
7076

@@ -76,8 +82,9 @@ def _register_module(cls, module_class, module_name=None, force=False):
7682
if not force and name in cls._module_dict:
7783
raise KeyError(f'{name} is already registered in {cls.name}')
7884
cls._module_dict[name] = module_class
85+
cls._inherit_dict[name] = inherit
7986

80-
def register_module(cls, name: Optional[str] = None, force: bool = False, module: Type = None):
87+
def register_module(cls, name: Optional[str] = None, force: bool = False, module: Type = None, *, inherit=False):
8188
if not isinstance(force, bool):
8289
raise TypeError(f'force must be a boolean, but got {type(force)}')
8390

@@ -89,12 +96,12 @@ def register_module(cls, name: Optional[str] = None, force: bool = False, module
8996

9097
# use it as a normal method: x.register_module(module=SomeClass)
9198
if module is not None:
92-
cls._register_module(module_class=module, module_name=name, force=force)
99+
cls._register_module(module_class=module, module_name=name, force=force, inherit=inherit)
93100
return module
94101

95102
# use it as a decorator: @x.register_module()
96103
def _register(reg_cls):
97-
cls._register_module(module_class=reg_cls, module_name=name, force=force)
104+
cls._register_module(module_class=reg_cls, module_name=name, force=force, inherit=inherit)
98105
return reg_cls
99106

100107
return _register
@@ -139,37 +146,83 @@ class SubclassConfig(Generic[T], metaclass=DataclassType):
139146
"""
140147

141148

142-
def dataclass_from_class(cls):
149+
def dataclass_from_class(cls, *, inherit_signature=False):
143150
"""Create a configurable dataclass for a class
144151
based on its ``__init__`` signature.
145152
"""
146153
class_name = cls.__name__ + 'Config'
147154
fields = [
148155
('_type', ClassVar[Type], cls),
149156
]
157+
non_default_fields = []
158+
default_fields = []
150159
init_signature = inspect.signature(cls.__init__)
160+
# Track presented param names. The same name may appear in different classes when **kwargs is passed.
161+
existing_names = dict()
162+
expand_super = False
151163
for idx, param in enumerate(init_signature.parameters.values()):
152164
if idx == 0:
153165
# skip self
154166
continue
155-
if param.kind in (param.VAR_POSITIONAL, param.VAR_KEYWORD):
156-
# FIXME add support for args and kwargs later
167+
if param.kind == param.VAR_POSITIONAL:
168+
if inherit_signature:
169+
# Prohibit uncollected positional varibles
170+
# TODO: should positional params be banned from all use cases?
171+
raise TypeError(f'Use of positional params `*arg` in "{cls}" is prehibitted. Try to use `**kwargs` instead to avoid possible confusion.')
172+
continue
173+
if param.kind == param.VAR_KEYWORD:
174+
# Expand __init__ of the super classes for signitures
175+
if inherit_signature:
176+
expand_super = True
157177
continue
158178

159179
# TODO: fix type annotation for dependency injection
160180
if param.annotation == param.empty:
161181
raise TypeError(f'Parameter of `__init__` "{param}" of "{cls}" must have annotation.')
182+
existing_names[param.name] = (cls, param.annotation)
162183
if param.default != param.empty:
163-
fields.append((param.name, param.annotation, param.default))
184+
default_fields.append((param.name, param.annotation, param.default))
164185
else:
165-
fields.append((param.name, param.annotation))
186+
non_default_fields.append((param.name, param.annotation))
187+
188+
# check the super classes of cls
189+
for scls in cls.mro()[1:]:
190+
scls_signature = inspect.signature(scls.__init__)
191+
for idx, param in enumerate(scls_signature.parameters.values()):
192+
if idx == 0:
193+
# skip self
194+
continue
195+
if param.kind in (param.VAR_POSITIONAL, param.VAR_KEYWORD):
196+
# mro has already contained all the super classes so we don't need to do expansion again.
197+
continue
198+
199+
if param.annotation == param.empty:
200+
raise TypeError(f'Parameter of `__init__` "{param}" of the superclass "{scls}" of "{cls}" must have annotation.')
201+
202+
if param.name in existing_names:
203+
if existing_names[param.name][1] != param.annotation:
204+
raise TypeError(
205+
f'Inconsist annotations found for the same param for inherited classes:\n'
206+
f'\tParam name: {param.name}\n'
207+
f'\tAnnotation in {existing_names[param.name][0]}: {existing_names[param.name][1]}\n'
208+
f'\tAnnotation in {scls}: {param.annotation}'
209+
)
210+
else:
211+
if expand_super:
212+
if param.default != param.empty:
213+
default_fields.append((param.name, param.annotation, param.default))
214+
else:
215+
non_default_fields.append((param.name, param.annotation))
216+
217+
fields = fields + non_default_fields + default_fields
166218

167219
def type_fn(self): return self._type
168220

169221
def build_fn(self, **kwargs):
170222
result = {f.name: getattr(self, f.name) for f in dataclasses.fields(self)}
171223
for k in kwargs:
172224
# silently overwrite the arguments with given ones.
225+
# FIXME: add type check when building?
173226
result[k] = kwargs[k]
174227
try:
175228
return self._type(**result)

utilsd/config/type_def.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -584,9 +584,8 @@ def from_plain(self, plain, ctx):
584584
raise TypeError(f'Expect a dict with key "type", but found {type(plain)}: {plain}')
585585
# copy the raw object to prevent unexpected modification
586586
plain = copy.copy(plain)
587-
588-
type_ = self.registry.get(plain.pop('type'))
589-
dataclass = dataclass_from_class(type_)
587+
type_, inherit = self.registry.get_module_with_inherit(plain.pop('type'))
588+
dataclass = dataclass_from_class(type_, inherit_signature=inherit)
590589
return super().from_plain(plain, ctx, type_=dataclass)
591590

592591
def to_plain(self, obj, ctx):

0 commit comments

Comments
 (0)