Skip to content

Commit 2819d0d

Browse files
committed
Updated to create a BufferDTypes dataclass and updated pytests
1 parent 66f8300 commit 2819d0d

File tree

5 files changed

+107
-85
lines changed

5 files changed

+107
-85
lines changed

stable_baselines3/common/buffers.py

Lines changed: 69 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import warnings
22
from abc import ABC, abstractmethod
3-
from collections.abc import Generator
4-
from typing import Any, Optional, Union
3+
from collections.abc import Generator, Mapping
4+
from dataclasses import InitVar, dataclass, field
5+
from types import MappingProxyType
6+
from typing import Any, ClassVar, Optional, Union
57

68
import numpy as np
79
import torch as th
@@ -11,6 +13,7 @@
1113
from stable_baselines3.common.type_aliases import (
1214
DictReplayBufferSamples,
1315
DictRolloutBufferSamples,
16+
DTypeLike,
1417
ReplayBufferSamples,
1518
RolloutBufferSamples,
1619
)
@@ -24,6 +27,51 @@
2427
psutil = None
2528

2629

30+
@dataclass
31+
class BufferDTypes:
32+
"""
33+
Data class representing the data types used by a buffer.
34+
35+
:param _observations: Datatype of observation space
36+
:param _actions: Datatype of action space
37+
"""
38+
39+
MAP_TORCH_DTYPES: ClassVar[dict] = dict(complex32="complex64", float="float32", bfloat16="float32", bool="bool_")
40+
_observations: InitVar[Union[DTypeLike, Mapping[str, DTypeLike]]]
41+
_actions: InitVar[DTypeLike]
42+
observations: Union[np.dtype, MappingProxyType[str, np.dtype]] = field(init=False)
43+
actions: np.dtype = field(init=False)
44+
45+
def __post_init__(self, _observations: Union[DTypeLike, Mapping[str, DTypeLike]], _actions: DTypeLike):
46+
if isinstance(_observations, Mapping):
47+
self.observations = MappingProxyType({k: self.to_numpy_dtype(v) for k, v in _observations.items()})
48+
else:
49+
self.observations = self.to_numpy_dtype(_observations)
50+
self.actions = self.to_numpy_dtype(_actions)
51+
52+
@classmethod
53+
def to_numpy_dtype(cls, dtype_like: DTypeLike) -> np.dtype:
54+
if isinstance(dtype_like, np.dtype):
55+
return dtype_like
56+
elif isinstance(dtype_like, th.dtype):
57+
torch_dtype_name = repr(dtype_like).removeprefix("torch.")
58+
numpy_dtype_name = cls.MAP_TORCH_DTYPES.get(torch_dtype_name, torch_dtype_name)
59+
try:
60+
return np.dtype(getattr(np, numpy_dtype_name))
61+
except AttributeError as e:
62+
raise TypeError(f"Cannot cast torch dtype '{torch_dtype_name}' to numpy.dtype implicitly.") from e
63+
elif isinstance(dtype_like, type) and issubclass(dtype_like, np.generic):
64+
return np.dtype(dtype_like)
65+
elif isinstance(dtype_like, str):
66+
try:
67+
return np.dtype(dtype_like)
68+
except TypeError as e:
69+
raise TypeError(f"Cannot interpret str '{dtype_like}' as a valid numpy datatype.") from e
70+
elif dtype_like is None:
71+
return np.dtype(dtype_like)
72+
raise TypeError(f"Cannot interpret unknown object '{dtype_like}' as a valid numpy datatype.")
73+
74+
2775
class BaseBuffer(ABC):
2876
"""
2977
Base class that represent a buffer (rollout or replay)
@@ -46,7 +94,6 @@ def __init__(
4694
action_space: spaces.Space,
4795
device: Union[th.device, str] = "auto",
4896
n_envs: int = 1,
49-
dtypes: Optional[dict] = None,
5097
):
5198
super().__init__()
5299
self.buffer_size = buffer_size
@@ -62,38 +109,13 @@ def __init__(
62109

63110
# unify the dtype decision logic for all buffer classes
64111
# see https://github.com/DLR-RM/stable-baselines3/issues/2162
65-
dtypes = dtypes or dict()
66-
dtypes = dtypes.copy()
67-
object_dtype = np.dtype(object)
68-
69-
# Ensure dtypes override is valid for dict observations
70-
obs_dtype: Union[dict, np.dtype]
71112
if isinstance(observation_space, spaces.Dict):
72-
if dtypes.get("observations"):
73-
if not isinstance(dtypes["observations"], dict):
74-
dtypes["observations"] = {key: np.dtype(dtypes["observations"]) for key in self.obs_shape}
75-
else:
76-
dtypes["observations"] = {key: np.dtype(dtype) for (key, dtype) in dtypes["observations"].items()}
77-
obs_dtype = {key: np.dtype(space.dtype) for (key, space) in observation_space.spaces.items()} # type: ignore[misc]
113+
self.dtypes = BufferDTypes(
114+
{key: space.dtype for (key, space) in observation_space.spaces.items()},
115+
action_space.dtype,
116+
)
78117
else:
79-
obs_dtype = np.dtype(observation_space.dtype)
80-
81-
# Validate the dtypes
82-
self.dtypes = dict(
83-
observations=dtypes.get("observations", obs_dtype), actions=np.dtype(dtypes.get("actions", action_space.dtype))
84-
)
85-
for space, dtype in self.dtypes.items():
86-
if not isinstance(dtype, dict):
87-
dtype = {"": dtype}
88-
for key, subspace_dtype in dtype.items():
89-
if subspace_dtype == object_dtype:
90-
if key:
91-
key = f"[{key}]"
92-
warnings.warn(
93-
f"An object dtype has been assigned to {space}{key}, you are likely using a custom "
94-
f"environment, please use it with caution and ensure that {space}{key} is properly "
95-
"dereferenced / copied within each step to avoid unwanted consequences."
96-
)
118+
self.dtypes = BufferDTypes(observation_space.dtype, action_space.dtype)
97119

98120
@staticmethod
99121
def swap_and_flatten(arr: np.ndarray) -> np.ndarray:
@@ -227,9 +249,8 @@ def __init__(
227249
n_envs: int = 1,
228250
optimize_memory_usage: bool = False,
229251
handle_timeout_termination: bool = True,
230-
dtypes: Optional[dict] = None,
231252
):
232-
super().__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs, dtypes=dtypes)
253+
super().__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
233254

234255
# Adjust buffer size
235256
self.buffer_size = max(buffer_size // n_envs, 1)
@@ -247,16 +268,14 @@ def __init__(
247268
)
248269
self.optimize_memory_usage = optimize_memory_usage
249270

250-
self.observations = np.zeros((self.buffer_size, self.n_envs, *self.obs_shape), dtype=self.dtypes["observations"])
271+
self.observations = np.zeros((self.buffer_size, self.n_envs, *self.obs_shape), dtype=self.dtypes.observations)
251272

252273
if not optimize_memory_usage:
253274
# When optimizing memory, `observations` contains also the next observation
254-
self.next_observations = np.zeros(
255-
(self.buffer_size, self.n_envs, *self.obs_shape), dtype=self.dtypes["observations"]
256-
)
275+
self.next_observations = np.zeros((self.buffer_size, self.n_envs, *self.obs_shape), dtype=self.dtypes.observations)
257276

258277
self.actions = np.zeros(
259-
(self.buffer_size, self.n_envs, self.action_dim), dtype=self._maybe_cast_dtype(self.dtypes["actions"])
278+
(self.buffer_size, self.n_envs, self.action_dim), dtype=self._maybe_cast_dtype(self.dtypes.actions)
260279
)
261280

262281
self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
@@ -420,17 +439,16 @@ def __init__(
420439
gae_lambda: float = 1,
421440
gamma: float = 0.99,
422441
n_envs: int = 1,
423-
dtypes: Optional[dict] = None,
424442
):
425-
super().__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs, dtypes=dtypes)
443+
super().__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
426444
self.gae_lambda = gae_lambda
427445
self.gamma = gamma
428446
self.generator_ready = False
429447
self.reset()
430448

431449
def reset(self) -> None:
432-
self.observations = np.zeros((self.buffer_size, self.n_envs, *self.obs_shape), dtype=self.dtypes["observations"])
433-
self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=self.dtypes["actions"])
450+
self.observations = np.zeros((self.buffer_size, self.n_envs, *self.obs_shape), dtype=self.dtypes.observations)
451+
self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=self.dtypes.actions)
434452
self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
435453
self.returns = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
436454
self.episode_starts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
@@ -592,9 +610,8 @@ def __init__(
592610
n_envs: int = 1,
593611
optimize_memory_usage: bool = False,
594612
handle_timeout_termination: bool = True,
595-
dtypes: Optional[dict] = None,
596613
):
597-
super(ReplayBuffer, self).__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs, dtypes=dtypes)
614+
super(ReplayBuffer, self).__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
598615

599616
assert isinstance(self.obs_shape, dict), "DictReplayBuffer must be used with Dict obs space only"
600617
self.buffer_size = max(buffer_size // n_envs, 1)
@@ -609,16 +626,16 @@ def __init__(
609626
self.optimize_memory_usage = optimize_memory_usage
610627

611628
self.observations = {
612-
key: np.zeros((self.buffer_size, self.n_envs, *_obs_shape), dtype=self.dtypes["observations"][key])
629+
key: np.zeros((self.buffer_size, self.n_envs, *_obs_shape), dtype=self.dtypes.observations[key])
613630
for key, _obs_shape in self.obs_shape.items()
614631
}
615632
self.next_observations = {
616-
key: np.zeros((self.buffer_size, self.n_envs, *_obs_shape), dtype=self.dtypes["observations"][key])
633+
key: np.zeros((self.buffer_size, self.n_envs, *_obs_shape), dtype=self.dtypes.observations[key])
617634
for key, _obs_shape in self.obs_shape.items()
618635
}
619636

620637
self.actions = np.zeros(
621-
(self.buffer_size, self.n_envs, self.action_dim), dtype=self._maybe_cast_dtype(self.dtypes["actions"])
638+
(self.buffer_size, self.n_envs, self.action_dim), dtype=self._maybe_cast_dtype(self.dtypes.actions)
622639
)
623640
self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
624641
self.dones = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
@@ -772,9 +789,8 @@ def __init__(
772789
gae_lambda: float = 1,
773790
gamma: float = 0.99,
774791
n_envs: int = 1,
775-
dtypes: Optional[dict] = None,
776792
):
777-
super(RolloutBuffer, self).__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs, dtypes=dtypes)
793+
super(RolloutBuffer, self).__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
778794

779795
assert isinstance(self.obs_shape, dict), "DictRolloutBuffer must be used with Dict obs space only"
780796

@@ -788,9 +804,9 @@ def reset(self) -> None:
788804
self.observations = {}
789805
for key, obs_input_shape in self.obs_shape.items():
790806
self.observations[key] = np.zeros(
791-
(self.buffer_size, self.n_envs, *obs_input_shape), dtype=self.dtypes["observations"][key]
807+
(self.buffer_size, self.n_envs, *obs_input_shape), dtype=self.dtypes.observations[key]
792808
)
793-
self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=self.dtypes["actions"])
809+
self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=self.dtypes.actions)
794810
self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
795811
self.returns = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
796812
self.episode_starts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)

stable_baselines3/common/type_aliases.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
OptimizerStateDict = dict[str, Any]
2323
MaybeCallback = Union[None, Callable, list["BaseCallback"], "BaseCallback"]
2424
PyTorchObs = Union[th.Tensor, TensorDict]
25+
DTypeLike = Union[None, np.dtype, th.dtype, type, str]
2526

2627
# A schedule takes the remaining progress as input
2728
# and outputs a scalar (e.g. learning rate, clip range, ...)

stable_baselines3/her/her_replay_buffer.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@ def __init__(
6060
n_sampled_goal: int = 4,
6161
goal_selection_strategy: Union[GoalSelectionStrategy, str] = "future",
6262
copy_info_dict: bool = False,
63-
dtypes: Optional[dict] = None,
6463
):
6564
super().__init__(
6665
buffer_size,
@@ -70,7 +69,6 @@ def __init__(
7069
n_envs=n_envs,
7170
optimize_memory_usage=optimize_memory_usage,
7271
handle_timeout_termination=handle_timeout_termination,
73-
dtypes=dtypes,
7472
)
7573
self.env = env
7674
self.copy_info_dict = copy_info_dict

tests/test_buffer_dtypes.py

Lines changed: 0 additions & 30 deletions
This file was deleted.

tests/test_buffers.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Union
2+
13
import gymnasium as gym
24
import numpy as np
35
import pytest
@@ -163,6 +165,41 @@ def test_device_buffer(replay_buffer_cls, device):
163165
raise TypeError(f"Unknown value type: {type(value)}")
164166

165167

168+
@pytest.mark.parametrize(
169+
"obs_dtype",
170+
[
171+
np.dtype(np.uint8),
172+
np.dtype(np.int8),
173+
np.dtype(np.uint16),
174+
np.dtype(np.int16),
175+
np.dtype(np.uint32),
176+
np.dtype(np.int32),
177+
np.dtype(np.uint64),
178+
np.dtype(np.int64),
179+
np.dtype(np.float16),
180+
np.dtype(np.float32),
181+
np.dtype(np.float64),
182+
],
183+
)
184+
@pytest.mark.parametrize("use_dict", [False, True])
185+
def test_buffer_dtypes(obs_dtype: Union[type[np.integer], type[np.floating]], use_dict: bool):
186+
rollout_buffer: Union[RolloutBuffer, DictRolloutBuffer]
187+
obs_space = spaces.Box(0, 100, dtype=obs_dtype)
188+
action_space = spaces.Discrete(10)
189+
190+
if use_dict:
191+
obs_space_2 = spaces.Box(0, 100, dtype=np.uint8)
192+
observation_space = spaces.Dict({"obs": obs_space, "obs_2": obs_space_2})
193+
rollout_buffer = DictRolloutBuffer(buffer_size=1, observation_space=observation_space, action_space=action_space)
194+
assert rollout_buffer.observations["obs"].dtype == obs_dtype
195+
assert rollout_buffer.observations["obs_2"].dtype == np.uint8
196+
else:
197+
rollout_buffer = RolloutBuffer(buffer_size=1, observation_space=obs_space, action_space=action_space)
198+
assert rollout_buffer.observations.dtype == obs_dtype
199+
200+
assert rollout_buffer.actions.dtype == np.int64
201+
202+
166203
def test_custom_rollout_buffer():
167204
A2C("MlpPolicy", "Pendulum-v1", rollout_buffer_class=RolloutBuffer, rollout_buffer_kwargs=dict())
168205

0 commit comments

Comments
 (0)