Skip to content

Commit

Permalink
feat: Agent can now load model
Browse files Browse the repository at this point in the history
  • Loading branch information
CinquilCinquil committed Aug 12, 2024
1 parent b47045c commit 7356f5a
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 9 deletions.
1 change: 1 addition & 0 deletions sc2_collectables_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def declare_trainer():
def main(unused_argv):
try:
trainer = declare_trainer()
#trainer.load("saves/")
trainer.train()
# trainer.play()

Expand Down
8 changes: 1 addition & 7 deletions urnai/agents/agent_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,4 @@ def learn(self, obs, reward, done) -> None:
if self.previous_state is not None:
next_state = self.state_space.update(obs)
self.model.learn(self.previous_state, self.previous_action,
reward, next_state, done)

def save(self, savepath) -> None:
if (self.model.persistence == None):
raise 'No persistence set in model'

self.model.persistence.save(savepath)
reward, next_state, done)
1 change: 1 addition & 0 deletions urnai/models/dqn_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def __init__(self, action_wrapper: ActionSpaceBase, state_builder: StateBase,
field_names=['state', 'action', 'reward', 'next_state', 'done'])

self.persistence = PersistencePickle(self)
self.persistence.attr_block_list.append(['persistence'])

def make_model(self):
model = QNetwork(
Expand Down
8 changes: 7 additions & 1 deletion urnai/models/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,10 @@ def choose_action(self, state, excluded_actions, is_testing=False) -> int:
exploration algorithm for when is_testing=False.
One such exploration algorithm commonly used it the epsilon greedy strategy.
"""
pass
pass

def save(self, persist_path) -> None:
self.persistence.save(persist_path)

def load(self, persist_path) -> None:
self.persistence.load(persist_path)
5 changes: 4 additions & 1 deletion urnai/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ def __init__(self, env, agent, max_training_episodes, max_playing_episodes, max_
def train(self, reward_from_agent=True):
self.training_loop(is_training=True, reward_from_agent=reward_from_agent)

def load(self, persist_path):
self.agent.model.load(persist_path)

def play(self, reward_from_agent=True):
self.training_loop(is_training=False, reward_from_agent=reward_from_agent)

Expand Down Expand Up @@ -88,7 +91,7 @@ def training_loop(self, is_training, reward_from_agent=True):

if done:
print("Episode: %d, Reward: %d" % (current_episodes, ep_reward))
self.agent.save("saves/")
self.agent.model.save("saves/")
break

# if this is not a test (evaluation), saving is enabled and we are in a multiple
Expand Down

0 comments on commit 7356f5a

Please sign in to comment.