@@ -4596,10 +4596,11 @@ class TensorDictPrimer(Transform):
4596
4596
The corresponding value has to be a TensorSpec instance indicating
4597
4597
what the value must be.
4598
4598
4599
- When used in a TransfomedEnv, the spec shapes must match the envs shape if
4600
- the parent env is batch-locked (:obj:`env.batch_locked=True`).
4601
- If the env is not batch-locked (e.g. model-based envs), it is assumed that the batch is
4602
- given by the input tensordict instead.
4599
+ When used in a `TransformedEnv`, the spec shapes must match the environment's shape if
4600
+ the parent environment is batch-locked (`env.batch_locked=True`). If the spec shapes and
4601
+ parent shapes do not match, the spec shapes are modified in-place to match the leading
4602
+ dimensions of the parent's batch size. This adjustment is made for cases where the parent
4603
+ batch size dimension is not known during instantiation.
4603
4604
4604
4605
Examples:
4605
4606
>>> from torchrl.envs.libs.gym import GymEnv
@@ -4639,6 +4640,40 @@ class TensorDictPrimer(Transform):
4639
4640
tensor([[1., 1., 1.],
4640
4641
[1., 1., 1.]])
4641
4642
4643
+ Examples:
4644
+ >>> from torchrl.envs.libs.gym import GymEnv
4645
+ >>> from torchrl.envs import SerialEnv, TransformedEnv
4646
+ >>> from torchrl.modules.utils import get_primers_from_module
4647
+ >>> from torchrl.modules import GRUModule
4648
+ >>> base_env = SerialEnv(2, lambda: GymEnv("Pendulum-v1"))
4649
+ >>> env = TransformedEnv(base_env)
4650
+ >>> model = GRUModule(input_size=2, hidden_size=2, in_key="observation", out_key="action")
4651
+ >>> primers = get_primers_from_module(model)
4652
+ >>> print(primers) # Primers shape is independent of the env batch size
4653
+ TensorDictPrimer(primers=Composite(
4654
+ recurrent_state: UnboundedContinuous(
4655
+ shape=torch.Size([1, 2]),
4656
+ space=ContinuousBox(
4657
+ low=Tensor(shape=torch.Size([1, 2]), device=cpu, dtype=torch.float32, contiguous=True),
4658
+ high=Tensor(shape=torch.Size([1, 2]), device=cpu, dtype=torch.float32, contiguous=True)),
4659
+ device=cpu,
4660
+ dtype=torch.float32,
4661
+ domain=continuous),
4662
+ device=None,
4663
+ shape=torch.Size([])), default_value={'recurrent_state': 0.0}, random=None)
4664
+ >>> env.append_transform(primers)
4665
+ >>> print(env.reset()) # The primers are automatically expanded to match the env batch size
4666
+ TensorDict(
4667
+ fields={
4668
+ done: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
4669
+ observation: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
4670
+ recurrent_state: Tensor(shape=torch.Size([2, 1, 2]), device=cpu, dtype=torch.float32, is_shared=False),
4671
+ terminated: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
4672
+ truncated: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
4673
+ batch_size=torch.Size([2]),
4674
+ device=None,
4675
+ is_shared=False)
4676
+
4642
4677
.. note:: Some TorchRL modules rely on specific keys being present in the environment TensorDicts,
4643
4678
like :class:`~torchrl.modules.models.LSTM` or :class:`~torchrl.modules.models.GRU`.
4644
4679
To facilitate this process, the method :func:`~torchrl.modules.utils.get_primers_from_module`
@@ -4764,7 +4799,7 @@ def transform_observation_spec(self, observation_spec: Composite) -> Composite:
4764
4799
# We try to set the primer shape to the observation spec shape
4765
4800
self .primers .shape = observation_spec .shape
4766
4801
except ValueError :
4767
- # If we fail, we expnad them to that shape
4802
+ # If we fail, we expand them to that shape
4768
4803
self .primers = self ._expand_shape (self .primers )
4769
4804
device = observation_spec .device
4770
4805
observation_spec .update (self .primers .clone ().to (device ))
@@ -4831,12 +4866,17 @@ def _reset(
4831
4866
) -> TensorDictBase :
4832
4867
"""Sets the default values in the input tensordict.
4833
4868
4834
- If the parent is batch-locked, we assume that the specs have the appropriate leading
4869
+ If the parent is batch-locked, we make sure the specs have the appropriate leading
4835
4870
shape. We allow for execution when the parent is missing, in which case the
4836
4871
spec shape is assumed to match the tensordict's.
4837
-
4838
4872
"""
4839
4873
_reset = _get_reset (self .reset_key , tensordict )
4874
+ if (
4875
+ self .parent
4876
+ and self .parent .batch_locked
4877
+ and self .primers .shape [: len (self .parent .shape )] != self .parent .batch_size
4878
+ ):
4879
+ self .primers = self ._expand_shape (self .primers )
4840
4880
if _reset .any ():
4841
4881
for key , spec in self .primers .items (True , True ):
4842
4882
if self .random :
0 commit comments