Skip to content

Commit 55fab2a

Browse files
committed
[Feature] inplace arg for TensorDictSequential
ghstack-source-id: 3f392cbce42b5adf696604211c16f807cad048c3 Pull Request resolved: #1253
1 parent eb2fd8e commit 55fab2a

File tree

3 files changed

+92
-3
lines changed

3 files changed

+92
-3
lines changed

tensordict/nn/sequence.py

+55-2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from typing import Any, Callable, Iterable, List, OrderedDict, overload
1212

1313
from tensordict._nestedkey import NestedKey
14+
from tensordict._td import TensorDict
1415

1516
from tensordict.nn.common import (
1617
dispatch,
@@ -61,14 +62,17 @@ class TensorDictSequential(TensorDictModule):
6162
Regular ``dict`` inputs will be converted to ``OrderedDict`` if necessary.
6263
6364
Keyword Args:
64-
partial_tolerant (bool, optional): if True, the input tensordict can miss some of the input keys.
65+
partial_tolerant (bool, optional): if True, the input tensordict can miss some of the input keys.
6566
If so, the only module that will be executed are those who can be executed given the keys that
6667
are present.
6768
Also, if the input tensordict is a lazy stack of tensordicts AND if partial_tolerant is :obj:`True` AND if the
6869
stack does not have the required keys, then TensorDictSequential will scan through the sub-tensordicts
6970
looking for those that have the required keys, if any. Defaults to False.
70-
selected_out_keys (iterable of NestedKeys, optional): the list of out-keys to select. If not provided, all
71+
selected_out_keys (iterable of NestedKeys, optional): the list of out-keys to select. If not provided, all
7172
``out_keys`` will be written.
73+
inplace (bool, optional): if `True`, the input tensordict is modified in-place. If `False`, a new empty
74+
:class:`~tensordict.TensorDict` instance is created. If `"empty"`, `input.empty()` is used instead (ie, the
75+
output preserves type, device and batch-size). Defaults to `None` (relies on sub-modules).
7276
7377
.. note::
7478
A :class:`TensorDictSequential` instance may have a long list of output keys, and one may wish to remove
@@ -185,6 +189,7 @@ def __init__(
185189
*,
186190
partial_tolerant: bool = False,
187191
selected_out_keys: List[NestedKey] | None = None,
192+
inplace: bool | None = None,
188193
) -> None: ...
189194

190195
@overload
@@ -194,13 +199,15 @@ def __init__(
194199
*,
195200
partial_tolerant: bool = False,
196201
selected_out_keys: List[NestedKey] | None = None,
202+
inplace: bool | None = None,
197203
) -> None: ...
198204

199205
def __init__(
200206
self,
201207
*modules: Callable[[TensorDictBase], TensorDictBase],
202208
partial_tolerant: bool = False,
203209
selected_out_keys: List[NestedKey] | None = None,
210+
inplace: bool | None = None,
204211
) -> None:
205212

206213
if len(modules) == 1 and isinstance(modules[0], collections.OrderedDict):
@@ -236,6 +243,7 @@ def __init__(
236243
module=nn.ModuleList(list(modules)), in_keys=in_keys, out_keys=out_keys
237244
)
238245

246+
self.inplace = inplace
239247
self.partial_tolerant = partial_tolerant
240248
if selected_out_keys:
241249
self._select_before_return = True
@@ -452,6 +460,43 @@ def select_subsequence(
452460
in_keys=['b'],
453461
out_keys=['d', 'e'])
454462
463+
The `inplace` argument allows for a fine-grained control over the output type, allowing for instance to write
464+
the result of the computational graph in the input object without tracking the intermediate tensors.
465+
466+
Example:
467+
>>> import torch
468+
>>> from tensordict import TensorClass
469+
>>> from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq
470+
>>>
471+
>>> class MyClass(TensorClass):
472+
... input: torch.Tensor
473+
... output: torch.Tensor | None = None
474+
>>>
475+
>>> obj = MyClass(torch.randn(2, 3), batch_size=(2,))
476+
>>>
477+
>>> model = Seq(
478+
... Mod(
479+
... lambda x: (x + 1, x - 1),
480+
... in_keys=["input"],
481+
... out_keys=[("intermediate", "0"), ("intermediate", "1")],
482+
... inplace=False
483+
... ),
484+
... Mod(
485+
... lambda y0, y1: y0 * y1,
486+
... in_keys=[("intermediate", "0"), ("intermediate", "1")],
487+
... out_keys=["output"],
488+
... inplace=False
489+
... ),
490+
... inplace=True, )
491+
>>> print(model(obj))
492+
MyClass(
493+
input=Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
494+
output=Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
495+
output=None,
496+
batch_size=torch.Size([2]),
497+
device=None,
498+
is_shared=False)
499+
455500
"""
456501
if in_keys is None:
457502
in_keys = deepcopy(self.in_keys)
@@ -558,6 +603,14 @@ def forward(
558603
tensordict_exec = tensordict.copy()
559604
else:
560605
tensordict_exec = tensordict
606+
if tensordict_out is None:
607+
if self.inplace is True:
608+
tensordict_out = tensordict
609+
elif self.inplace is False:
610+
tensordict_out = TensorDict()
611+
elif self.inplace == "empty":
612+
tensordict_out = tensordict.empty()
613+
561614
if not len(kwargs):
562615
for module in self._module_iter():
563616
try:

tensordict/tensorclass.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1939,7 +1939,8 @@ def _from_dict(
19391939
# We don't want to enforce them to be tensorclasses so we can't do much about it...
19401940
return cls.from_tensordict(
19411941
tensordict=TensorDict(
1942-
batch_size=batch_size, device=device, batch_dims=batch_dims
1942+
batch_size=batch_size,
1943+
device=device,
19431944
),
19441945
non_tensordict=input_dict,
19451946
)

test/test_nn.py

+35
Original file line numberDiff line numberDiff line change
@@ -1009,6 +1009,41 @@ def test_tdmodule_inplace(self):
10091009

10101010

10111011
class TestTDSequence:
1012+
@pytest.mark.parametrize("inplace", [True, False, None])
1013+
@pytest.mark.parametrize("module_inplace", [True, False])
1014+
def test_tdseq_inplace(self, inplace, module_inplace):
1015+
model = TensorDictSequential(
1016+
TensorDictModule(
1017+
lambda x: (x + 1, x - 1),
1018+
in_keys=["input"],
1019+
out_keys=[("intermediate", "0"), ("intermediate", "1")],
1020+
inplace=module_inplace,
1021+
),
1022+
TensorDictModule(
1023+
lambda y0, y1: y0 * y1,
1024+
in_keys=[("intermediate", "0"), ("intermediate", "1")],
1025+
out_keys=["output"],
1026+
inplace=module_inplace,
1027+
),
1028+
inplace=inplace,
1029+
)
1030+
input = TensorDict(input=torch.zeros(()))
1031+
output = model(input)
1032+
if inplace:
1033+
assert output is input
1034+
assert "input" in output
1035+
else:
1036+
if not module_inplace or inplace is False:
1037+
# In this case, inplace=False and inplace=None have the same behavior
1038+
assert output is not input, (module_inplace, inplace)
1039+
assert "input" not in output, (module_inplace, inplace)
1040+
else:
1041+
# In this case, inplace=False and inplace=None have the same behavior
1042+
assert output is input, (module_inplace, inplace)
1043+
assert "input" in output, (module_inplace, inplace)
1044+
1045+
assert "output" in output
1046+
10121047
def test_ordered_dict(self):
10131048
linear = nn.Linear(3, 4)
10141049
linear.weight.data.fill_(0)

0 commit comments

Comments
 (0)