Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for gymnasium v1.0 #475

Merged
merged 8 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@ jobs:
strategy:
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11"]

include:
# Default version
- gymnasium-version: "1.0.0"
# Add a new config to test gym<1.0
- python-version: "3.10"
gymnasium-version: "0.29.1"
steps:
- uses: actions/checkout@v3
with:
Expand All @@ -32,22 +37,22 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip

# Use uv for faster downloads
pip install uv
# Install Atari Roms
uv pip install --system autorom
wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64
base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz
AutoROM --accept-license --source-file Roms.tar.gz

# cpu version of pytorch
# See https://github.com/astral-sh/uv/issues/1497
uv pip install --system torch==2.4.1+cpu --index https://download.pytorch.org/whl/cpu
# Install full requirements (for additional envs and test tools)
uv pip install --system -r requirements.txt
# Use headless version
uv pip install --system opencv-python-headless
uv pip install --system -e .[plots,tests]

- name: Install specific version of gym
run: |
uv pip install --system gymnasium==${{ matrix.gymnasium-version }}
# Only run for python 3.10, downgrade gym to 0.29.1

- name: Lint with ruff
run: |
make lint
Expand Down
21 changes: 14 additions & 7 deletions .github/workflows/trained_agents.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,12 @@ jobs:
strategy:
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11"]

include:
# Default version
- gymnasium-version: "1.0.0"
# Add a new config to test gym<1.0
- python-version: "3.10"
gymnasium-version: "0.29.1"
steps:
- uses: actions/checkout@v3
with:
Expand All @@ -36,19 +41,21 @@ jobs:

# Use uv for faster downloads
pip install uv
# Install Atari Roms
uv pip install --system autorom
wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64
base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz
AutoROM --accept-license --source-file Roms.tar.gz

# cpu version of pytorch
# See https://github.com/astral-sh/uv/issues/1497
uv pip install --system torch==2.4.1+cpu --index https://download.pytorch.org/whl/cpu
# Install full requirements (for additional envs and test tools)
# Install full requirements (for additional envs and test tools)
uv pip install --system -r requirements.txt
# Use headless version
uv pip install --system opencv-python-headless
uv pip install --system -e .[plots,tests]

- name: Install specific version of gym
run: |
uv pip install --system gymnasium==${{ matrix.gymnasium-version }}
# Only run for python 3.10, downgrade gym to 0.29.1

- name: Check trained agents
run: |
make check-trained-agents
7 changes: 5 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
## Release 2.4.0a10 (WIP)
## Release 2.4.0a11 (WIP)

**New algorithm: CrossQ, and better defaults for SAC/TQC on Swimmer-v4 env**
**New algorithm: CrossQ, Gymnasium v1.0 support, and better defaults for SAC/TQC on Swimmer-v4 env**

### Breaking Changes
- Updated defaults hyperparameters for TQC/SAC for Swimmer-v4 (decrease gamma for more consistent results) (@JacobHA) [W&B report](https://wandb.ai/openrlbenchmark/sbx/reports/SAC-MuJoCo-Swimmer-v4--Vmlldzo3NzM5OTk2)
- Upgraded to SB3 >= 2.4.0
- Renamed `LunarLander-v2` to `LunarLander-v3` in hyperparameters

### New Features
- Added `CrossQ` hyperparameters for SB3-contrib (@danielpalen)
- Added Gymnasium v1.0 support
- `--custom-objects` in `enjoy.py` now also patches obs space (when bounds are changed) to solve "Observation spaces do not match" errors

### Bug fixes
- Replaced deprecated `huggingface_hub.Repository` when pushing to Hugging Face Hub by the recommended `HfApi` (see https://huggingface.co/docs/huggingface_hub/concepts/git_vs_http) (@cochaviz)
Expand Down
2 changes: 1 addition & 1 deletion hyperparams/a2c.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ Pendulum-v1:
policy_kwargs: "dict(log_std_init=-2, ortho_init=False)"

# Tuned
LunarLanderContinuous-v2:
LunarLanderContinuous-v3:
normalize: true
n_envs: 4
n_timesteps: !!float 5e6
Expand Down
3 changes: 1 addition & 2 deletions hyperparams/ars.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ LunarLander-v2:
n_timesteps: !!float 2e6

# Tuned
LunarLanderContinuous-v2:
LunarLanderContinuous-v3:
<<: *pendulum-params
n_timesteps: !!float 2e6

Expand Down Expand Up @@ -215,4 +215,3 @@ A1Jumping-v0:
# alive_bonus_offset: -1
normalize: "dict(norm_obs=True, norm_reward=False)"
# policy_kwargs: "dict(net_arch=[16])"

2 changes: 1 addition & 1 deletion hyperparams/crossq.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ Pendulum-v1:
policy_kwargs: "dict(net_arch=[256, 256])"


LunarLanderContinuous-v2:
LunarLanderContinuous-v3:
n_timesteps: !!float 2e5
policy: 'MlpPolicy'
buffer_size: 1000000
Expand Down
2 changes: 1 addition & 1 deletion hyperparams/ddpg.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ Pendulum-v1:
learning_rate: !!float 1e-3
policy_kwargs: "dict(net_arch=[400, 300])"

LunarLanderContinuous-v2:
LunarLanderContinuous-v3:
n_timesteps: !!float 3e5
policy: 'MlpPolicy'
gamma: 0.98
Expand Down
2 changes: 1 addition & 1 deletion hyperparams/ppo.yml
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ LunarLander-v2:
n_epochs: 4
ent_coef: 0.01

LunarLanderContinuous-v2:
LunarLanderContinuous-v3:
n_envs: 16
n_timesteps: !!float 1e6
policy: 'MlpPolicy'
Expand Down
2 changes: 1 addition & 1 deletion hyperparams/sac.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ Pendulum-v1:
learning_rate: !!float 1e-3


LunarLanderContinuous-v2:
LunarLanderContinuous-v3:
n_timesteps: !!float 5e5
policy: 'MlpPolicy'
batch_size: 256
Expand Down
2 changes: 1 addition & 1 deletion hyperparams/td3.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ Pendulum-v1:
learning_rate: !!float 1e-3
policy_kwargs: "dict(net_arch=[400, 300])"

LunarLanderContinuous-v2:
LunarLanderContinuous-v3:
n_timesteps: !!float 3e5
policy: 'MlpPolicy'
gamma: 0.98
Expand Down
2 changes: 1 addition & 1 deletion hyperparams/tqc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ Pendulum-v1:
policy: 'MlpPolicy'
learning_rate: !!float 1e-3

LunarLanderContinuous-v2:
LunarLanderContinuous-v3:
n_timesteps: !!float 5e5
policy: 'MlpPolicy'
learning_rate: lin_7.3e-4
Expand Down
2 changes: 1 addition & 1 deletion hyperparams/trpo.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ LunarLander-v2:
n_critic_updates: 15

# Tuned
LunarLanderContinuous-v2:
LunarLanderContinuous-v3:
normalize: true
n_envs: 2
n_timesteps: !!float 1e5
Expand Down
6 changes: 3 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
gym==0.26.2
stable-baselines3[extra_no_roms,tests,docs]>=2.4.0a10,<3.0
stable-baselines3[extra,tests,docs]>=2.4.0a11,<3.0
box2d-py==2.3.8
pybullet_envs_gymnasium>=0.4.0
pybullet_envs_gymnasium>=0.5.0
# minigrid
cloudpickle>=2.2.1
# optuna plots:
plotly
# need to upgrade to gymnasium:
# panda-gym~=3.0.1
wandb
moviepy
moviepy>=1.0.0
11 changes: 11 additions & 0 deletions rl_zoo3/enjoy.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,12 +184,23 @@ def enjoy() -> None: # noqa: C901
"learning_rate": 0.0,
"lr_schedule": lambda _: 0.0,
"clip_range": lambda _: 0.0,
# load models with different obs bounds
# Note: doesn't work with channel last envs
# "observation_space": env.observation_space,
}

if "HerReplayBuffer" in hyperparams.get("replay_buffer_class", ""):
kwargs["env"] = env

model = ALGOS[algo].load(model_path, custom_objects=custom_objects, device=args.device, **kwargs)
# Uncomment to save patched file (for instance gym -> gymnasium)
# model.save(model_path)
# Patch VecNormalize (gym -> gymnasium)
# from pathlib import Path
# env.observation_space = model.observation_space
# env.action_space = model.action_space
# env.save(Path(model_path).parent / env_name / "vecnormalize.pkl")

obs = env.reset()

# Deterministic by default except for atari games
Expand Down
5 changes: 4 additions & 1 deletion rl_zoo3/gym_patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,5 +39,8 @@ def step(self, action):

# Patch Gymnasium TimeLimit
gymnasium.wrappers.TimeLimit = PatchedTimeLimit # type: ignore[misc]
gymnasium.wrappers.time_limit.TimeLimit = PatchedTimeLimit # type: ignore[misc]
try:
gymnasium.wrappers.time_limit.TimeLimit = PatchedTimeLimit # type: ignore[misc]
except AttributeError:
gymnasium.wrappers.common.TimeLimit = PatchedTimeLimit # type: ignore
gymnasium.envs.registration.TimeLimit = PatchedTimeLimit # type: ignore[misc,attr-defined]
10 changes: 9 additions & 1 deletion rl_zoo3/import_envs.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Callable, Optional

import gymnasium as gym
from gymnasium.envs.registration import register
from gymnasium.envs.registration import register, register_envs

from rl_zoo3.wrappers import MaskVelocityWrapper

Expand All @@ -10,6 +10,14 @@
except ImportError:
pass

try:
import ale_py

# no-op
gym.register_envs(ale_py)
except ImportError:
pass

try:
import highway_env
except ImportError:
Expand Down
2 changes: 1 addition & 1 deletion rl_zoo3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.4.0a10
2.4.0a11
9 changes: 5 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,22 @@
See https://github.com/DLR-RM/rl-baselines3-zoo
"""
install_requires = [
"sb3_contrib>=2.4.0a10,<3.0",
"gymnasium~=0.29.1",
"sb3_contrib>=2.4.0a11,<3.0",
"gymnasium>=0.29.1,<1.1.0",
"huggingface_sb3>=3.0,<4.0",
"tqdm",
"rich",
"optuna>=3.0",
"pyyaml>=5.1",
"pytablewriter~=1.2",
"shimmy~=2.0",
]
plots_requires = ["seaborn", "rliable~=1.2.0", "scipy~=1.10"]
test_requires = [
# for MuJoCo envs v4:
"mujoco~=2.3",
"mujoco>=2.3,<4",
# install parking-env to test HER
"highway-env==1.8.2",
"highway-env>=1.10.1,<1.11.0",
]

setup(
Expand Down
6 changes: 6 additions & 0 deletions tests/test_enjoy.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import shlex
import subprocess
from importlib.metadata import version

import pytest

Expand Down Expand Up @@ -40,6 +41,11 @@ def test_trained_agents(trained_model):
if "Panda" in env_id:
return

# TODO: rename trained agents once we drop support for gymnasium v0.29
if "Lander" in env_id and version("gymnasium") > "0.29.1":
# LunarLander-v2 is now LunarLander-v3
return

# Skip mujoco envs
if "Fetch" in trained_model or "-v3" in trained_model:
return
Expand Down
12 changes: 10 additions & 2 deletions tests/test_train.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import shlex
import subprocess
from importlib.metadata import version

import pytest

Expand Down Expand Up @@ -36,10 +37,17 @@ def test_train(tmp_path, experiment):


def test_continue_training(tmp_path):
algo, env_id = "a2c", "CartPole-v1"
algo = "a2c"
if version("gymnasium") > "0.29.1":
# See https://github.com/DLR-RM/stable-baselines3/pull/1837#issuecomment-2457322341
# obs bounds have changed...
env_id = "CartPole-v1"
else:
env_id = "Pendulum-v1"

cmd = (
f"python train.py -n {N_STEPS} --algo {algo} --env {env_id} --log-folder {tmp_path} "
"-i rl-trained-agents/a2c/CartPole-v1_1/CartPole-v1.zip"
f"-i rl-trained-agents/a2c/{env_id}_1/{env_id}.zip"
)
return_code = subprocess.call(shlex.split(cmd))
_assert_eq(return_code, 0)
Expand Down
Loading