From fc68af8841f172d739f42d90c4f93b4a526c7ea7 Mon Sep 17 00:00:00 2001 From: Max Lodel Date: Tue, 26 Jul 2022 00:27:17 +0200 Subject: [PATCH] Fixed shared_lstm argument in CNN and MultiInput Policies for RecurrentPPO (#90) * fixed shared_lstm parameter in CNN and MultiInput Policies * updated tests * changelog * Fix FPS for recurrent PPO * Fix import * Update changelog Co-authored-by: Antonin RAFFIN --- docs/misc/changelog.rst | 5 ++- sb3_contrib/common/recurrent/policies.py | 4 +++ sb3_contrib/ppo_recurrent/ppo_recurrent.py | 6 ++-- sb3_contrib/version.txt | 2 +- tests/test_lstm.py | 42 +++++++++++++++++++--- 5 files changed, 51 insertions(+), 8 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 40ca9e92..5a5f92fe 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,7 +3,7 @@ Changelog ========== -Release 1.6.1a0 (WIP) +Release 1.6.1a1 (WIP) ------------------------------- Breaking Changes: @@ -15,8 +15,10 @@ New Features: Bug Fixes: ^^^^^^^^^^ +- Fixed the issue of wrongly passing policy arguments when using CnnLstmPolicy or MultiInputLstmPolicy with ``RecurrentPPO`` (@mlodel) - Fixed division by zero error when computing FPS when a small number of time has elapsed in operating systems with low-precision timers. + Deprecations: ^^^^^^^^^^^^^ @@ -294,3 +296,4 @@ Contributors: ------------- @ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona @cyprienc @sgillen @Gregwar @rnederstigt @qgallouedec +@mlodel diff --git a/sb3_contrib/common/recurrent/policies.py b/sb3_contrib/common/recurrent/policies.py index 2285baa6..1ba52734 100644 --- a/sb3_contrib/common/recurrent/policies.py +++ b/sb3_contrib/common/recurrent/policies.py @@ -483,6 +483,7 @@ def __init__( optimizer_kwargs: Optional[Dict[str, Any]] = None, lstm_hidden_size: int = 256, n_lstm_layers: int = 1, + shared_lstm: bool = False, enable_critic_lstm: bool = True, lstm_kwargs: Optional[Dict[str, Any]] = None, ): @@ -506,6 +507,7 @@ def __init__( optimizer_kwargs, lstm_hidden_size, n_lstm_layers, + shared_lstm, enable_critic_lstm, lstm_kwargs, ) @@ -573,6 +575,7 @@ def __init__( optimizer_kwargs: Optional[Dict[str, Any]] = None, lstm_hidden_size: int = 256, n_lstm_layers: int = 1, + shared_lstm: bool = False, enable_critic_lstm: bool = True, lstm_kwargs: Optional[Dict[str, Any]] = None, ): @@ -596,6 +599,7 @@ def __init__( optimizer_kwargs, lstm_hidden_size, n_lstm_layers, + shared_lstm, enable_critic_lstm, lstm_kwargs, ) diff --git a/sb3_contrib/ppo_recurrent/ppo_recurrent.py b/sb3_contrib/ppo_recurrent/ppo_recurrent.py index 3f9b6e41..645375e4 100644 --- a/sb3_contrib/ppo_recurrent/ppo_recurrent.py +++ b/sb3_contrib/ppo_recurrent/ppo_recurrent.py @@ -1,3 +1,4 @@ +import sys import time from copy import deepcopy from typing import Any, Dict, Optional, Tuple, Type, Union @@ -513,13 +514,14 @@ def learn( # Display training infos if log_interval is not None and iteration % log_interval == 0: - fps = int((self.num_timesteps - self._num_timesteps_at_start) / (time.time() - self.start_time)) + time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon) + fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed) self.logger.record("time/iterations", iteration, exclude="tensorboard") if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0: self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer])) self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer])) self.logger.record("time/fps", fps) - self.logger.record("time/time_elapsed", int(time.time() - self.start_time), exclude="tensorboard") + self.logger.record("time/time_elapsed", int(time_elapsed), exclude="tensorboard") self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard") self.logger.dump(step=self.num_timesteps) diff --git a/sb3_contrib/version.txt b/sb3_contrib/version.txt index 035e3b6c..e36b7272 100644 --- a/sb3_contrib/version.txt +++ b/sb3_contrib/version.txt @@ -1 +1 @@ -1.6.1a0 +1.6.1a1 diff --git a/tests/test_lstm.py b/tests/test_lstm.py index f0ba3e6a..eb320bf4 100644 --- a/tests/test_lstm.py +++ b/tests/test_lstm.py @@ -57,13 +57,30 @@ def step(self, action): return CartPoleNoVelEnv._pos_obs(full_obs), rew, done, info -def test_cnn(): +@pytest.mark.parametrize( + "policy_kwargs", + [ + {}, + dict(shared_lstm=True, enable_critic_lstm=False), + dict( + enable_critic_lstm=True, + lstm_hidden_size=4, + lstm_kwargs=dict(dropout=0.5), + ), + dict( + enable_critic_lstm=False, + lstm_hidden_size=4, + lstm_kwargs=dict(dropout=0.5), + ), + ], +) +def test_cnn(policy_kwargs): model = RecurrentPPO( "CnnLstmPolicy", FakeImageEnv(screen_height=40, screen_width=40, n_channels=3), n_steps=16, seed=0, - policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32)), + policy_kwargs=dict(**policy_kwargs, features_extractor_kwargs=dict(features_dim=32)), ) model.learn(total_timesteps=32) @@ -138,9 +155,26 @@ def test_run_sde(): model.learn(total_timesteps=200, eval_freq=150) -def test_dict_obs(): +@pytest.mark.parametrize( + "policy_kwargs", + [ + {}, + dict(shared_lstm=True, enable_critic_lstm=False), + dict( + enable_critic_lstm=True, + lstm_hidden_size=4, + lstm_kwargs=dict(dropout=0.5), + ), + dict( + enable_critic_lstm=False, + lstm_hidden_size=4, + lstm_kwargs=dict(dropout=0.5), + ), + ], +) +def test_dict_obs(policy_kwargs): env = make_vec_env("CartPole-v1", n_envs=1, wrapper_class=ToDictWrapper) - model = RecurrentPPO("MultiInputLstmPolicy", env, n_steps=32).learn(64) + model = RecurrentPPO("MultiInputLstmPolicy", env, n_steps=32, policy_kwargs=policy_kwargs).learn(64) evaluate_policy(model, env, warn=False)