Skip to content

Commit 773c366

Browse files
louisfauryLouis Faury
andauthored
[Format] Value-network can be None in GAE typing (#3029)
Co-authored-by: Louis Faury <[email protected]>
1 parent ed051bc commit 773c366

File tree

1 file changed

+13
-2
lines changed

1 file changed

+13
-2
lines changed

torchrl/objectives/value/advantages.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1252,7 +1252,9 @@ class GAE(ValueEstimatorBase):
12521252
Args:
12531253
gamma (scalar): exponential mean discount.
12541254
lmbda (scalar): trajectory discount.
1255-
value_network (TensorDictModule): value operator used to retrieve the value estimates.
1255+
value_network (TensorDictModule, optional): value operator used to retrieve the value estimates.
1256+
If ``None``, this module will expect the ``"state_value"`` keys to be already filled, and
1257+
will not call the value network to produce it.
12561258
average_gae (bool): if ``True``, the resulting GAE values will be standardized.
12571259
Default is ``False``.
12581260
differentiable (bool, optional): if ``True``, gradients are propagated through
@@ -1327,7 +1329,7 @@ def __init__(
13271329
*,
13281330
gamma: float | torch.Tensor,
13291331
lmbda: float | torch.Tensor,
1330-
value_network: TensorDictModule,
1332+
value_network: TensorDictModule | None,
13311333
average_gae: bool = False,
13321334
differentiable: bool = False,
13331335
vectorized: bool | None = None,
@@ -1499,6 +1501,15 @@ def forward(
14991501
value = tensordict.get(self.tensor_keys.value)
15001502
next_value = tensordict.get(("next", self.tensor_keys.value))
15011503

1504+
if value is None:
1505+
raise ValueError(
1506+
f"The tensor with key {self.tensor_keys.value} is missing, and no value network was provided."
1507+
)
1508+
if next_value is None:
1509+
raise ValueError(
1510+
f"The tensor with key {('next', self.tensor_keys.value)} is missing, and no value network was provided."
1511+
)
1512+
15021513
done = tensordict.get(("next", self.tensor_keys.done))
15031514
terminated = tensordict.get(("next", self.tensor_keys.terminated), default=done)
15041515
time_dim = self._get_time_dim(time_dim, tensordict)

0 commit comments

Comments
 (0)