Skip to content

Commit e1e15d6

Browse files
NTR0314Oswald Zink
andauthored
[Feature] Add optional Explained Variance logging (#3010)
Co-authored-by: Oswald Zink <[email protected]>
1 parent 92b52a0 commit e1e15d6

File tree

1 file changed

+26
-5
lines changed

1 file changed

+26
-5
lines changed

torchrl/objectives/ppo.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,9 @@ class PPOLoss(LossModule):
104104
* **Scalar**: one value applied to the summed entropy of every action head.
105105
* **Mapping** ``{head_name: coef}`` gives an individual coefficient for each action-head's entropy.
106106
Defaults to ``0.01``.
107+
log_explained_variance (bool, optional): if ``True``, the explained variance of the critic
108+
predictions w.r.t. value targets will be computed and logged as ``"explained_variance"``.
109+
This can help monitor critic quality during training. Best possible score is 1.0, lower values are worse. Defaults to ``True``.
107110
critic_coef (scalar, optional): critic loss multiplier when computing the total
108111
loss. Defaults to ``1.0``. Set ``critic_coef`` to ``None`` to exclude the value
109112
loss from the forward outputs.
@@ -349,6 +352,7 @@ def __init__(
349352
entropy_bonus: bool = True,
350353
samples_mc_entropy: int = 1,
351354
entropy_coeff: float | Mapping[str, float] = 0.01,
355+
log_explained_variance: bool = True,
352356
critic_coef: float | None = None,
353357
loss_critic_type: str = "smooth_l1",
354358
normalize_advantage: bool = False,
@@ -413,6 +417,7 @@ def __init__(
413417
self.critic_network_params = None
414418
self.target_critic_network_params = None
415419

420+
self.log_explained_variance = log_explained_variance
416421
self.samples_mc_entropy = samples_mc_entropy
417422
self.entropy_bonus = entropy_bonus
418423
self.separate_losses = separate_losses
@@ -745,6 +750,16 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor:
745750
self.loss_critic_type,
746751
)
747752

753+
explained_variance = None
754+
if self.log_explained_variance:
755+
with torch.no_grad(): # <‑‑ break grad‐flow
756+
tgt = target_return.detach()
757+
pred = state_value.detach()
758+
eps = torch.finfo(tgt.dtype).eps
759+
resid = torch.var(tgt - pred, unbiased=False, dim=0)
760+
total = torch.var(tgt, unbiased=False, dim=0)
761+
explained_variance = 1.0 - resid / (total + eps)
762+
748763
self._clear_weakrefs(
749764
tensordict,
750765
"actor_network_params",
@@ -753,8 +768,8 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor:
753768
"target_critic_network_params",
754769
)
755770
if self._has_critic:
756-
return self.critic_coef * loss_value, clip_fraction
757-
return loss_value, clip_fraction
771+
return self.critic_coef * loss_value, clip_fraction, explained_variance
772+
return loss_value, clip_fraction, explained_variance
758773

759774
@property
760775
@_cache_values
@@ -804,10 +819,12 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
804819
td_out.set("entropy", entropy.detach().mean()) # for logging
805820
td_out.set("loss_entropy", self._weighted_loss_entropy(entropy))
806821
if self._has_critic:
807-
loss_critic, value_clip_fraction = self.loss_critic(tensordict)
822+
loss_critic, value_clip_fraction, explained_variance = self.loss_critic(tensordict)
808823
td_out.set("loss_critic", loss_critic)
809824
if value_clip_fraction is not None:
810825
td_out.set("value_clip_fraction", value_clip_fraction)
826+
if explained_variance is not None:
827+
td_out.set("explained_variance", explained_variance)
811828
td_out = td_out.named_apply(
812829
lambda name, value: _reduce(value, reduction=self.reduction).squeeze(-1)
813830
if name.startswith("loss_")
@@ -1172,10 +1189,12 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
11721189
td_out.set("entropy", entropy.detach().mean()) # for logging
11731190
td_out.set("loss_entropy", self._weighted_loss_entropy(entropy))
11741191
if self._has_critic:
1175-
loss_critic, value_clip_fraction = self.loss_critic(tensordict)
1192+
loss_critic, value_clip_fraction, explained_variance = self.loss_critic(tensordict)
11761193
td_out.set("loss_critic", loss_critic)
11771194
if value_clip_fraction is not None:
11781195
td_out.set("value_clip_fraction", value_clip_fraction)
1196+
if explained_variance is not None:
1197+
td_out.set("explained_variance", explained_variance)
11791198

11801199
td_out.set("ESS", _reduce(ess, self.reduction) / batch)
11811200
td_out = td_out.named_apply(
@@ -1518,10 +1537,12 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict:
15181537
td_out.set("entropy", entropy.detach().mean()) # for logging
15191538
td_out.set("loss_entropy", self._weighted_loss_entropy(entropy))
15201539
if self._has_critic:
1521-
loss_critic, value_clip_fraction = self.loss_critic(tensordict_copy)
1540+
loss_critic, value_clip_fraction, explained_variance = self.loss_critic(tensordict_copy)
15221541
td_out.set("loss_critic", loss_critic)
15231542
if value_clip_fraction is not None:
15241543
td_out.set("value_clip_fraction", value_clip_fraction)
1544+
if explained_variance is not None:
1545+
td_out.set("explained_variance", explained_variance)
15251546
td_out = td_out.named_apply(
15261547
lambda name, value: _reduce(value, reduction=self.reduction).squeeze(-1)
15271548
if name.startswith("loss_")

0 commit comments

Comments
 (0)