@@ -104,6 +104,9 @@ class PPOLoss(LossModule):
104
104
* **Scalar**: one value applied to the summed entropy of every action head.
105
105
* **Mapping** ``{head_name: coef}`` gives an individual coefficient for each action-head's entropy.
106
106
Defaults to ``0.01``.
107
+ log_explained_variance (bool, optional): if ``True``, the explained variance of the critic
108
+ predictions w.r.t. value targets will be computed and logged as ``"explained_variance"``.
109
+ This can help monitor critic quality during training. Best possible score is 1.0, lower values are worse. Defaults to ``True``.
107
110
critic_coef (scalar, optional): critic loss multiplier when computing the total
108
111
loss. Defaults to ``1.0``. Set ``critic_coef`` to ``None`` to exclude the value
109
112
loss from the forward outputs.
@@ -349,6 +352,7 @@ def __init__(
349
352
entropy_bonus : bool = True ,
350
353
samples_mc_entropy : int = 1 ,
351
354
entropy_coeff : float | Mapping [str , float ] = 0.01 ,
355
+ log_explained_variance : bool = True ,
352
356
critic_coef : float | None = None ,
353
357
loss_critic_type : str = "smooth_l1" ,
354
358
normalize_advantage : bool = False ,
@@ -413,6 +417,7 @@ def __init__(
413
417
self .critic_network_params = None
414
418
self .target_critic_network_params = None
415
419
420
+ self .log_explained_variance = log_explained_variance
416
421
self .samples_mc_entropy = samples_mc_entropy
417
422
self .entropy_bonus = entropy_bonus
418
423
self .separate_losses = separate_losses
@@ -745,6 +750,16 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor:
745
750
self .loss_critic_type ,
746
751
)
747
752
753
+ explained_variance = None
754
+ if self .log_explained_variance :
755
+ with torch .no_grad (): # <‑‑ break grad‐flow
756
+ tgt = target_return .detach ()
757
+ pred = state_value .detach ()
758
+ eps = torch .finfo (tgt .dtype ).eps
759
+ resid = torch .var (tgt - pred , unbiased = False , dim = 0 )
760
+ total = torch .var (tgt , unbiased = False , dim = 0 )
761
+ explained_variance = 1.0 - resid / (total + eps )
762
+
748
763
self ._clear_weakrefs (
749
764
tensordict ,
750
765
"actor_network_params" ,
@@ -753,8 +768,8 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor:
753
768
"target_critic_network_params" ,
754
769
)
755
770
if self ._has_critic :
756
- return self .critic_coef * loss_value , clip_fraction
757
- return loss_value , clip_fraction
771
+ return self .critic_coef * loss_value , clip_fraction , explained_variance
772
+ return loss_value , clip_fraction , explained_variance
758
773
759
774
@property
760
775
@_cache_values
@@ -804,10 +819,12 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
804
819
td_out .set ("entropy" , entropy .detach ().mean ()) # for logging
805
820
td_out .set ("loss_entropy" , self ._weighted_loss_entropy (entropy ))
806
821
if self ._has_critic :
807
- loss_critic , value_clip_fraction = self .loss_critic (tensordict )
822
+ loss_critic , value_clip_fraction , explained_variance = self .loss_critic (tensordict )
808
823
td_out .set ("loss_critic" , loss_critic )
809
824
if value_clip_fraction is not None :
810
825
td_out .set ("value_clip_fraction" , value_clip_fraction )
826
+ if explained_variance is not None :
827
+ td_out .set ("explained_variance" , explained_variance )
811
828
td_out = td_out .named_apply (
812
829
lambda name , value : _reduce (value , reduction = self .reduction ).squeeze (- 1 )
813
830
if name .startswith ("loss_" )
@@ -1172,10 +1189,12 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
1172
1189
td_out .set ("entropy" , entropy .detach ().mean ()) # for logging
1173
1190
td_out .set ("loss_entropy" , self ._weighted_loss_entropy (entropy ))
1174
1191
if self ._has_critic :
1175
- loss_critic , value_clip_fraction = self .loss_critic (tensordict )
1192
+ loss_critic , value_clip_fraction , explained_variance = self .loss_critic (tensordict )
1176
1193
td_out .set ("loss_critic" , loss_critic )
1177
1194
if value_clip_fraction is not None :
1178
1195
td_out .set ("value_clip_fraction" , value_clip_fraction )
1196
+ if explained_variance is not None :
1197
+ td_out .set ("explained_variance" , explained_variance )
1179
1198
1180
1199
td_out .set ("ESS" , _reduce (ess , self .reduction ) / batch )
1181
1200
td_out = td_out .named_apply (
@@ -1518,10 +1537,12 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict:
1518
1537
td_out .set ("entropy" , entropy .detach ().mean ()) # for logging
1519
1538
td_out .set ("loss_entropy" , self ._weighted_loss_entropy (entropy ))
1520
1539
if self ._has_critic :
1521
- loss_critic , value_clip_fraction = self .loss_critic (tensordict_copy )
1540
+ loss_critic , value_clip_fraction , explained_variance = self .loss_critic (tensordict_copy )
1522
1541
td_out .set ("loss_critic" , loss_critic )
1523
1542
if value_clip_fraction is not None :
1524
1543
td_out .set ("value_clip_fraction" , value_clip_fraction )
1544
+ if explained_variance is not None :
1545
+ td_out .set ("explained_variance" , explained_variance )
1525
1546
td_out = td_out .named_apply (
1526
1547
lambda name , value : _reduce (value , reduction = self .reduction ).squeeze (- 1 )
1527
1548
if name .startswith ("loss_" )
0 commit comments