Skip to content

Commit 15ffb45

Browse files
committed
Removed BufferDTypes
1 parent b60074d commit 15ffb45

File tree

5 files changed

+22
-93
lines changed

5 files changed

+22
-93
lines changed

docs/misc/changelog.rst

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ New Features:
1313

1414
Others:
1515
^^^^^^^
16-
- Added an additional ``BufferDTypes`` dataclass to ``stable_baselines3/common/buffers.py`` for representing buffer datatypes (@Trenza1ore)
1716
- Added an additional ``test_buffers.py::test_buffer_dtypes`` which tests the `dtype` of `RolloutBuffer` and `DictRolloutBuffer` (@Trenza1ore)
1817

1918
Release 2.7.0 (2025-07-25)

stable_baselines3/common/buffers.py

Lines changed: 6 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
import warnings
22
from abc import ABC, abstractmethod
3-
from collections.abc import Generator, Mapping
4-
from dataclasses import InitVar, dataclass, field
5-
from types import MappingProxyType
6-
from typing import Any, ClassVar, Optional, Union
3+
from collections.abc import Generator
4+
from typing import Any, Optional, Union
75

86
import numpy as np
97
import torch as th
@@ -13,7 +11,6 @@
1311
from stable_baselines3.common.type_aliases import (
1412
DictReplayBufferSamples,
1513
DictRolloutBufferSamples,
16-
DTypeLike,
1714
ReplayBufferSamples,
1815
RolloutBufferSamples,
1916
)
@@ -27,55 +24,6 @@
2724
psutil = None
2825

2926

30-
@dataclass
31-
class BufferDTypes:
32-
"""
33-
Data class representing the data types used by a buffer.
34-
35-
:param observations: Datatype of observation space
36-
:param actions: Datatype of action space
37-
"""
38-
39-
MAP_TORCH_DTYPES: ClassVar[dict] = dict(complex32="complex64", float="float32", bfloat16="float32", bool="bool_")
40-
41-
observations: InitVar[Union[DTypeLike, Mapping[str, DTypeLike]]]
42-
actions: InitVar[DTypeLike]
43-
44-
dict_obs: MappingProxyType[str, np.dtype] = field(default_factory=lambda: MappingProxyType({}), init=False)
45-
obs: Optional[np.dtype] = field(default=None, init=False)
46-
act: Optional[np.dtype] = field(default=None, init=False)
47-
48-
def __post_init__(self, observations: Union[DTypeLike, Mapping[str, DTypeLike]], actions: DTypeLike):
49-
if isinstance(observations, Mapping):
50-
self.dict_obs = MappingProxyType({k: self.to_numpy_dtype(v) for k, v in observations.items()})
51-
else:
52-
self.obs = self.to_numpy_dtype(observations)
53-
self.act = self.to_numpy_dtype(actions)
54-
55-
def __getstate__(self):
56-
state = self.__dict__.copy()
57-
if isinstance(self.dict_obs, MappingProxyType):
58-
state["dict_obs"] = dict(self.dict_obs)
59-
return state
60-
61-
def __setstate__(self, state: Mapping[str, Any]):
62-
state = dict(state)
63-
if state.get("dict_obs"):
64-
state["dict_obs"] = MappingProxyType(state["dict_obs"].copy())
65-
self.__dict__.update(state)
66-
67-
@classmethod
68-
def to_numpy_dtype(cls, dtype_like: DTypeLike) -> np.dtype:
69-
if isinstance(dtype_like, th.dtype):
70-
torch_dtype_name = repr(dtype_like).removeprefix("torch.")
71-
numpy_dtype_name = cls.MAP_TORCH_DTYPES.get(torch_dtype_name, torch_dtype_name)
72-
try:
73-
return np.dtype(getattr(np, numpy_dtype_name))
74-
except AttributeError as e:
75-
raise TypeError(f"Cannot cast torch dtype '{torch_dtype_name}' to numpy.dtype implicitly.") from e
76-
return np.dtype(dtype_like)
77-
78-
7927
class BaseBuffer(ABC):
8028
"""
8129
Base class that represent a buffer (rollout or replay)
@@ -111,16 +59,6 @@ def __init__(
11159
self.device = get_device(device)
11260
self.n_envs = n_envs
11361

114-
# unify the dtype decision logic for all buffer classes
115-
# see https://github.com/DLR-RM/stable-baselines3/issues/2162
116-
if isinstance(observation_space, spaces.Dict):
117-
self.dtypes = BufferDTypes(
118-
observations={key: space.dtype for (key, space) in observation_space.spaces.items()},
119-
actions=action_space.dtype,
120-
)
121-
else:
122-
self.dtypes = BufferDTypes(observations=observation_space.dtype, actions=action_space.dtype)
123-
12462
@staticmethod
12563
def swap_and_flatten(arr: np.ndarray) -> np.ndarray:
12664
"""
@@ -451,8 +389,8 @@ def __init__(
451389
self.reset()
452390

453391
def reset(self) -> None:
454-
self.observations = np.zeros((self.buffer_size, self.n_envs, *self.obs_shape), dtype=self.dtypes.obs)
455-
self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=self.dtypes.act)
392+
self.observations = np.zeros((self.buffer_size, self.n_envs, *self.obs_shape), dtype=self.observation_space.dtype)
393+
self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=self.action_space.dtype)
456394
self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
457395
self.returns = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
458396
self.episode_starts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
@@ -808,9 +746,9 @@ def reset(self) -> None:
808746
self.observations = {}
809747
for key, obs_input_shape in self.obs_shape.items():
810748
self.observations[key] = np.zeros(
811-
(self.buffer_size, self.n_envs, *obs_input_shape), dtype=self.dtypes.dict_obs[key]
749+
(self.buffer_size, self.n_envs, *obs_input_shape), dtype=self.observation_space.dtype
812750
)
813-
self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=self.dtypes.act)
751+
self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=self.action_space.dtype)
814752
self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
815753
self.returns = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
816754
self.episode_starts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)

stable_baselines3/common/off_policy_algorithm.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from gymnasium import spaces
1212

1313
from stable_baselines3.common.base_class import BaseAlgorithm
14-
from stable_baselines3.common.buffers import BufferDTypes, DictReplayBuffer, NStepReplayBuffer, ReplayBuffer
14+
from stable_baselines3.common.buffers import DictReplayBuffer, NStepReplayBuffer, ReplayBuffer
1515
from stable_baselines3.common.callbacks import BaseCallback
1616
from stable_baselines3.common.noise import ActionNoise, VectorizedActionNoise
1717
from stable_baselines3.common.policies import BasePolicy
@@ -247,19 +247,6 @@ def load_replay_buffer(
247247
self.replay_buffer.handle_timeout_termination = False
248248
self.replay_buffer.timeouts = np.zeros_like(self.replay_buffer.dones)
249249

250-
# Backward compatibility with SB3 < 2.7.1 replay buffer
251-
if not hasattr(self.replay_buffer, "dtypes"):
252-
if isinstance(self.replay_buffer, DictReplayBuffer):
253-
self.replay_buffer.dtypes = BufferDTypes(
254-
observations={key: obs.dtype for (key, obs) in self.replay_buffer.observations.items()},
255-
actions=self.replay_buffer.actions.dtype,
256-
)
257-
else:
258-
self.replay_buffer.dtypes = BufferDTypes(
259-
observations=self.replay_buffer.observations.dtype,
260-
actions=self.replay_buffer.actions.dtype,
261-
)
262-
263250
if isinstance(self.replay_buffer, HerReplayBuffer):
264251
assert self.env is not None, "You must pass an environment at load time when using `HerReplayBuffer`"
265252
self.replay_buffer.set_env(self.env)

stable_baselines3/common/type_aliases.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
OptimizerStateDict = dict[str, Any]
2323
MaybeCallback = Union[None, Callable, list["BaseCallback"], "BaseCallback"]
2424
PyTorchObs = Union[th.Tensor, TensorDict]
25-
DTypeLike = Union[None, np.dtype, th.dtype, type, str]
2625

2726
# A schedule takes the remaining progress as input
2827
# and outputs a scalar (e.g. learning rate, clip range, ...)

tests/test_buffers.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -184,20 +184,26 @@ def test_device_buffer(replay_buffer_cls, device):
184184
@pytest.mark.parametrize("use_dict", [False, True])
185185
def test_buffer_dtypes(obs_dtype: Union[type[np.integer], type[np.floating]], use_dict: bool):
186186
rollout_buffer: Union[RolloutBuffer, DictRolloutBuffer]
187+
replay_buffer: Union[ReplayBuffer, DictReplayBuffer]
187188
obs_space = spaces.Box(0, 100, dtype=obs_dtype)
188-
action_space = spaces.Discrete(10)
189+
act_space = spaces.Discrete(10)
190+
buffer_params = dict(buffer_size=1, action_space=act_space)
189191

190192
if use_dict:
191-
obs_space_2 = spaces.Box(0, 100, dtype=np.uint8)
192-
observation_space = spaces.Dict({"obs": obs_space, "obs_2": obs_space_2})
193-
rollout_buffer = DictRolloutBuffer(buffer_size=1, observation_space=observation_space, action_space=action_space)
194-
assert rollout_buffer.observations["obs"].dtype == obs_dtype
195-
assert rollout_buffer.observations["obs_2"].dtype == np.uint8
193+
dict_obs_space = spaces.Dict({"obs": obs_space, "obs_2": spaces.Box(0, 100, dtype=np.uint8)})
194+
buffer_params["observation_space"] = dict_obs_space
195+
rollout_buffer = DictRolloutBuffer(**buffer_params) # type: ignore[arg-type]
196+
replay_buffer = DictReplayBuffer(**buffer_params) # type: ignore[arg-type]
197+
assert rollout_buffer.observations["obs"].dtype == replay_buffer.observations["obs"].dtype == obs_dtype
198+
assert rollout_buffer.observations["obs_2"].dtype == replay_buffer.observations["obs_2"].dtype == np.uint8
196199
else:
197-
rollout_buffer = RolloutBuffer(buffer_size=1, observation_space=obs_space, action_space=action_space)
198-
assert rollout_buffer.observations.dtype == obs_dtype
200+
buffer_params["observation_space"] = obs_space
201+
rollout_buffer = RolloutBuffer(**buffer_params) # type: ignore[arg-type]
202+
replay_buffer = ReplayBuffer(**buffer_params) # type: ignore[arg-type]
203+
assert rollout_buffer.observations.dtype == replay_buffer.observations.dtype == obs_dtype
199204

200-
assert rollout_buffer.actions.dtype == np.int64
205+
assert rollout_buffer.actions.dtype == np.float32, "RolloutBuffer action dtype must be np.float32"
206+
assert replay_buffer.actions.dtype == act_space.dtype
201207

202208

203209
def test_custom_rollout_buffer():

0 commit comments

Comments
 (0)