Skip to content

Commit 7883ed4

Browse files
Copilotaraffin
andauthored
Fix env checker to handle Sequence spaces in composite spaces (Dict, Tuple, OneOf) (#2174)
* Initial plan * Fix env checker to handle Sequence spaces in composite spaces Co-authored-by: araffin <[email protected]> * Simplify tests and checks * Add fix for mypy for gym<1.0 --------- Co-authored-by: copilot-swe-agent[bot] <[email protected]> Co-authored-by: araffin <[email protected]> Co-authored-by: Antonin RAFFIN <[email protected]>
1 parent c40b5e4 commit 7883ed4

File tree

3 files changed

+89
-12
lines changed

3 files changed

+89
-12
lines changed

docs/misc/changelog.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ New Features:
1515

1616
Bug Fixes:
1717
^^^^^^^^^^
18+
- Fixed env checker to properly handle ``Sequence`` observation spaces when nested inside composite spaces (``Dict``, ``Tuple``, ``OneOf``) (@copilot)
19+
- Update env checker to warn users when using Graph space (@dhruvmalik007).
1820

1921
`SB3-Contrib`_
2022
^^^^^^^^^^^^^^
@@ -57,7 +59,6 @@ Bug Fixes:
5759
^^^^^^^^^^
5860
- Fixed docker GPU image (PyTorch GPU was not installed)
5961
- Fixed segmentation faults caused by non-portable schedules during model loading (@akanto)
60-
- Update env checker to warn users when using Graph space (@dhruvmalik007).
6162

6263
`SB3-Contrib`_
6364
^^^^^^^^^^^^^^

stable_baselines3/common/env_checker.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,18 @@
99
from stable_baselines3.common.vec_env import DummyVecEnv, VecCheckNan
1010

1111

12+
def _is_oneof_space(space: spaces.Space) -> bool:
13+
"""
14+
Return True if the provided space is a OneOf space,
15+
False if not or if the current version of Gym doesn't support this space.
16+
"""
17+
try:
18+
return isinstance(space, spaces.OneOf) # type: ignore[attr-defined]
19+
except AttributeError:
20+
# Gym < v1.0
21+
return False
22+
23+
1224
def _is_numpy_array_space(space: spaces.Space) -> bool:
1325
"""
1426
Returns False if provided space is not representable as a single numpy array
@@ -80,21 +92,23 @@ def _check_image_input(observation_space: spaces.Box, key: str = "") -> None:
8092
)
8193

8294

83-
def _check_unsupported_spaces(env: gym.Env, observation_space: spaces.Space, action_space: spaces.Space) -> bool:
95+
def _check_unsupported_spaces(env: gym.Env, observation_space: spaces.Space, action_space: spaces.Space) -> bool: # noqa: C901
8496
"""
8597
Emit warnings when the observation space or action space used is not supported by Stable-Baselines.
8698
8799
:return: True if return value tests should be skipped.
88100
"""
89101

90-
should_skip = graph_space = False
102+
should_skip = graph_space = sequence_space = False
91103
if isinstance(observation_space, spaces.Dict):
92104
nested_dict = False
93105
for key, space in observation_space.spaces.items():
94106
if isinstance(space, spaces.Dict):
95107
nested_dict = True
96-
if isinstance(space, spaces.Graph):
108+
elif isinstance(space, spaces.Graph):
97109
graph_space = True
110+
elif isinstance(space, spaces.Sequence):
111+
sequence_space = True
98112
_check_non_zero_start(space, "observation", key)
99113

100114
if nested_dict:
@@ -122,10 +136,24 @@ def _check_unsupported_spaces(env: gym.Env, observation_space: spaces.Space, act
122136
"(cf. https://gymnasium.farama.org/api/spaces/composite/#dict). "
123137
"which is supported by SB3."
124138
)
139+
# Check for Sequence spaces inside Tuple
140+
for space in observation_space.spaces:
141+
if isinstance(space, spaces.Sequence):
142+
sequence_space = True
143+
elif isinstance(space, spaces.Graph):
144+
graph_space = True
145+
146+
# Check for Sequence spaces inside OneOf
147+
if _is_oneof_space(observation_space):
148+
warnings.warn(
149+
"OneOf observation space is not supported by Stable-Baselines3. "
150+
"Note: The checks for returned values are skipped."
151+
)
152+
should_skip = True
125153

126154
_check_non_zero_start(observation_space, "observation")
127155

128-
if isinstance(observation_space, spaces.Sequence):
156+
if isinstance(observation_space, spaces.Sequence) or sequence_space:
129157
warnings.warn(
130158
"Sequence observation space is not supported by Stable-Baselines3. "
131159
"You can pad your observation to have a fixed size instead.\n"

tests/test_env_checker.py

Lines changed: 55 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,13 +53,6 @@ def step(self, action):
5353
return self.observation_space.sample(), 1.0, False, False, {}
5454

5555

56-
def test_check_env_sequence_obs():
57-
test_env = CustomEnv()
58-
59-
with pytest.warns(Warning, match="Sequence.*not supported"):
60-
check_env(env=test_env, warn=True)
61-
62-
6356
@pytest.mark.parametrize(
6457
"obs_tuple",
6558
[
@@ -223,3 +216,58 @@ def test_check_env_graph_space():
223216

224217
with pytest.warns(UserWarning, match="Graph.*not supported"):
225218
check_env(SimpleDictGraphEnv(), warn=True)
219+
220+
221+
class SequenceInDictEnv(CustomEnv):
222+
"""Test env with Sequence space inside Dict space."""
223+
224+
def __init__(self):
225+
self.action_space = spaces.Discrete(2)
226+
self.observation_space = spaces.Dict(
227+
{"seq": spaces.Sequence(spaces.Box(low=-100, high=100, shape=(1,), dtype=np.float32))}
228+
)
229+
230+
231+
class SequenceInTupleEnv(CustomEnv):
232+
"""Test env with Sequence space inside Tuple space."""
233+
234+
def __init__(self):
235+
self.action_space = spaces.Discrete(2)
236+
self.observation_space = spaces.Tuple((spaces.Sequence(spaces.Box(low=-100, high=100, shape=(1,), dtype=np.float32)),))
237+
238+
239+
class SequenceInOneOfEnv(CustomEnv):
240+
"""Test env with Sequence space inside OneOf space."""
241+
242+
def __init__(self):
243+
self.action_space = spaces.Discrete(2)
244+
self.observation_space = spaces.OneOf(
245+
(
246+
spaces.Sequence(spaces.Box(low=-100, high=100, shape=(1,), dtype=np.float32)),
247+
spaces.Discrete(3),
248+
)
249+
)
250+
251+
252+
@pytest.mark.parametrize("env_class", [CustomEnv, SequenceInDictEnv])
253+
def test_check_env_sequence_obs(env_class):
254+
with pytest.warns(Warning, match="Sequence.*not supported"):
255+
check_env(env_class(), warn=True)
256+
257+
258+
def test_check_env_sequence_tuple():
259+
with (
260+
pytest.warns(Warning, match="Sequence.*not supported"),
261+
pytest.warns(Warning, match="Tuple.*not supported"),
262+
):
263+
check_env(SequenceInTupleEnv(), warn=True)
264+
265+
266+
def test_check_env_oneof():
267+
try:
268+
env = SequenceInOneOfEnv()
269+
except AttributeError:
270+
pytest.skip("OneOf not supported by current Gymnasium version")
271+
272+
with pytest.warns(Warning, match="OneOf.*not supported"):
273+
check_env(env, warn=True)

0 commit comments

Comments
 (0)