From 6d803241da70390fe88da5f54c349a52e7ac8e31 Mon Sep 17 00:00:00 2001 From: Elliot Tower Date: Thu, 20 Jul 2023 20:41:01 -0400 Subject: [PATCH] Sb3 supersuit bugfix (#1031) --- .github/workflows/build-publish.yml | 7 +--- .github/workflows/linux-test.yml | 4 +- docs/tutorials/sb3/index.md | 2 +- docs/tutorials/sb3/kaz.md | 2 + tutorials/SB3/connect_four/requirements.txt | 3 +- tutorials/SB3/kaz/requirements.txt | 3 +- tutorials/SB3/kaz/sb3_kaz_vector.py | 43 ++++++++++----------- tutorials/SB3/pistonball/requirements.txt | 3 +- tutorials/SB3/waterworld/requirements.txt | 4 +- 9 files changed, 36 insertions(+), 35 deletions(-) diff --git a/.github/workflows/build-publish.yml b/.github/workflows/build-publish.yml index 2bd40ec23..799fda065 100644 --- a/.github/workflows/build-publish.yml +++ b/.github/workflows/build-publish.yml @@ -18,9 +18,6 @@ jobs: strategy: matrix: include: - - os: ubuntu-latest - python: 37 - platform: manylinux_x86_64 - os: ubuntu-latest python: 38 platform: manylinux_x86_64 @@ -41,9 +38,9 @@ jobs: with: python-version: '3.x' - name: Install dependencies - run: python -m pip install --upgrade setuptools wheel + run: python -m pip install --upgrade setuptools wheel build - name: Build wheels - run: python setup.py sdist bdist_wheel + run: python -m build --sdist --wheel - name: Store wheels uses: actions/upload-artifact@v2 with: diff --git a/.github/workflows/linux-test.yml b/.github/workflows/linux-test.yml index 0cf939047..30cdbe6b8 100644 --- a/.github/workflows/linux-test.yml +++ b/.github/workflows/linux-test.yml @@ -33,8 +33,8 @@ jobs: AutoROM -v - name: Source distribution test run: | - pip install build - python -m build + python -m pip install --upgrade build + python -m build --sdist pip install dist/*.tar.gz - name: Release Test run: | diff --git a/docs/tutorials/sb3/index.md b/docs/tutorials/sb3/index.md index 8c07ca496..2e48bd460 100644 --- a/docs/tutorials/sb3/index.md +++ b/docs/tutorials/sb3/index.md @@ -6,7 +6,7 @@ title: "Stable-Baselines3" These tutorials show you how to use the [Stable-Baselines3](https://stable-baselines3.readthedocs.io/en/master/) (SB3) library to train agents in PettingZoo environments. -For environments with visual observations, we use a [CNN](https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html#stable_baselines3.ppo.CnnPolicy) policy and perform pre-processing steps such as frame-stacking, color reduction, and resizing using [SuperSuit](/api/wrappers/supersuit_wrappers/). +For environments with visual observation spaces, we use a [CNN](https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html#stable_baselines3.ppo.CnnPolicy) policy and perform pre-processing steps such as frame-stacking and resizing using [SuperSuit](/api/wrappers/supersuit_wrappers/). * [PPO for Knights-Archers-Zombies](/tutorials/sb3/kaz/) _Train agents using PPO in a vectorized environment with visual observations_ diff --git a/docs/tutorials/sb3/kaz.md b/docs/tutorials/sb3/kaz.md index e6522caf9..5907d23de 100644 --- a/docs/tutorials/sb3/kaz.md +++ b/docs/tutorials/sb3/kaz.md @@ -10,6 +10,8 @@ We use SuperSuit to create vectorized environments, leveraging multithreading to After training and evaluation, this script will launch a demo game using human rendering. Trained models are saved and loaded from disk (see SB3's [model saving documentation](https://stable-baselines3.readthedocs.io/en/master/guide/save_format.html)). +If the observation space is visual (`vector_state=False` in `env_kwargs`), we pre-process using color reduction, resizing, and frame stacking, and use a [CNN](https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html#stable_baselines3.ppo.CnnPolicy) policy. + ```{eval-rst} .. note:: diff --git a/tutorials/SB3/connect_four/requirements.txt b/tutorials/SB3/connect_four/requirements.txt index 30917f7b2..7cb55ab3e 100644 --- a/tutorials/SB3/connect_four/requirements.txt +++ b/tutorials/SB3/connect_four/requirements.txt @@ -1,3 +1,4 @@ pettingzoo[classic]>=1.23.1 -stable-baselines3>=2.0.0 +# TODO: update to v2.1.0 once it is released +stable-baselines3 @ git+https://github.com/DLR-RM/stable-baselines3 sb3-contrib>=2.0.0 diff --git a/tutorials/SB3/kaz/requirements.txt b/tutorials/SB3/kaz/requirements.txt index 6199e7131..7596f648d 100644 --- a/tutorials/SB3/kaz/requirements.txt +++ b/tutorials/SB3/kaz/requirements.txt @@ -1,3 +1,4 @@ pettingzoo[butterfly]>=1.23.1 stable-baselines3>=2.0.0 -supersuit>=3.8.1 +# TODO: update to SS release once it's released +supersuit @ git+https://github.com/Farama-Foundation/SuperSuit.git diff --git a/tutorials/SB3/kaz/sb3_kaz_vector.py b/tutorials/SB3/kaz/sb3_kaz_vector.py index a8dace475..b4336d7de 100644 --- a/tutorials/SB3/kaz/sb3_kaz_vector.py +++ b/tutorials/SB3/kaz/sb3_kaz_vector.py @@ -14,7 +14,7 @@ import supersuit as ss from stable_baselines3 import PPO -from stable_baselines3.ppo import MlpPolicy +from stable_baselines3.ppo import CnnPolicy, MlpPolicy from pettingzoo.butterfly import knights_archers_zombies_v10 @@ -27,31 +27,26 @@ def train(env_fn, steps: int = 10_000, seed: int | None = 0, **env_kwargs): # MarkovVectorEnv does not support environments with varying numbers of active agents unless black_death is set to True env = ss.black_death_v3(env) - # Pre-process using SuperSuit (color reduction, resizing and frame stacking) - env = ss.resize_v1(env, x_size=84, y_size=84) - env = ss.frame_stack_v1(env, 3) + # Pre-process using SuperSuit + visual_observation = not env.unwrapped.vector_state + if visual_observation: + # If the observation space is visual, reduce the color channels, resize from 512px to 84px, and apply frame stacking + env = ss.color_reduction_v0(env, mode="B") + env = ss.resize_v1(env, x_size=84, y_size=84) + env = ss.frame_stack_v1(env, 3) env.reset(seed=seed) print(f"Starting training on {str(env.metadata['name'])}.") env = ss.pettingzoo_env_to_vec_env_v1(env) - env = ss.concat_vec_envs_v1(env, 8, num_cpus=2, base_class="stable_baselines3") + env = ss.concat_vec_envs_v1(env, 8, num_cpus=1, base_class="stable_baselines3") - # TODO: test different hyperparameters + # Use a CNN policy if the observation space is visual model = PPO( - MlpPolicy, + CnnPolicy if visual_observation else MlpPolicy, env, verbose=3, - gamma=0.95, - n_steps=256, - ent_coef=0.0905168, - learning_rate=0.00062211, - vf_coef=0.042202, - max_grad_norm=0.9, - gae_lambda=0.99, - n_epochs=5, - clip_range=0.3, batch_size=256, ) @@ -70,9 +65,13 @@ def eval(env_fn, num_games: int = 100, render_mode: str | None = None, **env_kwa # Evaluate a trained agent vs a random agent env = env_fn.env(render_mode=render_mode, **env_kwargs) - # Pre-process using SuperSuit (color reduction, resizing and frame stacking) - env = ss.resize_v1(env, x_size=84, y_size=84) - env = ss.frame_stack_v1(env, 3) + # Pre-process using SuperSuit + visual_observation = not env.unwrapped.vector_state + if visual_observation: + # If the observation space is visual, reduce the color channels, resize from 512px to 84px, and apply frame stacking + env = ss.color_reduction_v0(env, mode="B") + env = ss.resize_v1(env, x_size=84, y_size=84) + env = ss.frame_stack_v1(env, 3) print( f"\nStarting evaluation on {str(env.metadata['name'])} (num_games={num_games}, render_mode={render_mode})" @@ -125,10 +124,8 @@ def eval(env_fn, num_games: int = 100, render_mode: str | None = None, **env_kwa if __name__ == "__main__": env_fn = knights_archers_zombies_v10 - # Notes on environment configuration: - # max_cycles 100, max_zombies 4, seems to work well (13 points over 10 games) - # max_cycles 900 (default) allowed the knights to get kills 1/10 games, but worse archer performance (6 points) - env_kwargs = dict(max_cycles=100, max_zombies=4) + # Set vector_state to false in order to use visual observations (significantly longer training time) + env_kwargs = dict(max_cycles=100, max_zombies=4, vector_state=True) # Train a model (takes ~5 minutes on a laptop CPU) train(env_fn, steps=81_920, seed=0, **env_kwargs) diff --git a/tutorials/SB3/pistonball/requirements.txt b/tutorials/SB3/pistonball/requirements.txt index 6199e7131..7596f648d 100644 --- a/tutorials/SB3/pistonball/requirements.txt +++ b/tutorials/SB3/pistonball/requirements.txt @@ -1,3 +1,4 @@ pettingzoo[butterfly]>=1.23.1 stable-baselines3>=2.0.0 -supersuit>=3.8.1 +# TODO: update to SS release once it's released +supersuit @ git+https://github.com/Farama-Foundation/SuperSuit.git diff --git a/tutorials/SB3/waterworld/requirements.txt b/tutorials/SB3/waterworld/requirements.txt index cb7dac213..5c00ddc4a 100644 --- a/tutorials/SB3/waterworld/requirements.txt +++ b/tutorials/SB3/waterworld/requirements.txt @@ -1,4 +1,6 @@ pettingzoo[sisl]>=1.23.1 stable-baselines3>=2.0.0 -supersuit>=3.8.1 +# TODO: update to SS release once it's done +supersuit @ git+https://github.com/Farama-Foundation/SuperSuit.git +# TODO: remove pymunk requirement before 1.24.0 PZ release, as it will be added as a requirement to sisl pymunk