Skip to content

Commit dd7f5bf

Browse files
Trenza1orearaffin
andauthored
Use proper dtype for RolloutBuffer storage (#2163)
* Initial implementation of dtype-decision logic * Fixed init logic * Updated changelog * Added a test * Reformatted using make format * Ensure make type passes * Fixed DictRolloutBuffer dtype assignment * Updated to create a BufferDTypes dataclass and updated pytests * Fix type check errors on Github, separate dict_obs and obs, honor _normalize_obs for rollout buffers * Revert _normalize_obs calls in rollout buffers * Updated docs * Updated docs * Added save / load support with backward compatibility * Cast sampled actions of rollout buffers to float32 to avoid breaking changes * Fixed pickle loading of BufferDTypes * Use default_factory instead of default for BufferDTypes.dict_obs * Simplified BufferDTypes and reverted changes on replay buffers as requested * Removed BufferDTypes * Fixed oversight in dictrolloutbuffer dtype * Update changelog and version * Remove cast to float32 * Update tests * Remove cast to long * Revert "Remove cast to long" This reverts commit 216d757. * Revert "Remove cast to float32" This reverts commit d1e5221. * Reapply "Remove cast to float32" This reverts commit 3511452. * Reapply "Remove cast to long" This reverts commit c2d532c. * Cast int8 to float32 to avoid PyTorch issues (MultiBinary) * Revert "Reapply "Remove cast to long"" This reverts commit 88e6b68. * Cast at sample time only * Update changelog.rst --------- Co-authored-by: Antonin RAFFIN <[email protected]>
1 parent bf51a62 commit dd7f5bf

File tree

4 files changed

+113
-8
lines changed

4 files changed

+113
-8
lines changed

docs/misc/changelog.rst

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,38 @@
33
Changelog
44
==========
55

6+
Release 2.7.1a1 (WIP)
7+
--------------------------
8+
9+
Breaking Changes:
10+
^^^^^^^^^^^^^^^^^
11+
12+
New Features:
13+
^^^^^^^^^^^^^
14+
- ``RolloutBuffer`` and ``DictRolloutBuffer`` now uses the actual observation / action space ``dtype`` (instead of float32), this should save memory (@Trenza1ore)
15+
16+
Bug Fixes:
17+
^^^^^^^^^^
18+
19+
`SB3-Contrib`_
20+
^^^^^^^^^^^^^^
21+
22+
`RL Zoo`_
23+
^^^^^^^^^
24+
25+
`SBX`_ (SB3 + Jax)
26+
^^^^^^^^^^^^^^^^^^
27+
28+
Deprecations:
29+
^^^^^^^^^^^^^
30+
31+
Others:
32+
^^^^^^^
33+
34+
Documentation:
35+
^^^^^^^^^^^^^^
36+
37+
638
Release 2.7.0 (2025-07-25)
739
--------------------------
840

@@ -1857,4 +1889,4 @@ And all the contributors:
18571889
@DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel @troiganto
18581890
@lutogniew @lbergmann1 @lukashass @BertrandDecoster @pseudo-rnd-thoughts @stefanbschneider @kyle-he @PatrickHelm @corentinlger
18591891
@marekm4 @stagoverflow @rushitnshah @markscsmith @NickLucche @cschindlbeck @peteole @jak3122 @will-maclean
1860-
@brn-dev @jmacglashan @kplers @MarcDcls @chrisgao99 @pstahlhofen @akanto
1892+
@brn-dev @jmacglashan @kplers @MarcDcls @chrisgao99 @pstahlhofen @akanto @Trenza1ore

stable_baselines3/common/buffers.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -389,8 +389,8 @@ def __init__(
389389
self.reset()
390390

391391
def reset(self) -> None:
392-
self.observations = np.zeros((self.buffer_size, self.n_envs, *self.obs_shape), dtype=np.float32)
393-
self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32)
392+
self.observations = np.zeros((self.buffer_size, self.n_envs, *self.obs_shape), dtype=self.observation_space.dtype)
393+
self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=self.action_space.dtype)
394394
self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
395395
self.returns = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
396396
self.episode_starts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
@@ -512,7 +512,8 @@ def _get_samples(
512512
) -> RolloutBufferSamples:
513513
data = (
514514
self.observations[batch_inds],
515-
self.actions[batch_inds],
515+
# Cast to float32 (backward compatible), this would lead to RuntimeError for MultiBinary space
516+
self.actions[batch_inds].astype(np.float32, copy=False),
516517
self.values[batch_inds].flatten(),
517518
self.log_probs[batch_inds].flatten(),
518519
self.advantages[batch_inds].flatten(),
@@ -745,8 +746,10 @@ def __init__(
745746
def reset(self) -> None:
746747
self.observations = {}
747748
for key, obs_input_shape in self.obs_shape.items():
748-
self.observations[key] = np.zeros((self.buffer_size, self.n_envs, *obs_input_shape), dtype=np.float32)
749-
self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32)
749+
self.observations[key] = np.zeros(
750+
(self.buffer_size, self.n_envs, *obs_input_shape), dtype=self.observation_space[key].dtype
751+
)
752+
self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=self.action_space.dtype)
750753
self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
751754
self.returns = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
752755
self.episode_starts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
@@ -832,7 +835,8 @@ def _get_samples( # type: ignore[override]
832835
) -> DictRolloutBufferSamples:
833836
return DictRolloutBufferSamples(
834837
observations={key: self.to_torch(obs[batch_inds]) for (key, obs) in self.observations.items()},
835-
actions=self.to_torch(self.actions[batch_inds]),
838+
# Cast to float32 (backward compatible), this would lead to RuntimeError for MultiBinary space
839+
actions=self.to_torch(self.actions[batch_inds].astype(np.float32, copy=False)),
836840
old_values=self.to_torch(self.values[batch_inds].flatten()),
837841
old_log_prob=self.to_torch(self.log_probs[batch_inds].flatten()),
838842
advantages=self.to_torch(self.advantages[batch_inds].flatten()),

stable_baselines3/version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
2.7.0
1+
2.7.1a0

tests/test_buffers.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,75 @@ def test_device_buffer(replay_buffer_cls, device):
163163
raise TypeError(f"Unknown value type: {type(value)}")
164164

165165

166+
@pytest.mark.parametrize(
167+
"obs_dtype",
168+
[
169+
np.dtype(np.uint8),
170+
np.dtype(np.int8),
171+
np.dtype(np.uint16),
172+
np.dtype(np.int16),
173+
np.dtype(np.uint32),
174+
np.dtype(np.int32),
175+
np.dtype(np.uint64),
176+
np.dtype(np.int64),
177+
np.dtype(np.float16),
178+
np.dtype(np.float32),
179+
np.dtype(np.float64),
180+
],
181+
)
182+
@pytest.mark.parametrize("use_dict", [False, True])
183+
@pytest.mark.parametrize(
184+
"action_space",
185+
[
186+
spaces.Discrete(10),
187+
spaces.Box(low=-1.0, high=1.0, dtype=np.float32),
188+
spaces.Box(low=-1.0, high=1.0, dtype=np.float64),
189+
],
190+
)
191+
def test_buffer_dtypes(obs_dtype, use_dict, action_space):
192+
obs_space = spaces.Box(0, 100, dtype=obs_dtype)
193+
buffer_params = dict(buffer_size=1, action_space=action_space)
194+
# For off-policy algorithms, we cast float64 actions to float32, see GH#1145
195+
actual_replay_action_dtype = ReplayBuffer._maybe_cast_dtype(action_space.dtype)
196+
# For on-policy, we cast at sample time to float32 for backward compat
197+
# and to avoid issue computing log prob with multibinary
198+
actual_rollout_action_dtype = np.float32
199+
200+
if use_dict:
201+
dict_obs_space = spaces.Dict({"obs": obs_space, "obs_2": spaces.Box(0, 100, dtype=np.uint8)})
202+
buffer_params["observation_space"] = dict_obs_space
203+
rollout_buffer = DictRolloutBuffer(**buffer_params)
204+
replay_buffer = DictReplayBuffer(**buffer_params)
205+
assert rollout_buffer.observations["obs"].dtype == obs_dtype
206+
assert replay_buffer.observations["obs"].dtype == obs_dtype
207+
assert rollout_buffer.observations["obs_2"].dtype == np.uint8
208+
assert replay_buffer.observations["obs_2"].dtype == np.uint8
209+
else:
210+
buffer_params["observation_space"] = obs_space
211+
rollout_buffer = RolloutBuffer(**buffer_params)
212+
replay_buffer = ReplayBuffer(**buffer_params)
213+
assert rollout_buffer.observations.dtype == obs_dtype
214+
assert replay_buffer.observations.dtype == obs_dtype
215+
216+
assert rollout_buffer.actions.dtype == action_space.dtype
217+
assert replay_buffer.actions.dtype == actual_replay_action_dtype
218+
# Check that sampled types are corrects
219+
rollout_buffer.full = True
220+
replay_buffer.full = True
221+
rollout_data = next(rollout_buffer.get(batch_size=64))
222+
buffer_data = replay_buffer.sample(batch_size=64)
223+
assert rollout_data.actions.numpy().dtype == actual_rollout_action_dtype
224+
assert buffer_data.actions.numpy().dtype == actual_replay_action_dtype
225+
if use_dict:
226+
assert buffer_data.observations["obs"].numpy().dtype == obs_dtype
227+
assert buffer_data.observations["obs_2"].numpy().dtype == np.uint8
228+
assert rollout_data.observations["obs"].numpy().dtype == obs_dtype
229+
assert rollout_data.observations["obs_2"].numpy().dtype == np.uint8
230+
else:
231+
assert buffer_data.observations.numpy().dtype == obs_dtype
232+
assert rollout_data.observations.numpy().dtype == obs_dtype
233+
234+
166235
def test_custom_rollout_buffer():
167236
A2C("MlpPolicy", "Pendulum-v1", rollout_buffer_class=RolloutBuffer, rollout_buffer_kwargs=dict())
168237

0 commit comments

Comments
 (0)