Skip to content

Commit 205243c

Browse files
authored
[Feature] Compose.pop (#3026)
1 parent 77dbc6c commit 205243c

File tree

3 files changed

+98
-8
lines changed

3 files changed

+98
-8
lines changed

test/test_transforms.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9674,6 +9674,7 @@ def _test_vecnorm_subproc_auto(
96749674
def rename_t(self):
96759675
return RenameTransform(in_keys=["observation"], out_keys=[("some", "obs")])
96769676

9677+
@retry(AssertionError, tries=10, delay=0)
96779678
@pytest.mark.parametrize("nprc", [2, 5])
96789679
def test_vecnorm_parallel_auto(self, nprc):
96799680
queues = []
@@ -10619,6 +10620,38 @@ def test_compose(self, keys, batch, device, nchannels=1, N=4):
1061910620
[nchannels * N, 16, 16]
1062010621
)
1062110622

10623+
def test_compose_pop(self):
10624+
t1 = CatFrames(in_keys=["a", "b"], N=2, dim=-1)
10625+
t2 = FiniteTensorDictCheck()
10626+
t3 = ExcludeTransform()
10627+
compose = Compose(t1, t2, t3)
10628+
assert len(compose.transforms) == 3
10629+
p = compose.pop()
10630+
assert p is t3
10631+
assert len(compose.transforms) == 2
10632+
p = compose.pop(0)
10633+
assert p is t1
10634+
assert len(compose.transforms) == 1
10635+
p = compose.pop()
10636+
assert p is t2
10637+
assert len(compose.transforms) == 0
10638+
with pytest.raises(IndexError, match="index -1 is out of range"):
10639+
compose.pop()
10640+
10641+
def test_compose_pop_parent_modification(self):
10642+
t1 = CatFrames(in_keys=["a", "b"], N=2, dim=-1)
10643+
t2 = FiniteTensorDictCheck()
10644+
t3 = ExcludeTransform()
10645+
compose = Compose(t1, t2, t3)
10646+
env = TransformedEnv(ContinuousActionVecMockEnv(), compose)
10647+
p = t2.parent
10648+
assert isinstance(p.transform[0], CatFrames)
10649+
env.transform.pop(0)
10650+
assert env.transform[0] is t2
10651+
new_p = t2.parent
10652+
assert new_p is not p
10653+
assert len(new_p.transform) == 0
10654+
1062210655
def test_lambda_functions(self):
1062310656
def trsf(data):
1062410657
if "y" in data.keys():

torchrl/envs/transforms/transforms.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -738,7 +738,7 @@ def __setstate__(self, state):
738738
self.__dict__.update(state)
739739

740740
@property
741-
def parent(self) -> EnvBase | None:
741+
def parent(self) -> TransformedEnv | None:
742742
"""Returns the parent env of the transform.
743743
744744
The parent env is the env that contains all the transforms up until the current one.
@@ -1249,6 +1249,7 @@ def close(self, *, raise_if_closed: bool = True):
12491249
def empty_cache(self):
12501250
self.__dict__["_output_spec"] = None
12511251
self.__dict__["_input_spec"] = None
1252+
self.transform.empty_cache()
12521253
super().empty_cache()
12531254

12541255
def append_transform(
@@ -1429,6 +1430,50 @@ def map_transform(trsf):
14291430
for t in transforms:
14301431
t.set_container(self)
14311432

1433+
def pop(self, index: int | None = None) -> Transform:
1434+
"""Pop a transform from the chain.
1435+
1436+
Args:
1437+
index (int, optional): The index of the transform to pop. If None, the last transform is popped.
1438+
1439+
Returns:
1440+
The popped transform.
1441+
"""
1442+
if index is None:
1443+
index = len(self.transforms) - 1
1444+
result = self.transforms.pop(index)
1445+
parent = self.parent
1446+
self.empty_cache()
1447+
if parent is not None:
1448+
parent.empty_cache()
1449+
return result
1450+
1451+
def __delitem__(self, index: int | slice | list):
1452+
"""Delete a transform in the chain.
1453+
1454+
:class:`~torchrl.envs.transforms.Transform` or callable are accepted.
1455+
"""
1456+
del self.transforms[index]
1457+
parent = self.parent
1458+
self.empty_cache()
1459+
if parent is not None:
1460+
parent.empty_cache()
1461+
1462+
def __setitem__(
1463+
self,
1464+
index: int | slice | list,
1465+
value: Transform | Callable[[TensorDictBase], TensorDictBase],
1466+
):
1467+
"""Set a transform in the chain.
1468+
1469+
:class:`~torchrl.envs.transforms.Transform` or callable are accepted.
1470+
"""
1471+
self.transforms[index] = value
1472+
parent = self.parent
1473+
self.empty_cache()
1474+
if parent is not None:
1475+
parent.empty_cache()
1476+
14321477
def close(self):
14331478
"""Close the transform."""
14341479
for t in self.transforms:
@@ -1594,6 +1639,9 @@ def append(
15941639
else:
15951640
self.transforms.append(transform)
15961641
transform.set_container(self)
1642+
parent = self.parent
1643+
if parent is not None:
1644+
parent.empty_cache()
15971645

15981646
def set_container(self, container: Transform | EnvBase) -> None:
15991647
self.reset_parent()
@@ -1626,6 +1674,9 @@ def insert(
16261674

16271675
# empty cache of all transforms to reset parents and specs
16281676
self.empty_cache()
1677+
parent = self.parent
1678+
if parent is not None:
1679+
parent.empty_cache()
16291680
if index < 0:
16301681
index = index + len(self.transforms)
16311682
transform.eval()

torchrl/objectives/ppo.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -752,10 +752,10 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor:
752752

753753
explained_variance = None
754754
if self.log_explained_variance:
755-
with torch.no_grad(): # <‑‑ break grad‐flow
756-
tgt = target_return.detach()
757-
pred = state_value.detach()
758-
eps = torch.finfo(tgt.dtype).eps
755+
with torch.no_grad(): # <‑‑ break grad‐flow
756+
tgt = target_return.detach()
757+
pred = state_value.detach()
758+
eps = torch.finfo(tgt.dtype).eps
759759
resid = torch.var(tgt - pred, unbiased=False, dim=0)
760760
total = torch.var(tgt, unbiased=False, dim=0)
761761
explained_variance = 1.0 - resid / (total + eps)
@@ -819,7 +819,9 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
819819
td_out.set("entropy", entropy.detach().mean()) # for logging
820820
td_out.set("loss_entropy", self._weighted_loss_entropy(entropy))
821821
if self._has_critic:
822-
loss_critic, value_clip_fraction, explained_variance = self.loss_critic(tensordict)
822+
loss_critic, value_clip_fraction, explained_variance = self.loss_critic(
823+
tensordict
824+
)
823825
td_out.set("loss_critic", loss_critic)
824826
if value_clip_fraction is not None:
825827
td_out.set("value_clip_fraction", value_clip_fraction)
@@ -1189,7 +1191,9 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
11891191
td_out.set("entropy", entropy.detach().mean()) # for logging
11901192
td_out.set("loss_entropy", self._weighted_loss_entropy(entropy))
11911193
if self._has_critic:
1192-
loss_critic, value_clip_fraction, explained_variance = self.loss_critic(tensordict)
1194+
loss_critic, value_clip_fraction, explained_variance = self.loss_critic(
1195+
tensordict
1196+
)
11931197
td_out.set("loss_critic", loss_critic)
11941198
if value_clip_fraction is not None:
11951199
td_out.set("value_clip_fraction", value_clip_fraction)
@@ -1537,7 +1541,9 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict:
15371541
td_out.set("entropy", entropy.detach().mean()) # for logging
15381542
td_out.set("loss_entropy", self._weighted_loss_entropy(entropy))
15391543
if self._has_critic:
1540-
loss_critic, value_clip_fraction, explained_variance = self.loss_critic(tensordict_copy)
1544+
loss_critic, value_clip_fraction, explained_variance = self.loss_critic(
1545+
tensordict_copy
1546+
)
15411547
td_out.set("loss_critic", loss_critic)
15421548
if value_clip_fraction is not None:
15431549
td_out.set("value_clip_fraction", value_clip_fraction)

0 commit comments

Comments
 (0)