Skip to content

Commit 4b0cfc3

Browse files
committed
Cast at sample time only
1 parent 0ca5017 commit 4b0cfc3

File tree

2 files changed

+10
-26
lines changed

2 files changed

+10
-26
lines changed

stable_baselines3/common/buffers.py

Lines changed: 6 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -390,9 +390,7 @@ def __init__(
390390

391391
def reset(self) -> None:
392392
self.observations = np.zeros((self.buffer_size, self.n_envs, *self.obs_shape), dtype=self.observation_space.dtype)
393-
self.actions = np.zeros(
394-
(self.buffer_size, self.n_envs, self.action_dim), dtype=self._maybe_cast_dtype(self.action_space.dtype)
395-
)
393+
self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=self.action_space.dtype)
396394
self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
397395
self.returns = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
398396
self.episode_starts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
@@ -514,29 +512,15 @@ def _get_samples(
514512
) -> RolloutBufferSamples:
515513
data = (
516514
self.observations[batch_inds],
517-
self.actions[batch_inds],
515+
# Cast to float32 (backward compatible), this would lead to RuntimeError for MultiBinary space
516+
self.actions[batch_inds].astype(np.float32, copy=False),
518517
self.values[batch_inds].flatten(),
519518
self.log_probs[batch_inds].flatten(),
520519
self.advantages[batch_inds].flatten(),
521520
self.returns[batch_inds].flatten(),
522521
)
523522
return RolloutBufferSamples(*tuple(map(self.to_torch, data)))
524523

525-
@staticmethod
526-
def _maybe_cast_dtype(dtype: np.typing.DTypeLike) -> np.typing.DTypeLike:
527-
"""
528-
Cast `np.int8` action datatype to `np.float32`, keep the others dtype unchanged.
529-
Otherwise, this would lead to
530-
"RuntimeError: result type Float can't be cast to the desired output type Char"
531-
when trying to compute the log prob for MultiBinary space.
532-
533-
:param dtype: The original action space dtype
534-
:return: ``np.float32`` if the dtype was int8, the original dtype otherwise.
535-
"""
536-
if dtype == np.int8:
537-
return np.float32
538-
return dtype
539-
540524

541525
class DictReplayBuffer(ReplayBuffer):
542526
"""
@@ -765,9 +749,7 @@ def reset(self) -> None:
765749
self.observations[key] = np.zeros(
766750
(self.buffer_size, self.n_envs, *obs_input_shape), dtype=self.observation_space[key].dtype
767751
)
768-
self.actions = np.zeros(
769-
(self.buffer_size, self.n_envs, self.action_dim), dtype=self._maybe_cast_dtype(self.action_space.dtype)
770-
)
752+
self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=self.action_space.dtype)
771753
self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
772754
self.returns = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
773755
self.episode_starts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
@@ -853,7 +835,8 @@ def _get_samples( # type: ignore[override]
853835
) -> DictRolloutBufferSamples:
854836
return DictRolloutBufferSamples(
855837
observations={key: self.to_torch(obs[batch_inds]) for (key, obs) in self.observations.items()},
856-
actions=self.to_torch(self.actions[batch_inds]),
838+
# Cast to float32 (backward compatible), this would lead to RuntimeError for MultiBinary space
839+
actions=self.to_torch(self.actions[batch_inds].astype(np.float32, copy=False)),
857840
old_values=self.to_torch(self.values[batch_inds].flatten()),
858841
old_log_prob=self.to_torch(self.log_probs[batch_inds].flatten()),
859842
advantages=self.to_torch(self.advantages[batch_inds].flatten()),

tests/test_buffers.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -193,8 +193,9 @@ def test_buffer_dtypes(obs_dtype, use_dict, action_space):
193193
buffer_params = dict(buffer_size=1, action_space=action_space)
194194
# For off-policy algorithms, we cast float64 actions to float32, see GH#1145
195195
actual_replay_action_dtype = ReplayBuffer._maybe_cast_dtype(action_space.dtype)
196-
# For on-policy, we cast int8 to int64 to avoid issue computing log prob
197-
actual_rollout_action_dtype = RolloutBuffer._maybe_cast_dtype(action_space.dtype)
196+
# For on-policy, we cast at sample time to float32 for backward compat
197+
# and to avoid issue computing log prob with multibinary
198+
actual_rollout_action_dtype = np.float32
198199

199200
if use_dict:
200201
dict_obs_space = spaces.Dict({"obs": obs_space, "obs_2": spaces.Box(0, 100, dtype=np.uint8)})
@@ -212,7 +213,7 @@ def test_buffer_dtypes(obs_dtype, use_dict, action_space):
212213
assert rollout_buffer.observations.dtype == obs_dtype
213214
assert replay_buffer.observations.dtype == obs_dtype
214215

215-
assert rollout_buffer.actions.dtype == actual_rollout_action_dtype
216+
assert rollout_buffer.actions.dtype == action_space.dtype
216217
assert replay_buffer.actions.dtype == actual_replay_action_dtype
217218
# Check that sampled types are corrects
218219
rollout_buffer.full = True

0 commit comments

Comments
 (0)