Skip to content

Commit 9cbce82

Browse files
committed
Fix mypy issues
1 parent 550d0e2 commit 9cbce82

File tree

3 files changed

+5
-4
lines changed

3 files changed

+5
-4
lines changed

stable_baselines3/common/off_policy_algorithm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -423,7 +423,7 @@ def dump_logs(self) -> None:
423423
self.logger.record("time/time_elapsed", int(time_elapsed), exclude="tensorboard")
424424
self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard")
425425
if self.use_sde:
426-
self.logger.record("train/std", (self.actor.get_std()).mean().item())
426+
self.logger.record("train/std", (self.actor.get_std()).mean().item()) # type: ignore[operator]
427427

428428
if len(self.ep_success_buffer) > 0:
429429
self.logger.record("rollout/success_rate", safe_mean(self.ep_success_buffer))
@@ -544,14 +544,14 @@ def collect_rollouts(
544544
assert train_freq.unit == TrainFrequencyUnit.STEP, "You must use only one env when doing episodic training."
545545

546546
if self.use_sde:
547-
self.actor.reset_noise(env.num_envs)
547+
self.actor.reset_noise(env.num_envs) # type: ignore[operator]
548548

549549
callback.on_rollout_start()
550550
continue_training = True
551551
while should_collect_more_steps(train_freq, num_collected_steps, num_collected_episodes):
552552
if self.use_sde and self.sde_sample_freq > 0 and num_collected_steps % self.sde_sample_freq == 0:
553553
# Sample a new noise matrix
554-
self.actor.reset_noise(env.num_envs)
554+
self.actor.reset_noise(env.num_envs) # type: ignore[operator]
555555

556556
# Select action randomly or according to policy
557557
actions, buffer_actions = self._sample_action(learning_starts, action_noise, env.num_envs)

stable_baselines3/sac/sac.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,7 @@ def train(self, gradient_steps: int, batch_size: int = 64) -> None:
228228
# so we don't change it with other losses
229229
# see https://github.com/rail-berkeley/softlearning/issues/60
230230
ent_coef = th.exp(self.log_ent_coef.detach())
231+
assert isinstance(self.target_entropy, float)
231232
ent_coef_loss = -(self.log_ent_coef * (log_prob + self.target_entropy).detach()).mean()
232233
ent_coef_losses.append(ent_coef_loss.item())
233234
else:

stable_baselines3/version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
2.6.0a1
1+
2.6.0a2

0 commit comments

Comments
 (0)