Skip to content

Commit 6d80324

Browse files
authored
Sb3 supersuit bugfix (#1031)
1 parent aa1a57f commit 6d80324

File tree

9 files changed

+36
-35
lines changed

9 files changed

+36
-35
lines changed

.github/workflows/build-publish.yml

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,6 @@ jobs:
1818
strategy:
1919
matrix:
2020
include:
21-
- os: ubuntu-latest
22-
python: 37
23-
platform: manylinux_x86_64
2421
- os: ubuntu-latest
2522
python: 38
2623
platform: manylinux_x86_64
@@ -41,9 +38,9 @@ jobs:
4138
with:
4239
python-version: '3.x'
4340
- name: Install dependencies
44-
run: python -m pip install --upgrade setuptools wheel
41+
run: python -m pip install --upgrade setuptools wheel build
4542
- name: Build wheels
46-
run: python setup.py sdist bdist_wheel
43+
run: python -m build --sdist --wheel
4744
- name: Store wheels
4845
uses: actions/upload-artifact@v2
4946
with:

.github/workflows/linux-test.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ jobs:
3333
AutoROM -v
3434
- name: Source distribution test
3535
run: |
36-
pip install build
37-
python -m build
36+
python -m pip install --upgrade build
37+
python -m build --sdist
3838
pip install dist/*.tar.gz
3939
- name: Release Test
4040
run: |

docs/tutorials/sb3/index.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ title: "Stable-Baselines3"
66

77
These tutorials show you how to use the [Stable-Baselines3](https://stable-baselines3.readthedocs.io/en/master/) (SB3) library to train agents in PettingZoo environments.
88

9-
For environments with visual observations, we use a [CNN](https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html#stable_baselines3.ppo.CnnPolicy) policy and perform pre-processing steps such as frame-stacking, color reduction, and resizing using [SuperSuit](/api/wrappers/supersuit_wrappers/).
9+
For environments with visual observation spaces, we use a [CNN](https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html#stable_baselines3.ppo.CnnPolicy) policy and perform pre-processing steps such as frame-stacking and resizing using [SuperSuit](/api/wrappers/supersuit_wrappers/).
1010

1111
* [PPO for Knights-Archers-Zombies](/tutorials/sb3/kaz/) _Train agents using PPO in a vectorized environment with visual observations_
1212

docs/tutorials/sb3/kaz.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ We use SuperSuit to create vectorized environments, leveraging multithreading to
1010

1111
After training and evaluation, this script will launch a demo game using human rendering. Trained models are saved and loaded from disk (see SB3's [model saving documentation](https://stable-baselines3.readthedocs.io/en/master/guide/save_format.html)).
1212

13+
If the observation space is visual (`vector_state=False` in `env_kwargs`), we pre-process using color reduction, resizing, and frame stacking, and use a [CNN](https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html#stable_baselines3.ppo.CnnPolicy) policy.
14+
1315
```{eval-rst}
1416
.. note::
1517
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
pettingzoo[classic]>=1.23.1
2-
stable-baselines3>=2.0.0
2+
# TODO: update to v2.1.0 once it is released
3+
stable-baselines3 @ git+https://github.com/DLR-RM/stable-baselines3
34
sb3-contrib>=2.0.0

tutorials/SB3/kaz/requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
pettingzoo[butterfly]>=1.23.1
22
stable-baselines3>=2.0.0
3-
supersuit>=3.8.1
3+
# TODO: update to SS release once it's released
4+
supersuit @ git+https://github.com/Farama-Foundation/SuperSuit.git

tutorials/SB3/kaz/sb3_kaz_vector.py

Lines changed: 20 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import supersuit as ss
1616
from stable_baselines3 import PPO
17-
from stable_baselines3.ppo import MlpPolicy
17+
from stable_baselines3.ppo import CnnPolicy, MlpPolicy
1818

1919
from pettingzoo.butterfly import knights_archers_zombies_v10
2020

@@ -27,31 +27,26 @@ def train(env_fn, steps: int = 10_000, seed: int | None = 0, **env_kwargs):
2727
# MarkovVectorEnv does not support environments with varying numbers of active agents unless black_death is set to True
2828
env = ss.black_death_v3(env)
2929

30-
# Pre-process using SuperSuit (color reduction, resizing and frame stacking)
31-
env = ss.resize_v1(env, x_size=84, y_size=84)
32-
env = ss.frame_stack_v1(env, 3)
30+
# Pre-process using SuperSuit
31+
visual_observation = not env.unwrapped.vector_state
32+
if visual_observation:
33+
# If the observation space is visual, reduce the color channels, resize from 512px to 84px, and apply frame stacking
34+
env = ss.color_reduction_v0(env, mode="B")
35+
env = ss.resize_v1(env, x_size=84, y_size=84)
36+
env = ss.frame_stack_v1(env, 3)
3337

3438
env.reset(seed=seed)
3539

3640
print(f"Starting training on {str(env.metadata['name'])}.")
3741

3842
env = ss.pettingzoo_env_to_vec_env_v1(env)
39-
env = ss.concat_vec_envs_v1(env, 8, num_cpus=2, base_class="stable_baselines3")
43+
env = ss.concat_vec_envs_v1(env, 8, num_cpus=1, base_class="stable_baselines3")
4044

41-
# TODO: test different hyperparameters
45+
# Use a CNN policy if the observation space is visual
4246
model = PPO(
43-
MlpPolicy,
47+
CnnPolicy if visual_observation else MlpPolicy,
4448
env,
4549
verbose=3,
46-
gamma=0.95,
47-
n_steps=256,
48-
ent_coef=0.0905168,
49-
learning_rate=0.00062211,
50-
vf_coef=0.042202,
51-
max_grad_norm=0.9,
52-
gae_lambda=0.99,
53-
n_epochs=5,
54-
clip_range=0.3,
5550
batch_size=256,
5651
)
5752

@@ -70,9 +65,13 @@ def eval(env_fn, num_games: int = 100, render_mode: str | None = None, **env_kwa
7065
# Evaluate a trained agent vs a random agent
7166
env = env_fn.env(render_mode=render_mode, **env_kwargs)
7267

73-
# Pre-process using SuperSuit (color reduction, resizing and frame stacking)
74-
env = ss.resize_v1(env, x_size=84, y_size=84)
75-
env = ss.frame_stack_v1(env, 3)
68+
# Pre-process using SuperSuit
69+
visual_observation = not env.unwrapped.vector_state
70+
if visual_observation:
71+
# If the observation space is visual, reduce the color channels, resize from 512px to 84px, and apply frame stacking
72+
env = ss.color_reduction_v0(env, mode="B")
73+
env = ss.resize_v1(env, x_size=84, y_size=84)
74+
env = ss.frame_stack_v1(env, 3)
7675

7776
print(
7877
f"\nStarting evaluation on {str(env.metadata['name'])} (num_games={num_games}, render_mode={render_mode})"
@@ -125,10 +124,8 @@ def eval(env_fn, num_games: int = 100, render_mode: str | None = None, **env_kwa
125124
if __name__ == "__main__":
126125
env_fn = knights_archers_zombies_v10
127126

128-
# Notes on environment configuration:
129-
# max_cycles 100, max_zombies 4, seems to work well (13 points over 10 games)
130-
# max_cycles 900 (default) allowed the knights to get kills 1/10 games, but worse archer performance (6 points)
131-
env_kwargs = dict(max_cycles=100, max_zombies=4)
127+
# Set vector_state to false in order to use visual observations (significantly longer training time)
128+
env_kwargs = dict(max_cycles=100, max_zombies=4, vector_state=True)
132129

133130
# Train a model (takes ~5 minutes on a laptop CPU)
134131
train(env_fn, steps=81_920, seed=0, **env_kwargs)
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
pettingzoo[butterfly]>=1.23.1
22
stable-baselines3>=2.0.0
3-
supersuit>=3.8.1
3+
# TODO: update to SS release once it's released
4+
supersuit @ git+https://github.com/Farama-Foundation/SuperSuit.git
Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
pettingzoo[sisl]>=1.23.1
22
stable-baselines3>=2.0.0
3-
supersuit>=3.8.1
3+
# TODO: update to SS release once it's done
4+
supersuit @ git+https://github.com/Farama-Foundation/SuperSuit.git
5+
# TODO: remove pymunk requirement before 1.24.0 PZ release, as it will be added as a requirement to sisl
46
pymunk

0 commit comments

Comments
 (0)