Skip to content

Commit 3d9a975

Browse files
jak3122araffin
andauthored
Fix QRDQN loading target_update_interval (#259)
* Fix QRDQN loading target_update_interval * Update changelog * Update version --------- Co-authored-by: Antonin RAFFIN <[email protected]>
1 parent 42595a5 commit 3d9a975

File tree

5 files changed

+21
-9
lines changed

5 files changed

+21
-9
lines changed

CONTRIBUTING.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,13 +152,13 @@ To run tests with `pytest`:
152152
make pytest
153153
```
154154

155-
Type checking with `pytype` and `mypy`:
155+
Type checking with `mypy`:
156156

157157
```
158158
make type
159159
```
160160

161-
Codestyle check with `black`, `isort` and `flake8`:
161+
Codestyle check with `black` and `ruff`:
162162

163163
```
164164
make check-codestyle

docs/misc/changelog.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ Changelog
44
==========
55

66

7-
Release 2.4.0a8 (WIP)
7+
Release 2.4.0a9 (WIP)
88
--------------------------
99

1010
Breaking Changes:
@@ -19,6 +19,7 @@ Bug Fixes:
1919
- Updated QR-DQN optimizer input to only include quantile_net parameters (@corentinlger)
2020
- Updated QR-DQN paper link in docs (@corentinlger)
2121
- Fixed a warning with PyTorch 2.4 when loading a `RecurrentPPO` model (You are using torch.load with weights_only=False)
22+
- Fixed loading QRDQN changes `target_update_interval` (@jak3122)
2223

2324
Deprecations:
2425
^^^^^^^^^^^^^

sb3_contrib/qrdqn/qrdqn.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -153,8 +153,7 @@ def _setup_model(self) -> None:
153153
self.exploration_schedule = get_linear_fn(
154154
self.exploration_initial_eps, self.exploration_final_eps, self.exploration_fraction
155155
)
156-
# Account for multiple environments
157-
# each call to step() corresponds to n_envs transitions
156+
158157
if self.n_envs > 1:
159158
if self.n_envs > self.target_update_interval:
160159
warnings.warn(
@@ -164,8 +163,6 @@ def _setup_model(self) -> None:
164163
f"which corresponds to {self.n_envs} steps."
165164
)
166165

167-
self.target_update_interval = max(self.target_update_interval // self.n_envs, 1)
168-
169166
def _create_aliases(self) -> None:
170167
self.quantile_net = self.policy.quantile_net
171168
self.quantile_net_target = self.policy.quantile_net_target
@@ -177,7 +174,9 @@ def _on_step(self) -> None:
177174
This method is called in ``collect_rollouts()`` after each step in the environment.
178175
"""
179176
self._n_calls += 1
180-
if self._n_calls % self.target_update_interval == 0:
177+
# Account for multiple environments
178+
# each call to step() corresponds to n_envs transitions
179+
if self._n_calls % max(self.target_update_interval // self.n_envs, 1) == 0:
181180
polyak_update(self.quantile_net.parameters(), self.quantile_net_target.parameters(), self.tau)
182181
# Copy running stats, see https://github.com/DLR-RM/stable-baselines3/issues/996
183182
polyak_update(self.batch_norm_stats, self.batch_norm_stats_target, 1.0)

sb3_contrib/version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
2.4.0a8
1+
2.4.0a9

tests/test_save_load.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import pytest
99
import torch as th
1010
from stable_baselines3.common.base_class import BaseAlgorithm
11+
from stable_baselines3.common.env_util import make_vec_env
1112
from stable_baselines3.common.envs import FakeImageEnv, IdentityEnv, IdentityEnvBox
1213
from stable_baselines3.common.utils import get_device
1314
from stable_baselines3.common.vec_env import DummyVecEnv
@@ -481,3 +482,14 @@ def test_save_load_pytorch_var(tmp_path):
481482
assert model.log_ent_coef is None
482483
# Check that the entropy coefficient is still the same
483484
assert th.allclose(ent_coef_before, ent_coef_after)
485+
486+
487+
def test_dqn_target_update_interval(tmp_path):
488+
# `target_update_interval` should not change when reloading the model. See GH Issue #258.
489+
env = make_vec_env(env_id="CartPole-v1", n_envs=2)
490+
model = QRDQN("MlpPolicy", env, verbose=1, target_update_interval=100)
491+
model.save(tmp_path / "dqn_cartpole")
492+
del model
493+
model = QRDQN.load(tmp_path / "dqn_cartpole")
494+
os.remove(tmp_path / "dqn_cartpole.zip")
495+
assert model.target_update_interval == 100

0 commit comments

Comments
 (0)