Skip to content

Commit f4bcbdc

Browse files
committed
Some formatting fixes.
1 parent 2e356b7 commit f4bcbdc

File tree

11 files changed

+173
-106
lines changed

11 files changed

+173
-106
lines changed

benchmarking/README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -185,4 +185,4 @@ where:
185185
- `algo` is the algorithm you want to compare against
186186

187187
If `your_runs_dir` contains runs for more than one algorithm, you will have to
188-
disambiguate using the `--algo` option.
188+
disambiguate using the `--algo` option.

src/imitation/algorithms/preference_comparisons.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1678,7 +1678,9 @@ def train(
16781678
unnormalized_probs = vec_schedule(np.linspace(0, 1, self.num_iterations))
16791679
probs = unnormalized_probs / np.sum(unnormalized_probs)
16801680
shares = util.oric(probs * total_comparisons)
1681-
shares[shares <= 0] = 1 # ensure we at least request one comparison per iteration
1681+
shares[
1682+
shares <= 0
1683+
] = 1 # ensure we at least request one comparison per iteration
16821684

16831685
schedule = [initial_comparisons] + shares.tolist()
16841686
print(f"Query schedule: {schedule}")

src/imitation/scripts/config/tuning.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -199,9 +199,13 @@ def pc():
199199
"named_configs": ["reward.reward_ensemble"],
200200
"config_updates": {
201201
"active_selection_oversampling": tune.randint(1, 11),
202-
"comparison_queue_size": tune.randint(1, 1001), # upper bound determined by total_comparisons=1000
202+
"comparison_queue_size": tune.randint(
203+
1, 1001
204+
), # upper bound determined by total_comparisons=1000
203205
"exploration_frac": tune.uniform(0.0, 0.5),
204-
"fragment_length": tune.randint(1, 1001), # trajectories are 1000 steps long
206+
"fragment_length": tune.randint(
207+
1, 1001
208+
), # trajectories are 1000 steps long
205209
"gatherer_kwargs": {
206210
"temperature": tune.uniform(0.0, 2.0),
207211
"discount_factor": tune.uniform(0.95, 1.0),
@@ -213,7 +217,9 @@ def pc():
213217
"noise_prob": tune.uniform(0.0, 0.1),
214218
"discount_factor": tune.uniform(0.95, 1.0),
215219
},
216-
"query_schedule": tune.choice(["hyperbolic", "constant", "inverse_quadratic"]),
220+
"query_schedule": tune.choice(
221+
["hyperbolic", "constant", "inverse_quadratic"]
222+
),
217223
"trajectory_generator_kwargs": {
218224
"switch_prob": tune.uniform(0.1, 1),
219225
"random_prob": tune.uniform(0.1, 0.9),

src/imitation/scripts/ingredients/rl.py

-1
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,6 @@ def dqn():
103103
rl_cls = sb3.DQN
104104

105105

106-
107106
def _maybe_add_relabel_buffer(
108107
rl_kwargs: Dict[str, Any],
109108
relabel_reward_fn: Optional[RewardFn] = None,

tuning/README.md

+4-4
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,22 @@ If you want to specify a custom algorithm and search space, add it to the dict i
77

88
You can tune using multiple workers in parallel by running multiple instances of `tune.py` that all point to the same journal log file (see `tune.py --help` for details).
99
To easily launch multiple workers on a SLURM cluster and ensure they don't conflict with each other,
10-
use the `tune_on_slurm.py` script.
10+
use the `tune_on_slurm.py` script.
1111
This script will launch a SLURM job array with the specified number of workers.
1212
If you want to tune all algorithms on all environments on SLURM, use `tune_all_on_slurm.sh`.
1313

1414
# Legacy Tuning Scripts
1515

16-
Note: There are some legacy tuning scripts that can be used like this:
16+
Note: There are some legacy tuning scripts that can be used like this:
1717

1818
The hyperparameters of any algorithm in imitation can be tuned using `src/imitation/scripts/tuning.py`.
1919
The benchmarking hyperparameter configs were generated by tuning the hyperparameters using
2020
the search space defined in the `scripts/config/tuning.py`.
2121

2222
The tuning script proceeds in two phases:
2323
1. Tune the hyperparameters using the search space provided.
24-
2. Re-evaluate the best hyperparameter config found in the first phase
25-
based on the maximum mean return on a separate set of seeds.
24+
2. Re-evaluate the best hyperparameter config found in the first phase
25+
based on the maximum mean return on a separate set of seeds.
2626
Report the mean and standard deviation of these trials.
2727

2828
To use it with the default search space:

tuning/benchmark_analysis.ipynb

+49-34
Original file line numberDiff line numberDiff line change
@@ -40,47 +40,56 @@
4040
"\n",
4141
"for log_file in experiment_log_files:\n",
4242
" d = dict()\n",
43-
" \n",
44-
" d['logfile'] = log_file\n",
45-
" \n",
46-
" study = optuna.load_study(storage=optuna.storages.JournalStorage(\n",
43+
"\n",
44+
" d[\"logfile\"] = log_file\n",
45+
"\n",
46+
" study = optuna.load_study(\n",
47+
" storage=optuna.storages.JournalStorage(\n",
4748
" optuna.storages.JournalFileStorage(str(log_file))\n",
4849
" ),\n",
4950
" # in our case, we have one journal file per study so the study name can be\n",
5051
" # inferred\n",
5152
" study_name=None,\n",
5253
" )\n",
53-
" d['study'] = study\n",
54-
" d['study_name'] = study.study_name\n",
55-
" \n",
54+
" d[\"study\"] = study\n",
55+
" d[\"study_name\"] = study.study_name\n",
56+
"\n",
5657
" trial_state_counter = Counter(t.state for t in study.trials)\n",
5758
" n_completed_trials = trial_state_counter[TrialState.COMPLETE]\n",
58-
" d['trials'] = n_completed_trials\n",
59-
" d['trials_running'] = Counter(t.state for t in study.trials)[TrialState.RUNNING]\n",
60-
" d['trials_failed'] = Counter(t.state for t in study.trials)[TrialState.FAIL]\n",
61-
" d['all_trials'] = len(study.trials)\n",
62-
" \n",
59+
" d[\"trials\"] = n_completed_trials\n",
60+
" d[\"trials_running\"] = Counter(t.state for t in study.trials)[TrialState.RUNNING]\n",
61+
" d[\"trials_failed\"] = Counter(t.state for t in study.trials)[TrialState.FAIL]\n",
62+
" d[\"all_trials\"] = len(study.trials)\n",
63+
"\n",
6364
" if n_completed_trials > 0:\n",
64-
" d['best_value'] = round(study.best_trial.value, 2)\n",
65-
" \n",
65+
" d[\"best_value\"] = round(study.best_trial.value, 2)\n",
66+
"\n",
6667
" assert \"_\" in study.study_name\n",
67-
" study_segments = study.study_name.split(\"_\") \n",
68+
" study_segments = study.study_name.split(\"_\")\n",
6869
" assert len(study_segments) > 3\n",
6970
" tuning, algo, with_ = study_segments[:3]\n",
7071
" assert (tuning, with_) == (\"tuning\", \"with\")\n",
71-
" \n",
72-
" d['algo'] = algo\n",
73-
" d['env'] = \"_\".join(study_segments[3:])\n",
74-
" d['best_trial_duration'] = study.best_trial.duration\n",
75-
" d['mean_duration'] = sum([t.duration for t in study.trials if t.state == TrialState.COMPLETE], datetime.timedelta())/n_completed_trials\n",
76-
" \n",
72+
"\n",
73+
" d[\"algo\"] = algo\n",
74+
" d[\"env\"] = \"_\".join(study_segments[3:])\n",
75+
" d[\"best_trial_duration\"] = study.best_trial.duration\n",
76+
" d[\"mean_duration\"] = (\n",
77+
" sum(\n",
78+
" [t.duration for t in study.trials if t.state == TrialState.COMPLETE],\n",
79+
" datetime.timedelta(),\n",
80+
" )\n",
81+
" / n_completed_trials\n",
82+
" )\n",
83+
"\n",
7784
" reruns_folder = log_file.parent / \"reruns\"\n",
78-
" rerun_results = [round(run['result']['imit_stats']['monitor_return_mean'], 2)\n",
79-
" for conf, run in sfp.find_sacred_runs(reruns_folder, only_completed_runs=True)]\n",
80-
" d['rerun_values'] = rerun_results\n",
81-
" \n",
85+
" rerun_results = [\n",
86+
" round(run[\"result\"][\"imit_stats\"][\"monitor_return_mean\"], 2)\n",
87+
" for conf, run in sfp.find_sacred_runs(reruns_folder, only_completed_runs=True)\n",
88+
" ]\n",
89+
" d[\"rerun_values\"] = rerun_results\n",
90+
"\n",
8291
" raw_study_data.append(d)\n",
83-
" \n",
92+
"\n",
8493
"study_data = pd.DataFrame(raw_study_data)"
8594
]
8695
},
@@ -103,7 +112,7 @@
103112
" \"seals_humanoid\",\n",
104113
" \"seals_cartpole\",\n",
105114
" \"pendulum\",\n",
106-
" \"seals_mountain_car\"\n",
115+
" \"seals_mountain_car\",\n",
107116
"]\n",
108117
"\n",
109118
"pc_paper_700 = dict(\n",
@@ -163,12 +172,14 @@
163172
" for env, value in values_by_env.items():\n",
164173
" if value == \"-\":\n",
165174
" continue\n",
166-
" raw_study_data.append(dict(\n",
167-
" algo=algo,\n",
168-
" env=env,\n",
169-
" best_value=value,\n",
170-
" ))\n",
171-
" \n",
175+
" raw_study_data.append(\n",
176+
" dict(\n",
177+
" algo=algo,\n",
178+
" env=env,\n",
179+
" best_value=value,\n",
180+
" )\n",
181+
" )\n",
182+
"\n",
172183
"study_data = pd.DataFrame(raw_study_data)"
173184
]
174185
},
@@ -185,7 +196,11 @@
185196
"display(study_data[[\"algo\", \"env\", \"best_value\"]])\n",
186197
"\n",
187198
"print(\"Rerun Data\")\n",
188-
"display(study_data[[\"algo\", \"env\", \"best_value\", \"rerun_values\"]][study_data[\"rerun_values\"].map(np.std) > 0])"
199+
"display(\n",
200+
" study_data[[\"algo\", \"env\", \"best_value\", \"rerun_values\"]][\n",
201+
" study_data[\"rerun_values\"].map(np.std) > 0\n",
202+
" ]\n",
203+
")"
189204
]
190205
}
191206
],

0 commit comments

Comments
 (0)