Skip to content

Commit

Permalink
feat: Agent can now save model
Browse files Browse the repository at this point in the history
  • Loading branch information
CinquilCinquil committed Aug 12, 2024
1 parent 33b8691 commit b47045c
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 1 deletion.
6 changes: 6 additions & 0 deletions urnai/agents/agent_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,9 @@ def learn(self, obs, reward, done) -> None:
next_state = self.state_space.update(obs)
self.model.learn(self.previous_state, self.previous_action,
reward, next_state, done)

def save(self, savepath) -> None:
if (self.model.persistence == None):
raise 'No persistence set in model'

self.model.persistence.save(savepath)
2 changes: 1 addition & 1 deletion urnai/base/persistence_pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def _get_attributes(self):
def _get_dict(self):
pickleable_attr_dict = {}

for attr in self.object_to_save._get_attributes():
for attr in self._get_attributes():
pickleable_attr_dict[attr] = getattr(self.object_to_save, attr)

return pickleable_attr_dict
1 change: 1 addition & 0 deletions urnai/models/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ class ModelBase(ABC):

def __init__(self):
self.learning_data = {}
self.persistence = None

@abstractmethod
def learn(self, current_state, action, reward, next_state):
Expand Down
1 change: 1 addition & 0 deletions urnai/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def training_loop(self, is_training, reward_from_agent=True):

if done:
print("Episode: %d, Reward: %d" % (current_episodes, ep_reward))
self.agent.save("saves/")
break

# if this is not a test (evaluation), saving is enabled and we are in a multiple
Expand Down

0 comments on commit b47045c

Please sign in to comment.