Skip to content

Commit 98b45a6

Browse files
authored
[BugFix] Allow expanding TensorDictPrimer transforms shape with parent batch size (#2521)
1 parent 2a07f4c commit 98b45a6

File tree

2 files changed

+75
-7
lines changed

2 files changed

+75
-7
lines changed

Diff for: test/test_transforms.py

+28
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@
159159
from torchrl.envs.transforms.vip import _VIPNet, VIPRewardTransform
160160
from torchrl.envs.utils import check_env_specs, step_mdp
161161
from torchrl.modules import GRUModule, LSTMModule, MLP, ProbabilisticActor, TanhNormal
162+
from torchrl.modules.utils import get_primers_from_module
162163

163164
IS_WIN = platform == "win32"
164165
if IS_WIN:
@@ -7163,6 +7164,33 @@ def test_dict_default_value(self):
71637164
rollout_td.get(("next", "mykey2")) == torch.tensor(1, dtype=torch.int64)
71647165
).all
71657166

7167+
def test_spec_shape_inplace_correction(self):
7168+
hidden_size = input_size = num_layers = 2
7169+
model = GRUModule(
7170+
input_size, hidden_size, num_layers, in_key="observation", out_key="action"
7171+
)
7172+
env = TransformedEnv(
7173+
SerialEnv(2, lambda: GymEnv("Pendulum-v1")),
7174+
)
7175+
# These primers do not have the leading batch dimension
7176+
# since model is agnostic to batch dimension that will be used.
7177+
primers = get_primers_from_module(model)
7178+
for primer in primers.primers:
7179+
assert primers.primers.get(primer).shape == torch.Size(
7180+
[num_layers, hidden_size]
7181+
)
7182+
env.append_transform(primers)
7183+
7184+
# Reset should add the batch dimension to the primers
7185+
# since the parent exists and is batch_locked.
7186+
td = env.reset()
7187+
7188+
for primer in primers.primers:
7189+
assert primers.primers.get(primer).shape == torch.Size(
7190+
[2, num_layers, hidden_size]
7191+
)
7192+
assert td.get(primer).shape == torch.Size([2, num_layers, hidden_size])
7193+
71667194

71677195
class TestTimeMaxPool(TransformBase):
71687196
@pytest.mark.parametrize("T", [2, 4])

Diff for: torchrl/envs/transforms/transforms.py

+47-7
Original file line numberDiff line numberDiff line change
@@ -4596,10 +4596,11 @@ class TensorDictPrimer(Transform):
45964596
The corresponding value has to be a TensorSpec instance indicating
45974597
what the value must be.
45984598
4599-
When used in a TransfomedEnv, the spec shapes must match the envs shape if
4600-
the parent env is batch-locked (:obj:`env.batch_locked=True`).
4601-
If the env is not batch-locked (e.g. model-based envs), it is assumed that the batch is
4602-
given by the input tensordict instead.
4599+
When used in a `TransformedEnv`, the spec shapes must match the environment's shape if
4600+
the parent environment is batch-locked (`env.batch_locked=True`). If the spec shapes and
4601+
parent shapes do not match, the spec shapes are modified in-place to match the leading
4602+
dimensions of the parent's batch size. This adjustment is made for cases where the parent
4603+
batch size dimension is not known during instantiation.
46034604
46044605
Examples:
46054606
>>> from torchrl.envs.libs.gym import GymEnv
@@ -4639,6 +4640,40 @@ class TensorDictPrimer(Transform):
46394640
tensor([[1., 1., 1.],
46404641
[1., 1., 1.]])
46414642
4643+
Examples:
4644+
>>> from torchrl.envs.libs.gym import GymEnv
4645+
>>> from torchrl.envs import SerialEnv, TransformedEnv
4646+
>>> from torchrl.modules.utils import get_primers_from_module
4647+
>>> from torchrl.modules import GRUModule
4648+
>>> base_env = SerialEnv(2, lambda: GymEnv("Pendulum-v1"))
4649+
>>> env = TransformedEnv(base_env)
4650+
>>> model = GRUModule(input_size=2, hidden_size=2, in_key="observation", out_key="action")
4651+
>>> primers = get_primers_from_module(model)
4652+
>>> print(primers) # Primers shape is independent of the env batch size
4653+
TensorDictPrimer(primers=Composite(
4654+
recurrent_state: UnboundedContinuous(
4655+
shape=torch.Size([1, 2]),
4656+
space=ContinuousBox(
4657+
low=Tensor(shape=torch.Size([1, 2]), device=cpu, dtype=torch.float32, contiguous=True),
4658+
high=Tensor(shape=torch.Size([1, 2]), device=cpu, dtype=torch.float32, contiguous=True)),
4659+
device=cpu,
4660+
dtype=torch.float32,
4661+
domain=continuous),
4662+
device=None,
4663+
shape=torch.Size([])), default_value={'recurrent_state': 0.0}, random=None)
4664+
>>> env.append_transform(primers)
4665+
>>> print(env.reset()) # The primers are automatically expanded to match the env batch size
4666+
TensorDict(
4667+
fields={
4668+
done: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
4669+
observation: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
4670+
recurrent_state: Tensor(shape=torch.Size([2, 1, 2]), device=cpu, dtype=torch.float32, is_shared=False),
4671+
terminated: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
4672+
truncated: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
4673+
batch_size=torch.Size([2]),
4674+
device=None,
4675+
is_shared=False)
4676+
46424677
.. note:: Some TorchRL modules rely on specific keys being present in the environment TensorDicts,
46434678
like :class:`~torchrl.modules.models.LSTM` or :class:`~torchrl.modules.models.GRU`.
46444679
To facilitate this process, the method :func:`~torchrl.modules.utils.get_primers_from_module`
@@ -4764,7 +4799,7 @@ def transform_observation_spec(self, observation_spec: Composite) -> Composite:
47644799
# We try to set the primer shape to the observation spec shape
47654800
self.primers.shape = observation_spec.shape
47664801
except ValueError:
4767-
# If we fail, we expnad them to that shape
4802+
# If we fail, we expand them to that shape
47684803
self.primers = self._expand_shape(self.primers)
47694804
device = observation_spec.device
47704805
observation_spec.update(self.primers.clone().to(device))
@@ -4831,12 +4866,17 @@ def _reset(
48314866
) -> TensorDictBase:
48324867
"""Sets the default values in the input tensordict.
48334868
4834-
If the parent is batch-locked, we assume that the specs have the appropriate leading
4869+
If the parent is batch-locked, we make sure the specs have the appropriate leading
48354870
shape. We allow for execution when the parent is missing, in which case the
48364871
spec shape is assumed to match the tensordict's.
4837-
48384872
"""
48394873
_reset = _get_reset(self.reset_key, tensordict)
4874+
if (
4875+
self.parent
4876+
and self.parent.batch_locked
4877+
and self.primers.shape[: len(self.parent.shape)] != self.parent.batch_size
4878+
):
4879+
self.primers = self._expand_shape(self.primers)
48404880
if _reset.any():
48414881
for key, spec in self.primers.items(True, True):
48424882
if self.random:

0 commit comments

Comments
 (0)