Skip to content

Commit b499574

Browse files
committed
Fixed init logic
1 parent 182578d commit b499574

File tree

1 file changed

+11
-6
lines changed

1 file changed

+11
-6
lines changed

stable_baselines3/common/buffers.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,17 +68,22 @@ def __init__(
6868

6969
# Ensure dtypes override is valid for dict observations
7070
if isinstance(observation_space, spaces.Dict):
71-
if dtypes.get("observations") and not hasattr(dtypes["observations"], "__getitem__"):
72-
dtypes["observations"] = {key: dtypes["observations"] for key in self.obs_shape}
73-
obs_dtype = {key: space.dtype for (key, space) in observation_space.spaces.items()} # type: ignore[misc]
71+
if dtypes.get("observations"):
72+
if not isinstance(dtypes["observations"], dict):
73+
dtypes["observations"] = {key: np.dtype(dtypes["observations"]) for key in self.obs_shape}
74+
else:
75+
dtypes["observations"] = {key: np.dtype(dtype) for (key, dtype) in dtypes["observations"].items()}
76+
obs_dtype = {
77+
key: np.dtype(space.dtype) for (key, space) in observation_space.spaces.items()
78+
} # type: ignore[misc]
7479
else:
75-
obs_dtype = observation_space.dtype
80+
obs_dtype = np.dtype(observation_space.dtype)
7681

7782
# Validate the dtypes
78-
self.dtypes = dict(observations=np.dtype(dtypes.get("observations", obs_dtype)),
83+
self.dtypes = dict(observations=dtypes.get("observations", obs_dtype),
7984
actions=np.dtype(dtypes.get("actions", action_space.dtype)))
8085
for space, dtype in self.dtypes.items():
81-
if not hasattr(dtype, "__getitem__"):
86+
if not isinstance(dtype, dict):
8287
dtype = {"": dtype}
8388
for key, subspace_dtype in dtype.items():
8489
if subspace_dtype == object_dtype:

0 commit comments

Comments
 (0)