Skip to content

Commit 66f8300

Browse files
committed
Fixed DictRolloutBuffer dtype assignment
1 parent 2807085 commit 66f8300

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

stable_baselines3/common/buffers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -788,7 +788,7 @@ def reset(self) -> None:
788788
self.observations = {}
789789
for key, obs_input_shape in self.obs_shape.items():
790790
self.observations[key] = np.zeros(
791-
(self.buffer_size, self.n_envs, *obs_input_shape), dtype=self.dtypes["observations"]
791+
(self.buffer_size, self.n_envs, *obs_input_shape), dtype=self.dtypes["observations"][key]
792792
)
793793
self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=self.dtypes["actions"])
794794
self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)

0 commit comments

Comments
 (0)