Skip to content

Commit

Permalink
Fixed shared_lstm argument in CNN and MultiInput Policies for Recurre…
Browse files Browse the repository at this point in the history
…ntPPO (#90)

* fixed shared_lstm parameter in CNN and MultiInput Policies

* updated tests

* changelog

* Fix FPS for recurrent PPO

* Fix import

* Update changelog

Co-authored-by: Antonin RAFFIN <[email protected]>
  • Loading branch information
mlodel and araffin authored Jul 25, 2022
1 parent 7e687ac commit fc68af8
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 8 deletions.
5 changes: 4 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Changelog
==========

Release 1.6.1a0 (WIP)
Release 1.6.1a1 (WIP)
-------------------------------

Breaking Changes:
Expand All @@ -15,8 +15,10 @@ New Features:

Bug Fixes:
^^^^^^^^^^
- Fixed the issue of wrongly passing policy arguments when using CnnLstmPolicy or MultiInputLstmPolicy with ``RecurrentPPO`` (@mlodel)
- Fixed division by zero error when computing FPS when a small number of time has elapsed in operating systems with low-precision timers.


Deprecations:
^^^^^^^^^^^^^

Expand Down Expand Up @@ -294,3 +296,4 @@ Contributors:
-------------

@ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona @cyprienc @sgillen @Gregwar @rnederstigt @qgallouedec
@mlodel
4 changes: 4 additions & 0 deletions sb3_contrib/common/recurrent/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,7 @@ def __init__(
optimizer_kwargs: Optional[Dict[str, Any]] = None,
lstm_hidden_size: int = 256,
n_lstm_layers: int = 1,
shared_lstm: bool = False,
enable_critic_lstm: bool = True,
lstm_kwargs: Optional[Dict[str, Any]] = None,
):
Expand All @@ -506,6 +507,7 @@ def __init__(
optimizer_kwargs,
lstm_hidden_size,
n_lstm_layers,
shared_lstm,
enable_critic_lstm,
lstm_kwargs,
)
Expand Down Expand Up @@ -573,6 +575,7 @@ def __init__(
optimizer_kwargs: Optional[Dict[str, Any]] = None,
lstm_hidden_size: int = 256,
n_lstm_layers: int = 1,
shared_lstm: bool = False,
enable_critic_lstm: bool = True,
lstm_kwargs: Optional[Dict[str, Any]] = None,
):
Expand All @@ -596,6 +599,7 @@ def __init__(
optimizer_kwargs,
lstm_hidden_size,
n_lstm_layers,
shared_lstm,
enable_critic_lstm,
lstm_kwargs,
)
6 changes: 4 additions & 2 deletions sb3_contrib/ppo_recurrent/ppo_recurrent.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import sys
import time
from copy import deepcopy
from typing import Any, Dict, Optional, Tuple, Type, Union
Expand Down Expand Up @@ -513,13 +514,14 @@ def learn(

# Display training infos
if log_interval is not None and iteration % log_interval == 0:
fps = int((self.num_timesteps - self._num_timesteps_at_start) / (time.time() - self.start_time))
time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon)
fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed)
self.logger.record("time/iterations", iteration, exclude="tensorboard")
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer]))
self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer]))
self.logger.record("time/fps", fps)
self.logger.record("time/time_elapsed", int(time.time() - self.start_time), exclude="tensorboard")
self.logger.record("time/time_elapsed", int(time_elapsed), exclude="tensorboard")
self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard")
self.logger.dump(step=self.num_timesteps)

Expand Down
2 changes: 1 addition & 1 deletion sb3_contrib/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.6.1a0
1.6.1a1
42 changes: 38 additions & 4 deletions tests/test_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,30 @@ def step(self, action):
return CartPoleNoVelEnv._pos_obs(full_obs), rew, done, info


def test_cnn():
@pytest.mark.parametrize(
"policy_kwargs",
[
{},
dict(shared_lstm=True, enable_critic_lstm=False),
dict(
enable_critic_lstm=True,
lstm_hidden_size=4,
lstm_kwargs=dict(dropout=0.5),
),
dict(
enable_critic_lstm=False,
lstm_hidden_size=4,
lstm_kwargs=dict(dropout=0.5),
),
],
)
def test_cnn(policy_kwargs):
model = RecurrentPPO(
"CnnLstmPolicy",
FakeImageEnv(screen_height=40, screen_width=40, n_channels=3),
n_steps=16,
seed=0,
policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32)),
policy_kwargs=dict(**policy_kwargs, features_extractor_kwargs=dict(features_dim=32)),
)

model.learn(total_timesteps=32)
Expand Down Expand Up @@ -138,9 +155,26 @@ def test_run_sde():
model.learn(total_timesteps=200, eval_freq=150)


def test_dict_obs():
@pytest.mark.parametrize(
"policy_kwargs",
[
{},
dict(shared_lstm=True, enable_critic_lstm=False),
dict(
enable_critic_lstm=True,
lstm_hidden_size=4,
lstm_kwargs=dict(dropout=0.5),
),
dict(
enable_critic_lstm=False,
lstm_hidden_size=4,
lstm_kwargs=dict(dropout=0.5),
),
],
)
def test_dict_obs(policy_kwargs):
env = make_vec_env("CartPole-v1", n_envs=1, wrapper_class=ToDictWrapper)
model = RecurrentPPO("MultiInputLstmPolicy", env, n_steps=32).learn(64)
model = RecurrentPPO("MultiInputLstmPolicy", env, n_steps=32, policy_kwargs=policy_kwargs).learn(64)
evaluate_policy(model, env, warn=False)


Expand Down

0 comments on commit fc68af8

Please sign in to comment.