Skip to content

Commit 606cb46

Browse files
committed
Do not overwrite timeout and rename variables
1 parent 59f0023 commit 606cb46

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

stable_baselines3/common/buffers.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -898,7 +898,9 @@ def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = Non
898898
# so we set self.pos-1 to truncated=True (temporarily) if done=False
899899
# TODO: avoid copying the whole array (requires some more indices trickery)
900900
safe_timeouts = self.timeouts.copy()
901-
safe_timeouts[self.pos - 1, :] = np.logical_not(self.dones[self.pos - 1, :])
901+
last_valid_index = self.pos - 1
902+
tmp_timeout = np.logical_not(self.dones[last_valid_index, :])
903+
safe_timeouts[last_valid_index, :] = np.logical_or(safe_timeouts[last_valid_index, :], tmp_timeout)
902904

903905
# Compute n-step indices with wrap-around
904906
steps = np.arange(self.n_steps).reshape(1, -1) # shape: [1, n_steps]
@@ -907,14 +909,14 @@ def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = Non
907909
# Retrieve sequences of transitions
908910
rewards_seq = self._normalize_reward(self.rewards[indices, env_indices[:, None]], env) # [batch, n_steps]
909911
dones_seq = self.dones[indices, env_indices[:, None]] # [batch, n_steps]
910-
truncs_seq = safe_timeouts[indices, env_indices[:, None]] # [batch, n_steps]
912+
truncated_seq = safe_timeouts[indices, env_indices[:, None]] # [batch, n_steps]
911913

912914
# Compute masks: 1 until first done/truncation (inclusive)
913-
done_or_trunc = np.logical_or(dones_seq, truncs_seq)
914-
done_idx = done_or_trunc.argmax(axis=1)
915+
done_or_truncated = np.logical_or(dones_seq, truncated_seq)
916+
done_idx = done_or_truncated.argmax(axis=1)
915917
# If no done/truncation, keep full sequence
916-
has_done_or_trunc = done_or_trunc.any(axis=1)
917-
done_idx = np.where(has_done_or_trunc, done_idx, self.n_steps - 1)
918+
has_done_or_truncated = done_or_truncated.any(axis=1)
919+
done_idx = np.where(has_done_or_truncated, done_idx, self.n_steps - 1)
918920

919921
mask = np.arange(self.n_steps).reshape(1, -1) <= done_idx[:, None] # shape: [batch, n_steps]
920922
# Compute discount factors for bootstrapping (using target Q-Value)

0 commit comments

Comments
 (0)