Skip to content

Commit 872a7e5

Browse files
committed
Revert _normalize_obs calls in rollout buffers
1 parent 2da14c7 commit 872a7e5

File tree

1 file changed

+3
-7
lines changed

1 file changed

+3
-7
lines changed

stable_baselines3/common/buffers.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -572,8 +572,8 @@ def _get_samples(
572572
env: Optional[VecNormalize] = None,
573573
) -> RolloutBufferSamples:
574574
data = (
575-
self._normalize_obs(self.observations[batch_inds], env),
576-
self.actions[batch_inds].astype(np.float32, copy=False),
575+
self.observations[batch_inds],
576+
self.actions[batch_inds],
577577
self.values[batch_inds].flatten(),
578578
self.log_probs[batch_inds].flatten(),
579579
self.advantages[batch_inds].flatten(),
@@ -893,12 +893,8 @@ def _get_samples( # type: ignore[override]
893893
batch_inds: np.ndarray,
894894
env: Optional[VecNormalize] = None,
895895
) -> DictRolloutBufferSamples:
896-
# Normalize if needed
897-
observations: dict[str, np.ndarray] = self._normalize_obs(
898-
obs={key: obs[batch_inds] for (key, obs) in self.observations.items()}, env=env
899-
) # type: ignore[assignment]
900896
return DictRolloutBufferSamples(
901-
observations={key: self.to_torch(obs) for (key, obs) in observations.items()},
897+
observations={key: self.to_torch(obs[batch_inds]) for (key, obs) in self.observations.items()},
902898
actions=self.to_torch(self.actions[batch_inds]),
903899
old_values=self.to_torch(self.values[batch_inds].flatten()),
904900
old_log_prob=self.to_torch(self.log_probs[batch_inds].flatten()),

0 commit comments

Comments
 (0)