1414"""
1515
1616import dataclasses
17- from typing import Any , Callable , Dict , List , Mapping
17+ from typing import Any , Callable , Dict , List , Mapping , Optional
1818
1919import optuna
2020import sacred
21- import stable_baselines3 as sb3
2221
2322import imitation .scripts .train_imitation
24- import imitation .scripts .train_preference_comparisons
23+ import imitation .scripts .train_preference_comparisons as train_pc_script
2524
2625
2726@dataclasses .dataclass
@@ -42,19 +41,27 @@ class RunSacredAsTrial:
4241 suggest_config_updates : Callable [[optuna .Trial ], Mapping [str , Any ]]
4342
4443 """Command name to pass to sacred.run."""
45- command_name : str = None
44+ command_name : Optional [ str ] = None
4645
4746 def __call__ (
48- self , trial : optuna .Trial , run_options : Dict , extra_named_configs : List [str ]
47+ self ,
48+ trial : optuna .Trial ,
49+ run_options : Dict ,
50+ extra_named_configs : List [str ],
4951 ) -> float :
5052 """Run the sacred experiment and return the performance.
5153
5254 Args:
5355 trial: The optuna trial to sample hyperparameters for.
5456 run_options: Options to pass to sacred.run(options=).
5557 extra_named_configs: Additional named configs to pass to sacred.run.
56- """
5758
59+ Returns:
60+ The performance of the trial.
61+
62+ Raises:
63+ RuntimeError: If the trial fails.
64+ """
5865 config_updates = self .suggest_config_updates (trial )
5966 named_configs = self .suggest_named_configs (trial ) + extra_named_configs
6067
@@ -71,15 +78,16 @@ def __call__(
7178 )
7279 if result .status != "COMPLETED" :
7380 raise RuntimeError (
74- f"Trial failed with { result .fail_trace ()} and status { result .status } ."
81+ f"Trial failed with { result .fail_trace ()} and status { result .status } ." ,
7582 )
7683 return result .result ["imit_stats" ]["monitor_return_mean" ]
7784
7885
79- """A mapping from algorithm names to functions that run the algorithm as an optuna trial."""
86+ """A mapping from algorithm names to functions that run the algorithm as an optuna
87+ trial."""
8088objectives_by_algo = dict (
8189 pc = RunSacredAsTrial (
82- sacred_ex = imitation . scripts . train_preference_comparisons .train_preference_comparisons_ex ,
90+ sacred_ex = train_pc_script .train_preference_comparisons_ex ,
8391 suggest_named_configs = lambda _ : ["reward.reward_ensemble" ],
8492 suggest_config_updates = lambda trial : {
8593 "seed" : trial .number ,
@@ -88,61 +96,87 @@ def __call__(
8896 "total_comparisons" : 1000 ,
8997 "active_selection" : True ,
9098 "active_selection_oversampling" : trial .suggest_int (
91- "active_selection_oversampling" , 1 , 11
99+ "active_selection_oversampling" ,
100+ 1 ,
101+ 11 ,
92102 ),
93103 "comparison_queue_size" : trial .suggest_int (
94- "comparison_queue_size" , 1 , 1001
104+ "comparison_queue_size" ,
105+ 1 ,
106+ 1001 ,
95107 ), # upper bound determined by total_comparisons=1000
96108 "exploration_frac" : trial .suggest_float ("exploration_frac" , 0.0 , 0.5 ),
97109 "fragment_length" : trial .suggest_int (
98- "fragment_length" , 1 , 1001
110+ "fragment_length" ,
111+ 1 ,
112+ 1001 ,
99113 ), # trajectories are 1000 steps long
100114 "gatherer_kwargs" : {
101115 "temperature" : trial .suggest_float ("gatherer_temperature" , 0.0 , 2.0 ),
102116 "discount_factor" : trial .suggest_float (
103- "gatherer_discount_factor" , 0.95 , 1.0
117+ "gatherer_discount_factor" ,
118+ 0.95 ,
119+ 1.0 ,
104120 ),
105121 "sample" : trial .suggest_categorical ("gatherer_sample" , [True , False ]),
106122 },
107123 "initial_epoch_multiplier" : trial .suggest_float (
108- "initial_epoch_multiplier" , 1 , 200.0
124+ "initial_epoch_multiplier" ,
125+ 1 ,
126+ 200.0 ,
109127 ),
110128 "initial_comparison_frac" : trial .suggest_float (
111- "initial_comparison_frac" , 0.01 , 1.0
129+ "initial_comparison_frac" ,
130+ 0.01 ,
131+ 1.0 ,
112132 ),
113133 "num_iterations" : trial .suggest_int ("num_iterations" , 1 , 51 ),
114134 "preference_model_kwargs" : {
115135 "noise_prob" : trial .suggest_float (
116- "preference_model_noise_prob" , 0.0 , 0.1
136+ "preference_model_noise_prob" ,
137+ 0.0 ,
138+ 0.1 ,
117139 ),
118140 "discount_factor" : trial .suggest_float (
119- "preference_model_discount_factor" , 0.95 , 1.0
141+ "preference_model_discount_factor" ,
142+ 0.95 ,
143+ 1.0 ,
120144 ),
121145 },
122146 "query_schedule" : trial .suggest_categorical (
123- "query_schedule" , ["hyperbolic" , "constant" , "inverse_quadratic" ]
147+ "query_schedule" ,
148+ [
149+ "hyperbolic" ,
150+ "constant" ,
151+ "inverse_quadratic" ,
152+ ],
124153 ),
125154 "trajectory_generator_kwargs" : {
126155 "switch_prob" : trial .suggest_float ("tr_gen_switch_prob" , 0.1 , 1 ),
127156 "random_prob" : trial .suggest_float ("tr_gen_random_prob" , 0.1 , 0.9 ),
128157 },
129158 "transition_oversampling" : trial .suggest_float (
130- "transition_oversampling" , 0.9 , 2.0
159+ "transition_oversampling" ,
160+ 0.9 ,
161+ 2.0 ,
131162 ),
132163 "reward_trainer_kwargs" : {
133164 "epochs" : trial .suggest_int ("reward_trainer_epochs" , 1 , 11 ),
134165 },
135166 "rl" : {
136167 "rl_kwargs" : {
137168 "ent_coef" : trial .suggest_float (
138- "rl_ent_coef" , 1e-7 , 1e-3 , log = True
169+ "rl_ent_coef" ,
170+ 1e-7 ,
171+ 1e-3 ,
172+ log = True ,
139173 ),
140174 },
141175 },
142176 },
143177 ),
144178 pc_classic_control = RunSacredAsTrial (
145- sacred_ex = imitation . scripts . train_preference_comparisons .train_preference_comparisons_ex ,
179+ sacred_ex = train_pc_script .train_preference_comparisons_ex ,
146180 suggest_named_configs = lambda _ : ["reward.reward_ensemble" ],
147181 suggest_config_updates = lambda trial : {
148182 "seed" : trial .number ,
@@ -151,54 +185,80 @@ def __call__(
151185 "total_comparisons" : 1000 ,
152186 "active_selection" : True ,
153187 "active_selection_oversampling" : trial .suggest_int (
154- "active_selection_oversampling" , 1 , 11
188+ "active_selection_oversampling" ,
189+ 1 ,
190+ 11 ,
155191 ),
156192 "comparison_queue_size" : trial .suggest_int (
157- "comparison_queue_size" , 1 , 1001
193+ "comparison_queue_size" ,
194+ 1 ,
195+ 1001 ,
158196 ), # upper bound determined by total_comparisons=1000
159197 "exploration_frac" : trial .suggest_float ("exploration_frac" , 0.0 , 0.5 ),
160198 "fragment_length" : trial .suggest_int (
161- "fragment_length" , 1 , 201
199+ "fragment_length" ,
200+ 1 ,
201+ 201 ,
162202 ), # trajectories are 1000 steps long
163203 "gatherer_kwargs" : {
164204 "temperature" : trial .suggest_float ("gatherer_temperature" , 0.0 , 2.0 ),
165205 "discount_factor" : trial .suggest_float (
166- "gatherer_discount_factor" , 0.95 , 1.0
206+ "gatherer_discount_factor" ,
207+ 0.95 ,
208+ 1.0 ,
167209 ),
168210 "sample" : trial .suggest_categorical ("gatherer_sample" , [True , False ]),
169211 },
170212 "initial_epoch_multiplier" : trial .suggest_float (
171- "initial_epoch_multiplier" , 1 , 200.0
213+ "initial_epoch_multiplier" ,
214+ 1 ,
215+ 200.0 ,
172216 ),
173217 "initial_comparison_frac" : trial .suggest_float (
174- "initial_comparison_frac" , 0.01 , 1.0
218+ "initial_comparison_frac" ,
219+ 0.01 ,
220+ 1.0 ,
175221 ),
176222 "num_iterations" : trial .suggest_int ("num_iterations" , 1 , 51 ),
177223 "preference_model_kwargs" : {
178224 "noise_prob" : trial .suggest_float (
179- "preference_model_noise_prob" , 0.0 , 0.1
225+ "preference_model_noise_prob" ,
226+ 0.0 ,
227+ 0.1 ,
180228 ),
181229 "discount_factor" : trial .suggest_float (
182- "preference_model_discount_factor" , 0.95 , 1.0
230+ "preference_model_discount_factor" ,
231+ 0.95 ,
232+ 1.0 ,
183233 ),
184234 },
185235 "query_schedule" : trial .suggest_categorical (
186- "query_schedule" , ["hyperbolic" , "constant" , "inverse_quadratic" ]
236+ "query_schedule" ,
237+ [
238+ "hyperbolic" ,
239+ "constant" ,
240+ "inverse_quadratic" ,
241+ ],
187242 ),
188243 "trajectory_generator_kwargs" : {
189244 "switch_prob" : trial .suggest_float ("tr_gen_switch_prob" , 0.1 , 1 ),
190245 "random_prob" : trial .suggest_float ("tr_gen_random_prob" , 0.1 , 0.9 ),
191246 },
192247 "transition_oversampling" : trial .suggest_float (
193- "transition_oversampling" , 0.9 , 2.0
248+ "transition_oversampling" ,
249+ 0.9 ,
250+ 2.0 ,
194251 ),
195252 "reward_trainer_kwargs" : {
196253 "epochs" : trial .suggest_int ("reward_trainer_epochs" , 1 , 11 ),
197254 },
198255 "rl" : {
199256 "rl_kwargs" : {
200257 "ent_coef" : trial .suggest_float (
201- "rl_ent_coef" , 1e-7 , 1e-3 , log = True
258+ "rl_ent_coef" ,
259+ 1e-7 ,
260+ 1e-3 ,
261+ log = True ,
202262 ),
203263 },
204264 },
@@ -217,28 +277,41 @@ def __call__(
217277 "rl" : {
218278 "rl_kwargs" : {
219279 "learning_rate" : trial .suggest_float (
220- "learning_rate" , 1e-6 , 1e-2 , log = True
280+ "learning_rate" ,
281+ 1e-6 ,
282+ 1e-2 ,
283+ log = True ,
221284 ),
222285 "buffer_size" : trial .suggest_int ("buffer_size" , 1000 , 100000 ),
223286 "learning_starts" : trial .suggest_int (
224- "learning_starts" , 1000 , 10000
287+ "learning_starts" ,
288+ 1000 ,
289+ 10000 ,
225290 ),
226291 "batch_size" : trial .suggest_int ("batch_size" , 32 , 128 ),
227292 "tau" : trial .suggest_float ("tau" , 0.0 , 1.0 ),
228293 "gamma" : trial .suggest_float ("gamma" , 0.9 , 0.999 ),
229294 "train_freq" : trial .suggest_int ("train_freq" , 1 , 40 ),
230295 "gradient_steps" : trial .suggest_int ("gradient_steps" , 1 , 10 ),
231296 "target_update_interval" : trial .suggest_int (
232- "target_update_interval" , 1 , 10000
297+ "target_update_interval" ,
298+ 1 ,
299+ 10000 ,
233300 ),
234301 "exploration_fraction" : trial .suggest_float (
235- "exploration_fraction" , 0.01 , 0.5
302+ "exploration_fraction" ,
303+ 0.01 ,
304+ 0.5 ,
236305 ),
237306 "exploration_final_eps" : trial .suggest_float (
238- "exploration_final_eps" , 0.01 , 1.0
307+ "exploration_final_eps" ,
308+ 0.01 ,
309+ 1.0 ,
239310 ),
240311 "exploration_initial_eps" : trial .suggest_float (
241- "exploration_initial_eps" , 0.01 , 0.5
312+ "exploration_initial_eps" ,
313+ 0.01 ,
314+ 0.5 ,
242315 ),
243316 "max_grad_norm" : trial .suggest_float ("max_grad_norm" , 0.1 , 10.0 ),
244317 },
0 commit comments