Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add wrappers.vector.TransformObs/Action single obs/action space argument #1288

Merged
merged 5 commits into from
Jan 12, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions gymnasium/wrappers/vector/vectorize_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,25 +61,44 @@ def __init__(
env: VectorEnv,
func: Callable[[ActType], Any],
action_space: Space | None = None,
single_action_space: Space | None = None,
):
"""Constructor for the lambda action wrapper.

Args:
env: The vector environment to wrap
func: A function that will transform an action. If this transformed action is outside the action space of ``env.action_space`` then provide an ``action_space``.
action_space: The action spaces of the wrapper, if None, then it is assumed the same as ``env.action_space``.
action_space: The action spaces of the wrapper. If None, then it is computed from ``single_action_space``. If ``single_action_space`` is not provided either, then it is assumed to be the same as ``env.action_space``.
single_action_space: The action space of the non-vectorized environment. If None, then it is assumed the same as ``env.single_action_space``.
"""
super().__init__(env)

if action_space is not None:
self._single_action_space_error = None
self._single_action_space = self.env.single_action_space
if action_space is None:
if single_action_space is not None:
self.action_space = batch_space(single_action_space, self.num_envs)
self._single_action_space = single_action_space
else:
self.action_space = action_space
if single_action_space is None:
self._single_action_space_error = "`single_action_space` not defined. A new action space was provided to the TransformAction wrapper, but not the single action space."
else:
self._single_action_space = single_action_space

self.func = func

def actions(self, actions: ActType) -> ActType:
"""Applies the :attr:`func` to the actions."""
return self.func(actions)

@property
def single_action_space(self) -> Space:
"""The single action space of the environment."""
if self._single_action_space_error is not None:
raise AttributeError(self._single_action_space_error)
return self._single_action_space


class VectorizeTransformAction(VectorActionWrapper):
"""Vectorizes a single-agent transform action wrapper for vector environments.
Expand Down
26 changes: 24 additions & 2 deletions gymnasium/wrappers/vector/vectorize_observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,25 +57,47 @@ def __init__(
env: VectorEnv,
func: Callable[[ObsType], Any],
observation_space: Space | None = None,
single_observation_space: Space | None = None,
):
"""Constructor for the transform observation wrapper.

Args:
env: The vector environment to wrap
func: A function that will transform the vector observation. If this transformed observation is outside the observation space of ``env.observation_space`` then provide an ``observation_space``.
observation_space: The observation spaces of the wrapper, if None, then it is assumed the same as ``env.observation_space``.
observation_space: The observation spaces of the wrapper. If None, then it is computed from ``single_observation_space``. If ``single_observation_space`` is not provided either, then it is assumed to be the same as ``env.observation_space``.
single_observation_space: The observation space of the non-vectorized environment. If None, then it is assumed the same as ``env.single_observation_space``.
"""
super().__init__(env)

if observation_space is not None:
self._single_observation_space_error = None
self._single_observation_space = self.env.single_observation_space
if observation_space is None:
if single_observation_space is not None:
self.observation_space = batch_space(
single_observation_space, self.num_envs
)
self._single_observation_space = single_observation_space
else:
self.observation_space = observation_space
if single_observation_space is None:
# TODO: We could compute this from the observation_space.
self._single_observation_space_error = "`single_observation_space` not defined. A new observation space was provided to the TransformObservation wrapper, but not the single observation space."
else:
self._single_observation_space = single_observation_space

self.func = func

def observations(self, observations: ObsType) -> ObsType:
"""Apply function to the vector observation."""
return self.func(observations)

@property
def single_observation_space(self) -> Space:
"""Returns the single observation space."""
if self._single_observation_space_error is not None:
raise AttributeError(self._single_observation_space_error)
return self._single_observation_space


class VectorizeTransformObservation(VectorObservationWrapper):
"""Vectorizes a single-agent transform observation wrapper for vector environments.
Expand Down
53 changes: 53 additions & 0 deletions tests/wrappers/vector/test_transform_action.py
howardh marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
"""Test suite for vector TransformAction wrapper."""

import numpy as np

from gymnasium import spaces, wrappers
from gymnasium.vector import SyncVectorEnv
from tests.testing_env import GenericTestEnv


def create_env():
return GenericTestEnv(
action_space=spaces.Box(
low=np.array([0, -10, -5], dtype=np.float32),
high=np.array([10, -5, 10], dtype=np.float32),
)
)


def test_observation_space_from_single_observation_space(
n_envs: int = 5,
):
vec_env = SyncVectorEnv([create_env for _ in range(n_envs)])
vec_env = wrappers.vector.TransformAction(
vec_env,
func=lambda x: x + 100,
single_action_space=spaces.Box(
low=np.array([0, -10, -5], dtype=np.float32) + 100,
high=np.array([10, -5, 10], dtype=np.float32) + 100,
),
)

# Check action space
assert isinstance(vec_env.action_space, spaces.Box)
assert vec_env.action_space.shape == (n_envs, 3)
assert vec_env.action_space.dtype == np.float32
assert (
vec_env.action_space.low == np.array([[100, 90, 95]] * n_envs, dtype=np.float32)
).all()
assert (
vec_env.action_space.high
== np.array([[110, 95, 110]] * n_envs, dtype=np.float32)
).all()

# Check single action space
assert isinstance(vec_env.single_action_space, spaces.Box)
assert vec_env.single_action_space.shape == (3,)
assert vec_env.single_action_space.dtype == np.float32
assert (
vec_env.single_action_space.low == np.array([100, 90, 95], dtype=np.float32)
).all()
assert (
vec_env.single_action_space.high == np.array([110, 95, 110], dtype=np.float32)
).all()
103 changes: 103 additions & 0 deletions tests/wrappers/vector/test_transform_observation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
"""Test suite for vector TransformObservation wrapper."""

import numpy as np
import pytest

from gymnasium import spaces, wrappers
from gymnasium.vector import SyncVectorEnv
from tests.testing_env import GenericTestEnv


def create_env():
return GenericTestEnv(
observation_space=spaces.Box(
low=np.array([0, -10, -5], dtype=np.float32),
high=np.array([10, -5, 10], dtype=np.float32),
)
)


def test_transform(n_envs: int = 2):
vec_env = SyncVectorEnv([create_env for _ in range(n_envs)])
vec_env = wrappers.vector.TransformObservation(
vec_env,
func=lambda x: x + 100,
single_observation_space=spaces.Box(
low=np.array([0, -10, -5], dtype=np.float32),
high=np.array([10, -5, 10], dtype=np.float32),
),
)

obs, _ = vec_env.reset(seed=123)
vec_env.observation_space.seed(123)
vec_env.action_space.seed(123)

assert (obs >= np.array([100, 90, 95], dtype=np.float32)).all()
assert (obs <= np.array([110, 95, 110], dtype=np.float32)).all()

obs, *_ = vec_env.step(vec_env.action_space.sample())

assert (obs >= np.array([100, 90, 95], dtype=np.float32)).all()
assert (obs <= np.array([110, 95, 110], dtype=np.float32)).all()


def test_observation_space_from_single_observation_space(
n_envs: int = 5,
):
vec_env = SyncVectorEnv([create_env for _ in range(n_envs)])
vec_env = wrappers.vector.TransformObservation(
vec_env,
func=lambda x: x + 100,
single_observation_space=spaces.Box(
low=np.array([0, -10, -5], dtype=np.float32) + 100,
high=np.array([10, -5, 10], dtype=np.float32) + 100,
),
)

# Check observation space
assert isinstance(vec_env.observation_space, spaces.Box)
assert vec_env.observation_space.shape == (n_envs, 3)
assert vec_env.observation_space.dtype == np.float32
assert (
vec_env.observation_space.low
== np.array([[100, 90, 95]] * n_envs, dtype=np.float32)
).all()
assert (
vec_env.observation_space.high
== np.array([[110, 95, 110]] * n_envs, dtype=np.float32)
).all()

# Check single observation space
assert isinstance(vec_env.single_observation_space, spaces.Box)
assert vec_env.single_observation_space.shape == (3,)
assert vec_env.single_observation_space.dtype == np.float32
assert (
vec_env.single_observation_space.low
== np.array([100, 90, 95], dtype=np.float32)
).all()
assert (
vec_env.single_observation_space.high
== np.array([110, 95, 110], dtype=np.float32)
).all()


def test_error_on_unspecified_single_observation_space(
n_envs: int = 5,
):
vec_env = SyncVectorEnv([create_env for _ in range(n_envs)])
vec_env = wrappers.vector.TransformObservation(
vec_env,
func=lambda x: x + 100,
observation_space=spaces.Box(
low=np.array([[0, -10, -5]] * n_envs, dtype=np.float32) + 100,
high=np.array([[10, -5, 10]] * n_envs, dtype=np.float32) + 100,
),
)

# Environment should still work normally
obs, _ = vec_env.reset()
obs, *_ = vec_env.step(vec_env.action_space.sample())

# But if we try to access the single_observation_space, it should error
with pytest.raises(AttributeError):
vec_env.single_observation_space
Loading