Skip to content

Commit 5ad53ff

Browse files
committed
Fixes and simplify checks
1 parent 12ca09a commit 5ad53ff

File tree

5 files changed

+58
-85
lines changed

5 files changed

+58
-85
lines changed

docs/misc/changelog.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
Changelog
44
==========
55

6-
Release 2.7.1a1 (WIP)
6+
Release 2.7.1a2 (WIP)
77
--------------------------
88

99
Breaking Changes:
@@ -57,7 +57,7 @@ Bug Fixes:
5757
^^^^^^^^^^
5858
- Fixed docker GPU image (PyTorch GPU was not installed)
5959
- Fixed segmentation faults caused by non-portable schedules during model loading (@akanto)
60-
- Fixed NoneType error when importing Gymnasium.graph for sb3 (@dhruvmalik007).
60+
- Update env checker to warn users when using Graph space (@dhruvmalik007).
6161

6262
`SB3-Contrib`_
6363
^^^^^^^^^^^^^^

stable_baselines3/common/env_checker.py

Lines changed: 27 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
def _is_numpy_array_space(space: spaces.Space) -> bool:
1313
"""
1414
Returns False if provided space is not representable as a single numpy array
15-
(e.g. Dict, Tuple spaces return False)
15+
(e.g. Dict and Tuple spaces return False)
1616
"""
1717
return not isinstance(space, (spaces.Dict, spaces.Tuple))
1818

@@ -80,14 +80,21 @@ def _check_image_input(observation_space: spaces.Box, key: str = "") -> None:
8080
)
8181

8282

83-
def _check_unsupported_spaces(env: gym.Env, observation_space: spaces.Space, action_space: spaces.Space) -> None:
84-
"""Emit warnings when the observation space or action space used is not supported by Stable-Baselines."""
83+
def _check_unsupported_spaces(env: gym.Env, observation_space: spaces.Space, action_space: spaces.Space) -> bool:
84+
"""
85+
Emit warnings when the observation space or action space used is not supported by Stable-Baselines.
86+
87+
:return: True if return value tests should be skipped.
88+
"""
8589

90+
should_skip = graph_space = False
8691
if isinstance(observation_space, spaces.Dict):
8792
nested_dict = False
8893
for key, space in observation_space.spaces.items():
8994
if isinstance(space, spaces.Dict):
9095
nested_dict = True
96+
if isinstance(space, spaces.Graph):
97+
graph_space = True
9198
_check_non_zero_start(space, "observation", key)
9299

93100
if nested_dict:
@@ -124,6 +131,14 @@ def _check_unsupported_spaces(env: gym.Env, observation_space: spaces.Space, act
124131
"You can pad your observation to have a fixed size instead.\n"
125132
"Note: The checks for returned values are skipped."
126133
)
134+
should_skip = True
135+
136+
if isinstance(observation_space, spaces.Graph) or graph_space:
137+
warnings.warn(
138+
"Graph observation space is not supported by Stable-Baselines3. "
139+
"Note: The checks for returned values are skipped."
140+
)
141+
should_skip = True
127142

128143
_check_non_zero_start(action_space, "action")
129144

@@ -133,6 +148,7 @@ def _check_unsupported_spaces(env: gym.Env, observation_space: spaces.Space, act
133148
"This type of action space is currently not supported by Stable Baselines 3. You should try to flatten the "
134149
"action using a wrapper."
135150
)
151+
return should_skip
136152

137153

138154
def _check_nan(env: gym.Env) -> None:
@@ -204,26 +220,10 @@ def _check_obs(obs: Union[tuple, dict, np.ndarray, int], observation_space: spac
204220
Check that the observation returned by the environment
205221
correspond to the declared one.
206222
"""
207-
# If the observation space is a Graph, return early with a warning.
208-
if _is_graph_space(observation_space):
209-
warnings.warn(
210-
f"The observation space for `{method_name}()` is a Graph space, which is not supported the env checker.",
211-
"Skipping further observation checks.",
212-
UserWarning,
213-
)
214-
return
215-
216223
if not isinstance(observation_space, spaces.Tuple):
217224
assert not isinstance(
218225
obs, tuple
219226
), f"The observation returned by the `{method_name}()` method should be a single value, not a tuple"
220-
# Graph spaces are not fully supported by the env checker
221-
if isinstance(observation_space, spaces.Graph):
222-
warnings.warn(
223-
"Graph observation spaces are not fully supported by the env checker. "
224-
"Skipping further observation checks."
225-
)
226-
return
227227
# The check for a GoalEnv is done by the base class
228228
if isinstance(observation_space, spaces.Discrete):
229229
# Since https://github.com/Farama-Foundation/Gymnasium/pull/141,
@@ -371,7 +371,6 @@ def _check_returned_values(env: gym.Env, observation_space: spaces.Space, action
371371
assert reward == env.compute_reward(obs["achieved_goal"], obs["desired_goal"], info)
372372

373373

374-
375374
def _check_spaces(env: gym.Env) -> None:
376375
"""
377376
Check that the observation and action spaces are defined and inherit from spaces.Space. For
@@ -400,18 +399,6 @@ def _check_spaces(env: gym.Env) -> None:
400399
)
401400

402401

403-
def _is_graph_space(space: Any) -> bool:
404-
"""
405-
Check if the space is a gymnasium Graph space.
406-
Handles the case where gymnasium.spaces.graph may not be available.
407-
"""
408-
try:
409-
from gymnasium.spaces.graph import Graph
410-
except ImportError:
411-
return False
412-
return isinstance(space, Graph)
413-
414-
415402
# Check render cannot be covered by CI
416403
def _check_render(env: gym.Env, warn: bool = False) -> None: # pragma: no cover
417404
"""
@@ -470,8 +457,10 @@ def check_env(env: gym.Env, warn: bool = True, skip_render_check: bool = True) -
470457
raise TypeError("The reset() method must accept a `seed` parameter") from e
471458

472459
# Warn the user if needed.
460+
# A warning means that the environment may run but not work properly with Stable Baselines algorithms
461+
should_skip = False
473462
if warn:
474-
_check_unsupported_spaces(env, observation_space, action_space)
463+
should_skip = _check_unsupported_spaces(env, observation_space, action_space)
475464

476465
obs_spaces = observation_space.spaces if isinstance(observation_space, spaces.Dict) else {"": observation_space}
477466
for key, space in obs_spaces.items():
@@ -499,8 +488,8 @@ def check_env(env: gym.Env, warn: bool = True, skip_render_check: bool = True) -
499488
f"Your action space has dtype {action_space.dtype}, we recommend using np.float32 to avoid cast errors."
500489
)
501490

502-
# If Sequence observation space, do not check the observation any further
503-
if isinstance(observation_space, spaces.Sequence):
491+
# If Sequence or Graph observation space, do not check the observation any further
492+
if should_skip:
504493
return
505494

506495
# ============ Check the returned values ===============
@@ -512,12 +501,8 @@ def check_env(env: gym.Env, warn: bool = True, skip_render_check: bool = True) -
512501

513502
try:
514503
check_for_nested_spaces(env.observation_space)
515-
# if there is None defined , then skip the test
516-
if _is_graph_space(observation_space) or (
517-
isinstance(observation_space, spaces.Dict) and any(_is_graph_space(s) for s in observation_space.spaces.values())
518-
):
519-
pass # skip _check_nan for Graph spaces
520-
else:
521-
_check_nan(env)
504+
# The check doesn't support nested observations/dict actions
505+
# A warning about it has already been emitted
506+
_check_nan(env)
522507
except NotImplementedError:
523508
pass

stable_baselines3/common/utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -574,8 +574,8 @@ def obs_as_tensor(obs: Union[np.ndarray, dict[str, np.ndarray]], device: th.devi
574574
elif isinstance(obs, dict):
575575
return {key: th.as_tensor(_obs, device=device) for (key, _obs) in obs.items()}
576576
else:
577-
raise Exception(f"Unrecognized type of observation {type(obs)}")
578-
return obs
577+
raise TypeError(f"Unrecognized type of observation {type(obs)}")
578+
579579

580580
def should_collect_more_steps(
581581
train_freq: TrainFreq,
@@ -604,6 +604,7 @@ def should_collect_more_steps(
604604
f"or TrainFrequencyUnit.EPISODE not '{train_freq.unit}'!"
605605
)
606606

607+
607608
def get_system_info(print_info: bool = True) -> tuple[dict[str, str], str]:
608609
"""
609610
Retrieve system and python env info for the current system.
@@ -636,4 +637,4 @@ def get_system_info(print_info: bool = True) -> tuple[dict[str, str], str]:
636637
env_info_str += f"- {key}: {value}\n"
637638
if print_info:
638639
print(env_info_str)
639-
return env_info, env_info_str
640+
return env_info, env_info_str

stable_baselines3/version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
2.7.1a0
1+
2.7.1a2

tests/test_env_checker.py

Lines changed: 24 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def step(self, action):
2323
info = {}
2424
return observation, reward, terminated, truncated, info
2525

26-
def reset(self, seed=None, options=None):
26+
def reset(self, *, seed=None, options=None):
2727
return np.array([1.0, 1.5, 0.5], dtype=self.observation_space.dtype), {}
2828

2929
def render(self):
@@ -37,14 +37,15 @@ def test_check_env_dict_action():
3737
check_env(env=test_env, warn=True)
3838

3939

40-
class SequenceObservationEnv(gym.Env):
40+
class CustomEnv(gym.Env):
4141
metadata = {"render_modes": [], "render_fps": 2}
4242

4343
def __init__(self, render_mode=None):
44+
# Test Sequence obs
4445
self.observation_space = spaces.Sequence(spaces.Discrete(8))
4546
self.action_space = spaces.Discrete(4)
4647

47-
def reset(self, seed=None, options=None):
48+
def reset(self, *, seed=None, options=None):
4849
super().reset(seed=seed)
4950
return self.observation_space.sample(), {}
5051

@@ -53,7 +54,7 @@ def step(self, action):
5354

5455

5556
def test_check_env_sequence_obs():
56-
test_env = SequenceObservationEnv()
57+
test_env = CustomEnv()
5758

5859
with pytest.warns(Warning, match="Sequence.*not supported"):
5960
check_env(env=test_env, warn=True)
@@ -193,46 +194,32 @@ def test_check_env_single_step_env():
193194
check_env(env=test_env, warn=True)
194195

195196

196-
197-
class SimpleGraphEnv(gym.Env):
197+
class SimpleGraphEnv(CustomEnv):
198198
def __init__(self):
199199
self.action_space = spaces.Discrete(2)
200-
# Define a simple Graph observation space
201-
node_shape = (2,)
202-
edge_shape = (3,)
203200
self.observation_space = spaces.Graph(
204-
node_space=spaces.Box(low=0, high=1, shape=node_shape),
205-
edge_space=spaces.Box(low=0, high=1, shape=edge_shape),
201+
node_space=spaces.Box(low=0, high=1, shape=(2,)),
202+
edge_space=spaces.Box(low=0, high=1, shape=(3,)),
206203
)
207204

208-
def reset(self, seed=None, options=None):
209-
# Just sample from the obs space
210-
return self.observation_space.sample(), {}
211205

212-
def step(self, action):
213-
return self.observation_space.sample(), 1.0, False, False, {}
206+
class SimpleDictGraphEnv(CustomEnv):
207+
def __init__(self):
208+
self.action_space = spaces.Discrete(2)
209+
self.observation_space = spaces.Dict(
210+
{
211+
"test": spaces.Graph(
212+
node_space=spaces.Box(low=0, high=1, shape=(2,)),
213+
edge_space=spaces.Box(low=0, high=1, shape=(3,)),
214+
)
215+
}
216+
)
214217

215218

216-
def test_check_env_simple_graph_space():
217-
env = SimpleGraphEnv()
219+
def test_check_env_graph_space():
218220
# Should emit a warning about Graph space, but not fail
219-
with pytest.warns(UserWarning, match="Graph space, which is not supported by the env checker"):
220-
check_env(env, warn=True)
221-
222-
223-
def test_check_wrong_type_graph_space():
224-
# Create a Graph space with a wrong type
225-
node_shape = (2,)
226-
edge_shape = (3,)
227-
graph_space = spaces.Graph(
228-
node_space=spaces.Box(low=0, high=1, shape=node_shape),
229-
edge_space=spaces.Box(low=0, high=1, shape=edge_shape),
230-
)
231-
# Create an env with the wrong type
232-
env = SimpleGraphEnv()
233-
env.observation_space = graph_space
234-
235-
# Check that the env checker raises an error
236-
with pytest.raises(AssertionError, match="incompatible w/ graph-obs"):
237-
check_env(env, warn=True)
221+
with pytest.warns(UserWarning, match="Graph.*not supported"):
222+
check_env(SimpleGraphEnv(), warn=True)
238223

224+
with pytest.warns(UserWarning, match="Graph.*not supported"):
225+
check_env(SimpleDictGraphEnv(), warn=True)

0 commit comments

Comments
 (0)