14
14
"""
15
15
16
16
import dataclasses
17
- from typing import Callable , List , Mapping , Any , Dict
17
+ from typing import Any , Callable , Dict , List , Mapping
18
18
19
19
import optuna
20
20
import sacred
@@ -35,7 +35,6 @@ class RunSacredAsTrial:
35
35
"""The sacred experiment to run."""
36
36
sacred_ex : sacred .Experiment
37
37
38
-
39
38
"""A function that returns a list of named configs to pass to sacred.run."""
40
39
suggest_named_configs : Callable [[optuna .Trial ], List [str ]]
41
40
@@ -46,10 +45,7 @@ class RunSacredAsTrial:
46
45
command_name : str = None
47
46
48
47
def __call__ (
49
- self ,
50
- trial : optuna .Trial ,
51
- run_options : Dict ,
52
- extra_named_configs : List [str ]
48
+ self , trial : optuna .Trial , run_options : Dict , extra_named_configs : List [str ]
53
49
) -> float :
54
50
"""Run the sacred experiment and return the performance.
55
51
@@ -77,7 +73,7 @@ def __call__(
77
73
raise RuntimeError (
78
74
f"Trial failed with { result .fail_trace ()} and status { result .status } ."
79
75
)
80
- return result .result [' imit_stats' ][ ' monitor_return_mean' ]
76
+ return result .result [" imit_stats" ][ " monitor_return_mean" ]
81
77
82
78
83
79
"""A mapping from algorithm names to functions that run the algorithm as an optuna trial."""
@@ -91,34 +87,56 @@ def __call__(
91
87
"total_timesteps" : 2e7 ,
92
88
"total_comparisons" : 1000 ,
93
89
"active_selection" : True ,
94
- "active_selection_oversampling" : trial .suggest_int ("active_selection_oversampling" , 1 , 11 ),
95
- "comparison_queue_size" : trial .suggest_int ("comparison_queue_size" , 1 , 1001 ), # upper bound determined by total_comparisons=1000
90
+ "active_selection_oversampling" : trial .suggest_int (
91
+ "active_selection_oversampling" , 1 , 11
92
+ ),
93
+ "comparison_queue_size" : trial .suggest_int (
94
+ "comparison_queue_size" , 1 , 1001
95
+ ), # upper bound determined by total_comparisons=1000
96
96
"exploration_frac" : trial .suggest_float ("exploration_frac" , 0.0 , 0.5 ),
97
- "fragment_length" : trial .suggest_int ("fragment_length" , 1 , 1001 ), # trajectories are 1000 steps long
97
+ "fragment_length" : trial .suggest_int (
98
+ "fragment_length" , 1 , 1001
99
+ ), # trajectories are 1000 steps long
98
100
"gatherer_kwargs" : {
99
101
"temperature" : trial .suggest_float ("gatherer_temperature" , 0.0 , 2.0 ),
100
- "discount_factor" : trial .suggest_float ("gatherer_discount_factor" , 0.95 , 1.0 ),
102
+ "discount_factor" : trial .suggest_float (
103
+ "gatherer_discount_factor" , 0.95 , 1.0
104
+ ),
101
105
"sample" : trial .suggest_categorical ("gatherer_sample" , [True , False ]),
102
106
},
103
- "initial_epoch_multiplier" : trial .suggest_float ("initial_epoch_multiplier" , 1 , 200.0 ),
104
- "initial_comparison_frac" : trial .suggest_float ("initial_comparison_frac" , 0.01 , 1.0 ),
107
+ "initial_epoch_multiplier" : trial .suggest_float (
108
+ "initial_epoch_multiplier" , 1 , 200.0
109
+ ),
110
+ "initial_comparison_frac" : trial .suggest_float (
111
+ "initial_comparison_frac" , 0.01 , 1.0
112
+ ),
105
113
"num_iterations" : trial .suggest_int ("num_iterations" , 1 , 51 ),
106
114
"preference_model_kwargs" : {
107
- "noise_prob" : trial .suggest_float ("preference_model_noise_prob" , 0.0 , 0.1 ),
108
- "discount_factor" : trial .suggest_float ("preference_model_discount_factor" , 0.95 , 1.0 ),
115
+ "noise_prob" : trial .suggest_float (
116
+ "preference_model_noise_prob" , 0.0 , 0.1
117
+ ),
118
+ "discount_factor" : trial .suggest_float (
119
+ "preference_model_discount_factor" , 0.95 , 1.0
120
+ ),
109
121
},
110
- "query_schedule" : trial .suggest_categorical ("query_schedule" , ["hyperbolic" , "constant" , "inverse_quadratic" ]),
122
+ "query_schedule" : trial .suggest_categorical (
123
+ "query_schedule" , ["hyperbolic" , "constant" , "inverse_quadratic" ]
124
+ ),
111
125
"trajectory_generator_kwargs" : {
112
126
"switch_prob" : trial .suggest_float ("tr_gen_switch_prob" , 0.1 , 1 ),
113
127
"random_prob" : trial .suggest_float ("tr_gen_random_prob" , 0.1 , 0.9 ),
114
128
},
115
- "transition_oversampling" : trial .suggest_float ("transition_oversampling" , 0.9 , 2.0 ),
129
+ "transition_oversampling" : trial .suggest_float (
130
+ "transition_oversampling" , 0.9 , 2.0
131
+ ),
116
132
"reward_trainer_kwargs" : {
117
133
"epochs" : trial .suggest_int ("reward_trainer_epochs" , 1 , 11 ),
118
134
},
119
135
"rl" : {
120
136
"rl_kwargs" : {
121
- "ent_coef" : trial .suggest_float ("rl_ent_coef" , 1e-7 , 1e-3 , log = True ),
137
+ "ent_coef" : trial .suggest_float (
138
+ "rl_ent_coef" , 1e-7 , 1e-3 , log = True
139
+ ),
122
140
},
123
141
},
124
142
},
@@ -132,34 +150,56 @@ def __call__(
132
150
"total_timesteps" : 1e6 ,
133
151
"total_comparisons" : 1000 ,
134
152
"active_selection" : True ,
135
- "active_selection_oversampling" : trial .suggest_int ("active_selection_oversampling" , 1 , 11 ),
136
- "comparison_queue_size" : trial .suggest_int ("comparison_queue_size" , 1 , 1001 ), # upper bound determined by total_comparisons=1000
153
+ "active_selection_oversampling" : trial .suggest_int (
154
+ "active_selection_oversampling" , 1 , 11
155
+ ),
156
+ "comparison_queue_size" : trial .suggest_int (
157
+ "comparison_queue_size" , 1 , 1001
158
+ ), # upper bound determined by total_comparisons=1000
137
159
"exploration_frac" : trial .suggest_float ("exploration_frac" , 0.0 , 0.5 ),
138
- "fragment_length" : trial .suggest_int ("fragment_length" , 1 , 201 ), # trajectories are 1000 steps long
160
+ "fragment_length" : trial .suggest_int (
161
+ "fragment_length" , 1 , 201
162
+ ), # trajectories are 1000 steps long
139
163
"gatherer_kwargs" : {
140
164
"temperature" : trial .suggest_float ("gatherer_temperature" , 0.0 , 2.0 ),
141
- "discount_factor" : trial .suggest_float ("gatherer_discount_factor" , 0.95 , 1.0 ),
165
+ "discount_factor" : trial .suggest_float (
166
+ "gatherer_discount_factor" , 0.95 , 1.0
167
+ ),
142
168
"sample" : trial .suggest_categorical ("gatherer_sample" , [True , False ]),
143
169
},
144
- "initial_epoch_multiplier" : trial .suggest_float ("initial_epoch_multiplier" , 1 , 200.0 ),
145
- "initial_comparison_frac" : trial .suggest_float ("initial_comparison_frac" , 0.01 , 1.0 ),
170
+ "initial_epoch_multiplier" : trial .suggest_float (
171
+ "initial_epoch_multiplier" , 1 , 200.0
172
+ ),
173
+ "initial_comparison_frac" : trial .suggest_float (
174
+ "initial_comparison_frac" , 0.01 , 1.0
175
+ ),
146
176
"num_iterations" : trial .suggest_int ("num_iterations" , 1 , 51 ),
147
177
"preference_model_kwargs" : {
148
- "noise_prob" : trial .suggest_float ("preference_model_noise_prob" , 0.0 , 0.1 ),
149
- "discount_factor" : trial .suggest_float ("preference_model_discount_factor" , 0.95 , 1.0 ),
178
+ "noise_prob" : trial .suggest_float (
179
+ "preference_model_noise_prob" , 0.0 , 0.1
180
+ ),
181
+ "discount_factor" : trial .suggest_float (
182
+ "preference_model_discount_factor" , 0.95 , 1.0
183
+ ),
150
184
},
151
- "query_schedule" : trial .suggest_categorical ("query_schedule" , ["hyperbolic" , "constant" , "inverse_quadratic" ]),
185
+ "query_schedule" : trial .suggest_categorical (
186
+ "query_schedule" , ["hyperbolic" , "constant" , "inverse_quadratic" ]
187
+ ),
152
188
"trajectory_generator_kwargs" : {
153
189
"switch_prob" : trial .suggest_float ("tr_gen_switch_prob" , 0.1 , 1 ),
154
190
"random_prob" : trial .suggest_float ("tr_gen_random_prob" , 0.1 , 0.9 ),
155
191
},
156
- "transition_oversampling" : trial .suggest_float ("transition_oversampling" , 0.9 , 2.0 ),
192
+ "transition_oversampling" : trial .suggest_float (
193
+ "transition_oversampling" , 0.9 , 2.0
194
+ ),
157
195
"reward_trainer_kwargs" : {
158
196
"epochs" : trial .suggest_int ("reward_trainer_epochs" , 1 , 11 ),
159
197
},
160
198
"rl" : {
161
199
"rl_kwargs" : {
162
- "ent_coef" : trial .suggest_float ("rl_ent_coef" , 1e-7 , 1e-3 , log = True ),
200
+ "ent_coef" : trial .suggest_float (
201
+ "rl_ent_coef" , 1e-7 , 1e-3 , log = True
202
+ ),
163
203
},
164
204
},
165
205
},
@@ -176,22 +216,33 @@ def __call__(
176
216
},
177
217
"rl" : {
178
218
"rl_kwargs" : {
179
- "learning_rate" : trial .suggest_float ("learning_rate" , 1e-6 , 1e-2 , log = True ),
219
+ "learning_rate" : trial .suggest_float (
220
+ "learning_rate" , 1e-6 , 1e-2 , log = True
221
+ ),
180
222
"buffer_size" : trial .suggest_int ("buffer_size" , 1000 , 100000 ),
181
- "learning_starts" : trial .suggest_int ("learning_starts" , 1000 , 10000 ),
223
+ "learning_starts" : trial .suggest_int (
224
+ "learning_starts" , 1000 , 10000
225
+ ),
182
226
"batch_size" : trial .suggest_int ("batch_size" , 32 , 128 ),
183
- "tau" : trial .suggest_float ("tau" , 0. , 1. ),
227
+ "tau" : trial .suggest_float ("tau" , 0.0 , 1.0 ),
184
228
"gamma" : trial .suggest_float ("gamma" , 0.9 , 0.999 ),
185
229
"train_freq" : trial .suggest_int ("train_freq" , 1 , 40 ),
186
230
"gradient_steps" : trial .suggest_int ("gradient_steps" , 1 , 10 ),
187
- "target_update_interval" : trial .suggest_int ("target_update_interval" , 1 , 10000 ),
188
- "exploration_fraction" : trial .suggest_float ("exploration_fraction" , 0.01 , 0.5 ),
189
- "exploration_final_eps" : trial .suggest_float ("exploration_final_eps" , 0.01 , 1.0 ),
190
- "exploration_initial_eps" : trial .suggest_float ("exploration_initial_eps" , 0.01 , 0.5 ),
231
+ "target_update_interval" : trial .suggest_int (
232
+ "target_update_interval" , 1 , 10000
233
+ ),
234
+ "exploration_fraction" : trial .suggest_float (
235
+ "exploration_fraction" , 0.01 , 0.5
236
+ ),
237
+ "exploration_final_eps" : trial .suggest_float (
238
+ "exploration_final_eps" , 0.01 , 1.0
239
+ ),
240
+ "exploration_initial_eps" : trial .suggest_float (
241
+ "exploration_initial_eps" , 0.01 , 0.5
242
+ ),
191
243
"max_grad_norm" : trial .suggest_float ("max_grad_norm" , 0.1 , 10.0 ),
192
-
193
244
},
194
245
},
195
246
},
196
247
),
197
- )
248
+ )
0 commit comments