Skip to content

Commit e9e2ad5

Browse files
committed
Bug
1 parent 5f9f621 commit e9e2ad5

File tree

2 files changed

+19
-18
lines changed

2 files changed

+19
-18
lines changed

scripts/optimize/optuna_crowd.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
from logging import ERROR
3+
from typing import Optional
34

45
import cv2
56
import numpy as np
@@ -37,7 +38,7 @@ class Parser(BaseParser):
3738
}
3839

3940

40-
def objective(trial: optuna.trial.BaseTrial, worker_id: int, path: str) -> float:
41+
def objective(trial: optuna.trial.BaseTrial, idx: Optional[int], worker_id: int, path: str) -> float:
4142
# Get some parameters
4243
lr = trial.suggest_loguniform("lr", 1e-5, 1e-2)
4344
n_episodes = 1
@@ -156,7 +157,7 @@ def objective(trial: optuna.trial.BaseTrial, worker_id: int, path: str) -> float
156157
entity="redtachyon",
157158
sync_tensorboard=True,
158159
config=config,
159-
name=f"trial{trial.number}",
160+
name=f"trial{trial.number}-{idx if idx is not None else ''}",
160161
)
161162

162163
model = RelationModel(config["model"], action_space=env.action_space)
@@ -188,7 +189,7 @@ def objective(trial: optuna.trial.BaseTrial, worker_id: int, path: str) -> float
188189

189190
# EVALUATION
190191
env = UnitySimpleCrowdEnv(
191-
file_name=args.env,
192+
file_name=path,
192193
virtual_display=(1600, 900),
193194
no_graphics=False,
194195
worker_id=worker_id+5,
@@ -243,21 +244,21 @@ def objective(trial: optuna.trial.BaseTrial, worker_id: int, path: str) -> float
243244
# Generate the dashboard
244245
print("Generating dashboard")
245246
print("Skipping dashboard")
246-
# trajectory = du.read_trajectory(trajectory_path)
247-
#
248-
# plt.clf()
249-
# du.make_dashboard(trajectory, save_path=dashboard_path)
247+
trajectory = du.read_trajectory(trajectory_path)
248+
249+
plt.clf()
250+
du.make_dashboard(trajectory, save_path=dashboard_path)
250251

251252
# Upload to wandb
252-
# print("Uploading dashboard")
253-
# wandb.log(
254-
# {
255-
# "dashboard": wandb.Image(
256-
# dashboard_path,
257-
# caption=f"Dashboard {mode} {'det' if d else 'rng'} {i}",
258-
# )
259-
# }
260-
# )
253+
print("Uploading dashboard")
254+
wandb.log(
255+
{
256+
"dashboard": wandb.Image(
257+
dashboard_path,
258+
caption=f"Dashboard {mode} {'det' if d else 'rng'} {i}",
259+
)
260+
}
261+
)
261262

262263
frame_size = renders.shape[1:3]
263264

@@ -300,7 +301,7 @@ def objective(trial: optuna.trial.BaseTrial, worker_id: int, path: str) -> float
300301
study = optuna.load_study(storage=f"sqlite:///{args.optuna_name}.db", study_name=args.optuna_name)
301302

302303
study.optimize(
303-
lambda trial: objective(trial, args.worker_id, args.env), n_trials=args.n_trials
304+
lambda trial: objective(trial, None, args.worker_id, args.env), n_trials=args.n_trials
304305
)
305306

306307
print("Best params:", study.best_params)

scripts/optimize/retrain_crowd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,4 +44,4 @@ class Parser(BaseParser):
4444
print(f"Trial {idx}")
4545
for i in range(args.n_trials):
4646
print(f"Run {i}")
47-
objective(trial, args.worker_id, args.env)
47+
objective(trial, i, args.worker_id, args.env)

0 commit comments

Comments
 (0)