@@ -898,7 +898,9 @@ def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = Non
898898 # so we set self.pos-1 to truncated=True (temporarily) if done=False
899899 # TODO: avoid copying the whole array (requires some more indices trickery)
900900 safe_timeouts = self .timeouts .copy ()
901- safe_timeouts [self .pos - 1 , :] = np .logical_not (self .dones [self .pos - 1 , :])
901+ last_valid_index = self .pos - 1
902+ tmp_timeout = np .logical_not (self .dones [last_valid_index , :])
903+ safe_timeouts [last_valid_index , :] = np .logical_or (safe_timeouts [last_valid_index , :], tmp_timeout )
902904
903905 # Compute n-step indices with wrap-around
904906 steps = np .arange (self .n_steps ).reshape (1 , - 1 ) # shape: [1, n_steps]
@@ -907,14 +909,14 @@ def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = Non
907909 # Retrieve sequences of transitions
908910 rewards_seq = self ._normalize_reward (self .rewards [indices , env_indices [:, None ]], env ) # [batch, n_steps]
909911 dones_seq = self .dones [indices , env_indices [:, None ]] # [batch, n_steps]
910- truncs_seq = safe_timeouts [indices , env_indices [:, None ]] # [batch, n_steps]
912+ truncated_seq = safe_timeouts [indices , env_indices [:, None ]] # [batch, n_steps]
911913
912914 # Compute masks: 1 until first done/truncation (inclusive)
913- done_or_trunc = np .logical_or (dones_seq , truncs_seq )
914- done_idx = done_or_trunc .argmax (axis = 1 )
915+ done_or_truncated = np .logical_or (dones_seq , truncated_seq )
916+ done_idx = done_or_truncated .argmax (axis = 1 )
915917 # If no done/truncation, keep full sequence
916- has_done_or_trunc = done_or_trunc .any (axis = 1 )
917- done_idx = np .where (has_done_or_trunc , done_idx , self .n_steps - 1 )
918+ has_done_or_truncated = done_or_truncated .any (axis = 1 )
919+ done_idx = np .where (has_done_or_truncated , done_idx , self .n_steps - 1 )
918920
919921 mask = np .arange (self .n_steps ).reshape (1 , - 1 ) <= done_idx [:, None ] # shape: [batch, n_steps]
920922 # Compute discount factors for bootstrapping (using target Q-Value)
0 commit comments