@@ -572,8 +572,8 @@ def _get_samples(
572572 env : Optional [VecNormalize ] = None ,
573573 ) -> RolloutBufferSamples :
574574 data = (
575- self ._normalize_obs ( self . observations [batch_inds ], env ) ,
576- self .actions [batch_inds ]. astype ( np . float32 , copy = False ) ,
575+ self .observations [batch_inds ],
576+ self .actions [batch_inds ],
577577 self .values [batch_inds ].flatten (),
578578 self .log_probs [batch_inds ].flatten (),
579579 self .advantages [batch_inds ].flatten (),
@@ -893,12 +893,8 @@ def _get_samples( # type: ignore[override]
893893 batch_inds : np .ndarray ,
894894 env : Optional [VecNormalize ] = None ,
895895 ) -> DictRolloutBufferSamples :
896- # Normalize if needed
897- observations : dict [str , np .ndarray ] = self ._normalize_obs (
898- obs = {key : obs [batch_inds ] for (key , obs ) in self .observations .items ()}, env = env
899- ) # type: ignore[assignment]
900896 return DictRolloutBufferSamples (
901- observations = {key : self .to_torch (obs ) for (key , obs ) in observations .items ()},
897+ observations = {key : self .to_torch (obs [ batch_inds ] ) for (key , obs ) in self . observations .items ()},
902898 actions = self .to_torch (self .actions [batch_inds ]),
903899 old_values = self .to_torch (self .values [batch_inds ].flatten ()),
904900 old_log_prob = self .to_torch (self .log_probs [batch_inds ].flatten ()),
0 commit comments