|
1 | 1 | import warnings |
2 | 2 | from abc import ABC, abstractmethod |
3 | | -from collections.abc import Generator, Mapping |
4 | | -from dataclasses import InitVar, dataclass, field |
5 | | -from types import MappingProxyType |
6 | | -from typing import Any, ClassVar, Optional, Union |
| 3 | +from collections.abc import Generator |
| 4 | +from typing import Any, Optional, Union |
7 | 5 |
|
8 | 6 | import numpy as np |
9 | 7 | import torch as th |
|
13 | 11 | from stable_baselines3.common.type_aliases import ( |
14 | 12 | DictReplayBufferSamples, |
15 | 13 | DictRolloutBufferSamples, |
16 | | - DTypeLike, |
17 | 14 | ReplayBufferSamples, |
18 | 15 | RolloutBufferSamples, |
19 | 16 | ) |
|
27 | 24 | psutil = None |
28 | 25 |
|
29 | 26 |
|
30 | | -@dataclass |
31 | | -class BufferDTypes: |
32 | | - """ |
33 | | - Data class representing the data types used by a buffer. |
34 | | -
|
35 | | - :param observations: Datatype of observation space |
36 | | - :param actions: Datatype of action space |
37 | | - """ |
38 | | - |
39 | | - MAP_TORCH_DTYPES: ClassVar[dict] = dict(complex32="complex64", float="float32", bfloat16="float32", bool="bool_") |
40 | | - |
41 | | - observations: InitVar[Union[DTypeLike, Mapping[str, DTypeLike]]] |
42 | | - actions: InitVar[DTypeLike] |
43 | | - |
44 | | - dict_obs: MappingProxyType[str, np.dtype] = field(default_factory=lambda: MappingProxyType({}), init=False) |
45 | | - obs: Optional[np.dtype] = field(default=None, init=False) |
46 | | - act: Optional[np.dtype] = field(default=None, init=False) |
47 | | - |
48 | | - def __post_init__(self, observations: Union[DTypeLike, Mapping[str, DTypeLike]], actions: DTypeLike): |
49 | | - if isinstance(observations, Mapping): |
50 | | - self.dict_obs = MappingProxyType({k: self.to_numpy_dtype(v) for k, v in observations.items()}) |
51 | | - else: |
52 | | - self.obs = self.to_numpy_dtype(observations) |
53 | | - self.act = self.to_numpy_dtype(actions) |
54 | | - |
55 | | - def __getstate__(self): |
56 | | - state = self.__dict__.copy() |
57 | | - if isinstance(self.dict_obs, MappingProxyType): |
58 | | - state["dict_obs"] = dict(self.dict_obs) |
59 | | - return state |
60 | | - |
61 | | - def __setstate__(self, state: Mapping[str, Any]): |
62 | | - state = dict(state) |
63 | | - if state.get("dict_obs"): |
64 | | - state["dict_obs"] = MappingProxyType(state["dict_obs"].copy()) |
65 | | - self.__dict__.update(state) |
66 | | - |
67 | | - @classmethod |
68 | | - def to_numpy_dtype(cls, dtype_like: DTypeLike) -> np.dtype: |
69 | | - if isinstance(dtype_like, th.dtype): |
70 | | - torch_dtype_name = repr(dtype_like).removeprefix("torch.") |
71 | | - numpy_dtype_name = cls.MAP_TORCH_DTYPES.get(torch_dtype_name, torch_dtype_name) |
72 | | - try: |
73 | | - return np.dtype(getattr(np, numpy_dtype_name)) |
74 | | - except AttributeError as e: |
75 | | - raise TypeError(f"Cannot cast torch dtype '{torch_dtype_name}' to numpy.dtype implicitly.") from e |
76 | | - return np.dtype(dtype_like) |
77 | | - |
78 | | - |
79 | 27 | class BaseBuffer(ABC): |
80 | 28 | """ |
81 | 29 | Base class that represent a buffer (rollout or replay) |
@@ -111,16 +59,6 @@ def __init__( |
111 | 59 | self.device = get_device(device) |
112 | 60 | self.n_envs = n_envs |
113 | 61 |
|
114 | | - # unify the dtype decision logic for all buffer classes |
115 | | - # see https://github.com/DLR-RM/stable-baselines3/issues/2162 |
116 | | - if isinstance(observation_space, spaces.Dict): |
117 | | - self.dtypes = BufferDTypes( |
118 | | - observations={key: space.dtype for (key, space) in observation_space.spaces.items()}, |
119 | | - actions=action_space.dtype, |
120 | | - ) |
121 | | - else: |
122 | | - self.dtypes = BufferDTypes(observations=observation_space.dtype, actions=action_space.dtype) |
123 | | - |
124 | 62 | @staticmethod |
125 | 63 | def swap_and_flatten(arr: np.ndarray) -> np.ndarray: |
126 | 64 | """ |
@@ -451,8 +389,8 @@ def __init__( |
451 | 389 | self.reset() |
452 | 390 |
|
453 | 391 | def reset(self) -> None: |
454 | | - self.observations = np.zeros((self.buffer_size, self.n_envs, *self.obs_shape), dtype=self.dtypes.obs) |
455 | | - self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=self.dtypes.act) |
| 392 | + self.observations = np.zeros((self.buffer_size, self.n_envs, *self.obs_shape), dtype=self.observation_space.dtype) |
| 393 | + self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=self.action_space.dtype) |
456 | 394 | self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) |
457 | 395 | self.returns = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) |
458 | 396 | self.episode_starts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) |
@@ -808,9 +746,9 @@ def reset(self) -> None: |
808 | 746 | self.observations = {} |
809 | 747 | for key, obs_input_shape in self.obs_shape.items(): |
810 | 748 | self.observations[key] = np.zeros( |
811 | | - (self.buffer_size, self.n_envs, *obs_input_shape), dtype=self.dtypes.dict_obs[key] |
| 749 | + (self.buffer_size, self.n_envs, *obs_input_shape), dtype=self.observation_space.dtype |
812 | 750 | ) |
813 | | - self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=self.dtypes.act) |
| 751 | + self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=self.action_space.dtype) |
814 | 752 | self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) |
815 | 753 | self.returns = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) |
816 | 754 | self.episode_starts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) |
|
0 commit comments