Skip to content

Commit 208b6fd

Browse files
authored
Add support for gymnasium v1.0 (#475)
* Add support for gymnasium v1.0 * Update versions * Fix requirements * Ignore mypy for gym 0.29 * Add explicit shimmy dep * Patch obs space and update trained agents * Comment out auto-fix obs space * Fix vecnormalize stats
1 parent b1288ed commit 208b6fd

21 files changed

+91
-40
lines changed

.github/workflows/ci.yml

+13-8
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,12 @@ jobs:
2020
strategy:
2121
matrix:
2222
python-version: ["3.8", "3.9", "3.10", "3.11"]
23-
23+
include:
24+
# Default version
25+
- gymnasium-version: "1.0.0"
26+
# Add a new config to test gym<1.0
27+
- python-version: "3.10"
28+
gymnasium-version: "0.29.1"
2429
steps:
2530
- uses: actions/checkout@v3
2631
with:
@@ -32,22 +37,22 @@ jobs:
3237
- name: Install dependencies
3338
run: |
3439
python -m pip install --upgrade pip
35-
3640
# Use uv for faster downloads
3741
pip install uv
38-
# Install Atari Roms
39-
uv pip install --system autorom
40-
wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64
41-
base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz
42-
AutoROM --accept-license --source-file Roms.tar.gz
43-
42+
# cpu version of pytorch
4443
# See https://github.com/astral-sh/uv/issues/1497
4544
uv pip install --system torch==2.4.1+cpu --index https://download.pytorch.org/whl/cpu
4645
# Install full requirements (for additional envs and test tools)
4746
uv pip install --system -r requirements.txt
4847
# Use headless version
4948
uv pip install --system opencv-python-headless
5049
uv pip install --system -e .[plots,tests]
50+
51+
- name: Install specific version of gym
52+
run: |
53+
uv pip install --system gymnasium==${{ matrix.gymnasium-version }}
54+
# Only run for python 3.10, downgrade gym to 0.29.1
55+
5156
- name: Lint with ruff
5257
run: |
5358
make lint

.github/workflows/trained_agents.yml

+14-7
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,12 @@ jobs:
2121
strategy:
2222
matrix:
2323
python-version: ["3.8", "3.9", "3.10", "3.11"]
24-
24+
include:
25+
# Default version
26+
- gymnasium-version: "1.0.0"
27+
# Add a new config to test gym<1.0
28+
- python-version: "3.10"
29+
gymnasium-version: "0.29.1"
2530
steps:
2631
- uses: actions/checkout@v3
2732
with:
@@ -36,19 +41,21 @@ jobs:
3641
3742
# Use uv for faster downloads
3843
pip install uv
39-
# Install Atari Roms
40-
uv pip install --system autorom
41-
wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64
42-
base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz
43-
AutoROM --accept-license --source-file Roms.tar.gz
44-
44+
# cpu version of pytorch
4545
# See https://github.com/astral-sh/uv/issues/1497
4646
uv pip install --system torch==2.4.1+cpu --index https://download.pytorch.org/whl/cpu
4747
# Install full requirements (for additional envs and test tools)
48+
# Install full requirements (for additional envs and test tools)
4849
uv pip install --system -r requirements.txt
4950
# Use headless version
5051
uv pip install --system opencv-python-headless
5152
uv pip install --system -e .[plots,tests]
53+
54+
- name: Install specific version of gym
55+
run: |
56+
uv pip install --system gymnasium==${{ matrix.gymnasium-version }}
57+
# Only run for python 3.10, downgrade gym to 0.29.1
58+
5259
- name: Check trained agents
5360
run: |
5461
make check-trained-agents

CHANGELOG.md

+5-2
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
1-
## Release 2.4.0a10 (WIP)
1+
## Release 2.4.0a11 (WIP)
22

3-
**New algorithm: CrossQ, and better defaults for SAC/TQC on Swimmer-v4 env**
3+
**New algorithm: CrossQ, Gymnasium v1.0 support, and better defaults for SAC/TQC on Swimmer-v4 env**
44

55
### Breaking Changes
66
- Updated defaults hyperparameters for TQC/SAC for Swimmer-v4 (decrease gamma for more consistent results) (@JacobHA) [W&B report](https://wandb.ai/openrlbenchmark/sbx/reports/SAC-MuJoCo-Swimmer-v4--Vmlldzo3NzM5OTk2)
77
- Upgraded to SB3 >= 2.4.0
8+
- Renamed `LunarLander-v2` to `LunarLander-v3` in hyperparameters
89

910
### New Features
1011
- Added `CrossQ` hyperparameters for SB3-contrib (@danielpalen)
12+
- Added Gymnasium v1.0 support
13+
- `--custom-objects` in `enjoy.py` now also patches obs space (when bounds are changed) to solve "Observation spaces do not match" errors
1114

1215
### Bug fixes
1316
- Replaced deprecated `huggingface_hub.Repository` when pushing to Hugging Face Hub by the recommended `HfApi` (see https://huggingface.co/docs/huggingface_hub/concepts/git_vs_http) (@cochaviz)

hyperparams/a2c.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ Pendulum-v1:
6161
policy_kwargs: "dict(log_std_init=-2, ortho_init=False)"
6262

6363
# Tuned
64-
LunarLanderContinuous-v2:
64+
LunarLanderContinuous-v3:
6565
normalize: true
6666
n_envs: 4
6767
n_timesteps: !!float 5e6

hyperparams/ars.yml

+1-2
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ LunarLander-v2:
2626
n_timesteps: !!float 2e6
2727

2828
# Tuned
29-
LunarLanderContinuous-v2:
29+
LunarLanderContinuous-v3:
3030
<<: *pendulum-params
3131
n_timesteps: !!float 2e6
3232

@@ -215,4 +215,3 @@ A1Jumping-v0:
215215
# alive_bonus_offset: -1
216216
normalize: "dict(norm_obs=True, norm_reward=False)"
217217
# policy_kwargs: "dict(net_arch=[16])"
218-

hyperparams/crossq.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ Pendulum-v1:
1818
policy_kwargs: "dict(net_arch=[256, 256])"
1919

2020

21-
LunarLanderContinuous-v2:
21+
LunarLanderContinuous-v3:
2222
n_timesteps: !!float 2e5
2323
policy: 'MlpPolicy'
2424
buffer_size: 1000000

hyperparams/ddpg.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ Pendulum-v1:
2323
learning_rate: !!float 1e-3
2424
policy_kwargs: "dict(net_arch=[400, 300])"
2525

26-
LunarLanderContinuous-v2:
26+
LunarLanderContinuous-v3:
2727
n_timesteps: !!float 3e5
2828
policy: 'MlpPolicy'
2929
gamma: 0.98

hyperparams/ppo.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ LunarLander-v2:
122122
n_epochs: 4
123123
ent_coef: 0.01
124124

125-
LunarLanderContinuous-v2:
125+
LunarLanderContinuous-v3:
126126
n_envs: 16
127127
n_timesteps: !!float 1e6
128128
policy: 'MlpPolicy'

hyperparams/sac.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ Pendulum-v1:
2222
learning_rate: !!float 1e-3
2323

2424

25-
LunarLanderContinuous-v2:
25+
LunarLanderContinuous-v3:
2626
n_timesteps: !!float 5e5
2727
policy: 'MlpPolicy'
2828
batch_size: 256

hyperparams/td3.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ Pendulum-v1:
2323
learning_rate: !!float 1e-3
2424
policy_kwargs: "dict(net_arch=[400, 300])"
2525

26-
LunarLanderContinuous-v2:
26+
LunarLanderContinuous-v3:
2727
n_timesteps: !!float 3e5
2828
policy: 'MlpPolicy'
2929
gamma: 0.98

hyperparams/tqc.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ Pendulum-v1:
1919
policy: 'MlpPolicy'
2020
learning_rate: !!float 1e-3
2121

22-
LunarLanderContinuous-v2:
22+
LunarLanderContinuous-v3:
2323
n_timesteps: !!float 5e5
2424
policy: 'MlpPolicy'
2525
learning_rate: lin_7.3e-4

hyperparams/trpo.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ LunarLander-v2:
3535
n_critic_updates: 15
3636

3737
# Tuned
38-
LunarLanderContinuous-v2:
38+
LunarLanderContinuous-v3:
3939
normalize: true
4040
n_envs: 2
4141
n_timesteps: !!float 1e5

requirements.txt

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
gym==0.26.2
2-
stable-baselines3[extra_no_roms,tests,docs]>=2.4.0a10,<3.0
2+
stable-baselines3[extra,tests,docs]>=2.4.0a11,<3.0
33
box2d-py==2.3.8
4-
pybullet_envs_gymnasium>=0.4.0
4+
pybullet_envs_gymnasium>=0.5.0
55
# minigrid
66
cloudpickle>=2.2.1
77
# optuna plots:
88
plotly
99
# need to upgrade to gymnasium:
1010
# panda-gym~=3.0.1
1111
wandb
12-
moviepy
12+
moviepy>=1.0.0

rl_zoo3/enjoy.py

+11
Original file line numberDiff line numberDiff line change
@@ -184,12 +184,23 @@ def enjoy() -> None: # noqa: C901
184184
"learning_rate": 0.0,
185185
"lr_schedule": lambda _: 0.0,
186186
"clip_range": lambda _: 0.0,
187+
# load models with different obs bounds
188+
# Note: doesn't work with channel last envs
189+
# "observation_space": env.observation_space,
187190
}
188191

189192
if "HerReplayBuffer" in hyperparams.get("replay_buffer_class", ""):
190193
kwargs["env"] = env
191194

192195
model = ALGOS[algo].load(model_path, custom_objects=custom_objects, device=args.device, **kwargs)
196+
# Uncomment to save patched file (for instance gym -> gymnasium)
197+
# model.save(model_path)
198+
# Patch VecNormalize (gym -> gymnasium)
199+
# from pathlib import Path
200+
# env.observation_space = model.observation_space
201+
# env.action_space = model.action_space
202+
# env.save(Path(model_path).parent / env_name / "vecnormalize.pkl")
203+
193204
obs = env.reset()
194205

195206
# Deterministic by default except for atari games

rl_zoo3/gym_patches.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -39,5 +39,8 @@ def step(self, action):
3939

4040
# Patch Gymnasium TimeLimit
4141
gymnasium.wrappers.TimeLimit = PatchedTimeLimit # type: ignore[misc]
42-
gymnasium.wrappers.time_limit.TimeLimit = PatchedTimeLimit # type: ignore[misc]
42+
try:
43+
gymnasium.wrappers.time_limit.TimeLimit = PatchedTimeLimit # type: ignore[misc]
44+
except AttributeError:
45+
gymnasium.wrappers.common.TimeLimit = PatchedTimeLimit # type: ignore
4346
gymnasium.envs.registration.TimeLimit = PatchedTimeLimit # type: ignore[misc,attr-defined]

rl_zoo3/import_envs.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Callable, Optional
22

33
import gymnasium as gym
4-
from gymnasium.envs.registration import register
4+
from gymnasium.envs.registration import register, register_envs
55

66
from rl_zoo3.wrappers import MaskVelocityWrapper
77

@@ -10,6 +10,14 @@
1010
except ImportError:
1111
pass
1212

13+
try:
14+
import ale_py
15+
16+
# no-op
17+
gym.register_envs(ale_py)
18+
except ImportError:
19+
pass
20+
1321
try:
1422
import highway_env
1523
except ImportError:

rl_zoo3/version.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
2.4.0a10
1+
2.4.0a11

setup.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -15,21 +15,22 @@
1515
See https://github.com/DLR-RM/rl-baselines3-zoo
1616
"""
1717
install_requires = [
18-
"sb3_contrib>=2.4.0a10,<3.0",
19-
"gymnasium~=0.29.1",
18+
"sb3_contrib>=2.4.0a11,<3.0",
19+
"gymnasium>=0.29.1,<1.1.0",
2020
"huggingface_sb3>=3.0,<4.0",
2121
"tqdm",
2222
"rich",
2323
"optuna>=3.0",
2424
"pyyaml>=5.1",
2525
"pytablewriter~=1.2",
26+
"shimmy~=2.0",
2627
]
2728
plots_requires = ["seaborn", "rliable~=1.2.0", "scipy~=1.10"]
2829
test_requires = [
2930
# for MuJoCo envs v4:
30-
"mujoco~=2.3",
31+
"mujoco>=2.3,<4",
3132
# install parking-env to test HER
32-
"highway-env==1.8.2",
33+
"highway-env>=1.10.1,<1.11.0",
3334
]
3435

3536
setup(

tests/test_enjoy.py

+6
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
import shlex
33
import subprocess
4+
from importlib.metadata import version
45

56
import pytest
67

@@ -40,6 +41,11 @@ def test_trained_agents(trained_model):
4041
if "Panda" in env_id:
4142
return
4243

44+
# TODO: rename trained agents once we drop support for gymnasium v0.29
45+
if "Lander" in env_id and version("gymnasium") > "0.29.1":
46+
# LunarLander-v2 is now LunarLander-v3
47+
return
48+
4349
# Skip mujoco envs
4450
if "Fetch" in trained_model or "-v3" in trained_model:
4551
return

tests/test_train.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
import shlex
33
import subprocess
4+
from importlib.metadata import version
45

56
import pytest
67

@@ -36,10 +37,17 @@ def test_train(tmp_path, experiment):
3637

3738

3839
def test_continue_training(tmp_path):
39-
algo, env_id = "a2c", "CartPole-v1"
40+
algo = "a2c"
41+
if version("gymnasium") > "0.29.1":
42+
# See https://github.com/DLR-RM/stable-baselines3/pull/1837#issuecomment-2457322341
43+
# obs bounds have changed...
44+
env_id = "CartPole-v1"
45+
else:
46+
env_id = "Pendulum-v1"
47+
4048
cmd = (
4149
f"python train.py -n {N_STEPS} --algo {algo} --env {env_id} --log-folder {tmp_path} "
42-
"-i rl-trained-agents/a2c/CartPole-v1_1/CartPole-v1.zip"
50+
f"-i rl-trained-agents/a2c/{env_id}_1/{env_id}.zip"
4351
)
4452
return_code = subprocess.call(shlex.split(cmd))
4553
_assert_eq(return_code, 0)

0 commit comments

Comments
 (0)