@@ -33,6 +33,7 @@ def __new__(cls, clsname, bases, attrs, name=None):
33
33
cls = super ().__new__ (cls , clsname , bases , attrs )
34
34
cls ._name = name
35
35
cls ._module_dict = {}
36
+ cls ._inherit_dict = {} # track whether a module should expand its superclass init parameters when specified with **kwargs
36
37
return cls
37
38
38
39
@property
@@ -47,7 +48,7 @@ def __len__(cls):
47
48
return len (cls ._module_dict )
48
49
49
50
def __contains__ (cls , key ):
50
- return cls . get ( key ) is not None
51
+ return key in cls . _module_dict
51
52
52
53
def __repr__ (cls ):
53
54
format_str = cls .__name__ + f'(name={ cls ._name } , items={ cls ._module_dict } )'
@@ -57,14 +58,19 @@ def get(cls, key):
57
58
if key in cls ._module_dict :
58
59
return cls ._module_dict [key ]
59
60
raise KeyError (f'{ key } not found in { cls } ' )
60
-
61
+
62
+ def get_module_with_inherit (cls , key ):
63
+ if key in cls ._module_dict :
64
+ return cls ._module_dict [key ], cls ._inherit_dict [key ]
65
+ raise KeyError (f'{ key } not found in { cls } ' )
66
+
61
67
def inverse_get (cls , value ):
62
68
keys = [k for k , v in cls ._module_dict .items () if v == value ]
63
69
if len (keys ) != 1 :
64
70
raise ValueError (f'{ value } needs to appear exactly once in { cls } ' )
65
71
return keys [0 ]
66
72
67
- def _register_module (cls , module_class , module_name = None , force = False ):
73
+ def _register_module (cls , module_class , module_name = None , force = False , * , inherit = False ):
68
74
if not inspect .isclass (module_class ):
69
75
raise TypeError (f'module must be a class, but got { type (module_class )} ' )
70
76
@@ -76,8 +82,9 @@ def _register_module(cls, module_class, module_name=None, force=False):
76
82
if not force and name in cls ._module_dict :
77
83
raise KeyError (f'{ name } is already registered in { cls .name } ' )
78
84
cls ._module_dict [name ] = module_class
85
+ cls ._inherit_dict [name ] = inherit
79
86
80
- def register_module (cls , name : Optional [str ] = None , force : bool = False , module : Type = None ):
87
+ def register_module (cls , name : Optional [str ] = None , force : bool = False , module : Type = None , * , inherit = False ):
81
88
if not isinstance (force , bool ):
82
89
raise TypeError (f'force must be a boolean, but got { type (force )} ' )
83
90
@@ -89,12 +96,12 @@ def register_module(cls, name: Optional[str] = None, force: bool = False, module
89
96
90
97
# use it as a normal method: x.register_module(module=SomeClass)
91
98
if module is not None :
92
- cls ._register_module (module_class = module , module_name = name , force = force )
99
+ cls ._register_module (module_class = module , module_name = name , force = force , inherit = inherit )
93
100
return module
94
101
95
102
# use it as a decorator: @x.register_module()
96
103
def _register (reg_cls ):
97
- cls ._register_module (module_class = reg_cls , module_name = name , force = force )
104
+ cls ._register_module (module_class = reg_cls , module_name = name , force = force , inherit = inherit )
98
105
return reg_cls
99
106
100
107
return _register
@@ -139,37 +146,83 @@ class SubclassConfig(Generic[T], metaclass=DataclassType):
139
146
"""
140
147
141
148
142
- def dataclass_from_class (cls ):
149
+ def dataclass_from_class (cls , * , inherit_signature = False ):
143
150
"""Create a configurable dataclass for a class
144
151
based on its ``__init__`` signature.
145
152
"""
146
153
class_name = cls .__name__ + 'Config'
147
154
fields = [
148
155
('_type' , ClassVar [Type ], cls ),
149
156
]
157
+ non_default_fields = []
158
+ default_fields = []
150
159
init_signature = inspect .signature (cls .__init__ )
160
+ # Track presented param names. The same name may appear in different classes when **kwargs is passed.
161
+ existing_names = dict ()
162
+ expand_super = False
151
163
for idx , param in enumerate (init_signature .parameters .values ()):
152
164
if idx == 0 :
153
165
# skip self
154
166
continue
155
- if param .kind in (param .VAR_POSITIONAL , param .VAR_KEYWORD ):
156
- # FIXME add support for args and kwargs later
167
+ if param .kind == param .VAR_POSITIONAL :
168
+ if inherit_signature :
169
+ # Prohibit uncollected positional varibles
170
+ # TODO: should positional params be banned from all use cases?
171
+ raise TypeError (f'Use of positional params `*arg` in "{ cls } " is prehibitted. Try to use `**kwargs` instead to avoid possible confusion.' )
172
+ continue
173
+ if param .kind == param .VAR_KEYWORD :
174
+ # Expand __init__ of the super classes for signitures
175
+ if inherit_signature :
176
+ expand_super = True
157
177
continue
158
178
159
179
# TODO: fix type annotation for dependency injection
160
180
if param .annotation == param .empty :
161
181
raise TypeError (f'Parameter of `__init__` "{ param } " of "{ cls } " must have annotation.' )
182
+ existing_names [param .name ] = (cls , param .annotation )
162
183
if param .default != param .empty :
163
- fields .append ((param .name , param .annotation , param .default ))
184
+ default_fields .append ((param .name , param .annotation , param .default ))
164
185
else :
165
- fields .append ((param .name , param .annotation ))
186
+ non_default_fields .append ((param .name , param .annotation ))
187
+
188
+ # check the super classes of cls
189
+ for scls in cls .mro ()[1 :]:
190
+ scls_signature = inspect .signature (scls .__init__ )
191
+ for idx , param in enumerate (scls_signature .parameters .values ()):
192
+ if idx == 0 :
193
+ # skip self
194
+ continue
195
+ if param .kind in (param .VAR_POSITIONAL , param .VAR_KEYWORD ):
196
+ # mro has already contained all the super classes so we don't need to do expansion again.
197
+ continue
198
+
199
+ if param .annotation == param .empty :
200
+ raise TypeError (f'Parameter of `__init__` "{ param } " of the superclass "{ scls } " of "{ cls } " must have annotation.' )
201
+
202
+ if param .name in existing_names :
203
+ if existing_names [param .name ][1 ] != param .annotation :
204
+ raise TypeError (
205
+ f'Inconsist annotations found for the same param for inherited classes:\n '
206
+ f'\t Param name: { param .name } \n '
207
+ f'\t Annotation in { existing_names [param .name ][0 ]} : { existing_names [param .name ][1 ]} \n '
208
+ f'\t Annotation in { scls } : { param .annotation } '
209
+ )
210
+ else :
211
+ if expand_super :
212
+ if param .default != param .empty :
213
+ default_fields .append ((param .name , param .annotation , param .default ))
214
+ else :
215
+ non_default_fields .append ((param .name , param .annotation ))
216
+
217
+ fields = fields + non_default_fields + default_fields
166
218
167
219
def type_fn (self ): return self ._type
168
220
169
221
def build_fn (self , ** kwargs ):
170
222
result = {f .name : getattr (self , f .name ) for f in dataclasses .fields (self )}
171
223
for k in kwargs :
172
224
# silently overwrite the arguments with given ones.
225
+ # FIXME: add type check when building?
173
226
result [k ] = kwargs [k ]
174
227
try :
175
228
return self ._type (** result )
0 commit comments