diff --git a/sc2_collectables_trainer.py b/sc2_collectables_trainer.py index 12639ef..21f59d2 100644 --- a/sc2_collectables_trainer.py +++ b/sc2_collectables_trainer.py @@ -37,6 +37,7 @@ def declare_trainer(): def main(unused_argv): try: trainer = declare_trainer() + #trainer.load("saves/") trainer.train() # trainer.play() diff --git a/urnai/agents/agent_base.py b/urnai/agents/agent_base.py index be2cacd..4772c78 100644 --- a/urnai/agents/agent_base.py +++ b/urnai/agents/agent_base.py @@ -56,10 +56,4 @@ def learn(self, obs, reward, done) -> None: if self.previous_state is not None: next_state = self.state_space.update(obs) self.model.learn(self.previous_state, self.previous_action, - reward, next_state, done) - - def save(self, savepath) -> None: - if (self.model.persistence == None): - raise 'No persistence set in model' - - self.model.persistence.save(savepath) + reward, next_state, done) \ No newline at end of file diff --git a/urnai/models/dqn_pytorch.py b/urnai/models/dqn_pytorch.py index 0c4ef15..a92a193 100644 --- a/urnai/models/dqn_pytorch.py +++ b/urnai/models/dqn_pytorch.py @@ -118,6 +118,7 @@ def __init__(self, action_wrapper: ActionSpaceBase, state_builder: StateBase, field_names=['state', 'action', 'reward', 'next_state', 'done']) self.persistence = PersistencePickle(self) + self.persistence.attr_block_list.append(['persistence']) def make_model(self): model = QNetwork( diff --git a/urnai/models/model_base.py b/urnai/models/model_base.py index 8a4a887..21563bb 100644 --- a/urnai/models/model_base.py +++ b/urnai/models/model_base.py @@ -29,4 +29,10 @@ def choose_action(self, state, excluded_actions, is_testing=False) -> int: exploration algorithm for when is_testing=False. One such exploration algorithm commonly used it the epsilon greedy strategy. """ - pass \ No newline at end of file + pass + + def save(self, persist_path) -> None: + self.persistence.save(persist_path) + + def load(self, persist_path) -> None: + self.persistence.load(persist_path) \ No newline at end of file diff --git a/urnai/trainers/trainer.py b/urnai/trainers/trainer.py index 9efb1c1..910a1fc 100644 --- a/urnai/trainers/trainer.py +++ b/urnai/trainers/trainer.py @@ -25,6 +25,9 @@ def __init__(self, env, agent, max_training_episodes, max_playing_episodes, max_ def train(self, reward_from_agent=True): self.training_loop(is_training=True, reward_from_agent=reward_from_agent) + def load(self, persist_path): + self.agent.model.load(persist_path) + def play(self, reward_from_agent=True): self.training_loop(is_training=False, reward_from_agent=reward_from_agent) @@ -88,7 +91,7 @@ def training_loop(self, is_training, reward_from_agent=True): if done: print("Episode: %d, Reward: %d" % (current_episodes, ep_reward)) - self.agent.save("saves/") + self.agent.model.save("saves/") break # if this is not a test (evaluation), saving is enabled and we are in a multiple