Skip to content

Commit fc68af8

Browse files
mlodelaraffin
andauthored
Fixed shared_lstm argument in CNN and MultiInput Policies for RecurrentPPO (#90)
* fixed shared_lstm parameter in CNN and MultiInput Policies * updated tests * changelog * Fix FPS for recurrent PPO * Fix import * Update changelog Co-authored-by: Antonin RAFFIN <[email protected]>
1 parent 7e687ac commit fc68af8

File tree

5 files changed

+51
-8
lines changed

5 files changed

+51
-8
lines changed

docs/misc/changelog.rst

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
Changelog
44
==========
55

6-
Release 1.6.1a0 (WIP)
6+
Release 1.6.1a1 (WIP)
77
-------------------------------
88

99
Breaking Changes:
@@ -15,8 +15,10 @@ New Features:
1515

1616
Bug Fixes:
1717
^^^^^^^^^^
18+
- Fixed the issue of wrongly passing policy arguments when using CnnLstmPolicy or MultiInputLstmPolicy with ``RecurrentPPO`` (@mlodel)
1819
- Fixed division by zero error when computing FPS when a small number of time has elapsed in operating systems with low-precision timers.
1920

21+
2022
Deprecations:
2123
^^^^^^^^^^^^^
2224

@@ -294,3 +296,4 @@ Contributors:
294296
-------------
295297

296298
@ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona @cyprienc @sgillen @Gregwar @rnederstigt @qgallouedec
299+
@mlodel

sb3_contrib/common/recurrent/policies.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,7 @@ def __init__(
483483
optimizer_kwargs: Optional[Dict[str, Any]] = None,
484484
lstm_hidden_size: int = 256,
485485
n_lstm_layers: int = 1,
486+
shared_lstm: bool = False,
486487
enable_critic_lstm: bool = True,
487488
lstm_kwargs: Optional[Dict[str, Any]] = None,
488489
):
@@ -506,6 +507,7 @@ def __init__(
506507
optimizer_kwargs,
507508
lstm_hidden_size,
508509
n_lstm_layers,
510+
shared_lstm,
509511
enable_critic_lstm,
510512
lstm_kwargs,
511513
)
@@ -573,6 +575,7 @@ def __init__(
573575
optimizer_kwargs: Optional[Dict[str, Any]] = None,
574576
lstm_hidden_size: int = 256,
575577
n_lstm_layers: int = 1,
578+
shared_lstm: bool = False,
576579
enable_critic_lstm: bool = True,
577580
lstm_kwargs: Optional[Dict[str, Any]] = None,
578581
):
@@ -596,6 +599,7 @@ def __init__(
596599
optimizer_kwargs,
597600
lstm_hidden_size,
598601
n_lstm_layers,
602+
shared_lstm,
599603
enable_critic_lstm,
600604
lstm_kwargs,
601605
)

sb3_contrib/ppo_recurrent/ppo_recurrent.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import sys
12
import time
23
from copy import deepcopy
34
from typing import Any, Dict, Optional, Tuple, Type, Union
@@ -513,13 +514,14 @@ def learn(
513514

514515
# Display training infos
515516
if log_interval is not None and iteration % log_interval == 0:
516-
fps = int((self.num_timesteps - self._num_timesteps_at_start) / (time.time() - self.start_time))
517+
time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon)
518+
fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed)
517519
self.logger.record("time/iterations", iteration, exclude="tensorboard")
518520
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
519521
self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer]))
520522
self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer]))
521523
self.logger.record("time/fps", fps)
522-
self.logger.record("time/time_elapsed", int(time.time() - self.start_time), exclude="tensorboard")
524+
self.logger.record("time/time_elapsed", int(time_elapsed), exclude="tensorboard")
523525
self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard")
524526
self.logger.dump(step=self.num_timesteps)
525527

sb3_contrib/version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
1.6.1a0
1+
1.6.1a1

tests/test_lstm.py

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,13 +57,30 @@ def step(self, action):
5757
return CartPoleNoVelEnv._pos_obs(full_obs), rew, done, info
5858

5959

60-
def test_cnn():
60+
@pytest.mark.parametrize(
61+
"policy_kwargs",
62+
[
63+
{},
64+
dict(shared_lstm=True, enable_critic_lstm=False),
65+
dict(
66+
enable_critic_lstm=True,
67+
lstm_hidden_size=4,
68+
lstm_kwargs=dict(dropout=0.5),
69+
),
70+
dict(
71+
enable_critic_lstm=False,
72+
lstm_hidden_size=4,
73+
lstm_kwargs=dict(dropout=0.5),
74+
),
75+
],
76+
)
77+
def test_cnn(policy_kwargs):
6178
model = RecurrentPPO(
6279
"CnnLstmPolicy",
6380
FakeImageEnv(screen_height=40, screen_width=40, n_channels=3),
6481
n_steps=16,
6582
seed=0,
66-
policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32)),
83+
policy_kwargs=dict(**policy_kwargs, features_extractor_kwargs=dict(features_dim=32)),
6784
)
6885

6986
model.learn(total_timesteps=32)
@@ -138,9 +155,26 @@ def test_run_sde():
138155
model.learn(total_timesteps=200, eval_freq=150)
139156

140157

141-
def test_dict_obs():
158+
@pytest.mark.parametrize(
159+
"policy_kwargs",
160+
[
161+
{},
162+
dict(shared_lstm=True, enable_critic_lstm=False),
163+
dict(
164+
enable_critic_lstm=True,
165+
lstm_hidden_size=4,
166+
lstm_kwargs=dict(dropout=0.5),
167+
),
168+
dict(
169+
enable_critic_lstm=False,
170+
lstm_hidden_size=4,
171+
lstm_kwargs=dict(dropout=0.5),
172+
),
173+
],
174+
)
175+
def test_dict_obs(policy_kwargs):
142176
env = make_vec_env("CartPole-v1", n_envs=1, wrapper_class=ToDictWrapper)
143-
model = RecurrentPPO("MultiInputLstmPolicy", env, n_steps=32).learn(64)
177+
model = RecurrentPPO("MultiInputLstmPolicy", env, n_steps=32, policy_kwargs=policy_kwargs).learn(64)
144178
evaluate_policy(model, env, warn=False)
145179

146180

0 commit comments

Comments
 (0)