@@ -1252,7 +1252,9 @@ class GAE(ValueEstimatorBase):
1252
1252
Args:
1253
1253
gamma (scalar): exponential mean discount.
1254
1254
lmbda (scalar): trajectory discount.
1255
- value_network (TensorDictModule): value operator used to retrieve the value estimates.
1255
+ value_network (TensorDictModule, optional): value operator used to retrieve the value estimates.
1256
+ If ``None``, this module will expect the ``"state_value"`` keys to be already filled, and
1257
+ will not call the value network to produce it.
1256
1258
average_gae (bool): if ``True``, the resulting GAE values will be standardized.
1257
1259
Default is ``False``.
1258
1260
differentiable (bool, optional): if ``True``, gradients are propagated through
@@ -1327,7 +1329,7 @@ def __init__(
1327
1329
* ,
1328
1330
gamma : float | torch .Tensor ,
1329
1331
lmbda : float | torch .Tensor ,
1330
- value_network : TensorDictModule ,
1332
+ value_network : TensorDictModule | None ,
1331
1333
average_gae : bool = False ,
1332
1334
differentiable : bool = False ,
1333
1335
vectorized : bool | None = None ,
@@ -1499,6 +1501,15 @@ def forward(
1499
1501
value = tensordict .get (self .tensor_keys .value )
1500
1502
next_value = tensordict .get (("next" , self .tensor_keys .value ))
1501
1503
1504
+ if value is None :
1505
+ raise ValueError (
1506
+ f"The tensor with key { self .tensor_keys .value } is missing, and no value network was provided."
1507
+ )
1508
+ if next_value is None :
1509
+ raise ValueError (
1510
+ f"The tensor with key { ('next' , self .tensor_keys .value )} is missing, and no value network was provided."
1511
+ )
1512
+
1502
1513
done = tensordict .get (("next" , self .tensor_keys .done ))
1503
1514
terminated = tensordict .get (("next" , self .tensor_keys .terminated ), default = done )
1504
1515
time_dim = self ._get_time_dim (time_dim , tensordict )
0 commit comments