Skip to content

Commit 7d1b46e

Browse files
committed
Add docstring and remove unused method
1 parent b91050c commit 7d1b46e

File tree

2 files changed

+50
-19
lines changed

2 files changed

+50
-19
lines changed

stable_baselines3/common/buffers.py

Lines changed: 49 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -841,21 +841,64 @@ def _get_samples( # type: ignore[override]
841841

842842

843843
class NStepReplayBuffer(ReplayBuffer):
844+
"""
845+
Replay buffer used for computing n-step returns in off-policy algorithms like SAC/DQN.
846+
847+
The n-step return combines multiple steps of future rewards,
848+
discounted by the discount factor gamma.
849+
This can help improve sample efficiency and credit assignment.
850+
851+
This implementation uses the same storage space as a normal replay buffer,
852+
and NumPy vectorized operations at sampling time to efficiently compute the
853+
n-step return, without requiring extra memory.
854+
855+
This implementation is inspired by:
856+
- https://github.com/younggyoseo/FastTD3
857+
- https://github.com/DLR-RM/stable-baselines3/pull/81
858+
859+
It avoids potential issues such as:
860+
- https://github.com/younggyoseo/FastTD3/issues/6
861+
862+
:param buffer_size: Max number of element in the buffer
863+
:param observation_space: Observation space
864+
:param action_space: Action space
865+
:param device: PyTorch device
866+
:param n_envs: Number of parallel environments
867+
:param optimize_memory_usage: Not supported
868+
:param handle_timeout_termination: Handle timeout termination (due to timelimit)
869+
separately and treat the task as infinite horizon task.
870+
https://github.com/DLR-RM/stable-baselines3/issues/284
871+
:param n_steps: Number of steps to accumulate rewards for n-step returns
872+
:param gamma: Discount factor for future rewards
873+
"""
874+
844875
def __init__(self, *args, n_steps: int = 3, gamma: float = 0.99, **kwargs):
845876
super().__init__(*args, **kwargs)
846877
self.n_steps = n_steps
847878
self.gamma = gamma
879+
if self.optimize_memory_usage:
880+
raise NotImplementedError("NStepReplayBuffer doesn't support optimize_memory_usage=True")
848881

849882
def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> ReplayBufferSamples:
850-
n_steps = self.n_steps
883+
"""
884+
Sample a batch of transitions and compute n-step returns.
851885
886+
For each sampled transition, the method computes the cumulative discounted reward over
887+
the next `n_steps`, properly handling episode termination and timeouts.
888+
The next observation and done flag correspond to the last transition in the computed n-step trajectory.
889+
890+
:param batch_inds: Indices of samples to retrieve
891+
:param env: Optional VecNormalize environment for normalizing observations/rewards
892+
:return: A batch of samples with n-step returns and corresponding observations/actions
893+
"""
852894
# Randomly choose env indices for each sample
853895
env_indices = np.random.randint(0, self.n_envs, size=batch_inds.shape)
854896

855897
# Compute n-step indices with wrap-around
856-
steps = np.arange(n_steps).reshape(1, -1) # shape: [1, n_steps]
898+
steps = np.arange(self.n_steps).reshape(1, -1) # shape: [1, n_steps]
857899
# Note: the self.pos index is dangerous (will overlap two different episodes when buffer is full)
858-
# so we set self.pos-1 to truncated=True (temporarly) if done=False
900+
# so we set self.pos-1 to truncated=True (temporarily) if done=False
901+
# TODO: avoid copying the whole array (requires some more indices trickery)
859902
safe_timeouts = self.timeouts.copy()
860903
safe_timeouts[self.pos - 1, :] = np.logical_not(self.dones[self.pos - 1, :])
861904

@@ -871,12 +914,12 @@ def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = Non
871914
done_idx = done_or_trunc.argmax(axis=1)
872915
# If no done/truncation, keep full sequence
873916
has_done_or_trunc = done_or_trunc.any(axis=1)
874-
done_idx = np.where(has_done_or_trunc, done_idx, n_steps - 1)
917+
done_idx = np.where(has_done_or_trunc, done_idx, self.n_steps - 1)
875918

876-
mask = np.arange(n_steps).reshape(1, -1) <= done_idx[:, None] # shape: [batch, n_steps]
919+
mask = np.arange(self.n_steps).reshape(1, -1) <= done_idx[:, None] # shape: [batch, n_steps]
877920

878921
# Apply discount
879-
discounts = self.gamma ** np.arange(n_steps, dtype=np.float32).reshape(1, -1) # [1, n_steps]
922+
discounts = self.gamma ** np.arange(self.n_steps, dtype=np.float32).reshape(1, -1) # [1, n_steps]
880923
discounted_rewards = rewards_seq * discounts * mask
881924
n_step_returns = discounted_rewards.sum(axis=1, keepdims=True) # [batch, 1]
882925

tests/test_n_step_replay.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -82,21 +82,9 @@ def fill_buffer(buffer, length, done_at=None, truncated_at=None):
8282

8383
def compute_expected_nstep_reward(gamma, n_steps, stop_idx=None):
8484
"""
85-
Compute the expected n-step reward starting from zero idx,
85+
Compute the expected n-step reward for the test env (reward=1 for each step),
8686
optionally stopping early due to termination/truncation.
8787
"""
88-
rewards = [1.0 * (gamma**i) for i in range(n_steps)]
89-
if stop_idx is not None:
90-
rewards = rewards[: stop_idx + 1]
91-
return sum(rewards)
92-
93-
94-
def compute_expected_nstep_reward2(gamma, n_steps, stop_idx=None):
95-
"""
96-
Compute the expected n-step reward,
97-
optionally stopping early due to termination/truncation.
98-
Alternative implementation that can handle different rewards.
99-
"""
10088
returns = np.zeros(n_steps)
10189
rewards = np.ones(n_steps)
10290
last_sum = 0.0

0 commit comments

Comments
 (0)