@@ -841,21 +841,64 @@ def _get_samples( # type: ignore[override]
841841
842842
843843class NStepReplayBuffer (ReplayBuffer ):
844+ """
845+ Replay buffer used for computing n-step returns in off-policy algorithms like SAC/DQN.
846+
847+ The n-step return combines multiple steps of future rewards,
848+ discounted by the discount factor gamma.
849+ This can help improve sample efficiency and credit assignment.
850+
851+ This implementation uses the same storage space as a normal replay buffer,
852+ and NumPy vectorized operations at sampling time to efficiently compute the
853+ n-step return, without requiring extra memory.
854+
855+ This implementation is inspired by:
856+ - https://github.com/younggyoseo/FastTD3
857+ - https://github.com/DLR-RM/stable-baselines3/pull/81
858+
859+ It avoids potential issues such as:
860+ - https://github.com/younggyoseo/FastTD3/issues/6
861+
862+ :param buffer_size: Max number of element in the buffer
863+ :param observation_space: Observation space
864+ :param action_space: Action space
865+ :param device: PyTorch device
866+ :param n_envs: Number of parallel environments
867+ :param optimize_memory_usage: Not supported
868+ :param handle_timeout_termination: Handle timeout termination (due to timelimit)
869+ separately and treat the task as infinite horizon task.
870+ https://github.com/DLR-RM/stable-baselines3/issues/284
871+ :param n_steps: Number of steps to accumulate rewards for n-step returns
872+ :param gamma: Discount factor for future rewards
873+ """
874+
844875 def __init__ (self , * args , n_steps : int = 3 , gamma : float = 0.99 , ** kwargs ):
845876 super ().__init__ (* args , ** kwargs )
846877 self .n_steps = n_steps
847878 self .gamma = gamma
879+ if self .optimize_memory_usage :
880+ raise NotImplementedError ("NStepReplayBuffer doesn't support optimize_memory_usage=True" )
848881
849882 def _get_samples (self , batch_inds : np .ndarray , env : Optional [VecNormalize ] = None ) -> ReplayBufferSamples :
850- n_steps = self .n_steps
883+ """
884+ Sample a batch of transitions and compute n-step returns.
851885
886+ For each sampled transition, the method computes the cumulative discounted reward over
887+ the next `n_steps`, properly handling episode termination and timeouts.
888+ The next observation and done flag correspond to the last transition in the computed n-step trajectory.
889+
890+ :param batch_inds: Indices of samples to retrieve
891+ :param env: Optional VecNormalize environment for normalizing observations/rewards
892+ :return: A batch of samples with n-step returns and corresponding observations/actions
893+ """
852894 # Randomly choose env indices for each sample
853895 env_indices = np .random .randint (0 , self .n_envs , size = batch_inds .shape )
854896
855897 # Compute n-step indices with wrap-around
856- steps = np .arange (n_steps ).reshape (1 , - 1 ) # shape: [1, n_steps]
898+ steps = np .arange (self . n_steps ).reshape (1 , - 1 ) # shape: [1, n_steps]
857899 # Note: the self.pos index is dangerous (will overlap two different episodes when buffer is full)
858- # so we set self.pos-1 to truncated=True (temporarly) if done=False
900+ # so we set self.pos-1 to truncated=True (temporarily) if done=False
901+ # TODO: avoid copying the whole array (requires some more indices trickery)
859902 safe_timeouts = self .timeouts .copy ()
860903 safe_timeouts [self .pos - 1 , :] = np .logical_not (self .dones [self .pos - 1 , :])
861904
@@ -871,12 +914,12 @@ def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = Non
871914 done_idx = done_or_trunc .argmax (axis = 1 )
872915 # If no done/truncation, keep full sequence
873916 has_done_or_trunc = done_or_trunc .any (axis = 1 )
874- done_idx = np .where (has_done_or_trunc , done_idx , n_steps - 1 )
917+ done_idx = np .where (has_done_or_trunc , done_idx , self . n_steps - 1 )
875918
876- mask = np .arange (n_steps ).reshape (1 , - 1 ) <= done_idx [:, None ] # shape: [batch, n_steps]
919+ mask = np .arange (self . n_steps ).reshape (1 , - 1 ) <= done_idx [:, None ] # shape: [batch, n_steps]
877920
878921 # Apply discount
879- discounts = self .gamma ** np .arange (n_steps , dtype = np .float32 ).reshape (1 , - 1 ) # [1, n_steps]
922+ discounts = self .gamma ** np .arange (self . n_steps , dtype = np .float32 ).reshape (1 , - 1 ) # [1, n_steps]
880923 discounted_rewards = rewards_seq * discounts * mask
881924 n_step_returns = discounted_rewards .sum (axis = 1 , keepdims = True ) # [batch, 1]
882925
0 commit comments