@@ -28,7 +28,7 @@ def sample_ppo_params(trial: optuna.Trial) -> Dict[str, Any]:
28
28
gae_lambda = trial .suggest_categorical ("gae_lambda" , [0.8 , 0.9 , 0.92 , 0.95 , 0.98 , 0.99 , 1.0 ])
29
29
max_grad_norm = trial .suggest_categorical ("max_grad_norm" , [0.3 , 0.5 , 0.6 , 0.7 , 0.8 , 0.9 , 1 , 2 , 5 ])
30
30
vf_coef = trial .suggest_float ("vf_coef" , 0 , 1 )
31
- net_arch = trial .suggest_categorical ("net_arch" , ["small" , "medium" ])
31
+ net_arch = trial .suggest_categorical ("net_arch" , ["tiny" , " small" , "medium" ])
32
32
# Uncomment for gSDE (continuous actions)
33
33
# log_std_init = trial.suggest_float("log_std_init", -4, 1)
34
34
# Uncomment for gSDE (continuous action)
@@ -49,6 +49,7 @@ def sample_ppo_params(trial: optuna.Trial) -> Dict[str, Any]:
49
49
# Independent networks usually work best
50
50
# when not working with images
51
51
net_arch = {
52
+ "tiny" : dict (pi = [64 ], vf = [64 ]),
52
53
"small" : dict (pi = [64 , 64 ], vf = [64 , 64 ]),
53
54
"medium" : dict (pi = [256 , 256 ], vf = [256 , 256 ]),
54
55
}[net_arch ]
@@ -76,6 +77,28 @@ def sample_ppo_params(trial: optuna.Trial) -> Dict[str, Any]:
76
77
}
77
78
78
79
80
+ def sample_ppo_lstm_params (trial : optuna .Trial ) -> Dict [str , Any ]:
81
+ """
82
+ Sampler for RecurrentPPO hyperparams.
83
+ uses sample_ppo_params(), this function samples for the policy_kwargs
84
+ :param trial:
85
+ :return:
86
+ """
87
+ hyperparams = sample_ppo_params (trial )
88
+
89
+ enable_critic_lstm = trial .suggest_categorical ("enable_critic_lstm" , [False , True ])
90
+ lstm_hidden_size = trial .suggest_categorical ("lstm_hidden_size" , [16 , 32 , 64 , 128 , 256 , 512 ])
91
+
92
+ hyperparams ["policy_kwargs" ].update (
93
+ {
94
+ "enable_critic_lstm" : enable_critic_lstm ,
95
+ "lstm_hidden_size" : lstm_hidden_size ,
96
+ }
97
+ )
98
+
99
+ return hyperparams
100
+
101
+
79
102
def sample_trpo_params (trial : optuna .Trial ) -> Dict [str , Any ]:
80
103
"""
81
104
Sampler for TRPO hyperparams.
@@ -527,6 +550,7 @@ def sample_ars_params(trial: optuna.Trial) -> Dict[str, Any]:
527
550
"sac" : sample_sac_params ,
528
551
"tqc" : sample_tqc_params ,
529
552
"ppo" : sample_ppo_params ,
553
+ "ppo_lstm" : sample_ppo_lstm_params ,
530
554
"td3" : sample_td3_params ,
531
555
"trpo" : sample_trpo_params ,
532
556
}
0 commit comments