@@ -838,3 +838,113 @@ def _get_samples( # type: ignore[override]
838838 advantages = self .to_torch (self .advantages [batch_inds ].flatten ()),
839839 returns = self .to_torch (self .returns [batch_inds ].flatten ()),
840840 )
841+
842+
843+ class NStepReplayBuffer (ReplayBuffer ):
844+ """
845+ Replay buffer used for computing n-step returns in off-policy algorithms like SAC/DQN.
846+
847+ The n-step return combines multiple steps of future rewards,
848+ discounted by the discount factor gamma.
849+ This can help improve sample efficiency and credit assignment.
850+
851+ This implementation uses the same storage space as a normal replay buffer,
852+ and NumPy vectorized operations at sampling time to efficiently compute the
853+ n-step return, without requiring extra memory.
854+
855+ This implementation is inspired by:
856+ - https://github.com/younggyoseo/FastTD3
857+ - https://github.com/DLR-RM/stable-baselines3/pull/81
858+
859+ It avoids potential issues such as:
860+ - https://github.com/younggyoseo/FastTD3/issues/6
861+
862+ :param buffer_size: Max number of element in the buffer
863+ :param observation_space: Observation space
864+ :param action_space: Action space
865+ :param device: PyTorch device
866+ :param n_envs: Number of parallel environments
867+ :param optimize_memory_usage: Not supported
868+ :param handle_timeout_termination: Handle timeout termination (due to timelimit)
869+ separately and treat the task as infinite horizon task.
870+ https://github.com/DLR-RM/stable-baselines3/issues/284
871+ :param n_steps: Number of steps to accumulate rewards for n-step returns
872+ :param gamma: Discount factor for future rewards
873+ """
874+
875+ def __init__ (self , * args , n_steps : int = 3 , gamma : float = 0.99 , ** kwargs ):
876+ super ().__init__ (* args , ** kwargs )
877+ self .n_steps = n_steps
878+ self .gamma = gamma
879+ if self .optimize_memory_usage :
880+ raise NotImplementedError ("NStepReplayBuffer doesn't support optimize_memory_usage=True" )
881+
882+ def _get_samples (self , batch_inds : np .ndarray , env : Optional [VecNormalize ] = None ) -> ReplayBufferSamples :
883+ """
884+ Sample a batch of transitions and compute n-step returns.
885+
886+ For each sampled transition, the method computes the cumulative discounted reward over
887+ the next `n_steps`, properly handling episode termination and timeouts.
888+ The next observation and done flag correspond to the last transition in the computed n-step trajectory.
889+
890+ :param batch_inds: Indices of samples to retrieve
891+ :param env: Optional VecNormalize environment for normalizing observations/rewards
892+ :return: A batch of samples with n-step returns and corresponding observations/actions
893+ """
894+ # Randomly choose env indices for each sample
895+ env_indices = np .random .randint (0 , self .n_envs , size = batch_inds .shape )
896+
897+ # Note: the self.pos index is dangerous (will overlap two different episodes when buffer is full)
898+ # so we set self.pos-1 to truncated=True (temporarily) if done=False and truncated=False
899+ last_valid_index = self .pos - 1
900+ original_timeout_values = self .timeouts [last_valid_index ].copy ()
901+ self .timeouts [last_valid_index ] = np .logical_or (original_timeout_values , np .logical_not (self .dones [last_valid_index ]))
902+
903+ # Compute n-step indices with wrap-around
904+ steps = np .arange (self .n_steps ).reshape (1 , - 1 ) # shape: [1, n_steps]
905+ indices = (batch_inds [:, None ] + steps ) % self .buffer_size # shape: [batch, n_steps]
906+
907+ # Retrieve sequences of transitions
908+ rewards_seq = self ._normalize_reward (self .rewards [indices , env_indices [:, None ]], env ) # [batch, n_steps]
909+ dones_seq = self .dones [indices , env_indices [:, None ]] # [batch, n_steps]
910+ truncated_seq = self .timeouts [indices , env_indices [:, None ]] # [batch, n_steps]
911+
912+ # Compute masks: 1 until first done/truncation (inclusive)
913+ done_or_truncated = np .logical_or (dones_seq , truncated_seq )
914+ done_idx = done_or_truncated .argmax (axis = 1 )
915+ # If no done/truncation, keep full sequence
916+ has_done_or_truncated = done_or_truncated .any (axis = 1 )
917+ done_idx = np .where (has_done_or_truncated , done_idx , self .n_steps - 1 )
918+
919+ mask = np .arange (self .n_steps ).reshape (1 , - 1 ) <= done_idx [:, None ] # shape: [batch, n_steps]
920+ # Compute discount factors for bootstrapping (using target Q-Value)
921+ # It is gamma ** n_steps by default but should be adjusted in case of early termination/truncation.
922+ target_q_discounts = self .gamma ** mask .sum (axis = 1 , keepdims = True ).astype (np .float32 ) # [batch, 1]
923+
924+ # Apply discount
925+ discounts = self .gamma ** np .arange (self .n_steps , dtype = np .float32 ).reshape (1 , - 1 ) # [1, n_steps]
926+ discounted_rewards = rewards_seq * discounts * mask
927+ n_step_returns = discounted_rewards .sum (axis = 1 , keepdims = True ) # [batch, 1]
928+
929+ # Compute indices of next_obs/done at the final point of the n-step transition
930+ last_indices = (batch_inds + done_idx ) % self .buffer_size
931+ next_obs = self ._normalize_obs (self .next_observations [last_indices , env_indices ], env )
932+ next_dones = self .dones [last_indices , env_indices ][:, None ].astype (np .float32 )
933+ next_timeouts = self .timeouts [last_indices , env_indices ][:, None ].astype (np .float32 )
934+ final_dones = next_dones * (1.0 - next_timeouts )
935+
936+ # Revert back tmp changes to avoid sampling across episodes
937+ self .timeouts [last_valid_index ] = original_timeout_values
938+
939+ # Gather observations and actions
940+ obs = self ._normalize_obs (self .observations [batch_inds , env_indices ], env )
941+ actions = self .actions [batch_inds , env_indices ]
942+
943+ return ReplayBufferSamples (
944+ observations = self .to_torch (obs ), # type: ignore[arg-type]
945+ actions = self .to_torch (actions ),
946+ next_observations = self .to_torch (next_obs ), # type: ignore[arg-type]
947+ dones = self .to_torch (final_dones ),
948+ rewards = self .to_torch (n_step_returns ),
949+ discounts = self .to_torch (target_q_discounts ),
950+ )
0 commit comments