Skip to content

Commit 546c040

Browse files
authored
Add Training Tutorial and CITATION.cff (#379)
1 parent e103104 commit 546c040

File tree

3 files changed

+119
-0
lines changed

3 files changed

+119
-0
lines changed

CITATION.cff

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
cff-version: 1.2.0
2+
title: Minigrid
3+
message: >-
4+
If you use this software, please cite it using the
5+
metadata from this file.
6+
authors:
7+
- family-names: Chevalier-Boisvert
8+
given-names: Maxime
9+
- family-names: Dai
10+
given-names: Bolun
11+
- family-names: Towers
12+
given-names: Mark
13+
- family-names: de Lazcano
14+
given-names: Rodrigo
15+
- family-names: Willems
16+
given-names: Lucas
17+
- family-names: Lahlou
18+
given-names: Salem
19+
- family-names: Pal
20+
given-names: Suman
21+
- family-names: Castro
22+
given-names: Pablo Samuel
23+
- family-names: Terry
24+
given-names: Jordan
25+
url: "https://github.com/Farama-Foundation/Minigrid"
26+
27+
preferred-citation:
28+
type: article
29+
authors:
30+
- family-names: Chevalier-Boisvert
31+
given-names: Maxime
32+
- family-names: Dai
33+
given-names: Bolun
34+
- family-names: Towers
35+
given-names: Mark
36+
- family-names: de Lazcano
37+
given-names: Rodrigo
38+
- family-names: Willems
39+
given-names: Lucas
40+
- family-names: Lahlou
41+
given-names: Salem
42+
- family-names: Pal
43+
given-names: Suman
44+
- family-names: Castro
45+
given-names: Pablo Samuel
46+
- family-names: Terry
47+
given-names: Jordan
48+
journal: CoRR
49+
title: Minigrid
50+
volume: abs/2306.13831
51+
year: 2023

docs/content/training.md

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
---
2+
layout: "contents"
3+
title: Training Minigrid Environments
4+
firstpage:
5+
---
6+
7+
## Training Minigrid Environments
8+
9+
The environments in the Minigrid library can be trained easily using [StableBaselines3](https://stable-baselines3.readthedocs.io/en/master/). In this tutorial we show how a PPO agent can be trained on the `MiniGrid-Empty-16x16-v0` environment.
10+
11+
## Create Custom Feature Extractor
12+
13+
Although `StableBaselines3` is fully compatible with `Gymnasium`-based environments, including Minigrid, the default CNN architecture does not directly support the Minigrid observation space. Thus, to train an agent on Minigrid environments, we therefore need to create a custom feature extractor. This can be done by creating a feature extractor class that inherits from `stable_baselines3.common.torch_layers.BaseFeaturesExtractor`
14+
15+
```python
16+
class MinigridFeaturesExtractor(BaseFeaturesExtractor):
17+
def __init__(self, observation_space: gym.Space, features_dim: int = 512, normalized_image: bool = False) -> None:
18+
super().__init__(observation_space, features_dim)
19+
n_input_channels = observation_space.shape[0]
20+
self.cnn = nn.Sequential(
21+
nn.Conv2d(n_input_channels, 16, (2, 2)),
22+
nn.ReLU(),
23+
nn.Conv2d(16, 32, (2, 2)),
24+
nn.ReLU(),
25+
nn.Conv2d(32, 64, (2, 2)),
26+
nn.ReLU(),
27+
nn.Flatten(),
28+
)
29+
30+
# Compute shape by doing one forward pass
31+
with torch.no_grad():
32+
n_flatten = self.cnn(torch.as_tensor(observation_space.sample()[None]).float()).shape[1]
33+
34+
self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU())
35+
36+
def forward(self, observations: torch.Tensor) -> torch.Tensor:
37+
return self.linear(self.cnn(observations))
38+
```
39+
40+
This class is created based on the custom feature extractor [documentation](https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html#custom-feature-extractor:~:text=Custom%20Feature%20Extractor-,%EF%83%81,-If%20you%20want), the CNN architecture is copied from Lucas Willems' [rl-starter-files](https://github.com/lcswillems/rl-starter-files/blob/317da04a9a6fb26506bbd7f6c7c7e10fc0de86e0/model.py#L18).
41+
42+
## Train a PPO Agent
43+
44+
The using the custom feature extractor, we can train a PPO agent on the `MiniGrid-Empty-16x16-v0` environment. The following code snippet shows how this can be done.
45+
46+
```python
47+
import minigrid
48+
from minigrid.wrappers import ImgObsWrapper
49+
from stable_baselines3 import PPO
50+
51+
policy_kwargs = dict(
52+
features_extractor_class=MinigridFeaturesExtractor,
53+
features_extractor_kwargs=dict(features_dim=128),
54+
)
55+
56+
env = gym.make("MiniGrid-Empty-16x16-v0", render_mode="rgb_array")
57+
env = ImgObsWrapper(env)
58+
59+
model = PPO("CnnPolicy", env, policy_kwargs=policy_kwargs, verbose=1)
60+
model.learn(2e5)
61+
```
62+
63+
By default the observation of Minigrid environments are dictionaries. Since the `CnnPolicy` from StableBaseline3 by default takes in image observations, we need to wrap the environment using the `ImgObsWrapper` from the Minigrid library. This wrapper converts the dictionary observation to an image observation.
64+
65+
## Further Reading
66+
67+
One can also pass dictionary observations to StableBaseline3 policies, for a walkthrough the process of doing so see [here](https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html#multiple-inputs-and-dictionary-observations). An implementation utilizing this functionality can be found [here](https://github.com/BolunDai0216/MinigridMiniworldTransfer/blob/main/minigrid_gotoobj_train.py).

docs/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ To cite this project please use:
5757
content/basic_usage
5858
content/publications
5959
content/create_env_tutorial
60+
content/training
6061
```
6162

6263
```{toctree}

0 commit comments

Comments
 (0)