-
Notifications
You must be signed in to change notification settings - Fork 2k
Add support for pre and post linear modules in create_mlp
#1975
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
| @@ -1,4 +1,4 @@ | ||||||||
| from typing import Dict, List, Tuple, Type, Union | ||||||||
| from typing import Dict, List, Optional, Tuple, Type, Union | ||||||||
|
|
||||||||
| import gymnasium as gym | ||||||||
| import torch as th | ||||||||
|
|
@@ -14,7 +14,7 @@ class BaseFeaturesExtractor(nn.Module): | |||||||
| """ | ||||||||
| Base class that represents a features extractor. | ||||||||
|
|
||||||||
| :param observation_space: | ||||||||
| :param observation_space: The observation space of the environment | ||||||||
| :param features_dim: Number of features extracted. | ||||||||
| """ | ||||||||
|
|
||||||||
|
|
@@ -26,6 +26,7 @@ def __init__(self, observation_space: gym.Space, features_dim: int = 0) -> None: | |||||||
|
|
||||||||
| @property | ||||||||
| def features_dim(self) -> int: | ||||||||
| """The number of features that the extractor outputs.""" | ||||||||
| return self._features_dim | ||||||||
|
|
||||||||
|
|
||||||||
|
|
@@ -34,7 +35,7 @@ class FlattenExtractor(BaseFeaturesExtractor): | |||||||
| Feature extract that flatten the input. | ||||||||
| Used as a placeholder when feature extraction is not needed. | ||||||||
|
|
||||||||
| :param observation_space: | ||||||||
| :param observation_space: The observation space of the environment | ||||||||
| """ | ||||||||
|
|
||||||||
| def __init__(self, observation_space: gym.Space) -> None: | ||||||||
|
|
@@ -52,7 +53,7 @@ class NatureCNN(BaseFeaturesExtractor): | |||||||
| "Human-level control through deep reinforcement learning." | ||||||||
| Nature 518.7540 (2015): 529-533. | ||||||||
|
|
||||||||
| :param observation_space: | ||||||||
| :param observation_space: The observation space of the environment | ||||||||
| :param features_dim: Number of features extracted. | ||||||||
| This corresponds to the number of unit for the last layer. | ||||||||
| :param normalized_image: Whether to assume that the image is already normalized | ||||||||
|
|
@@ -113,13 +114,15 @@ def create_mlp( | |||||||
| activation_fn: Type[nn.Module] = nn.ReLU, | ||||||||
| squash_output: bool = False, | ||||||||
| with_bias: bool = True, | ||||||||
| pre_linear_modules: Optional[List[Type[nn.Module]]] = None, | ||||||||
| post_linear_modules: Optional[List[Type[nn.Module]]] = None, | ||||||||
| ) -> List[nn.Module]: | ||||||||
| """ | ||||||||
| Create a multi layer perceptron (MLP), which is | ||||||||
| a collection of fully-connected layers each followed by an activation function. | ||||||||
|
|
||||||||
| :param input_dim: Dimension of the input vector | ||||||||
| :param output_dim: | ||||||||
| :param output_dim: Dimension of the output (last layer, for instance, the number of actions) | ||||||||
| :param net_arch: Architecture of the neural net | ||||||||
| It represents the number of units per layer. | ||||||||
| The length of this list is the number of layers. | ||||||||
|
|
@@ -128,20 +131,48 @@ def create_mlp( | |||||||
| :param squash_output: Whether to squash the output using a Tanh | ||||||||
| activation function | ||||||||
| :param with_bias: If set to False, the layers will not learn an additive bias | ||||||||
| :return: | ||||||||
| :param pre_linear_modules: List of nn.Module to add before the linear layers, | ||||||||
| for instance, BatchNorm layers. | ||||||||
| Compared to post_linear_modules, they are used before the output layer (output_dim > 0). | ||||||||
| The number of input features is passed to the module's constructor. | ||||||||
| :param post_linear_modules: List of nn.Module to add after the linear layers (and before the activation function), | ||||||||
| for instance, Dropout or LayerNorm layers. | ||||||||
| They are not used after the output layer (output_dim > 0). | ||||||||
| The number of input features is passed to the module's constructor. | ||||||||
| :return: The list of layers of the neural network | ||||||||
| """ | ||||||||
|
|
||||||||
| pre_linear_modules = pre_linear_modules or [] | ||||||||
| post_linear_modules = post_linear_modules or [] | ||||||||
|
|
||||||||
| modules = [] | ||||||||
| if len(net_arch) > 0: | ||||||||
| modules = [nn.Linear(input_dim, net_arch[0], bias=with_bias), activation_fn()] | ||||||||
| else: | ||||||||
| modules = [] | ||||||||
| for module in pre_linear_modules: | ||||||||
| modules.append(module(input_dim)) | ||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is the input dim the same for all modules?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think I get it. It only allows modules that have the same input/output dimension.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe it is clearer with this test: stable-baselines3/tests/test_custom_policy.py Lines 102 to 104 in 3b84f71
shall I add a comment to avoid confusion?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes! I've also suggested a new documentation: https://github.com/DLR-RM/stable-baselines3/pull/1975/files#r1686213647 |
||||||||
|
|
||||||||
| modules.append(nn.Linear(input_dim, net_arch[0], bias=with_bias)) | ||||||||
|
|
||||||||
| for module in post_linear_modules: | ||||||||
| modules.append(module(net_arch[0])) | ||||||||
|
|
||||||||
| modules.append(activation_fn()) | ||||||||
|
|
||||||||
| for idx in range(len(net_arch) - 1): | ||||||||
| for module in pre_linear_modules: | ||||||||
| modules.append(module(net_arch[idx])) | ||||||||
|
|
||||||||
| modules.append(nn.Linear(net_arch[idx], net_arch[idx + 1], bias=with_bias)) | ||||||||
|
|
||||||||
| for module in post_linear_modules: | ||||||||
| modules.append(module(net_arch[idx + 1])) | ||||||||
|
|
||||||||
| modules.append(activation_fn()) | ||||||||
|
|
||||||||
| if output_dim > 0: | ||||||||
| last_layer_dim = net_arch[-1] if len(net_arch) > 0 else input_dim | ||||||||
| for module in pre_linear_modules: | ||||||||
| modules.append(module(last_layer_dim)) | ||||||||
|
|
||||||||
| modules.append(nn.Linear(last_layer_dim, output_dim, bias=with_bias)) | ||||||||
| if squash_output: | ||||||||
| modules.append(nn.Tanh()) | ||||||||
|
|
||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1 +1 @@ | ||
| 2.4.0a5 | ||
| 2.4.0a6 |
Uh oh!
There was an error while loading. Please reload this page.