@@ -37,17 +37,20 @@ class BufferDTypes:
3737 """
3838
3939 MAP_TORCH_DTYPES : ClassVar [dict ] = dict (complex32 = "complex64" , float = "float32" , bfloat16 = "float32" , bool = "bool_" )
40- _observations : InitVar [Union [DTypeLike , Mapping [str , DTypeLike ]]]
41- _actions : InitVar [DTypeLike ]
42- observations : Union [np .dtype , MappingProxyType [str , np .dtype ]] = field (init = False )
43- actions : np .dtype = field (init = False )
44-
45- def __post_init__ (self , _observations : Union [DTypeLike , Mapping [str , DTypeLike ]], _actions : DTypeLike ):
46- if isinstance (_observations , Mapping ):
47- self .observations = MappingProxyType ({k : self .to_numpy_dtype (v ) for k , v in _observations .items ()})
40+
41+ observations : InitVar [Union [DTypeLike , Mapping [str , DTypeLike ]]]
42+ actions : InitVar [DTypeLike ]
43+
44+ dict_obs : MappingProxyType [str , np .dtype ] = field (init = False )
45+ obs : Optional [np .dtype ] = field (default = None , init = False )
46+ act : Optional [np .dtype ] = field (default = None , init = False )
47+
48+ def __post_init__ (self , observations : Union [DTypeLike , Mapping [str , DTypeLike ]], actions : DTypeLike ):
49+ if isinstance (observations , Mapping ):
50+ self .dict_obs = MappingProxyType ({k : self .to_numpy_dtype (v ) for k , v in observations .items ()})
4851 else :
49- self .observations = self .to_numpy_dtype (_observations )
50- self .actions = self .to_numpy_dtype (_actions )
52+ self .obs = self .to_numpy_dtype (observations )
53+ self .act = self .to_numpy_dtype (actions )
5154
5255 @classmethod
5356 def to_numpy_dtype (cls , dtype_like : DTypeLike ) -> np .dtype :
@@ -111,11 +114,11 @@ def __init__(
111114 # see https://github.com/DLR-RM/stable-baselines3/issues/2162
112115 if isinstance (observation_space , spaces .Dict ):
113116 self .dtypes = BufferDTypes (
114- {key : space .dtype for (key , space ) in observation_space .spaces .items ()},
115- action_space .dtype ,
117+ observations = {key : space .dtype for (key , space ) in observation_space .spaces .items ()},
118+ actions = action_space .dtype ,
116119 )
117120 else :
118- self .dtypes = BufferDTypes (observation_space .dtype , action_space .dtype )
121+ self .dtypes = BufferDTypes (observations = observation_space .dtype , actions = action_space .dtype )
119122
120123 @staticmethod
121124 def swap_and_flatten (arr : np .ndarray ) -> np .ndarray :
@@ -268,14 +271,14 @@ def __init__(
268271 )
269272 self .optimize_memory_usage = optimize_memory_usage
270273
271- self .observations = np .zeros ((self .buffer_size , self .n_envs , * self .obs_shape ), dtype = self .dtypes .observations )
274+ self .observations = np .zeros ((self .buffer_size , self .n_envs , * self .obs_shape ), dtype = self .dtypes .obs )
272275
273276 if not optimize_memory_usage :
274277 # When optimizing memory, `observations` contains also the next observation
275- self .next_observations = np .zeros ((self .buffer_size , self .n_envs , * self .obs_shape ), dtype = self .dtypes .observations )
278+ self .next_observations = np .zeros ((self .buffer_size , self .n_envs , * self .obs_shape ), dtype = self .dtypes .obs )
276279
277280 self .actions = np .zeros (
278- (self .buffer_size , self .n_envs , self .action_dim ), dtype = self ._maybe_cast_dtype (self .dtypes .actions )
281+ (self .buffer_size , self .n_envs , self .action_dim ), dtype = self ._maybe_cast_dtype (self .dtypes .act )
279282 )
280283
281284 self .rewards = np .zeros ((self .buffer_size , self .n_envs ), dtype = np .float32 )
@@ -447,8 +450,8 @@ def __init__(
447450 self .reset ()
448451
449452 def reset (self ) -> None :
450- self .observations = np .zeros ((self .buffer_size , self .n_envs , * self .obs_shape ), dtype = self .dtypes .observations )
451- self .actions = np .zeros ((self .buffer_size , self .n_envs , self .action_dim ), dtype = self .dtypes .actions )
453+ self .observations = np .zeros ((self .buffer_size , self .n_envs , * self .obs_shape ), dtype = self .dtypes .obs )
454+ self .actions = np .zeros ((self .buffer_size , self .n_envs , self .action_dim ), dtype = self .dtypes .act )
452455 self .rewards = np .zeros ((self .buffer_size , self .n_envs ), dtype = np .float32 )
453456 self .returns = np .zeros ((self .buffer_size , self .n_envs ), dtype = np .float32 )
454457 self .episode_starts = np .zeros ((self .buffer_size , self .n_envs ), dtype = np .float32 )
@@ -569,7 +572,7 @@ def _get_samples(
569572 env : Optional [VecNormalize ] = None ,
570573 ) -> RolloutBufferSamples :
571574 data = (
572- self .observations [batch_inds ]. astype ( np . float32 , copy = False ),
575+ self ._normalize_obs ( self . observations [batch_inds ], env ),
573576 self .actions [batch_inds ].astype (np .float32 , copy = False ),
574577 self .values [batch_inds ].flatten (),
575578 self .log_probs [batch_inds ].flatten (),
@@ -626,16 +629,16 @@ def __init__(
626629 self .optimize_memory_usage = optimize_memory_usage
627630
628631 self .observations = {
629- key : np .zeros ((self .buffer_size , self .n_envs , * _obs_shape ), dtype = self .dtypes .observations [key ])
632+ key : np .zeros ((self .buffer_size , self .n_envs , * _obs_shape ), dtype = self .dtypes .dict_obs [key ])
630633 for key , _obs_shape in self .obs_shape .items ()
631634 }
632635 self .next_observations = {
633- key : np .zeros ((self .buffer_size , self .n_envs , * _obs_shape ), dtype = self .dtypes .observations [key ])
636+ key : np .zeros ((self .buffer_size , self .n_envs , * _obs_shape ), dtype = self .dtypes .dict_obs [key ])
634637 for key , _obs_shape in self .obs_shape .items ()
635638 }
636639
637640 self .actions = np .zeros (
638- (self .buffer_size , self .n_envs , self .action_dim ), dtype = self ._maybe_cast_dtype (self .dtypes .actions )
641+ (self .buffer_size , self .n_envs , self .action_dim ), dtype = self ._maybe_cast_dtype (self .dtypes .act )
639642 )
640643 self .rewards = np .zeros ((self .buffer_size , self .n_envs ), dtype = np .float32 )
641644 self .dones = np .zeros ((self .buffer_size , self .n_envs ), dtype = np .float32 )
@@ -804,9 +807,9 @@ def reset(self) -> None:
804807 self .observations = {}
805808 for key , obs_input_shape in self .obs_shape .items ():
806809 self .observations [key ] = np .zeros (
807- (self .buffer_size , self .n_envs , * obs_input_shape ), dtype = self .dtypes .observations [key ]
810+ (self .buffer_size , self .n_envs , * obs_input_shape ), dtype = self .dtypes .dict_obs [key ]
808811 )
809- self .actions = np .zeros ((self .buffer_size , self .n_envs , self .action_dim ), dtype = self .dtypes .actions )
812+ self .actions = np .zeros ((self .buffer_size , self .n_envs , self .action_dim ), dtype = self .dtypes .act )
810813 self .rewards = np .zeros ((self .buffer_size , self .n_envs ), dtype = np .float32 )
811814 self .returns = np .zeros ((self .buffer_size , self .n_envs ), dtype = np .float32 )
812815 self .episode_starts = np .zeros ((self .buffer_size , self .n_envs ), dtype = np .float32 )
@@ -890,12 +893,13 @@ def _get_samples( # type: ignore[override]
890893 batch_inds : np .ndarray ,
891894 env : Optional [VecNormalize ] = None ,
892895 ) -> DictRolloutBufferSamples :
896+ # Normalize if needed
897+ observations : dict [str , np .ndarray ] = self ._normalize_obs (
898+ obs = {key : obs [batch_inds ] for (key , obs ) in self .observations .items ()}, env = env
899+ ) # type: ignore[assignment]
893900 return DictRolloutBufferSamples (
894- observations = {
895- key : self .to_torch (obs [batch_inds ].astype (dtype = np .float32 , copy = False ))
896- for (key , obs ) in self .observations .items ()
897- },
898- actions = self .to_torch (self .actions [batch_inds ].astype (dtype = np .float32 , copy = False )),
901+ observations = {key : self .to_torch (obs ) for (key , obs ) in observations .items ()},
902+ actions = self .to_torch (self .actions [batch_inds ]),
899903 old_values = self .to_torch (self .values [batch_inds ].flatten ()),
900904 old_log_prob = self .to_torch (self .log_probs [batch_inds ].flatten ()),
901905 advantages = self .to_torch (self .advantages [batch_inds ].flatten ()),
0 commit comments