Skip to content

Commit cb81b5e

Browse files
committed
[BugFix] Fix tensorclass update
ghstack-source-id: 3f50604b340b7a7cdc710dfaceedf563295b2911 Pull Request resolved: #1255
1 parent 55fab2a commit cb81b5e

File tree

3 files changed

+18
-5
lines changed

3 files changed

+18
-5
lines changed

tensordict/nn/common.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1089,9 +1089,9 @@ def __init__(
10891089
)
10901090

10911091
self.module = module
1092-
if inplace not in (True, False, "empty"):
1092+
if inplace not in (None, True, False, "empty"):
10931093
raise ValueError(
1094-
f"The only accepted valued for inplace is `True`, `False`, or `'empty'`. Got inplace={inplace} "
1094+
f"The only accepted valued for inplace is `None`, `True`, `False`, or `'empty'`. Got inplace={inplace} "
10951095
"instead."
10961096
)
10971097
self.inplace = inplace

tensordict/nn/sequence.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,9 @@ def __init__(
218218
**{key: val for key, val in _zip_strict(modules[0], modules_vals)}
219219
)
220220
super().__init__(
221-
module=nn.ModuleDict(modules), in_keys=in_keys, out_keys=out_keys
221+
module=nn.ModuleDict(modules),
222+
in_keys=in_keys,
223+
out_keys=out_keys,
222224
)
223225
elif len(modules) == 1 and isinstance(
224226
modules[0], collections.abc.MutableSequence
@@ -227,20 +229,25 @@ def __init__(
227229
in_keys, out_keys = self._compute_in_and_out_keys(modules)
228230
self._complete_out_keys = list(out_keys)
229231
super().__init__(
230-
module=nn.ModuleList(modules), in_keys=in_keys, out_keys=out_keys
232+
module=nn.ModuleList(modules),
233+
in_keys=in_keys,
234+
out_keys=out_keys,
231235
)
232236
elif len(modules) == 1 and isinstance(modules[0], dict):
233237
return self.__init__(
234238
collections.OrderedDict(modules[0]),
235239
partial_tolerant=partial_tolerant,
236240
selected_out_keys=selected_out_keys,
241+
inplace=inplace,
237242
)
238243
else:
239244
modules = self._convert_modules(modules)
240245
in_keys, out_keys = self._compute_in_and_out_keys(modules)
241246
self._complete_out_keys = list(out_keys)
242247
super().__init__(
243-
module=nn.ModuleList(list(modules)), in_keys=in_keys, out_keys=out_keys
248+
module=nn.ModuleList(list(modules)),
249+
in_keys=in_keys,
250+
out_keys=out_keys,
244251
)
245252

246253
self.inplace = inplace

tensordict/tensorclass.py

+6
Original file line numberDiff line numberDiff line change
@@ -1694,6 +1694,12 @@ def _update(
16941694
update_batch_size=update_batch_size,
16951695
ignore_lock=ignore_lock,
16961696
)
1697+
# We also need to remove things from non_tensordict
1698+
if self._non_tensordict:
1699+
keys = set(self._tensordict.keys())
1700+
ntd = {k: val for k, val in self._non_tensordict.items() if k not in keys}
1701+
self._non_tensordict.clear()
1702+
self._non_tensordict.update(ntd)
16971703
return self
16981704

16991705

0 commit comments

Comments
 (0)