Skip to content

Commit 1d2da71

Browse files
authored
Add Python 3.13 support, drop Python 3.9 (#83)
* Drop Python 3.9 support * Apply autofixes for Python 3.10 * Reformat
1 parent 3e9e66d commit 1d2da71

26 files changed

+199
-201
lines changed

.github/workflows/ci.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,12 @@ jobs:
2020
runs-on: ubuntu-latest
2121
strategy:
2222
matrix:
23-
python-version: ["3.9", "3.10", "3.11", "3.12"]
23+
python-version: ["3.10", "3.11", "3.12", "3.13"]
2424

2525
steps:
26-
- uses: actions/checkout@v3
26+
- uses: actions/checkout@v6
2727
- name: Set up Python ${{ matrix.python-version }}
28-
uses: actions/setup-python@v4
28+
uses: actions/setup-python@v6
2929
with:
3030
python-version: ${{ matrix.python-version }}
3131
- name: Install dependencies
@@ -35,7 +35,7 @@ jobs:
3535
pip install uv
3636
# cpu version of pytorch
3737
# See https://github.com/astral-sh/uv/issues/1497
38-
uv pip install --system torch==2.3.1+cpu --index https://download.pytorch.org/whl/cpu
38+
uv pip install --system torch==2.9.1+cpu --index https://download.pytorch.org/whl/cpu
3939
4040
uv pip install --system .[tests]
4141
# Use headless version

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
[tool.ruff]
22
# Same as Black.
33
line-length = 127
4-
# Assume Python 3.9
5-
target-version = "py39"
4+
# Assume Python 3.10
5+
target-version = "py310"
66

77
[tool.ruff.lint]
88
# See https://beta.ruff.rs/docs/rules/

sbx/common/distributions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Optional
1+
from typing import Any
22

33
import jax.numpy as jnp
44
import tensorflow_probability.substrates.jax as tfp
@@ -19,7 +19,7 @@ def mode(self) -> jnp.ndarray:
1919
return self.bijector.forward(self.distribution.mode())
2020

2121
@classmethod
22-
def _parameter_properties(cls, dtype: Optional[Any], num_classes=None):
22+
def _parameter_properties(cls, dtype: Any | None, num_classes=None):
2323
td_properties = super()._parameter_properties(dtype, num_classes=num_classes)
2424
del td_properties["bijector"]
2525
return td_properties

sbx/common/jax_layers.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
from collections.abc import Sequence
2-
from typing import Any, Callable, Optional, Union
1+
from collections.abc import Callable, Sequence
2+
from typing import Any, Union
33

44
import flax.linen as nn
55
import jax
@@ -12,7 +12,7 @@
1212
Array = Any
1313
Shape = tuple[int, ...]
1414
Dtype = Any # this could be a real type?
15-
Axes = Union[int, Sequence[int]]
15+
Axes = Union[int, Sequence[int]] # noqa: UP007
1616

1717

1818
class BatchRenorm(Module):
@@ -78,26 +78,26 @@ class BatchRenorm(Module):
7878
calculation for the variance.
7979
"""
8080

81-
use_running_average: Optional[bool] = None
81+
use_running_average: bool | None = None
8282
axis: int = -1
8383
momentum: float = 0.99
8484
epsilon: float = 0.001
8585
warmup_steps: int = 100_000
86-
dtype: Optional[Dtype] = None
86+
dtype: Dtype | None = None
8787
param_dtype: Dtype = jnp.float32
8888
use_bias: bool = True
8989
use_scale: bool = True
9090
bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.zeros
9191
scale_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.ones
92-
axis_name: Optional[str] = None
92+
axis_name: str | None = None
9393
axis_index_groups: Any = None
9494
# This parameter was added in flax.linen 0.7.2 (08/2023)
9595
# commented out to be compatible with a wider range of jax versions
9696
# TODO: re-activate in some months (04/2024)
9797
# use_fast_variance: bool = True
9898

9999
@compact
100-
def __call__(self, x, use_running_average: Optional[bool] = None):
100+
def __call__(self, x, use_running_average: bool | None = None):
101101
"""Normalizes the input using batch statistics.
102102
103103
NOTE:

sbx/common/off_policy_algorithm.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import io
22
import pathlib
3-
from typing import Any, Optional, Union
3+
from typing import Any
44

55
import jax
66
import numpy as np
@@ -21,35 +21,35 @@ class OffPolicyAlgorithmJax(OffPolicyAlgorithm):
2121
def __init__(
2222
self,
2323
policy: type[BasePolicy],
24-
env: Union[GymEnv, str],
25-
learning_rate: Union[float, Schedule],
26-
qf_learning_rate: Optional[float] = None,
24+
env: GymEnv | str,
25+
learning_rate: float | Schedule,
26+
qf_learning_rate: float | None = None,
2727
buffer_size: int = 1_000_000, # 1e6
2828
learning_starts: int = 100,
2929
batch_size: int = 256,
3030
tau: float = 0.005,
3131
gamma: float = 0.99,
32-
train_freq: Union[int, tuple[int, str]] = (1, "step"),
32+
train_freq: int | tuple[int, str] = (1, "step"),
3333
gradient_steps: int = 1,
34-
action_noise: Optional[ActionNoise] = None,
35-
replay_buffer_class: Optional[type[ReplayBuffer]] = None,
36-
replay_buffer_kwargs: Optional[dict[str, Any]] = None,
34+
action_noise: ActionNoise | None = None,
35+
replay_buffer_class: type[ReplayBuffer] | None = None,
36+
replay_buffer_kwargs: dict[str, Any] | None = None,
3737
optimize_memory_usage: bool = False,
3838
n_steps: int = 1,
39-
policy_kwargs: Optional[dict[str, Any]] = None,
40-
tensorboard_log: Optional[str] = None,
39+
policy_kwargs: dict[str, Any] | None = None,
40+
tensorboard_log: str | None = None,
4141
verbose: int = 0,
4242
device: str = "auto",
4343
support_multi_env: bool = False,
4444
monitor_wrapper: bool = True,
45-
seed: Optional[int] = None,
45+
seed: int | None = None,
4646
use_sde: bool = False,
4747
sde_sample_freq: int = -1,
4848
use_sde_at_warmup: bool = False,
4949
sde_support: bool = True,
5050
stats_window_size: int = 100,
51-
param_resets: Optional[list[int]] = None,
52-
supported_action_spaces: Optional[tuple[type[spaces.Space], ...]] = None,
51+
param_resets: list[int] | None = None,
52+
supported_action_spaces: tuple[type[spaces.Space], ...] | None = None,
5353
):
5454
super().__init__(
5555
policy=policy,
@@ -108,7 +108,7 @@ def _excluded_save_params(self) -> list[str]:
108108

109109
def _update_learning_rate( # type: ignore[override]
110110
self,
111-
optimizers: Union[list[optax.OptState], optax.OptState],
111+
optimizers: list[optax.OptState] | optax.OptState,
112112
learning_rate: float,
113113
name: str = "learning_rate",
114114
) -> None:
@@ -129,7 +129,7 @@ def _update_learning_rate( # type: ignore[override]
129129
# Note: the optimizer must have been defined with inject_hyperparams
130130
optimizer.hyperparams["learning_rate"] = learning_rate
131131

132-
def set_random_seed(self, seed: Optional[int]) -> None: # type: ignore[override]
132+
def set_random_seed(self, seed: int | None) -> None: # type: ignore[override]
133133
super().set_random_seed(seed)
134134
if seed is None:
135135
# Sample random seed
@@ -173,7 +173,7 @@ def _setup_model(self) -> None:
173173

174174
def load_replay_buffer(
175175
self,
176-
path: Union[str, pathlib.Path, io.BufferedIOBase],
176+
path: str | pathlib.Path | io.BufferedIOBase,
177177
truncate_last_traj: bool = True,
178178
) -> None:
179179
super().load_replay_buffer(path, truncate_last_traj)

sbx/common/on_policy_algorithm.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Optional, TypeVar, Union
1+
from typing import Any, TypeVar
22

33
import gymnasium as gym
44
import jax
@@ -25,9 +25,9 @@ class OnPolicyAlgorithmJax(OnPolicyAlgorithm):
2525

2626
def __init__(
2727
self,
28-
policy: Union[str, type[BasePolicy]],
29-
env: Union[GymEnv, str],
30-
learning_rate: Union[float, Schedule],
28+
policy: str | type[BasePolicy],
29+
env: GymEnv | str,
30+
learning_rate: float | Schedule,
3131
n_steps: int,
3232
gamma: float,
3333
gae_lambda: float,
@@ -36,14 +36,14 @@ def __init__(
3636
max_grad_norm: float,
3737
use_sde: bool,
3838
sde_sample_freq: int,
39-
tensorboard_log: Optional[str] = None,
39+
tensorboard_log: str | None = None,
4040
monitor_wrapper: bool = True,
41-
policy_kwargs: Optional[dict[str, Any]] = None,
41+
policy_kwargs: dict[str, Any] | None = None,
4242
verbose: int = 0,
43-
seed: Optional[int] = None,
43+
seed: int | None = None,
4444
device: str = "auto",
4545
_init_setup_model: bool = True,
46-
supported_action_spaces: Optional[tuple[type[spaces.Space], ...]] = None,
46+
supported_action_spaces: tuple[type[spaces.Space], ...] | None = None,
4747
):
4848
super().__init__(
4949
policy=policy, # type: ignore[arg-type]
@@ -78,7 +78,7 @@ def _excluded_save_params(self) -> list[str]:
7878

7979
def _update_learning_rate( # type: ignore[override]
8080
self,
81-
optimizers: Union[list[optax.OptState], optax.OptState],
81+
optimizers: list[optax.OptState] | optax.OptState,
8282
learning_rate: float,
8383
) -> None:
8484
"""
@@ -97,7 +97,7 @@ def _update_learning_rate( # type: ignore[override]
9797
# Note: the optimizer must have been defined with inject_hyperparams
9898
optimizer.hyperparams["learning_rate"] = learning_rate
9999

100-
def set_random_seed(self, seed: Optional[int]) -> None: # type: ignore[override]
100+
def set_random_seed(self, seed: int | None) -> None: # type: ignore[override]
101101
super().set_random_seed(seed)
102102
if seed is None:
103103
# Sample random seed

sbx/common/policies.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# import copy
2-
from collections.abc import Sequence
3-
from typing import Callable, Optional, Union, no_type_check
2+
from collections.abc import Callable, Sequence
3+
from typing import no_type_check
44

55
import flax.linen as nn
66
import jax
@@ -50,11 +50,11 @@ def select_action(actor_state, observations):
5050
@no_type_check
5151
def predict(
5252
self,
53-
observation: Union[np.ndarray, dict[str, np.ndarray]],
54-
state: Optional[tuple[np.ndarray, ...]] = None,
55-
episode_start: Optional[np.ndarray] = None,
53+
observation: np.ndarray | dict[str, np.ndarray],
54+
state: tuple[np.ndarray, ...] | None = None,
55+
episode_start: np.ndarray | None = None,
5656
deterministic: bool = False,
57-
) -> tuple[np.ndarray, Optional[tuple[np.ndarray, ...]]]:
57+
) -> tuple[np.ndarray, tuple[np.ndarray, ...] | None]:
5858
# self.set_training_mode(False)
5959

6060
observation, vectorized_env = self.prepare_obs(observation)
@@ -81,7 +81,7 @@ def predict(
8181

8282
return actions, state
8383

84-
def prepare_obs(self, observation: Union[np.ndarray, dict[str, np.ndarray]]) -> tuple[np.ndarray, bool]:
84+
def prepare_obs(self, observation: np.ndarray | dict[str, np.ndarray]) -> tuple[np.ndarray, bool]:
8585
vectorized_env = False
8686
if isinstance(observation, dict):
8787
assert isinstance(self.observation_space, spaces.Dict)
@@ -132,7 +132,7 @@ def set_training_mode(self, mode: bool) -> None:
132132
class ContinuousCritic(nn.Module):
133133
net_arch: Sequence[int]
134134
use_layer_norm: bool = False
135-
dropout_rate: Optional[float] = None
135+
dropout_rate: float | None = None
136136
activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu
137137
output_dim: int = 1
138138

@@ -154,7 +154,7 @@ def __call__(self, x: jnp.ndarray, action: jnp.ndarray) -> jnp.ndarray:
154154
class SimbaContinuousCritic(nn.Module):
155155
net_arch: Sequence[int]
156156
use_layer_norm: bool = False # for consistency, not used
157-
dropout_rate: Optional[float] = None
157+
dropout_rate: float | None = None
158158
activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu
159159
output_dim: int = 1
160160
scale_factor: int = 4
@@ -179,7 +179,7 @@ def __call__(self, x: jnp.ndarray, action: jnp.ndarray) -> jnp.ndarray:
179179
class VectorCritic(nn.Module):
180180
net_arch: Sequence[int]
181181
use_layer_norm: bool = False
182-
dropout_rate: Optional[float] = None
182+
dropout_rate: float | None = None
183183
n_critics: int = 2
184184
activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu
185185
output_dim: int = 1
@@ -210,7 +210,7 @@ class SimbaVectorCritic(nn.Module):
210210
net_arch: Sequence[int]
211211
# Note: we have use_layer_norm for consistency but it is not used (always on)
212212
use_layer_norm: bool = True
213-
dropout_rate: Optional[float] = None
213+
dropout_rate: float | None = None
214214
n_critics: int = 2
215215
activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu
216216
output_dim: int = 1

sbx/common/utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from dataclasses import dataclass
2-
from typing import Union
32

43
import jax
54
import jax.numpy as jnp
@@ -37,7 +36,7 @@ def mask_from_prefix(params: FrozenDict, prefix: str = "NatureCNN_") -> dict:
3736
if the top-level module name starts with `prefix`.
3837
"""
3938

40-
def _traverse(tree: FrozenDict, path: tuple[str, ...] = ()) -> Union[dict, bool]:
39+
def _traverse(tree: FrozenDict, path: tuple[str, ...] = ()) -> dict | bool:
4140
if isinstance(tree, dict):
4241
return {key: _traverse(value, (*path, key)) for key, value in tree.items()}
4342
# leaf

sbx/crossq/crossq.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from functools import partial
2-
from typing import Any, ClassVar, Literal, Optional, Union
2+
from typing import Any, ClassVar, Literal
33

44
import flax
55
import flax.linen as nn
@@ -53,31 +53,31 @@ class CrossQ(OffPolicyAlgorithmJax):
5353
def __init__(
5454
self,
5555
policy,
56-
env: Union[GymEnv, str],
57-
learning_rate: Union[float, Schedule] = 1e-3,
58-
qf_learning_rate: Optional[float] = None,
56+
env: GymEnv | str,
57+
learning_rate: float | Schedule = 1e-3,
58+
qf_learning_rate: float | None = None,
5959
buffer_size: int = 1_000_000, # 1e6
6060
learning_starts: int = 100,
6161
batch_size: int = 256,
6262
gamma: float = 0.99,
63-
train_freq: Union[int, tuple[int, str]] = 1,
63+
train_freq: int | tuple[int, str] = 1,
6464
gradient_steps: int = 1,
6565
policy_delay: int = 3,
66-
action_noise: Optional[ActionNoise] = None,
67-
replay_buffer_class: Optional[type[ReplayBuffer]] = None,
68-
replay_buffer_kwargs: Optional[dict[str, Any]] = None,
66+
action_noise: ActionNoise | None = None,
67+
replay_buffer_class: type[ReplayBuffer] | None = None,
68+
replay_buffer_kwargs: dict[str, Any] | None = None,
6969
n_steps: int = 1,
70-
ent_coef: Union[str, float] = "auto",
71-
target_entropy: Union[Literal["auto"], float] = "auto",
70+
ent_coef: str | float = "auto",
71+
target_entropy: Literal["auto"] | float = "auto",
7272
use_sde: bool = False,
7373
sde_sample_freq: int = -1,
7474
use_sde_at_warmup: bool = False,
7575
stats_window_size: int = 100,
76-
tensorboard_log: Optional[str] = None,
77-
policy_kwargs: Optional[dict[str, Any]] = None,
78-
param_resets: Optional[list[int]] = None, # List of timesteps after which to reset the params
76+
tensorboard_log: str | None = None,
77+
policy_kwargs: dict[str, Any] | None = None,
78+
param_resets: list[int] | None = None, # List of timesteps after which to reset the params
7979
verbose: int = 0,
80-
seed: Optional[int] = None,
80+
seed: int | None = None,
8181
device: str = "auto",
8282
_init_setup_model: bool = True,
8383
) -> None:

0 commit comments

Comments
 (0)