Skip to content

Commit 6b65820

Browse files
committed
Add n_steps param and update changelog
1 parent bbf1036 commit 6b65820

File tree

9 files changed

+36
-19
lines changed

9 files changed

+36
-19
lines changed

docs/misc/changelog.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,16 @@
33
Changelog
44
==========
55

6-
Release 2.6.1a1 (WIP)
6+
Release 2.7.0a0 (WIP)
77
--------------------------
88

99
Breaking Changes:
1010
^^^^^^^^^^^^^^^^^
1111

1212
New Features:
1313
^^^^^^^^^^^^^
14+
- Added support for n-step returns for off-policy algorithms via the `n_steps` parameter
15+
- Added ``NStepReplayBuffer`` that allows to compute n-step returns without additional memory requirement (and without for loops)
1416

1517
Bug Fixes:
1618
^^^^^^^^^^

stable_baselines3/common/base_class.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,6 @@ def _excluded_save_params(self) -> list[str]:
316316
"replay_buffer",
317317
"rollout_buffer",
318318
"_vec_normalize_env",
319-
"_episode_storage",
320319
"_logger",
321320
"_custom_logger",
322321
]

stable_baselines3/common/off_policy_algorithm.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from gymnasium import spaces
1212

1313
from stable_baselines3.common.base_class import BaseAlgorithm
14-
from stable_baselines3.common.buffers import DictReplayBuffer, ReplayBuffer
14+
from stable_baselines3.common.buffers import DictReplayBuffer, NStepReplayBuffer, ReplayBuffer
1515
from stable_baselines3.common.callbacks import BaseCallback
1616
from stable_baselines3.common.noise import ActionNoise, VectorizedActionNoise
1717
from stable_baselines3.common.policies import BasePolicy
@@ -51,6 +51,7 @@ class OffPolicyAlgorithm(BaseAlgorithm):
5151
:param optimize_memory_usage: Enable a memory efficient variant of the replay buffer
5252
at a cost of more complexity.
5353
See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195
54+
:param n_steps: When n_step > 1, uses n-step return (with the NStepReplayBuffer) when updating the Q-value network.
5455
:param policy_kwargs: Additional arguments to be passed to the policy on creation
5556
:param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average
5657
the reported success rate, mean episode length, and mean reward over
@@ -93,6 +94,7 @@ def __init__(
9394
replay_buffer_class: Optional[type[ReplayBuffer]] = None,
9495
replay_buffer_kwargs: Optional[dict[str, Any]] = None,
9596
optimize_memory_usage: bool = False,
97+
n_steps: int = 1,
9698
policy_kwargs: Optional[dict[str, Any]] = None,
9799
stats_window_size: int = 100,
98100
tensorboard_log: Optional[str] = None,
@@ -134,7 +136,7 @@ def __init__(
134136
self.replay_buffer: Optional[ReplayBuffer] = None
135137
self.replay_buffer_class = replay_buffer_class
136138
self.replay_buffer_kwargs = replay_buffer_kwargs or {}
137-
self._episode_storage = None
139+
self.n_steps = n_steps
138140

139141
# Save train freq parameter, will be converted later to TrainFreq object
140142
self.train_freq = train_freq
@@ -176,6 +178,11 @@ def _setup_model(self) -> None:
176178
if self.replay_buffer_class is None:
177179
if isinstance(self.observation_space, spaces.Dict):
178180
self.replay_buffer_class = DictReplayBuffer
181+
assert self.n_steps == 1, "N-step returns are not supported for Dict observation spaces yet."
182+
elif self.n_steps > 1:
183+
self.replay_buffer_class = NStepReplayBuffer
184+
# Add required arguments for computing n-step returns
185+
self.replay_buffer_kwargs.update({"n_steps": self.n_steps, "gamma": self.gamma})
179186
else:
180187
self.replay_buffer_class = ReplayBuffer
181188

stable_baselines3/ddpg/ddpg.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ class DDPG(TD3):
4444
:param optimize_memory_usage: Enable a memory efficient variant of the replay buffer
4545
at a cost of more complexity.
4646
See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195
47+
:param n_steps: When n_step > 1, uses n-step return (with the NStepReplayBuffer) when updating the Q-value network.
4748
:param policy_kwargs: additional arguments to be passed to the policy on creation. See :ref:`ddpg_policies`
4849
:param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for
4950
debug messages
@@ -69,6 +70,7 @@ def __init__(
6970
replay_buffer_class: Optional[type[ReplayBuffer]] = None,
7071
replay_buffer_kwargs: Optional[dict[str, Any]] = None,
7172
optimize_memory_usage: bool = False,
73+
n_steps: int = 1,
7274
tensorboard_log: Optional[str] = None,
7375
policy_kwargs: Optional[dict[str, Any]] = None,
7476
verbose: int = 0,
@@ -90,12 +92,13 @@ def __init__(
9092
action_noise=action_noise,
9193
replay_buffer_class=replay_buffer_class,
9294
replay_buffer_kwargs=replay_buffer_kwargs,
95+
optimize_memory_usage=optimize_memory_usage,
96+
n_steps=n_steps,
9397
policy_kwargs=policy_kwargs,
9498
tensorboard_log=tensorboard_log,
9599
verbose=verbose,
96100
device=device,
97101
seed=seed,
98-
optimize_memory_usage=optimize_memory_usage,
99102
# Remove all tricks from TD3 to obtain DDPG:
100103
# we still need to specify target_policy_noise > 0 to avoid errors
101104
policy_delay=1,

stable_baselines3/dqn/dqn.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ class DQN(OffPolicyAlgorithm):
4444
:param optimize_memory_usage: Enable a memory efficient variant of the replay buffer
4545
at a cost of more complexity.
4646
See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195
47+
:param n_steps: When n_step > 1, uses n-step return (with the NStepReplayBuffer) when updating the Q-value network.
4748
:param target_update_interval: update the target network every ``target_update_interval``
4849
environment steps.
4950
:param exploration_fraction: fraction of entire training period over which the exploration rate is reduced
@@ -88,6 +89,7 @@ def __init__(
8889
replay_buffer_class: Optional[type[ReplayBuffer]] = None,
8990
replay_buffer_kwargs: Optional[dict[str, Any]] = None,
9091
optimize_memory_usage: bool = False,
92+
n_steps: int = 1,
9193
target_update_interval: int = 10000,
9294
exploration_fraction: float = 0.1,
9395
exploration_initial_eps: float = 1.0,
@@ -115,14 +117,15 @@ def __init__(
115117
action_noise=None, # No action noise
116118
replay_buffer_class=replay_buffer_class,
117119
replay_buffer_kwargs=replay_buffer_kwargs,
120+
optimize_memory_usage=optimize_memory_usage,
121+
n_steps=n_steps,
118122
policy_kwargs=policy_kwargs,
119123
stats_window_size=stats_window_size,
120124
tensorboard_log=tensorboard_log,
121125
verbose=verbose,
122126
device=device,
123127
seed=seed,
124128
sde_support=False,
125-
optimize_memory_usage=optimize_memory_usage,
126129
supported_action_spaces=(spaces.Discrete,),
127130
support_multi_env=True,
128131
)

stable_baselines3/sac/sac.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ class SAC(OffPolicyAlgorithm):
5353
:param optimize_memory_usage: Enable a memory efficient variant of the replay buffer
5454
at a cost of more complexity.
5555
See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195
56+
:param n_steps: When n_step > 1, uses n-step return (with the NStepReplayBuffer) when updating the Q-value network.
5657
:param ent_coef: Entropy regularization coefficient. (Equivalent to
5758
inverse of reward scale in the original SAC paper.) Controlling exploration/exploitation trade-off.
5859
Set it to 'auto' to learn it automatically (and 'auto_0.1' for using 0.1 as initial value)
@@ -103,6 +104,7 @@ def __init__(
103104
replay_buffer_class: Optional[type[ReplayBuffer]] = None,
104105
replay_buffer_kwargs: Optional[dict[str, Any]] = None,
105106
optimize_memory_usage: bool = False,
107+
n_steps: int = 1,
106108
ent_coef: Union[str, float] = "auto",
107109
target_update_interval: int = 1,
108110
target_entropy: Union[str, float] = "auto",
@@ -131,6 +133,8 @@ def __init__(
131133
action_noise,
132134
replay_buffer_class=replay_buffer_class,
133135
replay_buffer_kwargs=replay_buffer_kwargs,
136+
optimize_memory_usage=optimize_memory_usage,
137+
n_steps=n_steps,
134138
policy_kwargs=policy_kwargs,
135139
stats_window_size=stats_window_size,
136140
tensorboard_log=tensorboard_log,
@@ -140,7 +144,6 @@ def __init__(
140144
use_sde=use_sde,
141145
sde_sample_freq=sde_sample_freq,
142146
use_sde_at_warmup=use_sde_at_warmup,
143-
optimize_memory_usage=optimize_memory_usage,
144147
supported_action_spaces=(spaces.Box,),
145148
support_multi_env=True,
146149
)

stable_baselines3/td3/td3.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ class TD3(OffPolicyAlgorithm):
4848
:param optimize_memory_usage: Enable a memory efficient variant of the replay buffer
4949
at a cost of more complexity.
5050
See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195
51+
:param n_steps: When n_step > 1, uses n-step return (with the NStepReplayBuffer) when updating the Q-value network.
5152
:param policy_delay: Policy and target networks will only be updated once every policy_delay steps
5253
per training steps. The Q values will be updated policy_delay more often (update every training step).
5354
:param target_policy_noise: Standard deviation of Gaussian noise added to target policy
@@ -92,6 +93,7 @@ def __init__(
9293
replay_buffer_class: Optional[type[ReplayBuffer]] = None,
9394
replay_buffer_kwargs: Optional[dict[str, Any]] = None,
9495
optimize_memory_usage: bool = False,
96+
n_steps: int = 1,
9597
policy_delay: int = 2,
9698
target_policy_noise: float = 0.2,
9799
target_noise_clip: float = 0.5,
@@ -117,14 +119,15 @@ def __init__(
117119
action_noise=action_noise,
118120
replay_buffer_class=replay_buffer_class,
119121
replay_buffer_kwargs=replay_buffer_kwargs,
122+
optimize_memory_usage=optimize_memory_usage,
123+
n_steps=n_steps,
120124
policy_kwargs=policy_kwargs,
121125
stats_window_size=stats_window_size,
122126
tensorboard_log=tensorboard_log,
123127
verbose=verbose,
124128
device=device,
125129
seed=seed,
126130
sde_support=False,
127-
optimize_memory_usage=optimize_memory_usage,
128131
supported_action_spaces=(spaces.Box,),
129132
support_multi_env=True,
130133
)

stable_baselines3/version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
2.6.1a1
1+
2.7.0a0

tests/test_n_step_replay.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,33 +2,30 @@
22
import numpy as np
33
import pytest
44

5-
from stable_baselines3 import DQN, SAC
5+
from stable_baselines3 import DQN, SAC, TD3
66
from stable_baselines3.common.buffers import NStepReplayBuffer, ReplayBuffer
77
from stable_baselines3.common.env_util import make_vec_env
88

99

10-
@pytest.mark.parametrize("model_class", [SAC, DQN])
10+
@pytest.mark.parametrize("model_class", [SAC, DQN, TD3])
1111
def test_run(model_class):
1212
env_id = "CartPole-v1" if model_class == DQN else "Pendulum-v1"
1313
env = make_vec_env(env_id, n_envs=2)
14-
15-
n_steps = 2
16-
gamma = 0.99
14+
gamma = 0.989
1715

1816
model = model_class(
1917
"MlpPolicy",
2018
env,
21-
replay_buffer_class=NStepReplayBuffer,
22-
replay_buffer_kwargs=dict(
23-
n_steps=n_steps,
24-
gamma=gamma,
25-
),
2619
train_freq=4,
20+
n_steps=3,
2721
policy_kwargs=dict(net_arch=[64]),
2822
learning_starts=100,
2923
buffer_size=int(2e4),
3024
gamma=gamma,
3125
)
26+
assert isinstance(model.replay_buffer, NStepReplayBuffer)
27+
assert model.replay_buffer.n_steps == 3
28+
assert model.replay_buffer.gamma == gamma
3229

3330
model.learn(total_timesteps=150)
3431

0 commit comments

Comments
 (0)