Skip to content

Commit e98c00e

Browse files
authored
Added hyperparameter tuning for RecurrentPPO (#415)
* ppo_lstm sampling added * solution 2, added tiny to ppo * updated tests * added ppo_lstm to test_hyperparms_opt.py * updated formatting in hyperparams_opt.py * Update CHANGELOG.md
1 parent 94e5f72 commit e98c00e

File tree

3 files changed

+28
-1
lines changed

3 files changed

+28
-1
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
### New Features
99
- Add `--eval-env-kwargs` to `train.py` (@Quentin18)
10+
- Added `ppo_lstm` to hyperparams_opt.py (@technocrat13)
1011

1112
### Bug fixes
1213

rl_zoo3/hyperparams_opt.py

+25-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def sample_ppo_params(trial: optuna.Trial) -> Dict[str, Any]:
2828
gae_lambda = trial.suggest_categorical("gae_lambda", [0.8, 0.9, 0.92, 0.95, 0.98, 0.99, 1.0])
2929
max_grad_norm = trial.suggest_categorical("max_grad_norm", [0.3, 0.5, 0.6, 0.7, 0.8, 0.9, 1, 2, 5])
3030
vf_coef = trial.suggest_float("vf_coef", 0, 1)
31-
net_arch = trial.suggest_categorical("net_arch", ["small", "medium"])
31+
net_arch = trial.suggest_categorical("net_arch", ["tiny", "small", "medium"])
3232
# Uncomment for gSDE (continuous actions)
3333
# log_std_init = trial.suggest_float("log_std_init", -4, 1)
3434
# Uncomment for gSDE (continuous action)
@@ -49,6 +49,7 @@ def sample_ppo_params(trial: optuna.Trial) -> Dict[str, Any]:
4949
# Independent networks usually work best
5050
# when not working with images
5151
net_arch = {
52+
"tiny": dict(pi=[64], vf=[64]),
5253
"small": dict(pi=[64, 64], vf=[64, 64]),
5354
"medium": dict(pi=[256, 256], vf=[256, 256]),
5455
}[net_arch]
@@ -76,6 +77,28 @@ def sample_ppo_params(trial: optuna.Trial) -> Dict[str, Any]:
7677
}
7778

7879

80+
def sample_ppo_lstm_params(trial: optuna.Trial) -> Dict[str, Any]:
81+
"""
82+
Sampler for RecurrentPPO hyperparams.
83+
uses sample_ppo_params(), this function samples for the policy_kwargs
84+
:param trial:
85+
:return:
86+
"""
87+
hyperparams = sample_ppo_params(trial)
88+
89+
enable_critic_lstm = trial.suggest_categorical("enable_critic_lstm", [False, True])
90+
lstm_hidden_size = trial.suggest_categorical("lstm_hidden_size", [16, 32, 64, 128, 256, 512])
91+
92+
hyperparams["policy_kwargs"].update(
93+
{
94+
"enable_critic_lstm": enable_critic_lstm,
95+
"lstm_hidden_size": lstm_hidden_size,
96+
}
97+
)
98+
99+
return hyperparams
100+
101+
79102
def sample_trpo_params(trial: optuna.Trial) -> Dict[str, Any]:
80103
"""
81104
Sampler for TRPO hyperparams.
@@ -527,6 +550,7 @@ def sample_ars_params(trial: optuna.Trial) -> Dict[str, Any]:
527550
"sac": sample_sac_params,
528551
"tqc": sample_tqc_params,
529552
"ppo": sample_ppo_params,
553+
"ppo_lstm": sample_ppo_lstm_params,
530554
"td3": sample_td3_params,
531555
"trpo": sample_trpo_params,
532556
}

tests/test_hyperparams_opt.py

+2
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ def _assert_eq(left, right):
3333
experiments["tqc-parking-v0"] = ("tqc", "parking-v0")
3434
# Test for TQC
3535
experiments["tqc-Pendulum-v1"] = ("tqc", "Pendulum-v1")
36+
# Test for RecurrentPPO (ppo_lstm)
37+
experiments["ppo_lstm-CartPoleNoVel-v1"] = ("ppo_lstm", "CartPoleNoVel-v1")
3638

3739

3840
@pytest.mark.parametrize("sampler", ["random", "tpe"])

0 commit comments

Comments
 (0)