Skip to content

Commit

Permalink
refactor: Changed to train and test in only one run
Browse files Browse the repository at this point in the history
  • Loading branch information
RickFqt committed Dec 18, 2024
1 parent c25ced0 commit fff746d
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 71 deletions.
43 changes: 31 additions & 12 deletions experiments/solves/solve_collectables_sb3.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,30 @@
from pysc2.env import sc2_env
from stable_baselines3 import PPO

import wandb
from urnai.environments.stablebaselines3.custom_env import CustomEnv
from urnai.sc2.actions.collectables import CollectablesActionSpace
from urnai.sc2.environments.sc2environment import SC2Env
from urnai.sc2.rewards.collectables import CollectablesReward
from urnai.sc2.states.collectables import CollectablesMethod, CollectablesState
from urnai.trainers.stablebaselines3_trainer import SB3Trainer
from wandb.integration.sb3 import WandbCallback


def declare_trainer():
def declare_wandb_run(config_dict : dict, run_id : str = None):

wandb_run = wandb.init(
project='solve_collectables',
config=config_dict,
name=config_dict['model_save_name'],
sync_tensorboard=True, # auto-upload sb3's tensorboard metrics
resume="must" if run_id else None,
id=run_id
)

return wandb_run

def declare_trainer(config_dict : dict):
players = [sc2_env.Agent(sc2_env.Race.terran)]
env = SC2Env(map_name='CollectMineralShards', visualize=False,
step_mul=16, players=players)
Expand All @@ -34,26 +49,30 @@ def declare_trainer():
custom_env = CustomEnv(env, state, urnai_action_space, reward, observation_space,
action_space)

model_name = "PPOMlp"
models_dir = f"saves/models/{model_name}"
models_dir = f"saves/models/{config_dict['model_save_name']}"
logdir = "saves/logs"

conf_dict = {
"policy":"MlpPolicy",
"model_save_name": model_name}

model=PPO("CnnPolicy", custom_env, verbose=1, tensorboard_log=logdir)
model=PPO(config_dict['policy'], custom_env, verbose=1, tensorboard_log=logdir)

trainer = SB3Trainer(custom_env, models_dir, logdir, model, model_name,
"solve_collectables", conf_dict)
trainer = SB3Trainer(
custom_env, models_dir, logdir, model, config_dict['model_save_name']
)

return trainer

def main(unused_argv):
try:
trainer = declare_trainer()
config_dict = {
"policy":"MlpPolicy",
"model_save_name": "PPOMlp"}
wandb_run = declare_wandb_run(config_dict)
trainer = declare_trainer(config_dict)
# trainer.load_model(f"{trainer.models_dir}/100000")
trainer.alternate_train_test(iterations=100, train_steps=10000, test_steps=1000)
trainer.alternate_train_test(
iterations=100, train_steps=5000, test_steps=2400, test_episodes=100,
callback=WandbCallback()
)
wandb_run.finish()
except KeyboardInterrupt:
print("Training interrupted by user")

Expand Down
95 changes: 36 additions & 59 deletions urnai/trainers/stablebaselines3_trainer.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,19 @@
import os

from stable_baselines3.common.base_class import BaseAlgorithm
from stable_baselines3.common.type_aliases import MaybeCallback

import wandb
from urnai.environments.stablebaselines3.custom_env import CustomEnv
from wandb.integration.sb3 import WandbCallback


class SB3Trainer:
def __init__(self, custom_env : CustomEnv, models_dir : str, logdir : str,
model : BaseAlgorithm, model_name : str,
wandb_project : str= "default_project", wandb_conf_dict : dict = None):
model : BaseAlgorithm, model_name : str):
self.custom_env = custom_env
self.models_dir = models_dir
self.model = model
self.model_name = model_name
self.wandb_project = wandb_project
self.wandb_conf_dict = wandb_conf_dict

if not os.path.exists(models_dir):
os.makedirs(models_dir)
Expand All @@ -27,75 +24,55 @@ def __init__(self, custom_env : CustomEnv, models_dir : str, logdir : str,
def load_model(self, model_path):
self.model = self.model.load(model_path, env = self.custom_env)

def train_model(self, timesteps: int = 10000, log_interval: int = 1,
reset_num_timesteps: bool = False, progress_bar: bool = False,
repeat_times:int = 1, start_from:int = 1, train_run_id:str = None,
) -> str:

wandb_run = wandb.init(
project=self.wandb_project,
config=self.wandb_conf_dict,
name=f"train_{self.model_name}",
group="train",
sync_tensorboard=True,
resume="must" if train_run_id else None,
id=train_run_id
)
def train_model(
self, timesteps: int = 10000, log_interval: int = 1,
reset_num_timesteps: bool = False, progress_bar: bool = False,
repeat_times:int = 1, start_from:int = 1, callback : MaybeCallback = None
) -> None:

for repeat_time in range(repeat_times):
self.model.learn(total_timesteps = timesteps, callback = WandbCallback(),
self.model.learn(total_timesteps = timesteps, callback = callback,
log_interval = log_interval,
reset_num_timesteps = reset_num_timesteps,
progress_bar = progress_bar,
tb_log_name = self.model_name)
self.model.save(f"{self.models_dir}/{timesteps*(repeat_time + start_from)}")

wandb_run.finish()
return wandb_run.id

def test_model(self, total_steps: int = 10000,
deterministic: bool = True,
test_run_id : str = None
) -> str:
wandb_run = wandb.init(
project=self.wandb_project,
name=f"test_{self.model_name}",
group="test",
resume="must" if test_run_id else None,
id=test_run_id
)
def test_model(
self, total_steps: int = 10000, episodes : int = 100,
deterministic: bool = True
) -> None:

vec_env = self.model.get_env()
obs = vec_env.reset()

total_episodes = 0
total_reward = 0
curr_step = 0
for _ in range(episodes):
done = False
while not done:
action, _state = self.model.predict(obs, deterministic=deterministic)
obs, rewards, done, info = vec_env.step(action)

for _ in range(total_steps):
action, _state = self.model.predict(obs, deterministic=deterministic)
obs, rewards, done, info = vec_env.step(action)

total_reward += rewards
if done:
total_episodes += 1
wandb.log({
"test/total_reward": total_reward,
"test/episode": total_episodes,
})
total_reward = 0 # Reset reward for the new episode

wandb_run.finish()
return wandb_run.id
total_reward += rewards
curr_step += 1
wandb.log({
"eval/total_reward": total_reward
})
total_reward = 0 # Reset reward for the new episode
if curr_step >= total_steps:
break

def alternate_train_test(self, iterations : int = 100, train_steps : int = 10000,
train_repeat_times : int = 1, test_steps : int = 10000,
train_run_id = None, test_run_id = None
) -> None:
def alternate_train_test(
self, iterations : int = 100, train_steps : int = 10000,
train_repeat_times : int = 1, test_steps : int = 10000,
test_episodes : int = 100, callback : MaybeCallback = None
) -> None:

for iteration in range(iterations):
train_run_id = self.train_model(timesteps=train_steps,
repeat_times=train_repeat_times,
train_run_id=train_run_id,
start_from=iteration*train_repeat_times +1)
test_run_id= self.test_model(total_steps=test_steps,test_run_id=test_run_id)
self.train_model(
timesteps=train_steps, repeat_times=train_repeat_times,
start_from=iteration*train_repeat_times +1, callback=callback
)
self.test_model(total_steps=test_steps, episodes = test_episodes)

0 comments on commit fff746d

Please sign in to comment.