Skip to content

Commit 77d6ee1

Browse files
committed
Cast sampled actions of rollout buffers to float32 to avoid breaking changes
1 parent e31b0e4 commit 77d6ee1

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

stable_baselines3/common/buffers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -585,7 +585,7 @@ def _get_samples(
585585
) -> RolloutBufferSamples:
586586
data = (
587587
self.observations[batch_inds],
588-
self.actions[batch_inds],
588+
self.actions[batch_inds].astype(np.float32, copy=False),
589589
self.values[batch_inds].flatten(),
590590
self.log_probs[batch_inds].flatten(),
591591
self.advantages[batch_inds].flatten(),
@@ -907,7 +907,7 @@ def _get_samples( # type: ignore[override]
907907
) -> DictRolloutBufferSamples:
908908
return DictRolloutBufferSamples(
909909
observations={key: self.to_torch(obs[batch_inds]) for (key, obs) in self.observations.items()},
910-
actions=self.to_torch(self.actions[batch_inds]),
910+
actions=self.to_torch(self.actions[batch_inds].astype(np.float32, copy=False)),
911911
old_values=self.to_torch(self.values[batch_inds].flatten()),
912912
old_log_prob=self.to_torch(self.log_probs[batch_inds].flatten()),
913913
advantages=self.to_torch(self.advantages[batch_inds].flatten()),

0 commit comments

Comments
 (0)