@@ -218,7 +218,9 @@ def __init__(
218
218
** {key : val for key , val in _zip_strict (modules [0 ], modules_vals )}
219
219
)
220
220
super ().__init__ (
221
- module = nn .ModuleDict (modules ), in_keys = in_keys , out_keys = out_keys
221
+ module = nn .ModuleDict (modules ),
222
+ in_keys = in_keys ,
223
+ out_keys = out_keys ,
222
224
)
223
225
elif len (modules ) == 1 and isinstance (
224
226
modules [0 ], collections .abc .MutableSequence
@@ -227,20 +229,25 @@ def __init__(
227
229
in_keys , out_keys = self ._compute_in_and_out_keys (modules )
228
230
self ._complete_out_keys = list (out_keys )
229
231
super ().__init__ (
230
- module = nn .ModuleList (modules ), in_keys = in_keys , out_keys = out_keys
232
+ module = nn .ModuleList (modules ),
233
+ in_keys = in_keys ,
234
+ out_keys = out_keys ,
231
235
)
232
236
elif len (modules ) == 1 and isinstance (modules [0 ], dict ):
233
237
return self .__init__ (
234
238
collections .OrderedDict (modules [0 ]),
235
239
partial_tolerant = partial_tolerant ,
236
240
selected_out_keys = selected_out_keys ,
241
+ inplace = inplace ,
237
242
)
238
243
else :
239
244
modules = self ._convert_modules (modules )
240
245
in_keys , out_keys = self ._compute_in_and_out_keys (modules )
241
246
self ._complete_out_keys = list (out_keys )
242
247
super ().__init__ (
243
- module = nn .ModuleList (list (modules )), in_keys = in_keys , out_keys = out_keys
248
+ module = nn .ModuleList (list (modules )),
249
+ in_keys = in_keys ,
250
+ out_keys = out_keys ,
244
251
)
245
252
246
253
self .inplace = inplace
0 commit comments