Skip to content

Commit

Permalink
Transform[Observation/Action] single_[observation/action]_space fix (F…
Browse files Browse the repository at this point in the history
  • Loading branch information
howardh committed Jan 7, 2025
1 parent fc74bb8 commit 509ebc5
Show file tree
Hide file tree
Showing 4 changed files with 187 additions and 4 deletions.
36 changes: 34 additions & 2 deletions gymnasium/wrappers/vector/vectorize_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,25 +61,57 @@ def __init__(
env: VectorEnv,
func: Callable[[ActType], Any],
action_space: Space | None = None,
single_action_space: Space | None = None,
):
"""Constructor for the lambda action wrapper.
Args:
env: The vector environment to wrap
func: A function that will transform an action. If this transformed action is outside the action space of ``env.action_space`` then provide an ``action_space``.
action_space: The action spaces of the wrapper, if None, then it is assumed the same as ``env.action_space``.
action_space: The action spaces of the wrapper. If None, then it is computed from ``single_action_space``. If ``single_action_space`` is not provided either, then it is assumed to be the same as ``env.action_space``.
single_action_space: The action space of the non-vectorized environment. If None, then it is assumed the same as ``env.single_action_space``.
"""
super().__init__(env)

if action_space is not None:
"""
self._single_observation_space_error = None
self._single_observation_space = self.env.single_observation_space
if observation_space is None:
if single_observation_space is not None:
self.observation_space = batch_space(single_observation_space, self.num_envs)
else:
self.observation_space = observation_space
if single_observation_space is None:
# TODO: We could compute this from the observation_space.
self._single_observation_space_error = "`single_observation_space` not defined. A new observation space was provided to the TransformObservation wrapper, but not the single observation space."
else:
self._single_observation_space = single_observation_space
"""
self._single_action_space_error = None
self._single_action_space = self.env.single_action_space
if action_space is None:
if single_action_space is not None:
self.action_space = batch_space(single_action_space, self.num_envs)
else:
self.action_space = action_space
if single_action_space is None:
self._single_action_space_error = "`single_action_space` not defined. A new action space was provided to the TransformAction wrapper, but not the single action space."
else:
self._single_action_space = single_action_space

self.func = func

def actions(self, actions: ActType) -> ActType:
"""Applies the :attr:`func` to the actions."""
return self.func(actions)

@property
def single_action_space(self) -> Space:
"""The single observation space of the environment."""
if self._single_action_space_error is not None:
raise AttributeError(self._single_action_space_error)
return self._single_action_space


class VectorizeTransformAction(VectorActionWrapper):
"""Vectorizes a single-agent transform action wrapper for vector environments.
Expand Down
25 changes: 23 additions & 2 deletions gymnasium/wrappers/vector/vectorize_observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,25 +57,46 @@ def __init__(
env: VectorEnv,
func: Callable[[ObsType], Any],
observation_space: Space | None = None,
single_observation_space: Space | None = None,
):
"""Constructor for the transform observation wrapper.
Args:
env: The vector environment to wrap
func: A function that will transform the vector observation. If this transformed observation is outside the observation space of ``env.observation_space`` then provide an ``observation_space``.
observation_space: The observation spaces of the wrapper, if None, then it is assumed the same as ``env.observation_space``.
observation_space: The observation spaces of the wrapper. If None, then it is computed from ``single_observation_space``. If ``single_observation_space`` is not provided either, then it is assumed to be the same as ``env.observation_space``.
single_observation_space: The observation space of the non-vectorized environment. If None, then it is assumed the same as ``env.single_observation_space``.
"""
super().__init__(env)

if observation_space is not None:
self._single_observation_space_error = None
self._single_observation_space = self.env.single_observation_space
if observation_space is None:
if single_observation_space is not None:
self.observation_space = batch_space(
single_observation_space, self.num_envs
)
else:
self.observation_space = observation_space
if single_observation_space is None:
# TODO: We could compute this from the observation_space.
self._single_observation_space_error = "`single_observation_space` not defined. A new observation space was provided to the TransformObservation wrapper, but not the single observation space."
else:
self._single_observation_space = single_observation_space

self.func = func

def observations(self, observations: ObsType) -> ObsType:
"""Apply function to the vector observation."""
return self.func(observations)

@property
def single_observation_space(self) -> Space:
"""Returns the single observation space."""
if self._single_observation_space_error is not None:
raise AttributeError(self._single_observation_space_error)
return self._single_observation_space


class VectorizeTransformObservation(VectorObservationWrapper):
"""Vectorizes a single-agent transform observation wrapper for vector environments.
Expand Down
41 changes: 41 additions & 0 deletions tests/wrappers/vector/test_transform_action.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""Test suite for vector TransformObservation wrapper."""

import numpy as np

from gymnasium import spaces, wrappers
from gymnasium.vector import SyncVectorEnv
from tests.testing_env import GenericTestEnv


def create_env():
return GenericTestEnv(
action_space=spaces.Box(
low=np.array([0, -10, -5], dtype=np.float32),
high=np.array([10, -5, 10], dtype=np.float32),
)
)


def test_observation_space_from_single_observation_space(
n_envs: int = 5,
):
vec_env = SyncVectorEnv([create_env for _ in range(n_envs)])
vec_env = wrappers.vector.TransformAction(
vec_env,
func=lambda x: x + 100,
single_action_space=spaces.Box(
low=np.array([0, -10, -5], dtype=np.float32) + 100,
high=np.array([10, -5, 10], dtype=np.float32) + 100,
),
)

assert isinstance(vec_env.action_space, spaces.Box)
assert vec_env.action_space.shape == (n_envs, 3)
assert vec_env.action_space.dtype == np.float32
assert (
vec_env.action_space.low == np.array([[100, 90, 95]] * n_envs, dtype=np.float32)
).all()
assert (
vec_env.action_space.high
== np.array([[110, 95, 110]] * n_envs, dtype=np.float32)
).all()
89 changes: 89 additions & 0 deletions tests/wrappers/vector/test_transform_observation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
"""Test suite for vector TransformObservation wrapper."""

import numpy as np
import pytest

from gymnasium import spaces, wrappers
from gymnasium.vector import SyncVectorEnv
from tests.testing_env import GenericTestEnv


def create_env():
return GenericTestEnv(
observation_space=spaces.Box(
low=np.array([0, -10, -5], dtype=np.float32),
high=np.array([10, -5, 10], dtype=np.float32),
)
)


def test_transform(n_envs: int = 2):
vec_env = SyncVectorEnv([create_env for _ in range(n_envs)])
vec_env = wrappers.vector.TransformObservation(
vec_env,
func=lambda x: x + 100,
single_observation_space=spaces.Box(
low=np.array([0, -10, -5], dtype=np.float32),
high=np.array([10, -5, 10], dtype=np.float32),
),
)

obs, _ = vec_env.reset(seed=123)
vec_env.observation_space.seed(123)
vec_env.action_space.seed(123)

assert (obs >= np.array([100, 90, 95], dtype=np.float32)).all()
assert (obs <= np.array([110, 95, 110], dtype=np.float32)).all()

obs, *_ = vec_env.step(vec_env.action_space.sample())

assert (obs >= np.array([100, 90, 95], dtype=np.float32)).all()
assert (obs <= np.array([110, 95, 110], dtype=np.float32)).all()


def test_observation_space_from_single_observation_space(
n_envs: int = 5,
):
vec_env = SyncVectorEnv([create_env for _ in range(n_envs)])
vec_env = wrappers.vector.TransformObservation(
vec_env,
func=lambda x: x + 100,
single_observation_space=spaces.Box(
low=np.array([0, -10, -5], dtype=np.float32) + 100,
high=np.array([10, -5, 10], dtype=np.float32) + 100,
),
)

assert isinstance(vec_env.observation_space, spaces.Box)
assert vec_env.observation_space.shape == (n_envs, 3)
assert vec_env.observation_space.dtype == np.float32
assert (
vec_env.observation_space.low
== np.array([[100, 90, 95]] * n_envs, dtype=np.float32)
).all()
assert (
vec_env.observation_space.high
== np.array([[110, 95, 110]] * n_envs, dtype=np.float32)
).all()


def test_error_on_unspecified_single_observation_space(
n_envs: int = 5,
):
vec_env = SyncVectorEnv([create_env for _ in range(n_envs)])
vec_env = wrappers.vector.TransformObservation(
vec_env,
func=lambda x: x + 100,
observation_space=spaces.Box(
low=np.array([[0, -10, -5]] * n_envs, dtype=np.float32) + 100,
high=np.array([[10, -5, 10]] * n_envs, dtype=np.float32) + 100,
),
)

# Environment should still work normally
obs, _ = vec_env.reset()
obs, *_ = vec_env.step(vec_env.action_space.sample())

# But if we try to access the single_observation_space, it should error
with pytest.raises(AttributeError):
vec_env.single_observation_space

0 comments on commit 509ebc5

Please sign in to comment.