Skip to content

Commit

Permalink
COME ON
Browse files Browse the repository at this point in the history
  • Loading branch information
RedTachyon committed Mar 12, 2022
1 parent c88f0b0 commit 769ac79
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 16 deletions.
1 change: 1 addition & 0 deletions coltra/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ def train(
trial.report(mean_reward, step)
if trial.should_prune():
print("Trial was pruned at step {}".format(step))
self.env.close()
raise optuna.TrialPruned()

return metrics
Expand Down
31 changes: 16 additions & 15 deletions scripts/optimize/optuna_crowd.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def objective(trial: optuna.Trial, worker_id: int, path: str) -> float:
disable_tqdm=False,
save_path=trainer.path,
collect_kwargs=config["environment"],
trial=trial,
# trial=trial,
)

env.close()
Expand All @@ -166,7 +166,7 @@ def objective(trial: optuna.Trial, worker_id: int, path: str) -> float:
file_name=args.env,
virtual_display=(1600, 900),
no_graphics=False,
worker_id=worker_id,
worker_id=worker_id+5,
)
config["environment"]["evaluation_mode"] = 1.0
env.reset(**config["environment"])
Expand Down Expand Up @@ -217,21 +217,22 @@ def objective(trial: optuna.Trial, worker_id: int, path: str) -> float:

# Generate the dashboard
print("Generating dashboard")
trajectory = du.read_trajectory(trajectory_path)

plt.clf()
du.make_dashboard(trajectory, save_path=dashboard_path)
print("Skipping dashboard")
# trajectory = du.read_trajectory(trajectory_path)
#
# plt.clf()
# du.make_dashboard(trajectory, save_path=dashboard_path)

# Upload to wandb
print("Uploading dashboard")
wandb.log(
{
"dashboard": wandb.Image(
dashboard_path,
caption=f"Dashboard {mode} {'det' if d else 'rng'} {i}",
)
}
)
# print("Uploading dashboard")
# wandb.log(
# {
# "dashboard": wandb.Image(
# dashboard_path,
# caption=f"Dashboard {mode} {'det' if d else 'rng'} {i}",
# )
# }
# )

frame_size = renders.shape[1:3]

Expand Down
2 changes: 1 addition & 1 deletion scripts/optimize/optuna_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,6 @@ class Parser(BaseParser):
study = optuna.create_study(
storage=f"sqlite:///{args.name}.db",
study_name=args.name,
pruner=optuna.pruners.HyperbandPruner(),
pruner=None,
direction="maximize",
)

0 comments on commit 769ac79

Please sign in to comment.