Skip to content

Commit 2da14c7

Browse files
committed
Fix type check errors on Github, separate dict_obs and obs, honor _normalize_obs for rollout buffers
1 parent 2819d0d commit 2da14c7

File tree

1 file changed

+33
-29
lines changed

1 file changed

+33
-29
lines changed

stable_baselines3/common/buffers.py

Lines changed: 33 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -37,17 +37,20 @@ class BufferDTypes:
3737
"""
3838

3939
MAP_TORCH_DTYPES: ClassVar[dict] = dict(complex32="complex64", float="float32", bfloat16="float32", bool="bool_")
40-
_observations: InitVar[Union[DTypeLike, Mapping[str, DTypeLike]]]
41-
_actions: InitVar[DTypeLike]
42-
observations: Union[np.dtype, MappingProxyType[str, np.dtype]] = field(init=False)
43-
actions: np.dtype = field(init=False)
44-
45-
def __post_init__(self, _observations: Union[DTypeLike, Mapping[str, DTypeLike]], _actions: DTypeLike):
46-
if isinstance(_observations, Mapping):
47-
self.observations = MappingProxyType({k: self.to_numpy_dtype(v) for k, v in _observations.items()})
40+
41+
observations: InitVar[Union[DTypeLike, Mapping[str, DTypeLike]]]
42+
actions: InitVar[DTypeLike]
43+
44+
dict_obs: MappingProxyType[str, np.dtype] = field(init=False)
45+
obs: Optional[np.dtype] = field(default=None, init=False)
46+
act: Optional[np.dtype] = field(default=None, init=False)
47+
48+
def __post_init__(self, observations: Union[DTypeLike, Mapping[str, DTypeLike]], actions: DTypeLike):
49+
if isinstance(observations, Mapping):
50+
self.dict_obs = MappingProxyType({k: self.to_numpy_dtype(v) for k, v in observations.items()})
4851
else:
49-
self.observations = self.to_numpy_dtype(_observations)
50-
self.actions = self.to_numpy_dtype(_actions)
52+
self.obs = self.to_numpy_dtype(observations)
53+
self.act = self.to_numpy_dtype(actions)
5154

5255
@classmethod
5356
def to_numpy_dtype(cls, dtype_like: DTypeLike) -> np.dtype:
@@ -111,11 +114,11 @@ def __init__(
111114
# see https://github.com/DLR-RM/stable-baselines3/issues/2162
112115
if isinstance(observation_space, spaces.Dict):
113116
self.dtypes = BufferDTypes(
114-
{key: space.dtype for (key, space) in observation_space.spaces.items()},
115-
action_space.dtype,
117+
observations={key: space.dtype for (key, space) in observation_space.spaces.items()},
118+
actions=action_space.dtype,
116119
)
117120
else:
118-
self.dtypes = BufferDTypes(observation_space.dtype, action_space.dtype)
121+
self.dtypes = BufferDTypes(observations=observation_space.dtype, actions=action_space.dtype)
119122

120123
@staticmethod
121124
def swap_and_flatten(arr: np.ndarray) -> np.ndarray:
@@ -268,14 +271,14 @@ def __init__(
268271
)
269272
self.optimize_memory_usage = optimize_memory_usage
270273

271-
self.observations = np.zeros((self.buffer_size, self.n_envs, *self.obs_shape), dtype=self.dtypes.observations)
274+
self.observations = np.zeros((self.buffer_size, self.n_envs, *self.obs_shape), dtype=self.dtypes.obs)
272275

273276
if not optimize_memory_usage:
274277
# When optimizing memory, `observations` contains also the next observation
275-
self.next_observations = np.zeros((self.buffer_size, self.n_envs, *self.obs_shape), dtype=self.dtypes.observations)
278+
self.next_observations = np.zeros((self.buffer_size, self.n_envs, *self.obs_shape), dtype=self.dtypes.obs)
276279

277280
self.actions = np.zeros(
278-
(self.buffer_size, self.n_envs, self.action_dim), dtype=self._maybe_cast_dtype(self.dtypes.actions)
281+
(self.buffer_size, self.n_envs, self.action_dim), dtype=self._maybe_cast_dtype(self.dtypes.act)
279282
)
280283

281284
self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
@@ -447,8 +450,8 @@ def __init__(
447450
self.reset()
448451

449452
def reset(self) -> None:
450-
self.observations = np.zeros((self.buffer_size, self.n_envs, *self.obs_shape), dtype=self.dtypes.observations)
451-
self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=self.dtypes.actions)
453+
self.observations = np.zeros((self.buffer_size, self.n_envs, *self.obs_shape), dtype=self.dtypes.obs)
454+
self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=self.dtypes.act)
452455
self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
453456
self.returns = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
454457
self.episode_starts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
@@ -569,7 +572,7 @@ def _get_samples(
569572
env: Optional[VecNormalize] = None,
570573
) -> RolloutBufferSamples:
571574
data = (
572-
self.observations[batch_inds].astype(np.float32, copy=False),
575+
self._normalize_obs(self.observations[batch_inds], env),
573576
self.actions[batch_inds].astype(np.float32, copy=False),
574577
self.values[batch_inds].flatten(),
575578
self.log_probs[batch_inds].flatten(),
@@ -626,16 +629,16 @@ def __init__(
626629
self.optimize_memory_usage = optimize_memory_usage
627630

628631
self.observations = {
629-
key: np.zeros((self.buffer_size, self.n_envs, *_obs_shape), dtype=self.dtypes.observations[key])
632+
key: np.zeros((self.buffer_size, self.n_envs, *_obs_shape), dtype=self.dtypes.dict_obs[key])
630633
for key, _obs_shape in self.obs_shape.items()
631634
}
632635
self.next_observations = {
633-
key: np.zeros((self.buffer_size, self.n_envs, *_obs_shape), dtype=self.dtypes.observations[key])
636+
key: np.zeros((self.buffer_size, self.n_envs, *_obs_shape), dtype=self.dtypes.dict_obs[key])
634637
for key, _obs_shape in self.obs_shape.items()
635638
}
636639

637640
self.actions = np.zeros(
638-
(self.buffer_size, self.n_envs, self.action_dim), dtype=self._maybe_cast_dtype(self.dtypes.actions)
641+
(self.buffer_size, self.n_envs, self.action_dim), dtype=self._maybe_cast_dtype(self.dtypes.act)
639642
)
640643
self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
641644
self.dones = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
@@ -804,9 +807,9 @@ def reset(self) -> None:
804807
self.observations = {}
805808
for key, obs_input_shape in self.obs_shape.items():
806809
self.observations[key] = np.zeros(
807-
(self.buffer_size, self.n_envs, *obs_input_shape), dtype=self.dtypes.observations[key]
810+
(self.buffer_size, self.n_envs, *obs_input_shape), dtype=self.dtypes.dict_obs[key]
808811
)
809-
self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=self.dtypes.actions)
812+
self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=self.dtypes.act)
810813
self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
811814
self.returns = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
812815
self.episode_starts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
@@ -890,12 +893,13 @@ def _get_samples( # type: ignore[override]
890893
batch_inds: np.ndarray,
891894
env: Optional[VecNormalize] = None,
892895
) -> DictRolloutBufferSamples:
896+
# Normalize if needed
897+
observations: dict[str, np.ndarray] = self._normalize_obs(
898+
obs={key: obs[batch_inds] for (key, obs) in self.observations.items()}, env=env
899+
) # type: ignore[assignment]
893900
return DictRolloutBufferSamples(
894-
observations={
895-
key: self.to_torch(obs[batch_inds].astype(dtype=np.float32, copy=False))
896-
for (key, obs) in self.observations.items()
897-
},
898-
actions=self.to_torch(self.actions[batch_inds].astype(dtype=np.float32, copy=False)),
901+
observations={key: self.to_torch(obs) for (key, obs) in observations.items()},
902+
actions=self.to_torch(self.actions[batch_inds]),
899903
old_values=self.to_torch(self.values[batch_inds].flatten()),
900904
old_log_prob=self.to_torch(self.log_probs[batch_inds].flatten()),
901905
advantages=self.to_torch(self.advantages[batch_inds].flatten()),

0 commit comments

Comments
 (0)