@@ -68,17 +68,22 @@ def __init__(
6868
6969 # Ensure dtypes override is valid for dict observations
7070 if isinstance (observation_space , spaces .Dict ):
71- if dtypes .get ("observations" ) and not hasattr (dtypes ["observations" ], "__getitem__" ):
72- dtypes ["observations" ] = {key : dtypes ["observations" ] for key in self .obs_shape }
73- obs_dtype = {key : space .dtype for (key , space ) in observation_space .spaces .items ()} # type: ignore[misc]
71+ if dtypes .get ("observations" ):
72+ if not isinstance (dtypes ["observations" ], dict ):
73+ dtypes ["observations" ] = {key : np .dtype (dtypes ["observations" ]) for key in self .obs_shape }
74+ else :
75+ dtypes ["observations" ] = {key : np .dtype (dtype ) for (key , dtype ) in dtypes ["observations" ].items ()}
76+ obs_dtype = {
77+ key : np .dtype (space .dtype ) for (key , space ) in observation_space .spaces .items ()
78+ } # type: ignore[misc]
7479 else :
75- obs_dtype = observation_space .dtype
80+ obs_dtype = np . dtype ( observation_space .dtype )
7681
7782 # Validate the dtypes
78- self .dtypes = dict (observations = np . dtype ( dtypes .get ("observations" , obs_dtype ) ),
83+ self .dtypes = dict (observations = dtypes .get ("observations" , obs_dtype ),
7984 actions = np .dtype (dtypes .get ("actions" , action_space .dtype )))
8085 for space , dtype in self .dtypes .items ():
81- if not hasattr (dtype , "__getitem__" ):
86+ if not isinstance (dtype , dict ):
8287 dtype = {"" : dtype }
8388 for key , subspace_dtype in dtype .items ():
8489 if subspace_dtype == object_dtype :
0 commit comments