11
11
from typing import Any , Callable , Iterable , List , OrderedDict , overload
12
12
13
13
from tensordict ._nestedkey import NestedKey
14
+ from tensordict ._td import TensorDict
14
15
15
16
from tensordict .nn .common import (
16
17
dispatch ,
@@ -61,14 +62,17 @@ class TensorDictSequential(TensorDictModule):
61
62
Regular ``dict`` inputs will be converted to ``OrderedDict`` if necessary.
62
63
63
64
Keyword Args:
64
- partial_tolerant (bool, optional): if True, the input tensordict can miss some of the input keys.
65
+ partial_tolerant (bool, optional): if True, the input tensordict can miss some of the input keys.
65
66
If so, the only module that will be executed are those who can be executed given the keys that
66
67
are present.
67
68
Also, if the input tensordict is a lazy stack of tensordicts AND if partial_tolerant is :obj:`True` AND if the
68
69
stack does not have the required keys, then TensorDictSequential will scan through the sub-tensordicts
69
70
looking for those that have the required keys, if any. Defaults to False.
70
- selected_out_keys (iterable of NestedKeys, optional): the list of out-keys to select. If not provided, all
71
+ selected_out_keys (iterable of NestedKeys, optional): the list of out-keys to select. If not provided, all
71
72
``out_keys`` will be written.
73
+ inplace (bool, optional): if `True`, the input tensordict is modified in-place. If `False`, a new empty
74
+ :class:`~tensordict.TensorDict` instance is created. If `"empty"`, `input.empty()` is used instead (ie, the
75
+ output preserves type, device and batch-size). Defaults to `None` (relies on sub-modules).
72
76
73
77
.. note::
74
78
A :class:`TensorDictSequential` instance may have a long list of output keys, and one may wish to remove
@@ -185,6 +189,7 @@ def __init__(
185
189
* ,
186
190
partial_tolerant : bool = False ,
187
191
selected_out_keys : List [NestedKey ] | None = None ,
192
+ inplace : bool | None = None ,
188
193
) -> None : ...
189
194
190
195
@overload
@@ -194,13 +199,15 @@ def __init__(
194
199
* ,
195
200
partial_tolerant : bool = False ,
196
201
selected_out_keys : List [NestedKey ] | None = None ,
202
+ inplace : bool | None = None ,
197
203
) -> None : ...
198
204
199
205
def __init__ (
200
206
self ,
201
207
* modules : Callable [[TensorDictBase ], TensorDictBase ],
202
208
partial_tolerant : bool = False ,
203
209
selected_out_keys : List [NestedKey ] | None = None ,
210
+ inplace : bool | None = None ,
204
211
) -> None :
205
212
206
213
if len (modules ) == 1 and isinstance (modules [0 ], collections .OrderedDict ):
@@ -236,6 +243,7 @@ def __init__(
236
243
module = nn .ModuleList (list (modules )), in_keys = in_keys , out_keys = out_keys
237
244
)
238
245
246
+ self .inplace = inplace
239
247
self .partial_tolerant = partial_tolerant
240
248
if selected_out_keys :
241
249
self ._select_before_return = True
@@ -452,6 +460,43 @@ def select_subsequence(
452
460
in_keys=['b'],
453
461
out_keys=['d', 'e'])
454
462
463
+ The `inplace` argument allows for a fine-grained control over the output type, allowing for instance to write
464
+ the result of the computational graph in the input object without tracking the intermediate tensors.
465
+
466
+ Example:
467
+ >>> import torch
468
+ >>> from tensordict import TensorClass
469
+ >>> from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq
470
+ >>>
471
+ >>> class MyClass(TensorClass):
472
+ ... input: torch.Tensor
473
+ ... output: torch.Tensor | None = None
474
+ >>>
475
+ >>> obj = MyClass(torch.randn(2, 3), batch_size=(2,))
476
+ >>>
477
+ >>> model = Seq(
478
+ ... Mod(
479
+ ... lambda x: (x + 1, x - 1),
480
+ ... in_keys=["input"],
481
+ ... out_keys=[("intermediate", "0"), ("intermediate", "1")],
482
+ ... inplace=False
483
+ ... ),
484
+ ... Mod(
485
+ ... lambda y0, y1: y0 * y1,
486
+ ... in_keys=[("intermediate", "0"), ("intermediate", "1")],
487
+ ... out_keys=["output"],
488
+ ... inplace=False
489
+ ... ),
490
+ ... inplace=True, )
491
+ >>> print(model(obj))
492
+ MyClass(
493
+ input=Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
494
+ output=Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
495
+ output=None,
496
+ batch_size=torch.Size([2]),
497
+ device=None,
498
+ is_shared=False)
499
+
455
500
"""
456
501
if in_keys is None :
457
502
in_keys = deepcopy (self .in_keys )
@@ -558,6 +603,14 @@ def forward(
558
603
tensordict_exec = tensordict .copy ()
559
604
else :
560
605
tensordict_exec = tensordict
606
+ if tensordict_out is None :
607
+ if self .inplace is True :
608
+ tensordict_out = tensordict
609
+ elif self .inplace is False :
610
+ tensordict_out = TensorDict ()
611
+ elif self .inplace == "empty" :
612
+ tensordict_out = tensordict .empty ()
613
+
561
614
if not len (kwargs ):
562
615
for module in self ._module_iter ():
563
616
try :
0 commit comments