Skip to content

Commit 9856423

Browse files
authored
Add support for gymnasium v1.0 (#261)
* Add support for gymnasium v1.0 * Fix for gym v1.0 * Update CI matrix * Update SB3 min version * Fix warning
1 parent e05ee42 commit 9856423

File tree

9 files changed

+28
-16
lines changed

9 files changed

+28
-16
lines changed

.github/workflows/ci.yml

+12-8
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,12 @@ jobs:
2020
strategy:
2121
matrix:
2222
python-version: ["3.8", "3.9", "3.10", "3.11"]
23-
23+
include:
24+
# Default version
25+
- gymnasium-version: "1.0.0"
26+
# Add a new config to test gym<1.0
27+
- python-version: "3.10"
28+
gymnasium-version: "0.29.1"
2429
steps:
2530
- uses: actions/checkout@v3
2631
- name: Set up Python ${{ matrix.python-version }}
@@ -36,19 +41,18 @@ jobs:
3641
# See https://github.com/astral-sh/uv/issues/1497
3742
uv pip install --system torch==2.4.1+cpu --index https://download.pytorch.org/whl/cpu
3843
39-
# Install Atari Roms
40-
uv pip install --system autorom
41-
wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64
42-
base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz
43-
AutoROM --accept-license --source-file Roms.tar.gz
44-
4544
# Install master version
4645
# and dependencies for docs and tests
47-
uv pip install --system "stable_baselines3[extra_no_roms,tests,docs] @ git+https://github.com/DLR-RM/stable-baselines3"
46+
uv pip install --system "stable_baselines3[extra,tests,docs] @ git+https://github.com/DLR-RM/stable-baselines3"
4847
uv pip install --system .
4948
# Use headless version
5049
uv pip install --system opencv-python-headless
5150
51+
- name: Install specific version of gym
52+
run: |
53+
uv pip install --system gymnasium==${{ matrix.gymnasium-version }}
54+
# Only run for python 3.10, downgrade gym to 0.29.1
55+
5256
- name: Lint with ruff
5357
run: |
5458
make lint

docs/conda_env.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ dependencies:
88
- python=3.11
99
- pytorch=2.5.0=py3.11_cpu_0
1010
- pip:
11-
- gymnasium>=0.28.1,<0.30
11+
- gymnasium>=0.29.1,<1.1.0
1212
- stable-baselines3>=2.0.0,<3.0
1313
- cloudpickle
1414
- opencv-python-headless

docs/misc/changelog.rst

+2-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
Changelog
44
==========
55

6-
Release 2.4.0a10 (WIP)
6+
Release 2.4.0a11 (WIP)
77
--------------------------
88

99
**New algorithm: added CrossQ**
@@ -16,6 +16,7 @@ New Features:
1616
^^^^^^^^^^^^^
1717
- Added ``CrossQ`` algorithm, from "Batch Normalization in Deep Reinforcement Learning" paper (@danielpalen)
1818
- Added ``BatchRenorm`` PyTorch layer used in ``CrossQ`` (@danielpalen)
19+
- Added support for Gymnasium v1.0
1920

2021
Bug Fixes:
2122
^^^^^^^^^^

pyproject.toml

+2
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ env = ["PYTHONHASHSEED=0"]
4040
filterwarnings = [
4141
# Tensorboard warnings
4242
"ignore::DeprecationWarning:tensorboard",
43+
# tqdm warning about rich being experimental
44+
"ignore:rich is experimental",
4345
]
4446
markers = ["slow: marks tests as slow (deselect with '-m \"not slow\"')"]
4547

sb3_contrib/common/maskable/utils.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def get_action_masks(env: GymEnv) -> np.ndarray:
1616
if isinstance(env, VecEnv):
1717
return np.stack(env.env_method(EXPECTED_METHOD_NAME))
1818
else:
19-
return getattr(env, EXPECTED_METHOD_NAME)()
19+
return env.get_wrapper_attr(EXPECTED_METHOD_NAME)()
2020

2121

2222
def is_masking_supported(env: GymEnv) -> bool:
@@ -35,4 +35,8 @@ def is_masking_supported(env: GymEnv) -> bool:
3535
except AttributeError:
3636
return False
3737
else:
38-
return hasattr(env, EXPECTED_METHOD_NAME)
38+
try:
39+
env.get_wrapper_attr(EXPECTED_METHOD_NAME)
40+
return True
41+
except AttributeError:
42+
return False

sb3_contrib/common/wrappers/time_feature.py

+1
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def __init__(self, env: gym.Env, max_steps: int = 1000, test_mode: bool = False)
4444
low, high = obs_space.low, obs_space.high
4545
low, high = np.concatenate((low, [0.0])), np.concatenate((high, [1.0])) # type: ignore[arg-type]
4646
self.dtype = obs_space.dtype
47+
low, high = low.astype(self.dtype), high.astype(self.dtype)
4748

4849
if isinstance(env.observation_space, spaces.Dict):
4950
env.observation_space.spaces["observation"] = spaces.Box(

sb3_contrib/version.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
2.4.0a10
1+
2.4.0a11

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@
6767
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
6868
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
6969
install_requires=[
70-
"stable_baselines3>=2.4.0a6,<3.0",
70+
"stable_baselines3>=2.4.0a11,<3.0",
7171
],
7272
description="Contrib package of Stable Baselines3, experimental code.",
7373
author="Antonin Raffin",

tests/test_lstm.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import pytest
66
from gymnasium import spaces
77
from gymnasium.envs.classic_control import CartPoleEnv
8-
from gymnasium.wrappers.time_limit import TimeLimit
8+
from gymnasium.wrappers import TimeLimit
99
from stable_baselines3.common.callbacks import EvalCallback
1010
from stable_baselines3.common.env_checker import check_env
1111
from stable_baselines3.common.env_util import make_vec_env
@@ -43,7 +43,7 @@ def __init__(self):
4343
self.x_threshold * 2,
4444
self.theta_threshold_radians * 2,
4545
]
46-
)
46+
).astype(np.float32)
4747
self.observation_space = spaces.Box(-high, high, dtype=np.float32)
4848

4949
@staticmethod

0 commit comments

Comments
 (0)