Skip to content

Commit 29a481a

Browse files
honglu2875araffinqgallouedecBurakDmb
authored
Include running_mean and running_val when updating target networks (#1004)
* include `running_mean` and `running_val` when updating target networks in DQN, SAC, TD3. * Update stable_baselines3/common/utils.py Co-authored-by: Antonin RAFFIN <[email protected]> * Precompute batch norm parameters in `_setup_model` and directly copy them in the target update. * include `running_mean` and `running_val` when updating target networks in DQN, SAC, TD3. * Update stable_baselines3/common/utils.py Co-authored-by: Antonin RAFFIN <[email protected]> * Precompute batch norm parameters in `_setup_model` and directly copy them in the target update. * Fix `DictReplayBuffer.next_observations` type (#1013) * Fix DictReplayBuffer.next_observations type * Update changelog Co-authored-by: Antonin RAFFIN <[email protected]> * Fixed missing verbose parameter passing (#1011) Co-authored-by: Quentin Gallouédec <[email protected]> * Support for `device=auto` buffers and set it as default value (#1009) * Default device is "auto" for buffer + auto device support in BufferBaseClass * Update docstring * Update tests * Unify tests * Update changelog * Fix tests on CUDA device Co-authored-by: Antonin RAFFIN <[email protected]> Co-authored-by: Antonin Raffin <[email protected]> * Precompute batch norm parameters in `_setup_model` and directly copy them in the target update. * Update test * Add comments and update tests * Bump version * Remove one extra space to conform code style. * Update docstrings Co-authored-by: Antonin RAFFIN <[email protected]> Co-authored-by: Quentin Gallouédec <[email protected]> Co-authored-by: Burak Demirbilek <[email protected]> Co-authored-by: Antonin Raffin <[email protected]>
1 parent 01cc127 commit 29a481a

File tree

8 files changed

+91
-25
lines changed

8 files changed

+91
-25
lines changed

docs/misc/changelog.rst

+3-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
Changelog
44
==========
55

6-
Release 1.6.1a1 (WIP)
6+
Release 1.6.1a2 (WIP)
77
---------------------------
88

99
Breaking Changes:
@@ -23,6 +23,7 @@ Bug Fixes:
2323
- Fixed division by zero error when computing FPS when a small number of time has elapsed in operating systems with low-precision timers.
2424
- Added multidimensional action space support (@qgallouedec)
2525
- Fixed missing verbose parameter passing in the ``EvalCallback`` constructor (@burakdmb)
26+
- Fixed the issue that when updating the target network in DQN, SAC, TD3, the ``running_mean`` and ``running_var`` properties of batch norm layers are not updated (@honglu2875)
2627

2728
Deprecations:
2829
^^^^^^^^^^^^^
@@ -1026,4 +1027,4 @@ And all the contributors:
10261027
@eleurent @ac-93 @cove9988 @theDebugger811 @hsuehch @Demetrio92 @thomasgubler @IperGiove @ScheiklP
10271028
@simoninithomas @armandpl @manuel-delverme @Gautam-J @gianlucadecola @buoyancy99 @caburu @xy9485
10281029
@Gregwar @ycheng517 @quantitative-technologies @bcollazo @git-thor @TibiGG @cool-RR @MWeltevrede
1029-
@Melanol @qgallouedec @francescoluciano @jlp-ue @burakdmb @timothe-chaumont
1030+
@Melanol @qgallouedec @francescoluciano @jlp-ue @burakdmb @timothe-chaumont @honglu2875

stable_baselines3/common/utils.py

+25-12
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import random
55
from collections import deque
66
from itertools import zip_longest
7-
from typing import Dict, Iterable, Optional, Tuple, Union
7+
from typing import Dict, Iterable, List, Optional, Tuple, Union
88

99
import gym
1010
import numpy as np
@@ -67,8 +67,8 @@ def update_learning_rate(optimizer: th.optim.Optimizer, learning_rate: float) ->
6767
Update the learning rate for a given optimizer.
6868
Useful when doing linear schedule.
6969
70-
:param optimizer:
71-
:param learning_rate:
70+
:param optimizer: Pytorch optimizer
71+
:param learning_rate: New learning rate value
7272
"""
7373
for param_group in optimizer.param_groups:
7474
param_group["lr"] = learning_rate
@@ -79,8 +79,8 @@ def get_schedule_fn(value_schedule: Union[Schedule, float, int]) -> Schedule:
7979
Transform (if needed) learning rate and clip range (for PPO)
8080
to callable.
8181
82-
:param value_schedule:
83-
:return:
82+
:param value_schedule: Constant value of schedule function
83+
:return: Schedule function (can return constant value)
8484
"""
8585
# If the passed schedule is a float
8686
# create a constant function
@@ -104,7 +104,7 @@ def get_linear_fn(start: float, end: float, end_fraction: float) -> Schedule:
104104
:params end_fraction: fraction of ``progress_remaining``
105105
where end is reached e.g 0.1 then end is reached after 10%
106106
of the complete training process.
107-
:return:
107+
:return: Linear schedule function.
108108
"""
109109

110110
def func(progress_remaining: float) -> float:
@@ -121,8 +121,8 @@ def constant_fn(val: float) -> Schedule:
121121
Create a function that returns a constant
122122
It is useful for learning rate schedule (to avoid code duplication)
123123
124-
:param val:
125-
:return:
124+
:param val: constant value
125+
:return: Constant schedule function.
126126
"""
127127

128128
def func(_):
@@ -139,7 +139,7 @@ def get_device(device: Union[th.device, str] = "auto") -> th.device:
139139
By default, it tries to use the gpu.
140140
141141
:param device: One for 'auto', 'cuda', 'cpu'
142-
:return:
142+
:return: Supported Pytorch device
143143
"""
144144
# Cuda by default
145145
if device == "auto":
@@ -386,12 +386,25 @@ def safe_mean(arr: Union[np.ndarray, list, deque]) -> np.ndarray:
386386
Compute the mean of an array if there is at least one element.
387387
For empty array, return NaN. It is used for logging only.
388388
389-
:param arr:
389+
:param arr: Numpy array or list of values
390390
:return:
391391
"""
392392
return np.nan if len(arr) == 0 else np.mean(arr)
393393

394394

395+
def get_parameters_by_name(model: th.nn.Module, included_names: Iterable[str]) -> List[th.Tensor]:
396+
"""
397+
Extract parameters from the state dict of ``model``
398+
if the name contains one of the strings in ``included_names``.
399+
400+
:param model: the model where the parameters come from.
401+
:param included_names: substrings of names to include.
402+
:return: List of parameters values (Pytorch tensors)
403+
that matches the queried names.
404+
"""
405+
return [param for name, param in model.state_dict().items() if any([key in name for key in included_names])]
406+
407+
395408
def zip_strict(*iterables: Iterable) -> Iterable:
396409
r"""
397410
``zip()`` function but enforces that iterables are of equal length.
@@ -411,8 +424,8 @@ def zip_strict(*iterables: Iterable) -> Iterable:
411424

412425

413426
def polyak_update(
414-
params: Iterable[th.nn.Parameter],
415-
target_params: Iterable[th.nn.Parameter],
427+
params: Iterable[th.Tensor],
428+
target_params: Iterable[th.Tensor],
416429
tau: float,
417430
) -> None:
418431
"""

stable_baselines3/dqn/dqn.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from stable_baselines3.common.policies import BasePolicy
1212
from stable_baselines3.common.preprocessing import maybe_transpose
1313
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
14-
from stable_baselines3.common.utils import get_linear_fn, is_vectorized_observation, polyak_update
14+
from stable_baselines3.common.utils import get_linear_fn, get_parameters_by_name, is_vectorized_observation, polyak_update
1515
from stable_baselines3.dqn.policies import CnnPolicy, DQNPolicy, MlpPolicy, MultiInputPolicy
1616

1717

@@ -140,6 +140,9 @@ def __init__(
140140
def _setup_model(self) -> None:
141141
super()._setup_model()
142142
self._create_aliases()
143+
# Copy running stats, see GH issue #996
144+
self.batch_norm_stats = get_parameters_by_name(self.q_net, ["running_"])
145+
self.batch_norm_stats_target = get_parameters_by_name(self.q_net_target, ["running_"])
143146
self.exploration_schedule = get_linear_fn(
144147
self.exploration_initial_eps,
145148
self.exploration_final_eps,
@@ -170,6 +173,8 @@ def _on_step(self) -> None:
170173
self._n_calls += 1
171174
if self._n_calls % self.target_update_interval == 0:
172175
polyak_update(self.q_net.parameters(), self.q_net_target.parameters(), self.tau)
176+
# Copy running stats, see GH issue #996
177+
polyak_update(self.batch_norm_stats, self.batch_norm_stats_target, 1.0)
173178

174179
self.exploration_rate = self.exploration_schedule(self._current_progress_remaining)
175180
self.logger.record("rollout/exploration_rate", self.exploration_rate)

stable_baselines3/sac/sac.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
1111
from stable_baselines3.common.policies import BasePolicy
1212
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
13-
from stable_baselines3.common.utils import polyak_update
13+
from stable_baselines3.common.utils import get_parameters_by_name, polyak_update
1414
from stable_baselines3.sac.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, SACPolicy
1515

1616

@@ -152,6 +152,9 @@ def __init__(
152152
def _setup_model(self) -> None:
153153
super()._setup_model()
154154
self._create_aliases()
155+
# Running mean and running var
156+
self.batch_norm_stats = get_parameters_by_name(self.critic, ["running_"])
157+
self.batch_norm_stats_target = get_parameters_by_name(self.critic_target, ["running_"])
155158
# Target entropy is used when learning the entropy coefficient
156159
if self.target_entropy == "auto":
157160
# automatically set target entropy if needed
@@ -272,6 +275,8 @@ def train(self, gradient_steps: int, batch_size: int = 64) -> None:
272275
# Update target networks
273276
if gradient_step % self.target_update_interval == 0:
274277
polyak_update(self.critic.parameters(), self.critic_target.parameters(), self.tau)
278+
# Copy running stats, see GH issue #996
279+
polyak_update(self.batch_norm_stats, self.batch_norm_stats_target, 1.0)
275280

276281
self._n_updates += gradient_steps
277282

stable_baselines3/td3/td3.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
1111
from stable_baselines3.common.policies import BasePolicy
1212
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
13-
from stable_baselines3.common.utils import polyak_update
13+
from stable_baselines3.common.utils import get_parameters_by_name, polyak_update
1414
from stable_baselines3.td3.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, TD3Policy
1515

1616

@@ -131,6 +131,11 @@ def __init__(
131131
def _setup_model(self) -> None:
132132
super()._setup_model()
133133
self._create_aliases()
134+
# Running mean and running var
135+
self.actor_batch_norm_stats = get_parameters_by_name(self.actor, ["running_"])
136+
self.critic_batch_norm_stats = get_parameters_by_name(self.critic, ["running_"])
137+
self.actor_batch_norm_stats_target = get_parameters_by_name(self.actor_target, ["running_"])
138+
self.critic_batch_norm_stats_target = get_parameters_by_name(self.critic_target, ["running_"])
134139

135140
def _create_aliases(self) -> None:
136141
self.actor = self.policy.actor
@@ -189,6 +194,9 @@ def train(self, gradient_steps: int, batch_size: int = 100) -> None:
189194

190195
polyak_update(self.critic.parameters(), self.critic_target.parameters(), self.tau)
191196
polyak_update(self.actor.parameters(), self.actor_target.parameters(), self.tau)
197+
# Copy running stats, see GH issue #996
198+
polyak_update(self.critic_batch_norm_stats, self.critic_batch_norm_stats_target, 1.0)
199+
polyak_update(self.actor_batch_norm_stats, self.actor_batch_norm_stats_target, 1.0)
192200

193201
self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
194202
if len(actor_losses) > 0:

stable_baselines3/version.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
1.6.1a1
1+
1.6.1a2

tests/test_train_eval_mode.py

+18-6
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,8 @@ def test_dqn_train_with_batch_norm():
143143
policy_kwargs=dict(net_arch=[16, 16], features_extractor_class=FlattenBatchNormDropoutExtractor),
144144
learning_starts=0,
145145
seed=1,
146-
tau=0, # do not clone the target
146+
tau=0.0, # do not clone the target
147+
target_update_interval=100, # Copy the stats to the target
147148
)
148149

149150
(
@@ -154,6 +155,9 @@ def test_dqn_train_with_batch_norm():
154155
) = clone_dqn_batch_norm_stats(model)
155156

156157
model.learn(total_timesteps=200)
158+
# Force stats copy
159+
model.target_update_interval = 1
160+
model._on_step()
157161

158162
(
159163
q_net_bias_after,
@@ -165,8 +169,12 @@ def test_dqn_train_with_batch_norm():
165169
assert ~th.isclose(q_net_bias_before, q_net_bias_after).all()
166170
assert ~th.isclose(q_net_running_mean_before, q_net_running_mean_after).all()
167171

172+
# No weight update
173+
assert th.isclose(q_net_bias_before, q_net_target_bias_after).all()
168174
assert th.isclose(q_net_target_bias_before, q_net_target_bias_after).all()
169-
assert th.isclose(q_net_target_running_mean_before, q_net_target_running_mean_after).all()
175+
# Running stat should be copied even when tau=0
176+
assert th.isclose(q_net_running_mean_before, q_net_target_running_mean_before).all()
177+
assert th.isclose(q_net_running_mean_after, q_net_target_running_mean_after).all()
170178

171179

172180
def test_td3_train_with_batch_norm():
@@ -210,10 +218,12 @@ def test_td3_train_with_batch_norm():
210218
assert ~th.isclose(critic_running_mean_before, critic_running_mean_after).all()
211219

212220
assert th.isclose(actor_target_bias_before, actor_target_bias_after).all()
213-
assert th.isclose(actor_target_running_mean_before, actor_target_running_mean_after).all()
221+
# Running stat should be copied even when tau=0
222+
assert th.isclose(actor_running_mean_after, actor_target_running_mean_after).all()
214223

215224
assert th.isclose(critic_target_bias_before, critic_target_bias_after).all()
216-
assert th.isclose(critic_target_running_mean_before, critic_target_running_mean_after).all()
225+
# Running stat should be copied even when tau=0
226+
assert th.isclose(critic_running_mean_after, critic_target_running_mean_after).all()
217227

218228

219229
def test_sac_train_with_batch_norm():
@@ -250,10 +260,12 @@ def test_sac_train_with_batch_norm():
250260
assert ~th.isclose(actor_running_mean_before, actor_running_mean_after).all()
251261

252262
assert ~th.isclose(critic_bias_before, critic_bias_after).all()
253-
assert ~th.isclose(critic_running_mean_before, critic_running_mean_after).all()
263+
# Running stat should be copied even when tau=0
264+
assert th.isclose(critic_running_mean_before, critic_target_running_mean_before).all()
254265

255266
assert th.isclose(critic_target_bias_before, critic_target_bias_after).all()
256-
assert th.isclose(critic_target_running_mean_before, critic_target_running_mean_after).all()
267+
# Running stat should be copied even when tau=0
268+
assert th.isclose(critic_running_mean_after, critic_target_running_mean_after).all()
257269

258270

259271
@pytest.mark.parametrize("model_class", [A2C, PPO])

tests/test_utils.py

+23-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,13 @@
1414
from stable_baselines3.common.evaluation import evaluate_policy
1515
from stable_baselines3.common.monitor import Monitor
1616
from stable_baselines3.common.noise import ActionNoise, OrnsteinUhlenbeckActionNoise, VectorizedActionNoise
17-
from stable_baselines3.common.utils import get_system_info, is_vectorized_observation, polyak_update, zip_strict
17+
from stable_baselines3.common.utils import (
18+
get_parameters_by_name,
19+
get_system_info,
20+
is_vectorized_observation,
21+
polyak_update,
22+
zip_strict,
23+
)
1824
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv
1925

2026

@@ -322,6 +328,22 @@ def test_vec_noise():
322328
assert len(vec.noises) == num_envs
323329

324330

331+
def test_get_parameters_by_name():
332+
model = th.nn.Sequential(th.nn.Linear(5, 5), th.nn.BatchNorm1d(5))
333+
# Initialize stats
334+
model(th.ones(3, 5))
335+
included_names = ["weight", "bias", "running_"]
336+
# 2 x weight, 2 x bias, 1 x running_mean, 1 x running_var; Ignore num_batches_tracked.
337+
parameters = get_parameters_by_name(model, included_names)
338+
assert len(parameters) == 6
339+
assert th.allclose(parameters[4], model[1].running_mean)
340+
assert th.allclose(parameters[5], model[1].running_var)
341+
parameters = get_parameters_by_name(model, ["running_"])
342+
assert len(parameters) == 2
343+
assert th.allclose(parameters[0], model[1].running_mean)
344+
assert th.allclose(parameters[1], model[1].running_var)
345+
346+
325347
def test_polyak():
326348
param1, param2 = th.nn.Parameter(th.ones((5, 5))), th.nn.Parameter(th.zeros((5, 5)))
327349
target1, target2 = th.nn.Parameter(th.ones((5, 5))), th.nn.Parameter(th.zeros((5, 5)))

0 commit comments

Comments
 (0)