From e9e2ad5c1b0ecb2fc6259f22cb04705cad70a2fe Mon Sep 17 00:00:00 2001 From: Ariel Kwiatkowski Date: Fri, 18 Mar 2022 14:37:09 +0100 Subject: [PATCH] Bug --- scripts/optimize/optuna_crowd.py | 35 ++++++++++++++++--------------- scripts/optimize/retrain_crowd.py | 2 +- 2 files changed, 19 insertions(+), 18 deletions(-) diff --git a/scripts/optimize/optuna_crowd.py b/scripts/optimize/optuna_crowd.py index b7d09211..6f3982cb 100644 --- a/scripts/optimize/optuna_crowd.py +++ b/scripts/optimize/optuna_crowd.py @@ -1,5 +1,6 @@ import os from logging import ERROR +from typing import Optional import cv2 import numpy as np @@ -37,7 +38,7 @@ class Parser(BaseParser): } -def objective(trial: optuna.trial.BaseTrial, worker_id: int, path: str) -> float: +def objective(trial: optuna.trial.BaseTrial, idx: Optional[int], worker_id: int, path: str) -> float: # Get some parameters lr = trial.suggest_loguniform("lr", 1e-5, 1e-2) n_episodes = 1 @@ -156,7 +157,7 @@ def objective(trial: optuna.trial.BaseTrial, worker_id: int, path: str) -> float entity="redtachyon", sync_tensorboard=True, config=config, - name=f"trial{trial.number}", + name=f"trial{trial.number}-{idx if idx is not None else ''}", ) model = RelationModel(config["model"], action_space=env.action_space) @@ -188,7 +189,7 @@ def objective(trial: optuna.trial.BaseTrial, worker_id: int, path: str) -> float # EVALUATION env = UnitySimpleCrowdEnv( - file_name=args.env, + file_name=path, virtual_display=(1600, 900), no_graphics=False, worker_id=worker_id+5, @@ -243,21 +244,21 @@ def objective(trial: optuna.trial.BaseTrial, worker_id: int, path: str) -> float # Generate the dashboard print("Generating dashboard") print("Skipping dashboard") - # trajectory = du.read_trajectory(trajectory_path) - # - # plt.clf() - # du.make_dashboard(trajectory, save_path=dashboard_path) + trajectory = du.read_trajectory(trajectory_path) + + plt.clf() + du.make_dashboard(trajectory, save_path=dashboard_path) # Upload to wandb - # print("Uploading dashboard") - # wandb.log( - # { - # "dashboard": wandb.Image( - # dashboard_path, - # caption=f"Dashboard {mode} {'det' if d else 'rng'} {i}", - # ) - # } - # ) + print("Uploading dashboard") + wandb.log( + { + "dashboard": wandb.Image( + dashboard_path, + caption=f"Dashboard {mode} {'det' if d else 'rng'} {i}", + ) + } + ) frame_size = renders.shape[1:3] @@ -300,7 +301,7 @@ def objective(trial: optuna.trial.BaseTrial, worker_id: int, path: str) -> float study = optuna.load_study(storage=f"sqlite:///{args.optuna_name}.db", study_name=args.optuna_name) study.optimize( - lambda trial: objective(trial, args.worker_id, args.env), n_trials=args.n_trials + lambda trial: objective(trial, None, args.worker_id, args.env), n_trials=args.n_trials ) print("Best params:", study.best_params) diff --git a/scripts/optimize/retrain_crowd.py b/scripts/optimize/retrain_crowd.py index 08b5e245..bc965e88 100644 --- a/scripts/optimize/retrain_crowd.py +++ b/scripts/optimize/retrain_crowd.py @@ -44,4 +44,4 @@ class Parser(BaseParser): print(f"Trial {idx}") for i in range(args.n_trials): print(f"Run {i}") - objective(trial, args.worker_id, args.env) + objective(trial, i, args.worker_id, args.env)