|
| 1 | +--- |
| 2 | +layout: "contents" |
| 3 | +title: Training Minigrid Environments |
| 4 | +firstpage: |
| 5 | +--- |
| 6 | + |
| 7 | +## Training Minigrid Environments |
| 8 | + |
| 9 | +The environments in the Minigrid library can be trained easily using [StableBaselines3](https://stable-baselines3.readthedocs.io/en/master/). In this tutorial we show how a PPO agent can be trained on the `MiniGrid-Empty-16x16-v0` environment. |
| 10 | + |
| 11 | +## Create Custom Feature Extractor |
| 12 | + |
| 13 | +Although `StableBaselines3` is fully compatible with `Gymnasium`-based environments, including Minigrid, the default CNN architecture does not directly support the Minigrid observation space. Thus, to train an agent on Minigrid environments, we therefore need to create a custom feature extractor. This can be done by creating a feature extractor class that inherits from `stable_baselines3.common.torch_layers.BaseFeaturesExtractor` |
| 14 | + |
| 15 | +```python |
| 16 | +class MinigridFeaturesExtractor(BaseFeaturesExtractor): |
| 17 | + def __init__(self, observation_space: gym.Space, features_dim: int = 512, normalized_image: bool = False) -> None: |
| 18 | + super().__init__(observation_space, features_dim) |
| 19 | + n_input_channels = observation_space.shape[0] |
| 20 | + self.cnn = nn.Sequential( |
| 21 | + nn.Conv2d(n_input_channels, 16, (2, 2)), |
| 22 | + nn.ReLU(), |
| 23 | + nn.Conv2d(16, 32, (2, 2)), |
| 24 | + nn.ReLU(), |
| 25 | + nn.Conv2d(32, 64, (2, 2)), |
| 26 | + nn.ReLU(), |
| 27 | + nn.Flatten(), |
| 28 | + ) |
| 29 | + |
| 30 | + # Compute shape by doing one forward pass |
| 31 | + with torch.no_grad(): |
| 32 | + n_flatten = self.cnn(torch.as_tensor(observation_space.sample()[None]).float()).shape[1] |
| 33 | + |
| 34 | + self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU()) |
| 35 | + |
| 36 | + def forward(self, observations: torch.Tensor) -> torch.Tensor: |
| 37 | + return self.linear(self.cnn(observations)) |
| 38 | +``` |
| 39 | + |
| 40 | +This class is created based on the custom feature extractor [documentation](https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html#custom-feature-extractor:~:text=Custom%20Feature%20Extractor-,%EF%83%81,-If%20you%20want), the CNN architecture is copied from Lucas Willems' [rl-starter-files](https://github.com/lcswillems/rl-starter-files/blob/317da04a9a6fb26506bbd7f6c7c7e10fc0de86e0/model.py#L18). |
| 41 | + |
| 42 | +## Train a PPO Agent |
| 43 | + |
| 44 | +The using the custom feature extractor, we can train a PPO agent on the `MiniGrid-Empty-16x16-v0` environment. The following code snippet shows how this can be done. |
| 45 | + |
| 46 | +```python |
| 47 | +import minigrid |
| 48 | +from minigrid.wrappers import ImgObsWrapper |
| 49 | +from stable_baselines3 import PPO |
| 50 | + |
| 51 | +policy_kwargs = dict( |
| 52 | + features_extractor_class=MinigridFeaturesExtractor, |
| 53 | + features_extractor_kwargs=dict(features_dim=128), |
| 54 | +) |
| 55 | + |
| 56 | +env = gym.make("MiniGrid-Empty-16x16-v0", render_mode="rgb_array") |
| 57 | +env = ImgObsWrapper(env) |
| 58 | + |
| 59 | +model = PPO("CnnPolicy", env, policy_kwargs=policy_kwargs, verbose=1) |
| 60 | +model.learn(2e5) |
| 61 | +``` |
| 62 | + |
| 63 | +By default the observation of Minigrid environments are dictionaries. Since the `CnnPolicy` from StableBaseline3 by default takes in image observations, we need to wrap the environment using the `ImgObsWrapper` from the Minigrid library. This wrapper converts the dictionary observation to an image observation. |
| 64 | + |
| 65 | +## Further Reading |
| 66 | + |
| 67 | +One can also pass dictionary observations to StableBaseline3 policies, for a walkthrough the process of doing so see [here](https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html#multiple-inputs-and-dictionary-observations). An implementation utilizing this functionality can be found [here](https://github.com/BolunDai0216/MinigridMiniworldTransfer/blob/main/minigrid_gotoobj_train.py). |
0 commit comments