77from abc import ABC , abstractmethod
88from collections import deque
99from collections .abc import Iterable
10- from typing import Any , ClassVar , Optional , TypeVar , Union
10+ from typing import Any , ClassVar , TypeVar
1111
1212import gymnasium as gym
1313import numpy as np
4545SelfBaseAlgorithm = TypeVar ("SelfBaseAlgorithm" , bound = "BaseAlgorithm" )
4646
4747
48- def maybe_make_env (env : Union [ GymEnv , str ] , verbose : int ) -> GymEnv :
48+ def maybe_make_env (env : GymEnv | str , verbose : int ) -> GymEnv :
4949 """If env is a string, make the environment; otherwise, return env.
5050
5151 :param env: The environment to learn from.
@@ -105,20 +105,20 @@ class BaseAlgorithm(ABC):
105105
106106 def __init__ (
107107 self ,
108- policy : Union [ str , type [BasePolicy ] ],
109- env : Union [ GymEnv , str , None ] ,
110- learning_rate : Union [ float , Schedule ] ,
111- policy_kwargs : Optional [ dict [str , Any ]] = None ,
108+ policy : str | type [BasePolicy ],
109+ env : GymEnv | str | None ,
110+ learning_rate : float | Schedule ,
111+ policy_kwargs : dict [str , Any ] | None = None ,
112112 stats_window_size : int = 100 ,
113- tensorboard_log : Optional [ str ] = None ,
113+ tensorboard_log : str | None = None ,
114114 verbose : int = 0 ,
115- device : Union [ th .device , str ] = "auto" ,
115+ device : th .device | str = "auto" ,
116116 support_multi_env : bool = False ,
117117 monitor_wrapper : bool = True ,
118- seed : Optional [ int ] = None ,
118+ seed : int | None = None ,
119119 use_sde : bool = False ,
120120 sde_sample_freq : int = - 1 ,
121- supported_action_spaces : Optional [ tuple [type [spaces .Space ], ...]] = None ,
121+ supported_action_spaces : tuple [type [spaces .Space ], ...] | None = None ,
122122 ) -> None :
123123 if isinstance (policy , str ):
124124 self .policy_class = self ._get_policy_from_name (policy )
@@ -138,14 +138,14 @@ def __init__(
138138 # Used for computing fps, it is updated at each call of learn()
139139 self ._num_timesteps_at_start = 0
140140 self .seed = seed
141- self .action_noise : Optional [ ActionNoise ] = None
141+ self .action_noise : ActionNoise | None = None
142142 self .start_time = 0.0
143143 self .learning_rate = learning_rate
144144 self .tensorboard_log = tensorboard_log
145- self ._last_obs = None # type: Optional[Union[ np.ndarray, dict[str, np.ndarray]]]
146- self ._last_episode_starts = None # type: Optional[ np.ndarray]
145+ self ._last_obs = None # type: np.ndarray | dict[str, np.ndarray] | None
146+ self ._last_episode_starts = None # type: np.ndarray | None
147147 # When using VecNormalize:
148- self ._last_original_obs = None # type: Optional[Union[ np.ndarray, dict[str, np.ndarray]]]
148+ self ._last_original_obs = None # type: np.ndarray | dict[str, np.ndarray] | None
149149 self ._episode_num = 0
150150 # Used for gSDE only
151151 self .use_sde = use_sde
@@ -155,14 +155,14 @@ def __init__(
155155 self ._current_progress_remaining = 1.0
156156 # Buffers for logging
157157 self ._stats_window_size = stats_window_size
158- self .ep_info_buffer = None # type: Optional[ deque]
159- self .ep_success_buffer = None # type: Optional[ deque]
158+ self .ep_info_buffer = None # type: deque | None
159+ self .ep_success_buffer = None # type: deque | None
160160 # For logging (and TD3 delayed updates)
161161 self ._n_updates = 0 # type: int
162162 # Whether the user passed a custom logger or not
163163 self ._custom_logger = False
164- self .env : Optional [ VecEnv ] = None
165- self ._vec_normalize_env : Optional [ VecNormalize ] = None
164+ self .env : VecEnv | None = None
165+ self ._vec_normalize_env : VecNormalize | None = None
166166
167167 # Create and wrap the env if needed
168168 if env is not None :
@@ -284,7 +284,7 @@ def _update_current_progress_remaining(self, num_timesteps: int, total_timesteps
284284 """
285285 self ._current_progress_remaining = 1.0 - float (num_timesteps ) / float (total_timesteps )
286286
287- def _update_learning_rate (self , optimizers : Union [ list [th .optim .Optimizer ], th .optim .Optimizer ] ) -> None :
287+ def _update_learning_rate (self , optimizers : list [th .optim .Optimizer ] | th .optim .Optimizer ) -> None :
288288 """
289289 Update the optimizers learning rate using the current learning rate schedule
290290 and the current progress remaining (from 1 to 0).
@@ -435,7 +435,7 @@ def _setup_learn(
435435
436436 return total_timesteps , callback
437437
438- def _update_info_buffer (self , infos : list [dict [str , Any ]], dones : Optional [ np .ndarray ] = None ) -> None :
438+ def _update_info_buffer (self , infos : list [dict [str , Any ]], dones : np .ndarray | None = None ) -> None :
439439 """
440440 Retrieve reward, episode length, episode success and update the buffer
441441 if using Monitor wrapper or a GoalEnv.
@@ -456,15 +456,15 @@ def _update_info_buffer(self, infos: list[dict[str, Any]], dones: Optional[np.nd
456456 if maybe_is_success is not None and dones [idx ]:
457457 self .ep_success_buffer .append (maybe_is_success )
458458
459- def get_env (self ) -> Optional [ VecEnv ] :
459+ def get_env (self ) -> VecEnv | None :
460460 """
461461 Returns the current environment (can be None if not defined).
462462
463463 :return: The current environment
464464 """
465465 return self .env
466466
467- def get_vec_normalize_env (self ) -> Optional [ VecNormalize ] :
467+ def get_vec_normalize_env (self ) -> VecNormalize | None :
468468 """
469469 Return the ``VecNormalize`` wrapper of the training env
470470 if it exists.
@@ -536,11 +536,11 @@ def learn(
536536
537537 def predict (
538538 self ,
539- observation : Union [ np .ndarray , dict [str , np .ndarray ] ],
540- state : Optional [ tuple [np .ndarray , ...]] = None ,
541- episode_start : Optional [ np .ndarray ] = None ,
539+ observation : np .ndarray | dict [str , np .ndarray ],
540+ state : tuple [np .ndarray , ...] | None = None ,
541+ episode_start : np .ndarray | None = None ,
542542 deterministic : bool = False ,
543- ) -> tuple [np .ndarray , Optional [ tuple [np .ndarray , ...]] ]:
543+ ) -> tuple [np .ndarray , tuple [np .ndarray , ...] | None ]:
544544 """
545545 Get the policy action from an observation (and optional hidden state).
546546 Includes sugar-coating to handle different observations (e.g. normalizing images).
@@ -556,7 +556,7 @@ def predict(
556556 """
557557 return self .policy .predict (observation , state , episode_start , deterministic )
558558
559- def set_random_seed (self , seed : Optional [ int ] = None ) -> None :
559+ def set_random_seed (self , seed : int | None = None ) -> None :
560560 """
561561 Set the seed of the pseudo-random generators
562562 (python, numpy, pytorch, gym, action_space)
@@ -573,9 +573,9 @@ def set_random_seed(self, seed: Optional[int] = None) -> None:
573573
574574 def set_parameters (
575575 self ,
576- load_path_or_dict : Union [ str , TensorDict ] ,
576+ load_path_or_dict : str | TensorDict ,
577577 exact_match : bool = True ,
578- device : Union [ th .device , str ] = "auto" ,
578+ device : th .device | str = "auto" ,
579579 ) -> None :
580580 """
581581 Load parameters from a given zip-file or a nested dictionary containing parameters for
@@ -642,10 +642,10 @@ def set_parameters(
642642 @classmethod
643643 def load ( # noqa: C901
644644 cls : type [SelfBaseAlgorithm ],
645- path : Union [ str , pathlib .Path , io .BufferedIOBase ] ,
646- env : Optional [ GymEnv ] = None ,
647- device : Union [ th .device , str ] = "auto" ,
648- custom_objects : Optional [ dict [str , Any ]] = None ,
645+ path : str | pathlib .Path | io .BufferedIOBase ,
646+ env : GymEnv | None = None ,
647+ device : th .device | str = "auto" ,
648+ custom_objects : dict [str , Any ] | None = None ,
649649 print_system_info : bool = False ,
650650 force_reset : bool = True ,
651651 ** kwargs ,
@@ -818,9 +818,9 @@ def get_parameters(self) -> dict[str, dict]:
818818
819819 def save (
820820 self ,
821- path : Union [ str , pathlib .Path , io .BufferedIOBase ] ,
822- exclude : Optional [ Iterable [str ]] = None ,
823- include : Optional [ Iterable [str ]] = None ,
821+ path : str | pathlib .Path | io .BufferedIOBase ,
822+ exclude : Iterable [str ] | None = None ,
823+ include : Iterable [str ] | None = None ,
824824 ) -> None :
825825 """
826826 Save all the attributes of the object and the model parameters in a zip-file.
0 commit comments