Skip to content

Commit f4bcbdc

Browse files
committedFeb 29, 2024·
Some formatting fixes.
1 parent 2e356b7 commit f4bcbdc

File tree

11 files changed

+173
-106
lines changed

11 files changed

+173
-106
lines changed
 

‎benchmarking/README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -185,4 +185,4 @@ where:
185185
- `algo` is the algorithm you want to compare against
186186

187187
If `your_runs_dir` contains runs for more than one algorithm, you will have to
188-
disambiguate using the `--algo` option.
188+
disambiguate using the `--algo` option.

‎src/imitation/algorithms/preference_comparisons.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1678,7 +1678,9 @@ def train(
16781678
unnormalized_probs = vec_schedule(np.linspace(0, 1, self.num_iterations))
16791679
probs = unnormalized_probs / np.sum(unnormalized_probs)
16801680
shares = util.oric(probs * total_comparisons)
1681-
shares[shares <= 0] = 1 # ensure we at least request one comparison per iteration
1681+
shares[
1682+
shares <= 0
1683+
] = 1 # ensure we at least request one comparison per iteration
16821684

16831685
schedule = [initial_comparisons] + shares.tolist()
16841686
print(f"Query schedule: {schedule}")

‎src/imitation/scripts/config/tuning.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -199,9 +199,13 @@ def pc():
199199
"named_configs": ["reward.reward_ensemble"],
200200
"config_updates": {
201201
"active_selection_oversampling": tune.randint(1, 11),
202-
"comparison_queue_size": tune.randint(1, 1001), # upper bound determined by total_comparisons=1000
202+
"comparison_queue_size": tune.randint(
203+
1, 1001
204+
), # upper bound determined by total_comparisons=1000
203205
"exploration_frac": tune.uniform(0.0, 0.5),
204-
"fragment_length": tune.randint(1, 1001), # trajectories are 1000 steps long
206+
"fragment_length": tune.randint(
207+
1, 1001
208+
), # trajectories are 1000 steps long
205209
"gatherer_kwargs": {
206210
"temperature": tune.uniform(0.0, 2.0),
207211
"discount_factor": tune.uniform(0.95, 1.0),
@@ -213,7 +217,9 @@ def pc():
213217
"noise_prob": tune.uniform(0.0, 0.1),
214218
"discount_factor": tune.uniform(0.95, 1.0),
215219
},
216-
"query_schedule": tune.choice(["hyperbolic", "constant", "inverse_quadratic"]),
220+
"query_schedule": tune.choice(
221+
["hyperbolic", "constant", "inverse_quadratic"]
222+
),
217223
"trajectory_generator_kwargs": {
218224
"switch_prob": tune.uniform(0.1, 1),
219225
"random_prob": tune.uniform(0.1, 0.9),

‎src/imitation/scripts/ingredients/rl.py

-1
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,6 @@ def dqn():
103103
rl_cls = sb3.DQN
104104

105105

106-
107106
def _maybe_add_relabel_buffer(
108107
rl_kwargs: Dict[str, Any],
109108
relabel_reward_fn: Optional[RewardFn] = None,

‎tuning/README.md

+4-4
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,22 @@ If you want to specify a custom algorithm and search space, add it to the dict i
77

88
You can tune using multiple workers in parallel by running multiple instances of `tune.py` that all point to the same journal log file (see `tune.py --help` for details).
99
To easily launch multiple workers on a SLURM cluster and ensure they don't conflict with each other,
10-
use the `tune_on_slurm.py` script.
10+
use the `tune_on_slurm.py` script.
1111
This script will launch a SLURM job array with the specified number of workers.
1212
If you want to tune all algorithms on all environments on SLURM, use `tune_all_on_slurm.sh`.
1313

1414
# Legacy Tuning Scripts
1515

16-
Note: There are some legacy tuning scripts that can be used like this:
16+
Note: There are some legacy tuning scripts that can be used like this:
1717

1818
The hyperparameters of any algorithm in imitation can be tuned using `src/imitation/scripts/tuning.py`.
1919
The benchmarking hyperparameter configs were generated by tuning the hyperparameters using
2020
the search space defined in the `scripts/config/tuning.py`.
2121

2222
The tuning script proceeds in two phases:
2323
1. Tune the hyperparameters using the search space provided.
24-
2. Re-evaluate the best hyperparameter config found in the first phase
25-
based on the maximum mean return on a separate set of seeds.
24+
2. Re-evaluate the best hyperparameter config found in the first phase
25+
based on the maximum mean return on a separate set of seeds.
2626
Report the mean and standard deviation of these trials.
2727

2828
To use it with the default search space:

‎tuning/benchmark_analysis.ipynb

+49-34
Original file line numberDiff line numberDiff line change
@@ -40,47 +40,56 @@
4040
"\n",
4141
"for log_file in experiment_log_files:\n",
4242
" d = dict()\n",
43-
" \n",
44-
" d['logfile'] = log_file\n",
45-
" \n",
46-
" study = optuna.load_study(storage=optuna.storages.JournalStorage(\n",
43+
"\n",
44+
" d[\"logfile\"] = log_file\n",
45+
"\n",
46+
" study = optuna.load_study(\n",
47+
" storage=optuna.storages.JournalStorage(\n",
4748
" optuna.storages.JournalFileStorage(str(log_file))\n",
4849
" ),\n",
4950
" # in our case, we have one journal file per study so the study name can be\n",
5051
" # inferred\n",
5152
" study_name=None,\n",
5253
" )\n",
53-
" d['study'] = study\n",
54-
" d['study_name'] = study.study_name\n",
55-
" \n",
54+
" d[\"study\"] = study\n",
55+
" d[\"study_name\"] = study.study_name\n",
56+
"\n",
5657
" trial_state_counter = Counter(t.state for t in study.trials)\n",
5758
" n_completed_trials = trial_state_counter[TrialState.COMPLETE]\n",
58-
" d['trials'] = n_completed_trials\n",
59-
" d['trials_running'] = Counter(t.state for t in study.trials)[TrialState.RUNNING]\n",
60-
" d['trials_failed'] = Counter(t.state for t in study.trials)[TrialState.FAIL]\n",
61-
" d['all_trials'] = len(study.trials)\n",
62-
" \n",
59+
" d[\"trials\"] = n_completed_trials\n",
60+
" d[\"trials_running\"] = Counter(t.state for t in study.trials)[TrialState.RUNNING]\n",
61+
" d[\"trials_failed\"] = Counter(t.state for t in study.trials)[TrialState.FAIL]\n",
62+
" d[\"all_trials\"] = len(study.trials)\n",
63+
"\n",
6364
" if n_completed_trials > 0:\n",
64-
" d['best_value'] = round(study.best_trial.value, 2)\n",
65-
" \n",
65+
" d[\"best_value\"] = round(study.best_trial.value, 2)\n",
66+
"\n",
6667
" assert \"_\" in study.study_name\n",
67-
" study_segments = study.study_name.split(\"_\") \n",
68+
" study_segments = study.study_name.split(\"_\")\n",
6869
" assert len(study_segments) > 3\n",
6970
" tuning, algo, with_ = study_segments[:3]\n",
7071
" assert (tuning, with_) == (\"tuning\", \"with\")\n",
71-
" \n",
72-
" d['algo'] = algo\n",
73-
" d['env'] = \"_\".join(study_segments[3:])\n",
74-
" d['best_trial_duration'] = study.best_trial.duration\n",
75-
" d['mean_duration'] = sum([t.duration for t in study.trials if t.state == TrialState.COMPLETE], datetime.timedelta())/n_completed_trials\n",
76-
" \n",
72+
"\n",
73+
" d[\"algo\"] = algo\n",
74+
" d[\"env\"] = \"_\".join(study_segments[3:])\n",
75+
" d[\"best_trial_duration\"] = study.best_trial.duration\n",
76+
" d[\"mean_duration\"] = (\n",
77+
" sum(\n",
78+
" [t.duration for t in study.trials if t.state == TrialState.COMPLETE],\n",
79+
" datetime.timedelta(),\n",
80+
" )\n",
81+
" / n_completed_trials\n",
82+
" )\n",
83+
"\n",
7784
" reruns_folder = log_file.parent / \"reruns\"\n",
78-
" rerun_results = [round(run['result']['imit_stats']['monitor_return_mean'], 2)\n",
79-
" for conf, run in sfp.find_sacred_runs(reruns_folder, only_completed_runs=True)]\n",
80-
" d['rerun_values'] = rerun_results\n",
81-
" \n",
85+
" rerun_results = [\n",
86+
" round(run[\"result\"][\"imit_stats\"][\"monitor_return_mean\"], 2)\n",
87+
" for conf, run in sfp.find_sacred_runs(reruns_folder, only_completed_runs=True)\n",
88+
" ]\n",
89+
" d[\"rerun_values\"] = rerun_results\n",
90+
"\n",
8291
" raw_study_data.append(d)\n",
83-
" \n",
92+
"\n",
8493
"study_data = pd.DataFrame(raw_study_data)"
8594
]
8695
},
@@ -103,7 +112,7 @@
103112
" \"seals_humanoid\",\n",
104113
" \"seals_cartpole\",\n",
105114
" \"pendulum\",\n",
106-
" \"seals_mountain_car\"\n",
115+
" \"seals_mountain_car\",\n",
107116
"]\n",
108117
"\n",
109118
"pc_paper_700 = dict(\n",
@@ -163,12 +172,14 @@
163172
" for env, value in values_by_env.items():\n",
164173
" if value == \"-\":\n",
165174
" continue\n",
166-
" raw_study_data.append(dict(\n",
167-
" algo=algo,\n",
168-
" env=env,\n",
169-
" best_value=value,\n",
170-
" ))\n",
171-
" \n",
175+
" raw_study_data.append(\n",
176+
" dict(\n",
177+
" algo=algo,\n",
178+
" env=env,\n",
179+
" best_value=value,\n",
180+
" )\n",
181+
" )\n",
182+
"\n",
172183
"study_data = pd.DataFrame(raw_study_data)"
173184
]
174185
},
@@ -185,7 +196,11 @@
185196
"display(study_data[[\"algo\", \"env\", \"best_value\"]])\n",
186197
"\n",
187198
"print(\"Rerun Data\")\n",
188-
"display(study_data[[\"algo\", \"env\", \"best_value\", \"rerun_values\"]][study_data[\"rerun_values\"].map(np.std) > 0])"
199+
"display(\n",
200+
" study_data[[\"algo\", \"env\", \"best_value\", \"rerun_values\"]][\n",
201+
" study_data[\"rerun_values\"].map(np.std) > 0\n",
202+
" ]\n",
203+
")"
189204
]
190205
}
191206
],

‎tuning/hp_search_spaces.py

+89-38
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
"""
1515

1616
import dataclasses
17-
from typing import Callable, List, Mapping, Any, Dict
17+
from typing import Any, Callable, Dict, List, Mapping
1818

1919
import optuna
2020
import sacred
@@ -35,7 +35,6 @@ class RunSacredAsTrial:
3535
"""The sacred experiment to run."""
3636
sacred_ex: sacred.Experiment
3737

38-
3938
"""A function that returns a list of named configs to pass to sacred.run."""
4039
suggest_named_configs: Callable[[optuna.Trial], List[str]]
4140

@@ -46,10 +45,7 @@ class RunSacredAsTrial:
4645
command_name: str = None
4746

4847
def __call__(
49-
self,
50-
trial: optuna.Trial,
51-
run_options: Dict,
52-
extra_named_configs: List[str]
48+
self, trial: optuna.Trial, run_options: Dict, extra_named_configs: List[str]
5349
) -> float:
5450
"""Run the sacred experiment and return the performance.
5551
@@ -77,7 +73,7 @@ def __call__(
7773
raise RuntimeError(
7874
f"Trial failed with {result.fail_trace()} and status {result.status}."
7975
)
80-
return result.result['imit_stats']['monitor_return_mean']
76+
return result.result["imit_stats"]["monitor_return_mean"]
8177

8278

8379
"""A mapping from algorithm names to functions that run the algorithm as an optuna trial."""
@@ -91,34 +87,56 @@ def __call__(
9187
"total_timesteps": 2e7,
9288
"total_comparisons": 1000,
9389
"active_selection": True,
94-
"active_selection_oversampling": trial.suggest_int("active_selection_oversampling", 1, 11),
95-
"comparison_queue_size": trial.suggest_int("comparison_queue_size", 1, 1001), # upper bound determined by total_comparisons=1000
90+
"active_selection_oversampling": trial.suggest_int(
91+
"active_selection_oversampling", 1, 11
92+
),
93+
"comparison_queue_size": trial.suggest_int(
94+
"comparison_queue_size", 1, 1001
95+
), # upper bound determined by total_comparisons=1000
9696
"exploration_frac": trial.suggest_float("exploration_frac", 0.0, 0.5),
97-
"fragment_length": trial.suggest_int("fragment_length", 1, 1001), # trajectories are 1000 steps long
97+
"fragment_length": trial.suggest_int(
98+
"fragment_length", 1, 1001
99+
), # trajectories are 1000 steps long
98100
"gatherer_kwargs": {
99101
"temperature": trial.suggest_float("gatherer_temperature", 0.0, 2.0),
100-
"discount_factor": trial.suggest_float("gatherer_discount_factor", 0.95, 1.0),
102+
"discount_factor": trial.suggest_float(
103+
"gatherer_discount_factor", 0.95, 1.0
104+
),
101105
"sample": trial.suggest_categorical("gatherer_sample", [True, False]),
102106
},
103-
"initial_epoch_multiplier": trial.suggest_float("initial_epoch_multiplier", 1, 200.0),
104-
"initial_comparison_frac": trial.suggest_float("initial_comparison_frac", 0.01, 1.0),
107+
"initial_epoch_multiplier": trial.suggest_float(
108+
"initial_epoch_multiplier", 1, 200.0
109+
),
110+
"initial_comparison_frac": trial.suggest_float(
111+
"initial_comparison_frac", 0.01, 1.0
112+
),
105113
"num_iterations": trial.suggest_int("num_iterations", 1, 51),
106114
"preference_model_kwargs": {
107-
"noise_prob": trial.suggest_float("preference_model_noise_prob", 0.0, 0.1),
108-
"discount_factor": trial.suggest_float("preference_model_discount_factor", 0.95, 1.0),
115+
"noise_prob": trial.suggest_float(
116+
"preference_model_noise_prob", 0.0, 0.1
117+
),
118+
"discount_factor": trial.suggest_float(
119+
"preference_model_discount_factor", 0.95, 1.0
120+
),
109121
},
110-
"query_schedule": trial.suggest_categorical("query_schedule", ["hyperbolic", "constant", "inverse_quadratic"]),
122+
"query_schedule": trial.suggest_categorical(
123+
"query_schedule", ["hyperbolic", "constant", "inverse_quadratic"]
124+
),
111125
"trajectory_generator_kwargs": {
112126
"switch_prob": trial.suggest_float("tr_gen_switch_prob", 0.1, 1),
113127
"random_prob": trial.suggest_float("tr_gen_random_prob", 0.1, 0.9),
114128
},
115-
"transition_oversampling": trial.suggest_float("transition_oversampling", 0.9, 2.0),
129+
"transition_oversampling": trial.suggest_float(
130+
"transition_oversampling", 0.9, 2.0
131+
),
116132
"reward_trainer_kwargs": {
117133
"epochs": trial.suggest_int("reward_trainer_epochs", 1, 11),
118134
},
119135
"rl": {
120136
"rl_kwargs": {
121-
"ent_coef": trial.suggest_float("rl_ent_coef", 1e-7, 1e-3, log=True),
137+
"ent_coef": trial.suggest_float(
138+
"rl_ent_coef", 1e-7, 1e-3, log=True
139+
),
122140
},
123141
},
124142
},
@@ -132,34 +150,56 @@ def __call__(
132150
"total_timesteps": 1e6,
133151
"total_comparisons": 1000,
134152
"active_selection": True,
135-
"active_selection_oversampling": trial.suggest_int("active_selection_oversampling", 1, 11),
136-
"comparison_queue_size": trial.suggest_int("comparison_queue_size", 1, 1001), # upper bound determined by total_comparisons=1000
153+
"active_selection_oversampling": trial.suggest_int(
154+
"active_selection_oversampling", 1, 11
155+
),
156+
"comparison_queue_size": trial.suggest_int(
157+
"comparison_queue_size", 1, 1001
158+
), # upper bound determined by total_comparisons=1000
137159
"exploration_frac": trial.suggest_float("exploration_frac", 0.0, 0.5),
138-
"fragment_length": trial.suggest_int("fragment_length", 1, 201), # trajectories are 1000 steps long
160+
"fragment_length": trial.suggest_int(
161+
"fragment_length", 1, 201
162+
), # trajectories are 1000 steps long
139163
"gatherer_kwargs": {
140164
"temperature": trial.suggest_float("gatherer_temperature", 0.0, 2.0),
141-
"discount_factor": trial.suggest_float("gatherer_discount_factor", 0.95, 1.0),
165+
"discount_factor": trial.suggest_float(
166+
"gatherer_discount_factor", 0.95, 1.0
167+
),
142168
"sample": trial.suggest_categorical("gatherer_sample", [True, False]),
143169
},
144-
"initial_epoch_multiplier": trial.suggest_float("initial_epoch_multiplier", 1, 200.0),
145-
"initial_comparison_frac": trial.suggest_float("initial_comparison_frac", 0.01, 1.0),
170+
"initial_epoch_multiplier": trial.suggest_float(
171+
"initial_epoch_multiplier", 1, 200.0
172+
),
173+
"initial_comparison_frac": trial.suggest_float(
174+
"initial_comparison_frac", 0.01, 1.0
175+
),
146176
"num_iterations": trial.suggest_int("num_iterations", 1, 51),
147177
"preference_model_kwargs": {
148-
"noise_prob": trial.suggest_float("preference_model_noise_prob", 0.0, 0.1),
149-
"discount_factor": trial.suggest_float("preference_model_discount_factor", 0.95, 1.0),
178+
"noise_prob": trial.suggest_float(
179+
"preference_model_noise_prob", 0.0, 0.1
180+
),
181+
"discount_factor": trial.suggest_float(
182+
"preference_model_discount_factor", 0.95, 1.0
183+
),
150184
},
151-
"query_schedule": trial.suggest_categorical("query_schedule", ["hyperbolic", "constant", "inverse_quadratic"]),
185+
"query_schedule": trial.suggest_categorical(
186+
"query_schedule", ["hyperbolic", "constant", "inverse_quadratic"]
187+
),
152188
"trajectory_generator_kwargs": {
153189
"switch_prob": trial.suggest_float("tr_gen_switch_prob", 0.1, 1),
154190
"random_prob": trial.suggest_float("tr_gen_random_prob", 0.1, 0.9),
155191
},
156-
"transition_oversampling": trial.suggest_float("transition_oversampling", 0.9, 2.0),
192+
"transition_oversampling": trial.suggest_float(
193+
"transition_oversampling", 0.9, 2.0
194+
),
157195
"reward_trainer_kwargs": {
158196
"epochs": trial.suggest_int("reward_trainer_epochs", 1, 11),
159197
},
160198
"rl": {
161199
"rl_kwargs": {
162-
"ent_coef": trial.suggest_float("rl_ent_coef", 1e-7, 1e-3, log=True),
200+
"ent_coef": trial.suggest_float(
201+
"rl_ent_coef", 1e-7, 1e-3, log=True
202+
),
163203
},
164204
},
165205
},
@@ -176,22 +216,33 @@ def __call__(
176216
},
177217
"rl": {
178218
"rl_kwargs": {
179-
"learning_rate": trial.suggest_float("learning_rate", 1e-6, 1e-2, log=True),
219+
"learning_rate": trial.suggest_float(
220+
"learning_rate", 1e-6, 1e-2, log=True
221+
),
180222
"buffer_size": trial.suggest_int("buffer_size", 1000, 100000),
181-
"learning_starts": trial.suggest_int("learning_starts", 1000, 10000),
223+
"learning_starts": trial.suggest_int(
224+
"learning_starts", 1000, 10000
225+
),
182226
"batch_size": trial.suggest_int("batch_size", 32, 128),
183-
"tau": trial.suggest_float("tau", 0., 1.),
227+
"tau": trial.suggest_float("tau", 0.0, 1.0),
184228
"gamma": trial.suggest_float("gamma", 0.9, 0.999),
185229
"train_freq": trial.suggest_int("train_freq", 1, 40),
186230
"gradient_steps": trial.suggest_int("gradient_steps", 1, 10),
187-
"target_update_interval": trial.suggest_int("target_update_interval", 1, 10000),
188-
"exploration_fraction": trial.suggest_float("exploration_fraction", 0.01, 0.5),
189-
"exploration_final_eps": trial.suggest_float("exploration_final_eps", 0.01, 1.0),
190-
"exploration_initial_eps": trial.suggest_float("exploration_initial_eps", 0.01, 0.5),
231+
"target_update_interval": trial.suggest_int(
232+
"target_update_interval", 1, 10000
233+
),
234+
"exploration_fraction": trial.suggest_float(
235+
"exploration_fraction", 0.01, 0.5
236+
),
237+
"exploration_final_eps": trial.suggest_float(
238+
"exploration_final_eps", 0.01, 1.0
239+
),
240+
"exploration_initial_eps": trial.suggest_float(
241+
"exploration_initial_eps", 0.01, 0.5
242+
),
191243
"max_grad_norm": trial.suggest_float("max_grad_norm", 0.1, 10.0),
192-
193244
},
194245
},
195246
},
196247
),
197-
)
248+
)

‎tuning/rerun_best_trial.py

+11-13
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,15 @@
33
import random
44
from typing import List, Tuple
55

6+
import hp_search_spaces
67
import optuna
78
import sacred
89

9-
import hp_search_spaces
10-
1110

1211
def make_parser() -> argparse.ArgumentParser:
1312
parser = argparse.ArgumentParser(
14-
description=
15-
"Re-run the best trial from a previous tuning run.",
16-
epilog=f"Example usage:\n"
17-
f"python rerun_best_trials.py tuning_run.json\n",
13+
description="Re-run the best trial from a previous tuning run.",
14+
epilog=f"Example usage:\n" f"python rerun_best_trials.py tuning_run.json\n",
1815
formatter_class=argparse.RawDescriptionHelpFormatter,
1916
)
2017
parser.add_argument(
@@ -23,18 +20,18 @@ def make_parser() -> argparse.ArgumentParser:
2320
default=None,
2421
choices=hp_search_spaces.objectives_by_algo.keys(),
2522
help="The algorithm that has been tuned. "
26-
"Can usually be deduced from the study name.",
23+
"Can usually be deduced from the study name.",
2724
)
2825
parser.add_argument(
2926
"journal_log",
3027
type=str,
31-
help="The optuna journal file of the previous tuning run."
28+
help="The optuna journal file of the previous tuning run.",
3229
)
3330
parser.add_argument(
3431
"--seed",
3532
type=int,
3633
default=random.randint(0, 2**32 - 1),
37-
help="The seed to use for the re-run. A random seed is used by default."
34+
help="The seed to use for the re-run. A random seed is used by default.",
3835
)
3936
return parser
4037

@@ -46,7 +43,7 @@ def infer_algo_name(study: optuna.Study) -> str:
4643
"""
4744
assert study.study_name.startswith("tuning_")
4845
assert "_with_" in study.study_name
49-
return study.study_name[len("tuning_"):].split("_with_")[0]
46+
return study.study_name[len("tuning_") :].split("_with_")[0]
5047

5148

5249
def main():
@@ -63,21 +60,22 @@ def main():
6360
trial = study.best_trial
6461

6562
algo_name = args.algo or infer_algo_name(study)
66-
sacred_experiment: sacred.Experiment = hp_search_spaces.objectives_by_algo[algo_name].sacred_ex
63+
sacred_experiment: sacred.Experiment = hp_search_spaces.objectives_by_algo[
64+
algo_name
65+
].sacred_ex
6766

6867
config_updates = trial.user_attrs["config_updates"].copy()
6968
config_updates["seed"] = args.seed
7069
result = sacred_experiment.run(
7170
config_updates=config_updates,
7271
named_configs=trial.user_attrs["named_configs"],
7372
options={"--name": study.study_name, "--file_storage": "sacred"},
74-
7573
)
7674
if result.status != "COMPLETED":
7775
raise RuntimeError(
7876
f"Trial failed with {result.fail_trace()} and status {result.status}."
7977
)
8078

8179

82-
if __name__ == '__main__':
80+
if __name__ == "__main__":
8381
main()

‎tuning/tune.py

+5-9
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import argparse
33

44
import optuna
5-
65
from hp_search_spaces import objectives_by_algo
76

87

@@ -31,21 +30,18 @@ def make_parser() -> argparse.ArgumentParser:
3130
nargs="+",
3231
default=[],
3332
help="Additional named configs to pass to the sacred experiment. "
34-
"Use this to select the environment to tune on.",
33+
"Use this to select the environment to tune on.",
3534
)
3635
parser.add_argument(
37-
"--num_trials",
38-
type=int,
39-
default=100,
40-
help="Number of trials to run."
36+
"--num_trials", type=int, default=100, help="Number of trials to run."
4137
)
4238
parser.add_argument(
4339
"-j",
4440
"--journal-log",
4541
type=str,
4642
default=None,
4743
help="A journal file to synchronize multiple instances of this script. "
48-
"Works on NFS storage."
44+
"Works on NFS storage.",
4945
)
5046
return parser
5147

@@ -75,12 +71,12 @@ def main():
7571
lambda trial: objectives_by_algo[args.algo](
7672
trial,
7773
run_options={"--name": study.study_name, "--file_storage": "sacred"},
78-
extra_named_configs=args.named_configs
74+
extra_named_configs=args.named_configs,
7975
),
8076
callbacks=[optuna.study.MaxTrialsCallback(args.num_trials)],
8177
gc_after_trial=True,
8278
)
8379

8480

85-
if __name__ == '__main__':
81+
if __name__ == "__main__":
8682
main()

‎tuning/tune_all_on_slurm.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,4 @@ sbatch --job-name=tuning_pc_on_pendulum tune_on_slurm.sh pc pendulum
1212
sbatch --job-name=tuning_pc_on_seals_mountain_car tune_on_slurm.sh pc seals_mountain_car
1313

1414
sbatch --job-name=tuning_sqil_on_seals_mountain_car tune_on_slurm.sh sqil seals_mountain_car
15-
sbatch --job-name=tuning_sqil_on_seals_cartpole tune_on_slurm.sh sqil seals_cartpole
15+
sbatch --job-name=tuning_sqil_on_seals_cartpole tune_on_slurm.sh sqil seals_cartpole

‎tuning/tune_on_slurm.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -72,4 +72,4 @@ fi
7272

7373
cd "$SLURM_JOB_NAME/$SLURM_ARRAY_TASK_ID" || exit
7474

75-
srun --output=cout.txt python ../../tune.py --num_trials 400 -j ../optuna_study.log "$1" "$2"
75+
srun --output=cout.txt python ../../tune.py --num_trials 400 -j ../optuna_study.log "$1" "$2"

0 commit comments

Comments
 (0)
Please sign in to comment.