|
1 | 1 | import os
|
2 | 2 | from logging import ERROR
|
| 3 | +from typing import Optional |
3 | 4 |
|
4 | 5 | import cv2
|
5 | 6 | import numpy as np
|
@@ -37,7 +38,7 @@ class Parser(BaseParser):
|
37 | 38 | }
|
38 | 39 |
|
39 | 40 |
|
40 |
| -def objective(trial: optuna.trial.BaseTrial, worker_id: int, path: str) -> float: |
| 41 | +def objective(trial: optuna.trial.BaseTrial, idx: Optional[int], worker_id: int, path: str) -> float: |
41 | 42 | # Get some parameters
|
42 | 43 | lr = trial.suggest_loguniform("lr", 1e-5, 1e-2)
|
43 | 44 | n_episodes = 1
|
@@ -156,7 +157,7 @@ def objective(trial: optuna.trial.BaseTrial, worker_id: int, path: str) -> float
|
156 | 157 | entity="redtachyon",
|
157 | 158 | sync_tensorboard=True,
|
158 | 159 | config=config,
|
159 |
| - name=f"trial{trial.number}", |
| 160 | + name=f"trial{trial.number}-{idx if idx is not None else ''}", |
160 | 161 | )
|
161 | 162 |
|
162 | 163 | model = RelationModel(config["model"], action_space=env.action_space)
|
@@ -188,7 +189,7 @@ def objective(trial: optuna.trial.BaseTrial, worker_id: int, path: str) -> float
|
188 | 189 |
|
189 | 190 | # EVALUATION
|
190 | 191 | env = UnitySimpleCrowdEnv(
|
191 |
| - file_name=args.env, |
| 192 | + file_name=path, |
192 | 193 | virtual_display=(1600, 900),
|
193 | 194 | no_graphics=False,
|
194 | 195 | worker_id=worker_id+5,
|
@@ -243,21 +244,21 @@ def objective(trial: optuna.trial.BaseTrial, worker_id: int, path: str) -> float
|
243 | 244 | # Generate the dashboard
|
244 | 245 | print("Generating dashboard")
|
245 | 246 | print("Skipping dashboard")
|
246 |
| - # trajectory = du.read_trajectory(trajectory_path) |
247 |
| - # |
248 |
| - # plt.clf() |
249 |
| - # du.make_dashboard(trajectory, save_path=dashboard_path) |
| 247 | + trajectory = du.read_trajectory(trajectory_path) |
| 248 | + |
| 249 | + plt.clf() |
| 250 | + du.make_dashboard(trajectory, save_path=dashboard_path) |
250 | 251 |
|
251 | 252 | # Upload to wandb
|
252 |
| - # print("Uploading dashboard") |
253 |
| - # wandb.log( |
254 |
| - # { |
255 |
| - # "dashboard": wandb.Image( |
256 |
| - # dashboard_path, |
257 |
| - # caption=f"Dashboard {mode} {'det' if d else 'rng'} {i}", |
258 |
| - # ) |
259 |
| - # } |
260 |
| - # ) |
| 253 | + print("Uploading dashboard") |
| 254 | + wandb.log( |
| 255 | + { |
| 256 | + "dashboard": wandb.Image( |
| 257 | + dashboard_path, |
| 258 | + caption=f"Dashboard {mode} {'det' if d else 'rng'} {i}", |
| 259 | + ) |
| 260 | + } |
| 261 | + ) |
261 | 262 |
|
262 | 263 | frame_size = renders.shape[1:3]
|
263 | 264 |
|
@@ -300,7 +301,7 @@ def objective(trial: optuna.trial.BaseTrial, worker_id: int, path: str) -> float
|
300 | 301 | study = optuna.load_study(storage=f"sqlite:///{args.optuna_name}.db", study_name=args.optuna_name)
|
301 | 302 |
|
302 | 303 | study.optimize(
|
303 |
| - lambda trial: objective(trial, args.worker_id, args.env), n_trials=args.n_trials |
| 304 | + lambda trial: objective(trial, None, args.worker_id, args.env), n_trials=args.n_trials |
304 | 305 | )
|
305 | 306 |
|
306 | 307 | print("Best params:", study.best_params)
|
|
0 commit comments