Skip to content

Commit 69ad231

Browse files
committed
Fixed oversight in dictrolloutbuffer dtype
1 parent 15ffb45 commit 69ad231

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

stable_baselines3/common/buffers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -746,7 +746,7 @@ def reset(self) -> None:
746746
self.observations = {}
747747
for key, obs_input_shape in self.obs_shape.items():
748748
self.observations[key] = np.zeros(
749-
(self.buffer_size, self.n_envs, *obs_input_shape), dtype=self.observation_space.dtype
749+
(self.buffer_size, self.n_envs, *obs_input_shape), dtype=self.observation_space[key].dtype
750750
)
751751
self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=self.action_space.dtype)
752752
self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)

tests/test_buffers.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -194,15 +194,18 @@ def test_buffer_dtypes(obs_dtype: Union[type[np.integer], type[np.floating]], us
194194
buffer_params["observation_space"] = dict_obs_space
195195
rollout_buffer = DictRolloutBuffer(**buffer_params) # type: ignore[arg-type]
196196
replay_buffer = DictReplayBuffer(**buffer_params) # type: ignore[arg-type]
197-
assert rollout_buffer.observations["obs"].dtype == replay_buffer.observations["obs"].dtype == obs_dtype
198-
assert rollout_buffer.observations["obs_2"].dtype == replay_buffer.observations["obs_2"].dtype == np.uint8
197+
assert rollout_buffer.observations["obs"].dtype == obs_dtype
198+
assert replay_buffer.observations["obs"].dtype == obs_dtype
199+
assert rollout_buffer.observations["obs_2"].dtype == np.uint8
200+
assert replay_buffer.observations["obs_2"].dtype == np.uint8
199201
else:
200202
buffer_params["observation_space"] = obs_space
201203
rollout_buffer = RolloutBuffer(**buffer_params) # type: ignore[arg-type]
202204
replay_buffer = ReplayBuffer(**buffer_params) # type: ignore[arg-type]
203-
assert rollout_buffer.observations.dtype == replay_buffer.observations.dtype == obs_dtype
205+
assert rollout_buffer.observations.dtype == obs_dtype
206+
assert replay_buffer.observations.dtype == obs_dtype
204207

205-
assert rollout_buffer.actions.dtype == np.float32, "RolloutBuffer action dtype must be np.float32"
208+
assert rollout_buffer.actions.dtype == act_space.dtype
206209
assert replay_buffer.actions.dtype == act_space.dtype
207210

208211

0 commit comments

Comments
 (0)