11import io
22import pathlib
3- from typing import Any , Optional , Union
3+ from typing import Any
44
55import jax
66import numpy as np
@@ -21,35 +21,35 @@ class OffPolicyAlgorithmJax(OffPolicyAlgorithm):
2121 def __init__ (
2222 self ,
2323 policy : type [BasePolicy ],
24- env : Union [ GymEnv , str ] ,
25- learning_rate : Union [ float , Schedule ] ,
26- qf_learning_rate : Optional [ float ] = None ,
24+ env : GymEnv | str ,
25+ learning_rate : float | Schedule ,
26+ qf_learning_rate : float | None = None ,
2727 buffer_size : int = 1_000_000 , # 1e6
2828 learning_starts : int = 100 ,
2929 batch_size : int = 256 ,
3030 tau : float = 0.005 ,
3131 gamma : float = 0.99 ,
32- train_freq : Union [ int , tuple [int , str ] ] = (1 , "step" ),
32+ train_freq : int | tuple [int , str ] = (1 , "step" ),
3333 gradient_steps : int = 1 ,
34- action_noise : Optional [ ActionNoise ] = None ,
35- replay_buffer_class : Optional [ type [ReplayBuffer ]] = None ,
36- replay_buffer_kwargs : Optional [ dict [str , Any ]] = None ,
34+ action_noise : ActionNoise | None = None ,
35+ replay_buffer_class : type [ReplayBuffer ] | None = None ,
36+ replay_buffer_kwargs : dict [str , Any ] | None = None ,
3737 optimize_memory_usage : bool = False ,
3838 n_steps : int = 1 ,
39- policy_kwargs : Optional [ dict [str , Any ]] = None ,
40- tensorboard_log : Optional [ str ] = None ,
39+ policy_kwargs : dict [str , Any ] | None = None ,
40+ tensorboard_log : str | None = None ,
4141 verbose : int = 0 ,
4242 device : str = "auto" ,
4343 support_multi_env : bool = False ,
4444 monitor_wrapper : bool = True ,
45- seed : Optional [ int ] = None ,
45+ seed : int | None = None ,
4646 use_sde : bool = False ,
4747 sde_sample_freq : int = - 1 ,
4848 use_sde_at_warmup : bool = False ,
4949 sde_support : bool = True ,
5050 stats_window_size : int = 100 ,
51- param_resets : Optional [ list [int ]] = None ,
52- supported_action_spaces : Optional [ tuple [type [spaces .Space ], ...]] = None ,
51+ param_resets : list [int ] | None = None ,
52+ supported_action_spaces : tuple [type [spaces .Space ], ...] | None = None ,
5353 ):
5454 super ().__init__ (
5555 policy = policy ,
@@ -108,7 +108,7 @@ def _excluded_save_params(self) -> list[str]:
108108
109109 def _update_learning_rate ( # type: ignore[override]
110110 self ,
111- optimizers : Union [ list [optax .OptState ], optax .OptState ] ,
111+ optimizers : list [optax .OptState ] | optax .OptState ,
112112 learning_rate : float ,
113113 name : str = "learning_rate" ,
114114 ) -> None :
@@ -129,7 +129,7 @@ def _update_learning_rate( # type: ignore[override]
129129 # Note: the optimizer must have been defined with inject_hyperparams
130130 optimizer .hyperparams ["learning_rate" ] = learning_rate
131131
132- def set_random_seed (self , seed : Optional [ int ] ) -> None : # type: ignore[override]
132+ def set_random_seed (self , seed : int | None ) -> None : # type: ignore[override]
133133 super ().set_random_seed (seed )
134134 if seed is None :
135135 # Sample random seed
@@ -173,7 +173,7 @@ def _setup_model(self) -> None:
173173
174174 def load_replay_buffer (
175175 self ,
176- path : Union [ str , pathlib .Path , io .BufferedIOBase ] ,
176+ path : str | pathlib .Path | io .BufferedIOBase ,
177177 truncate_last_traj : bool = True ,
178178 ) -> None :
179179 super ().load_replay_buffer (path , truncate_last_traj )
0 commit comments