Skip to content

Commit

Permalink
feat: Wandb integration
Browse files Browse the repository at this point in the history
#112
Obs: Unsolved Lint formatting issues.
  • Loading branch information
CinquilCinquil committed Nov 28, 2024
1 parent 7a7de6f commit 6579c81
Showing 1 changed file with 14 additions and 1 deletion.
15 changes: 14 additions & 1 deletion experiments/solves/solve_collectables_sb3.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from gymnasium import spaces
from pysc2.env import sc2_env
from stable_baselines3 import PPO
import wandb
from wandb.integration.sb3 import WandbCallback

from urnai.environments.stablebaselines3.custom_env import CustomEnv
from urnai.sc2.actions.collectables import CollectablesActionSpace
Expand Down Expand Up @@ -36,6 +38,16 @@ def declare_trainer():
models_dir = "saves/models/PPO"
logdir = "saves/logs"

conf_dict = {
"policy":"MlpPolicy",
"model_save_name": "PPO"}

run = wandb.init(
project='solve_collectables',
config=conf_dict,
sync_tensorboard=True, # auto-upload sb3's tensorboard metrics
)

model=PPO("MlpPolicy", custom_env, verbose=1, tensorboard_log=logdir)

trainer = SB3Trainer(custom_env, models_dir, logdir, model)
Expand All @@ -46,7 +58,8 @@ def main(unused_argv):
try:
trainer = declare_trainer()
trainer.train_model(timesteps=10000, reset_num_timesteps=False,
tb_log_name="PPO", repeat_times=30)
tb_log_name="PPO", repeat_times=30,
callback=WandbCallback())
# trainer.load_model(f"{trainer.models_dir}/290000")
# trainer.test_model(total_steps=10000, deterministic=True)
except KeyboardInterrupt:
Expand Down

0 comments on commit 6579c81

Please sign in to comment.