Skip to content

Commit

Permalink
Sb3 supersuit bugfix (#1031)
Browse files Browse the repository at this point in the history
  • Loading branch information
elliottower authored Jul 21, 2023
1 parent aa1a57f commit 6d80324
Show file tree
Hide file tree
Showing 9 changed files with 36 additions and 35 deletions.
7 changes: 2 additions & 5 deletions .github/workflows/build-publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@ jobs:
strategy:
matrix:
include:
- os: ubuntu-latest
python: 37
platform: manylinux_x86_64
- os: ubuntu-latest
python: 38
platform: manylinux_x86_64
Expand All @@ -41,9 +38,9 @@ jobs:
with:
python-version: '3.x'
- name: Install dependencies
run: python -m pip install --upgrade setuptools wheel
run: python -m pip install --upgrade setuptools wheel build
- name: Build wheels
run: python setup.py sdist bdist_wheel
run: python -m build --sdist --wheel
- name: Store wheels
uses: actions/upload-artifact@v2
with:
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/linux-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ jobs:
AutoROM -v
- name: Source distribution test
run: |
pip install build
python -m build
python -m pip install --upgrade build
python -m build --sdist
pip install dist/*.tar.gz
- name: Release Test
run: |
Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/sb3/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ title: "Stable-Baselines3"

These tutorials show you how to use the [Stable-Baselines3](https://stable-baselines3.readthedocs.io/en/master/) (SB3) library to train agents in PettingZoo environments.

For environments with visual observations, we use a [CNN](https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html#stable_baselines3.ppo.CnnPolicy) policy and perform pre-processing steps such as frame-stacking, color reduction, and resizing using [SuperSuit](/api/wrappers/supersuit_wrappers/).
For environments with visual observation spaces, we use a [CNN](https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html#stable_baselines3.ppo.CnnPolicy) policy and perform pre-processing steps such as frame-stacking and resizing using [SuperSuit](/api/wrappers/supersuit_wrappers/).

* [PPO for Knights-Archers-Zombies](/tutorials/sb3/kaz/) _Train agents using PPO in a vectorized environment with visual observations_

Expand Down
2 changes: 2 additions & 0 deletions docs/tutorials/sb3/kaz.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ We use SuperSuit to create vectorized environments, leveraging multithreading to

After training and evaluation, this script will launch a demo game using human rendering. Trained models are saved and loaded from disk (see SB3's [model saving documentation](https://stable-baselines3.readthedocs.io/en/master/guide/save_format.html)).

If the observation space is visual (`vector_state=False` in `env_kwargs`), we pre-process using color reduction, resizing, and frame stacking, and use a [CNN](https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html#stable_baselines3.ppo.CnnPolicy) policy.

```{eval-rst}
.. note::
Expand Down
3 changes: 2 additions & 1 deletion tutorials/SB3/connect_four/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pettingzoo[classic]>=1.23.1
stable-baselines3>=2.0.0
# TODO: update to v2.1.0 once it is released
stable-baselines3 @ git+https://github.com/DLR-RM/stable-baselines3
sb3-contrib>=2.0.0
3 changes: 2 additions & 1 deletion tutorials/SB3/kaz/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pettingzoo[butterfly]>=1.23.1
stable-baselines3>=2.0.0
supersuit>=3.8.1
# TODO: update to SS release once it's released
supersuit @ git+https://github.com/Farama-Foundation/SuperSuit.git
43 changes: 20 additions & 23 deletions tutorials/SB3/kaz/sb3_kaz_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import supersuit as ss
from stable_baselines3 import PPO
from stable_baselines3.ppo import MlpPolicy
from stable_baselines3.ppo import CnnPolicy, MlpPolicy

from pettingzoo.butterfly import knights_archers_zombies_v10

Expand All @@ -27,31 +27,26 @@ def train(env_fn, steps: int = 10_000, seed: int | None = 0, **env_kwargs):
# MarkovVectorEnv does not support environments with varying numbers of active agents unless black_death is set to True
env = ss.black_death_v3(env)

# Pre-process using SuperSuit (color reduction, resizing and frame stacking)
env = ss.resize_v1(env, x_size=84, y_size=84)
env = ss.frame_stack_v1(env, 3)
# Pre-process using SuperSuit
visual_observation = not env.unwrapped.vector_state
if visual_observation:
# If the observation space is visual, reduce the color channels, resize from 512px to 84px, and apply frame stacking
env = ss.color_reduction_v0(env, mode="B")
env = ss.resize_v1(env, x_size=84, y_size=84)
env = ss.frame_stack_v1(env, 3)

env.reset(seed=seed)

print(f"Starting training on {str(env.metadata['name'])}.")

env = ss.pettingzoo_env_to_vec_env_v1(env)
env = ss.concat_vec_envs_v1(env, 8, num_cpus=2, base_class="stable_baselines3")
env = ss.concat_vec_envs_v1(env, 8, num_cpus=1, base_class="stable_baselines3")

# TODO: test different hyperparameters
# Use a CNN policy if the observation space is visual
model = PPO(
MlpPolicy,
CnnPolicy if visual_observation else MlpPolicy,
env,
verbose=3,
gamma=0.95,
n_steps=256,
ent_coef=0.0905168,
learning_rate=0.00062211,
vf_coef=0.042202,
max_grad_norm=0.9,
gae_lambda=0.99,
n_epochs=5,
clip_range=0.3,
batch_size=256,
)

Expand All @@ -70,9 +65,13 @@ def eval(env_fn, num_games: int = 100, render_mode: str | None = None, **env_kwa
# Evaluate a trained agent vs a random agent
env = env_fn.env(render_mode=render_mode, **env_kwargs)

# Pre-process using SuperSuit (color reduction, resizing and frame stacking)
env = ss.resize_v1(env, x_size=84, y_size=84)
env = ss.frame_stack_v1(env, 3)
# Pre-process using SuperSuit
visual_observation = not env.unwrapped.vector_state
if visual_observation:
# If the observation space is visual, reduce the color channels, resize from 512px to 84px, and apply frame stacking
env = ss.color_reduction_v0(env, mode="B")
env = ss.resize_v1(env, x_size=84, y_size=84)
env = ss.frame_stack_v1(env, 3)

print(
f"\nStarting evaluation on {str(env.metadata['name'])} (num_games={num_games}, render_mode={render_mode})"
Expand Down Expand Up @@ -125,10 +124,8 @@ def eval(env_fn, num_games: int = 100, render_mode: str | None = None, **env_kwa
if __name__ == "__main__":
env_fn = knights_archers_zombies_v10

# Notes on environment configuration:
# max_cycles 100, max_zombies 4, seems to work well (13 points over 10 games)
# max_cycles 900 (default) allowed the knights to get kills 1/10 games, but worse archer performance (6 points)
env_kwargs = dict(max_cycles=100, max_zombies=4)
# Set vector_state to false in order to use visual observations (significantly longer training time)
env_kwargs = dict(max_cycles=100, max_zombies=4, vector_state=True)

# Train a model (takes ~5 minutes on a laptop CPU)
train(env_fn, steps=81_920, seed=0, **env_kwargs)
Expand Down
3 changes: 2 additions & 1 deletion tutorials/SB3/pistonball/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pettingzoo[butterfly]>=1.23.1
stable-baselines3>=2.0.0
supersuit>=3.8.1
# TODO: update to SS release once it's released
supersuit @ git+https://github.com/Farama-Foundation/SuperSuit.git
4 changes: 3 additions & 1 deletion tutorials/SB3/waterworld/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
pettingzoo[sisl]>=1.23.1
stable-baselines3>=2.0.0
supersuit>=3.8.1
# TODO: update to SS release once it's done
supersuit @ git+https://github.com/Farama-Foundation/SuperSuit.git
# TODO: remove pymunk requirement before 1.24.0 PZ release, as it will be added as a requirement to sisl
pymunk

0 comments on commit 6d80324

Please sign in to comment.