From fff746de5e583c08f5cd5f20003fa357902fbe0f Mon Sep 17 00:00:00 2001 From: RickFqt Date: Wed, 18 Dec 2024 08:35:58 -0300 Subject: [PATCH] refactor: Changed to train and test in only one run --- experiments/solves/solve_collectables_sb3.py | 43 ++++++--- urnai/trainers/stablebaselines3_trainer.py | 95 ++++++++------------ 2 files changed, 67 insertions(+), 71 deletions(-) diff --git a/experiments/solves/solve_collectables_sb3.py b/experiments/solves/solve_collectables_sb3.py index b93896a..05904a2 100644 --- a/experiments/solves/solve_collectables_sb3.py +++ b/experiments/solves/solve_collectables_sb3.py @@ -10,15 +10,30 @@ from pysc2.env import sc2_env from stable_baselines3 import PPO +import wandb from urnai.environments.stablebaselines3.custom_env import CustomEnv from urnai.sc2.actions.collectables import CollectablesActionSpace from urnai.sc2.environments.sc2environment import SC2Env from urnai.sc2.rewards.collectables import CollectablesReward from urnai.sc2.states.collectables import CollectablesMethod, CollectablesState from urnai.trainers.stablebaselines3_trainer import SB3Trainer +from wandb.integration.sb3 import WandbCallback -def declare_trainer(): +def declare_wandb_run(config_dict : dict, run_id : str = None): + + wandb_run = wandb.init( + project='solve_collectables', + config=config_dict, + name=config_dict['model_save_name'], + sync_tensorboard=True, # auto-upload sb3's tensorboard metrics + resume="must" if run_id else None, + id=run_id + ) + + return wandb_run + +def declare_trainer(config_dict : dict): players = [sc2_env.Agent(sc2_env.Race.terran)] env = SC2Env(map_name='CollectMineralShards', visualize=False, step_mul=16, players=players) @@ -34,26 +49,30 @@ def declare_trainer(): custom_env = CustomEnv(env, state, urnai_action_space, reward, observation_space, action_space) - model_name = "PPOMlp" - models_dir = f"saves/models/{model_name}" + models_dir = f"saves/models/{config_dict['model_save_name']}" logdir = "saves/logs" - conf_dict = { - "policy":"MlpPolicy", - "model_save_name": model_name} - - model=PPO("CnnPolicy", custom_env, verbose=1, tensorboard_log=logdir) + model=PPO(config_dict['policy'], custom_env, verbose=1, tensorboard_log=logdir) - trainer = SB3Trainer(custom_env, models_dir, logdir, model, model_name, - "solve_collectables", conf_dict) + trainer = SB3Trainer( + custom_env, models_dir, logdir, model, config_dict['model_save_name'] + ) return trainer def main(unused_argv): try: - trainer = declare_trainer() + config_dict = { + "policy":"MlpPolicy", + "model_save_name": "PPOMlp"} + wandb_run = declare_wandb_run(config_dict) + trainer = declare_trainer(config_dict) # trainer.load_model(f"{trainer.models_dir}/100000") - trainer.alternate_train_test(iterations=100, train_steps=10000, test_steps=1000) + trainer.alternate_train_test( + iterations=100, train_steps=5000, test_steps=2400, test_episodes=100, + callback=WandbCallback() + ) + wandb_run.finish() except KeyboardInterrupt: print("Training interrupted by user") diff --git a/urnai/trainers/stablebaselines3_trainer.py b/urnai/trainers/stablebaselines3_trainer.py index bbbc16b..760d55c 100644 --- a/urnai/trainers/stablebaselines3_trainer.py +++ b/urnai/trainers/stablebaselines3_trainer.py @@ -1,22 +1,19 @@ import os from stable_baselines3.common.base_class import BaseAlgorithm +from stable_baselines3.common.type_aliases import MaybeCallback import wandb from urnai.environments.stablebaselines3.custom_env import CustomEnv -from wandb.integration.sb3 import WandbCallback class SB3Trainer: def __init__(self, custom_env : CustomEnv, models_dir : str, logdir : str, - model : BaseAlgorithm, model_name : str, - wandb_project : str= "default_project", wandb_conf_dict : dict = None): + model : BaseAlgorithm, model_name : str): self.custom_env = custom_env self.models_dir = models_dir self.model = model self.model_name = model_name - self.wandb_project = wandb_project - self.wandb_conf_dict = wandb_conf_dict if not os.path.exists(models_dir): os.makedirs(models_dir) @@ -27,75 +24,55 @@ def __init__(self, custom_env : CustomEnv, models_dir : str, logdir : str, def load_model(self, model_path): self.model = self.model.load(model_path, env = self.custom_env) - def train_model(self, timesteps: int = 10000, log_interval: int = 1, - reset_num_timesteps: bool = False, progress_bar: bool = False, - repeat_times:int = 1, start_from:int = 1, train_run_id:str = None, - ) -> str: - - wandb_run = wandb.init( - project=self.wandb_project, - config=self.wandb_conf_dict, - name=f"train_{self.model_name}", - group="train", - sync_tensorboard=True, - resume="must" if train_run_id else None, - id=train_run_id - ) + def train_model( + self, timesteps: int = 10000, log_interval: int = 1, + reset_num_timesteps: bool = False, progress_bar: bool = False, + repeat_times:int = 1, start_from:int = 1, callback : MaybeCallback = None + ) -> None: for repeat_time in range(repeat_times): - self.model.learn(total_timesteps = timesteps, callback = WandbCallback(), + self.model.learn(total_timesteps = timesteps, callback = callback, log_interval = log_interval, reset_num_timesteps = reset_num_timesteps, progress_bar = progress_bar, tb_log_name = self.model_name) self.model.save(f"{self.models_dir}/{timesteps*(repeat_time + start_from)}") - - wandb_run.finish() - return wandb_run.id - def test_model(self, total_steps: int = 10000, - deterministic: bool = True, - test_run_id : str = None - ) -> str: - wandb_run = wandb.init( - project=self.wandb_project, - name=f"test_{self.model_name}", - group="test", - resume="must" if test_run_id else None, - id=test_run_id - ) + def test_model( + self, total_steps: int = 10000, episodes : int = 100, + deterministic: bool = True + ) -> None: vec_env = self.model.get_env() obs = vec_env.reset() - total_episodes = 0 total_reward = 0 + curr_step = 0 + for _ in range(episodes): + done = False + while not done: + action, _state = self.model.predict(obs, deterministic=deterministic) + obs, rewards, done, info = vec_env.step(action) - for _ in range(total_steps): - action, _state = self.model.predict(obs, deterministic=deterministic) - obs, rewards, done, info = vec_env.step(action) - - total_reward += rewards - if done: - total_episodes += 1 - wandb.log({ - "test/total_reward": total_reward, - "test/episode": total_episodes, - }) - total_reward = 0 # Reset reward for the new episode - - wandb_run.finish() - return wandb_run.id + total_reward += rewards + curr_step += 1 + wandb.log({ + "eval/total_reward": total_reward + }) + total_reward = 0 # Reset reward for the new episode + if curr_step >= total_steps: + break - def alternate_train_test(self, iterations : int = 100, train_steps : int = 10000, - train_repeat_times : int = 1, test_steps : int = 10000, - train_run_id = None, test_run_id = None - ) -> None: + def alternate_train_test( + self, iterations : int = 100, train_steps : int = 10000, + train_repeat_times : int = 1, test_steps : int = 10000, + test_episodes : int = 100, callback : MaybeCallback = None + ) -> None: for iteration in range(iterations): - train_run_id = self.train_model(timesteps=train_steps, - repeat_times=train_repeat_times, - train_run_id=train_run_id, - start_from=iteration*train_repeat_times +1) - test_run_id= self.test_model(total_steps=test_steps,test_run_id=test_run_id) + self.train_model( + timesteps=train_steps, repeat_times=train_repeat_times, + start_from=iteration*train_repeat_times +1, callback=callback + ) + self.test_model(total_steps=test_steps, episodes = test_episodes) \ No newline at end of file