Skip to content

Commit 749a722

Browse files
committed
Fix inline types
1 parent cc88a25 commit 749a722

File tree

4 files changed

+8
-8
lines changed

4 files changed

+8
-8
lines changed

stable_baselines3/common/base_class.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -142,10 +142,10 @@ def __init__(
142142
self.start_time = 0.0
143143
self.learning_rate = learning_rate
144144
self.tensorboard_log = tensorboard_log
145-
self._last_obs = None # type: Optional[Union[np.ndarray, dict[str, np.ndarray]]]
146-
self._last_episode_starts = None # type: Optional[np.ndarray]
145+
self._last_obs = None # type: np.ndarray | dict[str, np.ndarray] | None
146+
self._last_episode_starts = None # type: np.ndarray | None
147147
# When using VecNormalize:
148-
self._last_original_obs = None # type: Optional[Union[np.ndarray, dict[str, np.ndarray]]]
148+
self._last_original_obs = None # type: np.ndarray | dict[str, np.ndarray] | None
149149
self._episode_num = 0
150150
# Used for gSDE only
151151
self.use_sde = use_sde
@@ -155,8 +155,8 @@ def __init__(
155155
self._current_progress_remaining = 1.0
156156
# Buffers for logging
157157
self._stats_window_size = stats_window_size
158-
self.ep_info_buffer = None # type: Optional[deque]
159-
self.ep_success_buffer = None # type: Optional[deque]
158+
self.ep_info_buffer = None # type: deque | None
159+
self.ep_success_buffer = None # type: deque | None
160160
# For logging (and TD3 delayed updates)
161161
self._n_updates = 0 # type: int
162162
# Whether the user passed a custom logger or not

stable_baselines3/common/callbacks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def __init__(self, verbose: int = 0):
5050
self.globals: dict[str, Any] = {}
5151
# Sometimes, for event callback, it is useful
5252
# to have access to the parent object
53-
self.parent = None # type: Optional[BaseCallback]
53+
self.parent = None # type: BaseCallback | None
5454

5555
@property
5656
def training_env(self) -> VecEnv:

stable_baselines3/common/vec_env/stacked_observations.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def __init__(
4343
}
4444
self.stacked_observation_space = spaces.Dict(
4545
{key: substack_obs.stacked_observation_space for key, substack_obs in self.sub_stacked_observations.items()}
46-
) # type: Union[spaces.Dict, spaces.Box] # make mypy happy
46+
) # type: spaces.Dict | spaces.Box # make mypy happy
4747
elif isinstance(observation_space, spaces.Box):
4848
if isinstance(channels_order, Mapping):
4949
raise TypeError("When the observation space is Box, channels_order can't be a dict.")

stable_baselines3/sac/sac.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def __init__(
149149
)
150150

151151
self.target_entropy = target_entropy
152-
self.log_ent_coef = None # type: Optional[th.Tensor]
152+
self.log_ent_coef = None # type: th.Tensor | None
153153
# Entropy coefficient / Entropy temperature
154154
# Inverse of the reward scale
155155
self.ent_coef = ent_coef

0 commit comments

Comments
 (0)