Skip to content

Commit 9bd4bca

Browse files
authored
Add flatten layer and update dependencies (#18)
* Add flatten layer and update dependencies * Reformat
1 parent f662613 commit 9bd4bca

File tree

12 files changed

+62
-23
lines changed

12 files changed

+62
-23
lines changed

.github/workflows/ci.yml

+1-7
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,7 @@ jobs:
3232
run: |
3333
python -m pip install --upgrade pip
3434
# cpu version of pytorch
35-
pip install torch==1.13+cpu -f https://download.pytorch.org/whl/torch_stable.html
36-
37-
# # Install Atari Roms
38-
# pip install autorom
39-
# wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64
40-
# base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz
41-
# AutoROM --accept-license --source-file Roms.tar.gz
35+
pip install torch==2.1.0 --index-url https://download.pytorch.org/whl/cpu
4236
4337
pip install .[tests]
4438
# Use headless version

Makefile

+2-2
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,13 @@ lint:
1818

1919
format:
2020
# Sort imports
21-
isort ${LINT_PATHS}
21+
ruff --select I ${LINT_PATHS} --fix
2222
# Reformat using black
2323
black ${LINT_PATHS}
2424

2525
check-codestyle:
2626
# Sort imports
27-
isort --check ${LINT_PATHS}
27+
ruff --select I ${LINT_PATHS}
2828
# Reformat using black
2929
black --check ${LINT_PATHS}
3030

pyproject.toml

-5
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,6 @@ max-complexity = 15
1616
[tool.black]
1717
line-length = 127
1818

19-
[tool.isort]
20-
profile = "black"
21-
line_length = 127
22-
src_paths = ["sbx"]
23-
2419
[tool.mypy]
2520
ignore_missing_imports = true
2621
follow_imports = "silent"

sbx/common/policies.py

+12
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,26 @@
11
# import copy
22
from typing import Dict, Optional, Tuple, Union, no_type_check
33

4+
import flax.linen as nn
45
import jax
6+
import jax.numpy as jnp
57
import numpy as np
68
from gymnasium import spaces
79
from stable_baselines3.common.policies import BasePolicy
810
from stable_baselines3.common.preprocessing import is_image_space, maybe_transpose
911
from stable_baselines3.common.utils import is_vectorized_observation
1012

1113

14+
class Flatten(nn.Module):
15+
"""
16+
Equivalent to PyTorch nn.Flatten() layer.
17+
"""
18+
19+
@nn.compact
20+
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
21+
return x.reshape((x.shape[0], -1))
22+
23+
1224
class BaseJaxPolicy(BasePolicy):
1325
def __init__(self, *args, **kwargs):
1426
super().__init__(

sbx/dqn/policies.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from gymnasium import spaces
99
from stable_baselines3.common.type_aliases import Schedule
1010

11-
from sbx.common.policies import BaseJaxPolicy
11+
from sbx.common.policies import BaseJaxPolicy, Flatten
1212
from sbx.common.type_aliases import RLTrainState
1313

1414

@@ -18,6 +18,7 @@ class QNetwork(nn.Module):
1818

1919
@nn.compact
2020
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
21+
x = Flatten()(x)
2122
x = nn.Dense(self.n_units)(x)
2223
x = nn.relu(x)
2324
x = nn.Dense(self.n_units)(x)

sbx/ppo/policies.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from gymnasium import spaces
1313
from stable_baselines3.common.type_aliases import Schedule
1414

15-
from sbx.common.policies import BaseJaxPolicy
15+
from sbx.common.policies import BaseJaxPolicy, Flatten
1616

1717
tfp = tensorflow_probability.substrates.jax
1818
tfd = tfp.distributions
@@ -24,6 +24,7 @@ class Critic(nn.Module):
2424

2525
@nn.compact
2626
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
27+
x = Flatten()(x)
2728
x = nn.Dense(self.n_units)(x)
2829
x = self.activation_fn(x)
2930
x = nn.Dense(self.n_units)(x)
@@ -45,6 +46,7 @@ def get_std(self):
4546

4647
@nn.compact
4748
def __call__(self, x: jnp.ndarray) -> tfd.Distribution: # type: ignore[name-defined]
49+
x = Flatten()(x)
4850
x = nn.Dense(self.n_units)(x)
4951
x = self.activation_fn(x)
5052
x = nn.Dense(self.n_units)(x)

sbx/sac/policies.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from stable_baselines3.common.type_aliases import Schedule
1212

1313
from sbx.common.distributions import TanhTransformedDistribution
14-
from sbx.common.policies import BaseJaxPolicy
14+
from sbx.common.policies import BaseJaxPolicy, Flatten
1515
from sbx.common.type_aliases import RLTrainState
1616

1717
tfp = tensorflow_probability.substrates.jax
@@ -25,6 +25,7 @@ class Critic(nn.Module):
2525

2626
@nn.compact
2727
def __call__(self, x: jnp.ndarray, action: jnp.ndarray) -> jnp.ndarray:
28+
x = Flatten()(x)
2829
x = jnp.concatenate([x, action], -1)
2930
for n_units in self.net_arch:
3031
x = nn.Dense(n_units)(x)
@@ -75,6 +76,7 @@ def get_std(self):
7576

7677
@nn.compact
7778
def __call__(self, x: jnp.ndarray) -> tfd.Distribution: # type: ignore[name-defined]
79+
x = Flatten()(x)
7880
for n_units in self.net_arch:
7981
x = nn.Dense(n_units)(x)
8082
x = nn.relu(x)

sbx/td3/policies.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from gymnasium import spaces
99
from stable_baselines3.common.type_aliases import Schedule
1010

11-
from sbx.common.policies import BaseJaxPolicy
11+
from sbx.common.policies import BaseJaxPolicy, Flatten
1212
from sbx.common.type_aliases import RLTrainState
1313

1414

@@ -19,6 +19,7 @@ class Critic(nn.Module):
1919

2020
@nn.compact
2121
def __call__(self, x: jnp.ndarray, action: jnp.ndarray) -> jnp.ndarray:
22+
x = Flatten()(x)
2223
x = jnp.concatenate([x, action], -1)
2324
for n_units in self.net_arch:
2425
x = nn.Dense(n_units)(x)
@@ -63,6 +64,7 @@ class Actor(nn.Module):
6364

6465
@nn.compact
6566
def __call__(self, x: jnp.ndarray) -> jnp.ndarray: # type: ignore[name-defined]
67+
x = Flatten()(x)
6668
for n_units in self.net_arch:
6769
x = nn.Dense(n_units)(x)
6870
x = nn.relu(x)

sbx/tqc/policies.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from stable_baselines3.common.type_aliases import Schedule
1212

1313
from sbx.common.distributions import TanhTransformedDistribution
14-
from sbx.common.policies import BaseJaxPolicy
14+
from sbx.common.policies import BaseJaxPolicy, Flatten
1515
from sbx.common.type_aliases import RLTrainState
1616

1717
tfp = tensorflow_probability.substrates.jax
@@ -26,6 +26,7 @@ class Critic(nn.Module):
2626

2727
@nn.compact
2828
def __call__(self, x: jnp.ndarray, a: jnp.ndarray, training: bool = False) -> jnp.ndarray:
29+
x = Flatten()(x)
2930
x = jnp.concatenate([x, a], -1)
3031
for n_units in self.net_arch:
3132
x = nn.Dense(n_units)(x)
@@ -50,6 +51,7 @@ def get_std(self):
5051

5152
@nn.compact
5253
def __call__(self, x: jnp.ndarray) -> tfd.Distribution: # type: ignore[name-defined]
54+
x = Flatten()(x)
5355
for n_units in self.net_arch:
5456
x = nn.Dense(n_units)(x)
5557
x = nn.relu(x)

sbx/version.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
0.8.0
1+
0.9.0

setup.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
packages=[package for package in find_packages() if package.startswith("sbx")],
4040
package_data={"sbx": ["py.typed", "version.txt"]},
4141
install_requires=[
42-
"stable_baselines3>=2.1.0",
42+
"stable_baselines3>=2.2.0a9",
4343
"jax",
4444
"jaxlib",
4545
"flax",
@@ -59,8 +59,6 @@
5959
"mypy",
6060
# Lint code
6161
"ruff",
62-
# Sort imports
63-
"isort>=5.0",
6462
# Reformat
6563
"black",
6664
],

tests/test_flatten.py

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from dataclasses import dataclass
2+
from typing import Dict, Optional
3+
4+
import gymnasium as gym
5+
import numpy as np
6+
import pytest
7+
from gymnasium import spaces
8+
9+
from sbx import DQN, PPO, SAC, TD3, TQC
10+
11+
12+
@dataclass
13+
class DummyEnv(gym.Env):
14+
observation_space: spaces.Space
15+
action_space: spaces.Space
16+
17+
def step(self, action):
18+
return self.observation_space.sample(), 0.0, False, False, {}
19+
20+
def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None):
21+
if seed is not None:
22+
super().reset(seed=seed)
23+
return self.observation_space.sample(), {}
24+
25+
26+
@pytest.mark.parametrize("model_class", [DQN, PPO, SAC, TD3, TQC])
27+
def test_flatten(model_class) -> None:
28+
action_space = spaces.Discrete(15) if model_class == DQN else spaces.Box(-1, 1, shape=(2,), dtype=np.float32)
29+
env = DummyEnv(spaces.Box(-1, 1, shape=(2, 1), dtype=np.float32), action_space)
30+
31+
model_class("MlpPolicy", env).learn(150)

0 commit comments

Comments
 (0)