Skip to content

Commit c7d151d

Browse files
authored
Merge branch 'master' into master
2 parents 6134732 + e206fc5 commit c7d151d

File tree

13 files changed

+333
-13
lines changed

13 files changed

+333
-13
lines changed

docs/misc/changelog.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,16 @@
33
Changelog
44
==========
55

6-
Release 2.6.1a1 (WIP)
6+
Release 2.7.0a0 (WIP)
77
--------------------------
88

99
Breaking Changes:
1010
^^^^^^^^^^^^^^^^^
1111

1212
New Features:
1313
^^^^^^^^^^^^^
14+
- Added support for n-step returns for off-policy algorithms via the `n_steps` parameter
15+
- Added ``NStepReplayBuffer`` that allows to compute n-step returns without additional memory requirement (and without for loops)
1416

1517
Bug Fixes:
1618
^^^^^^^^^^

stable_baselines3/common/base_class.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,6 @@ def _excluded_save_params(self) -> list[str]:
316316
"replay_buffer",
317317
"rollout_buffer",
318318
"_vec_normalize_env",
319-
"_episode_storage",
320319
"_logger",
321320
"_custom_logger",
322321
]

stable_baselines3/common/buffers.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -838,3 +838,113 @@ def _get_samples( # type: ignore[override]
838838
advantages=self.to_torch(self.advantages[batch_inds].flatten()),
839839
returns=self.to_torch(self.returns[batch_inds].flatten()),
840840
)
841+
842+
843+
class NStepReplayBuffer(ReplayBuffer):
844+
"""
845+
Replay buffer used for computing n-step returns in off-policy algorithms like SAC/DQN.
846+
847+
The n-step return combines multiple steps of future rewards,
848+
discounted by the discount factor gamma.
849+
This can help improve sample efficiency and credit assignment.
850+
851+
This implementation uses the same storage space as a normal replay buffer,
852+
and NumPy vectorized operations at sampling time to efficiently compute the
853+
n-step return, without requiring extra memory.
854+
855+
This implementation is inspired by:
856+
- https://github.com/younggyoseo/FastTD3
857+
- https://github.com/DLR-RM/stable-baselines3/pull/81
858+
859+
It avoids potential issues such as:
860+
- https://github.com/younggyoseo/FastTD3/issues/6
861+
862+
:param buffer_size: Max number of element in the buffer
863+
:param observation_space: Observation space
864+
:param action_space: Action space
865+
:param device: PyTorch device
866+
:param n_envs: Number of parallel environments
867+
:param optimize_memory_usage: Not supported
868+
:param handle_timeout_termination: Handle timeout termination (due to timelimit)
869+
separately and treat the task as infinite horizon task.
870+
https://github.com/DLR-RM/stable-baselines3/issues/284
871+
:param n_steps: Number of steps to accumulate rewards for n-step returns
872+
:param gamma: Discount factor for future rewards
873+
"""
874+
875+
def __init__(self, *args, n_steps: int = 3, gamma: float = 0.99, **kwargs):
876+
super().__init__(*args, **kwargs)
877+
self.n_steps = n_steps
878+
self.gamma = gamma
879+
if self.optimize_memory_usage:
880+
raise NotImplementedError("NStepReplayBuffer doesn't support optimize_memory_usage=True")
881+
882+
def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> ReplayBufferSamples:
883+
"""
884+
Sample a batch of transitions and compute n-step returns.
885+
886+
For each sampled transition, the method computes the cumulative discounted reward over
887+
the next `n_steps`, properly handling episode termination and timeouts.
888+
The next observation and done flag correspond to the last transition in the computed n-step trajectory.
889+
890+
:param batch_inds: Indices of samples to retrieve
891+
:param env: Optional VecNormalize environment for normalizing observations/rewards
892+
:return: A batch of samples with n-step returns and corresponding observations/actions
893+
"""
894+
# Randomly choose env indices for each sample
895+
env_indices = np.random.randint(0, self.n_envs, size=batch_inds.shape)
896+
897+
# Note: the self.pos index is dangerous (will overlap two different episodes when buffer is full)
898+
# so we set self.pos-1 to truncated=True (temporarily) if done=False and truncated=False
899+
last_valid_index = self.pos - 1
900+
original_timeout_values = self.timeouts[last_valid_index].copy()
901+
self.timeouts[last_valid_index] = np.logical_or(original_timeout_values, np.logical_not(self.dones[last_valid_index]))
902+
903+
# Compute n-step indices with wrap-around
904+
steps = np.arange(self.n_steps).reshape(1, -1) # shape: [1, n_steps]
905+
indices = (batch_inds[:, None] + steps) % self.buffer_size # shape: [batch, n_steps]
906+
907+
# Retrieve sequences of transitions
908+
rewards_seq = self._normalize_reward(self.rewards[indices, env_indices[:, None]], env) # [batch, n_steps]
909+
dones_seq = self.dones[indices, env_indices[:, None]] # [batch, n_steps]
910+
truncated_seq = self.timeouts[indices, env_indices[:, None]] # [batch, n_steps]
911+
912+
# Compute masks: 1 until first done/truncation (inclusive)
913+
done_or_truncated = np.logical_or(dones_seq, truncated_seq)
914+
done_idx = done_or_truncated.argmax(axis=1)
915+
# If no done/truncation, keep full sequence
916+
has_done_or_truncated = done_or_truncated.any(axis=1)
917+
done_idx = np.where(has_done_or_truncated, done_idx, self.n_steps - 1)
918+
919+
mask = np.arange(self.n_steps).reshape(1, -1) <= done_idx[:, None] # shape: [batch, n_steps]
920+
# Compute discount factors for bootstrapping (using target Q-Value)
921+
# It is gamma ** n_steps by default but should be adjusted in case of early termination/truncation.
922+
target_q_discounts = self.gamma ** mask.sum(axis=1, keepdims=True).astype(np.float32) # [batch, 1]
923+
924+
# Apply discount
925+
discounts = self.gamma ** np.arange(self.n_steps, dtype=np.float32).reshape(1, -1) # [1, n_steps]
926+
discounted_rewards = rewards_seq * discounts * mask
927+
n_step_returns = discounted_rewards.sum(axis=1, keepdims=True) # [batch, 1]
928+
929+
# Compute indices of next_obs/done at the final point of the n-step transition
930+
last_indices = (batch_inds + done_idx) % self.buffer_size
931+
next_obs = self._normalize_obs(self.next_observations[last_indices, env_indices], env)
932+
next_dones = self.dones[last_indices, env_indices][:, None].astype(np.float32)
933+
next_timeouts = self.timeouts[last_indices, env_indices][:, None].astype(np.float32)
934+
final_dones = next_dones * (1.0 - next_timeouts)
935+
936+
# Revert back tmp changes to avoid sampling across episodes
937+
self.timeouts[last_valid_index] = original_timeout_values
938+
939+
# Gather observations and actions
940+
obs = self._normalize_obs(self.observations[batch_inds, env_indices], env)
941+
actions = self.actions[batch_inds, env_indices]
942+
943+
return ReplayBufferSamples(
944+
observations=self.to_torch(obs), # type: ignore[arg-type]
945+
actions=self.to_torch(actions),
946+
next_observations=self.to_torch(next_obs), # type: ignore[arg-type]
947+
dones=self.to_torch(final_dones),
948+
rewards=self.to_torch(n_step_returns),
949+
discounts=self.to_torch(target_q_discounts),
950+
)

stable_baselines3/common/off_policy_algorithm.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from gymnasium import spaces
1212

1313
from stable_baselines3.common.base_class import BaseAlgorithm
14-
from stable_baselines3.common.buffers import DictReplayBuffer, ReplayBuffer
14+
from stable_baselines3.common.buffers import DictReplayBuffer, NStepReplayBuffer, ReplayBuffer
1515
from stable_baselines3.common.callbacks import BaseCallback
1616
from stable_baselines3.common.noise import ActionNoise, VectorizedActionNoise
1717
from stable_baselines3.common.policies import BasePolicy
@@ -51,6 +51,7 @@ class OffPolicyAlgorithm(BaseAlgorithm):
5151
:param optimize_memory_usage: Enable a memory efficient variant of the replay buffer
5252
at a cost of more complexity.
5353
See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195
54+
:param n_steps: When n_step > 1, uses n-step return (with the NStepReplayBuffer) when updating the Q-value network.
5455
:param policy_kwargs: Additional arguments to be passed to the policy on creation
5556
:param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average
5657
the reported success rate, mean episode length, and mean reward over
@@ -93,6 +94,7 @@ def __init__(
9394
replay_buffer_class: Optional[type[ReplayBuffer]] = None,
9495
replay_buffer_kwargs: Optional[dict[str, Any]] = None,
9596
optimize_memory_usage: bool = False,
97+
n_steps: int = 1,
9698
policy_kwargs: Optional[dict[str, Any]] = None,
9799
stats_window_size: int = 100,
98100
tensorboard_log: Optional[str] = None,
@@ -134,7 +136,7 @@ def __init__(
134136
self.replay_buffer: Optional[ReplayBuffer] = None
135137
self.replay_buffer_class = replay_buffer_class
136138
self.replay_buffer_kwargs = replay_buffer_kwargs or {}
137-
self._episode_storage = None
139+
self.n_steps = n_steps
138140

139141
# Save train freq parameter, will be converted later to TrainFreq object
140142
self.train_freq = train_freq
@@ -176,6 +178,11 @@ def _setup_model(self) -> None:
176178
if self.replay_buffer_class is None:
177179
if isinstance(self.observation_space, spaces.Dict):
178180
self.replay_buffer_class = DictReplayBuffer
181+
assert self.n_steps == 1, "N-step returns are not supported for Dict observation spaces yet."
182+
elif self.n_steps > 1:
183+
self.replay_buffer_class = NStepReplayBuffer
184+
# Add required arguments for computing n-step returns
185+
self.replay_buffer_kwargs.update({"n_steps": self.n_steps, "gamma": self.gamma})
179186
else:
180187
self.replay_buffer_class = ReplayBuffer
181188

stable_baselines3/common/type_aliases.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ class ReplayBufferSamples(NamedTuple):
5252
next_observations: th.Tensor
5353
dones: th.Tensor
5454
rewards: th.Tensor
55+
# For n-step replay buffer
56+
discounts: Optional[th.Tensor] = None
5557

5658

5759
class DictReplayBufferSamples(NamedTuple):
@@ -60,6 +62,7 @@ class DictReplayBufferSamples(NamedTuple):
6062
next_observations: TensorDict
6163
dones: th.Tensor
6264
rewards: th.Tensor
65+
discounts: Optional[th.Tensor] = None
6366

6467

6568
class RolloutReturn(NamedTuple):

stable_baselines3/ddpg/ddpg.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ class DDPG(TD3):
4444
:param optimize_memory_usage: Enable a memory efficient variant of the replay buffer
4545
at a cost of more complexity.
4646
See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195
47+
:param n_steps: When n_step > 1, uses n-step return (with the NStepReplayBuffer) when updating the Q-value network.
4748
:param policy_kwargs: additional arguments to be passed to the policy on creation. See :ref:`ddpg_policies`
4849
:param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for
4950
debug messages
@@ -69,6 +70,7 @@ def __init__(
6970
replay_buffer_class: Optional[type[ReplayBuffer]] = None,
7071
replay_buffer_kwargs: Optional[dict[str, Any]] = None,
7172
optimize_memory_usage: bool = False,
73+
n_steps: int = 1,
7274
tensorboard_log: Optional[str] = None,
7375
policy_kwargs: Optional[dict[str, Any]] = None,
7476
verbose: int = 0,
@@ -90,12 +92,13 @@ def __init__(
9092
action_noise=action_noise,
9193
replay_buffer_class=replay_buffer_class,
9294
replay_buffer_kwargs=replay_buffer_kwargs,
95+
optimize_memory_usage=optimize_memory_usage,
96+
n_steps=n_steps,
9397
policy_kwargs=policy_kwargs,
9498
tensorboard_log=tensorboard_log,
9599
verbose=verbose,
96100
device=device,
97101
seed=seed,
98-
optimize_memory_usage=optimize_memory_usage,
99102
# Remove all tricks from TD3 to obtain DDPG:
100103
# we still need to specify target_policy_noise > 0 to avoid errors
101104
policy_delay=1,

stable_baselines3/dqn/dqn.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ class DQN(OffPolicyAlgorithm):
4444
:param optimize_memory_usage: Enable a memory efficient variant of the replay buffer
4545
at a cost of more complexity.
4646
See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195
47+
:param n_steps: When n_step > 1, uses n-step return (with the NStepReplayBuffer) when updating the Q-value network.
4748
:param target_update_interval: update the target network every ``target_update_interval``
4849
environment steps.
4950
:param exploration_fraction: fraction of entire training period over which the exploration rate is reduced
@@ -88,6 +89,7 @@ def __init__(
8889
replay_buffer_class: Optional[type[ReplayBuffer]] = None,
8990
replay_buffer_kwargs: Optional[dict[str, Any]] = None,
9091
optimize_memory_usage: bool = False,
92+
n_steps: int = 1,
9193
target_update_interval: int = 10000,
9294
exploration_fraction: float = 0.1,
9395
exploration_initial_eps: float = 1.0,
@@ -115,14 +117,15 @@ def __init__(
115117
action_noise=None, # No action noise
116118
replay_buffer_class=replay_buffer_class,
117119
replay_buffer_kwargs=replay_buffer_kwargs,
120+
optimize_memory_usage=optimize_memory_usage,
121+
n_steps=n_steps,
118122
policy_kwargs=policy_kwargs,
119123
stats_window_size=stats_window_size,
120124
tensorboard_log=tensorboard_log,
121125
verbose=verbose,
122126
device=device,
123127
seed=seed,
124128
sde_support=False,
125-
optimize_memory_usage=optimize_memory_usage,
126129
supported_action_spaces=(spaces.Discrete,),
127130
support_multi_env=True,
128131
)
@@ -191,6 +194,8 @@ def train(self, gradient_steps: int, batch_size: int = 100) -> None:
191194
for _ in range(gradient_steps):
192195
# Sample replay buffer
193196
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) # type: ignore[union-attr]
197+
# For n-step replay, discount factor is gamma**n_steps (when no early termination)
198+
discounts = replay_data.discounts if replay_data.discounts is not None else self.gamma
194199

195200
with th.no_grad():
196201
# Compute the next Q-values using the target network
@@ -200,7 +205,7 @@ def train(self, gradient_steps: int, batch_size: int = 100) -> None:
200205
# Avoid potential broadcast issue
201206
next_q_values = next_q_values.reshape(-1, 1)
202207
# 1-step TD target
203-
target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values
208+
target_q_values = replay_data.rewards + (1 - replay_data.dones) * discounts * next_q_values
204209

205210
# Get current Q-values estimates
206211
current_q_values = self.q_net(replay_data.observations)

stable_baselines3/her/her_replay_buffer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,7 @@ def truncate_last_trajectory(self) -> None:
402402
self.dones[self.pos - 1, env_idx] = True
403403
# make sure that last episodes can be sampled and
404404
# update next episode start (self._current_ep_start)
405-
self._compute_episode_length(env_idx)
405+
self._compute_episode_length(int(env_idx))
406406
# handle infinite horizon tasks
407407
if self.handle_timeout_termination:
408408
self.timeouts[self.pos - 1, env_idx] = True # not an actual timeout, but it allows bootstrapping

stable_baselines3/sac/sac.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ class SAC(OffPolicyAlgorithm):
5353
:param optimize_memory_usage: Enable a memory efficient variant of the replay buffer
5454
at a cost of more complexity.
5555
See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195
56+
:param n_steps: When n_step > 1, uses n-step return (with the NStepReplayBuffer) when updating the Q-value network.
5657
:param ent_coef: Entropy regularization coefficient. (Equivalent to
5758
inverse of reward scale in the original SAC paper.) Controlling exploration/exploitation trade-off.
5859
Set it to 'auto' to learn it automatically (and 'auto_0.1' for using 0.1 as initial value)
@@ -103,6 +104,7 @@ def __init__(
103104
replay_buffer_class: Optional[type[ReplayBuffer]] = None,
104105
replay_buffer_kwargs: Optional[dict[str, Any]] = None,
105106
optimize_memory_usage: bool = False,
107+
n_steps: int = 1,
106108
ent_coef: Union[str, float] = "auto",
107109
target_update_interval: int = 1,
108110
target_entropy: Union[str, float] = "auto",
@@ -131,6 +133,8 @@ def __init__(
131133
action_noise,
132134
replay_buffer_class=replay_buffer_class,
133135
replay_buffer_kwargs=replay_buffer_kwargs,
136+
optimize_memory_usage=optimize_memory_usage,
137+
n_steps=n_steps,
134138
policy_kwargs=policy_kwargs,
135139
stats_window_size=stats_window_size,
136140
tensorboard_log=tensorboard_log,
@@ -140,7 +144,6 @@ def __init__(
140144
use_sde=use_sde,
141145
sde_sample_freq=sde_sample_freq,
142146
use_sde_at_warmup=use_sde_at_warmup,
143-
optimize_memory_usage=optimize_memory_usage,
144147
supported_action_spaces=(spaces.Box,),
145148
support_multi_env=True,
146149
)
@@ -213,6 +216,8 @@ def train(self, gradient_steps: int, batch_size: int = 64) -> None:
213216
for gradient_step in range(gradient_steps):
214217
# Sample replay buffer
215218
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) # type: ignore[union-attr]
219+
# For n-step replay, discount factor is gamma**n_steps (when no early termination)
220+
discounts = replay_data.discounts if replay_data.discounts is not None else self.gamma
216221

217222
# We need to sample because `log_std` may have changed between two gradient steps
218223
if self.use_sde:
@@ -252,7 +257,7 @@ def train(self, gradient_steps: int, batch_size: int = 64) -> None:
252257
# add entropy term
253258
next_q_values = next_q_values - ent_coef * next_log_prob.reshape(-1, 1)
254259
# td error + entropy term
255-
target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values
260+
target_q_values = replay_data.rewards + (1 - replay_data.dones) * discounts * next_q_values
256261

257262
# Get current Q-values estimates for each critic network
258263
# using action from the replay buffer

0 commit comments

Comments
 (0)