diff --git a/urnai/agents/agent_base.py b/urnai/agents/agent_base.py index 1094ad8..be2cacd 100644 --- a/urnai/agents/agent_base.py +++ b/urnai/agents/agent_base.py @@ -57,3 +57,9 @@ def learn(self, obs, reward, done) -> None: next_state = self.state_space.update(obs) self.model.learn(self.previous_state, self.previous_action, reward, next_state, done) + + def save(self, savepath) -> None: + if (self.model.persistence == None): + raise 'No persistence set in model' + + self.model.persistence.save(savepath) diff --git a/urnai/base/persistence_pickle.py b/urnai/base/persistence_pickle.py index 6708138..93ff108 100644 --- a/urnai/base/persistence_pickle.py +++ b/urnai/base/persistence_pickle.py @@ -108,7 +108,7 @@ def _get_attributes(self): def _get_dict(self): pickleable_attr_dict = {} - for attr in self.object_to_save._get_attributes(): + for attr in self._get_attributes(): pickleable_attr_dict[attr] = getattr(self.object_to_save, attr) return pickleable_attr_dict diff --git a/urnai/models/model_base.py b/urnai/models/model_base.py index 616f62b..8a4a887 100644 --- a/urnai/models/model_base.py +++ b/urnai/models/model_base.py @@ -8,6 +8,7 @@ class ModelBase(ABC): def __init__(self): self.learning_data = {} + self.persistence = None @abstractmethod def learn(self, current_state, action, reward, next_state): diff --git a/urnai/trainers/trainer.py b/urnai/trainers/trainer.py index 61a723e..9efb1c1 100644 --- a/urnai/trainers/trainer.py +++ b/urnai/trainers/trainer.py @@ -88,6 +88,7 @@ def training_loop(self, is_training, reward_from_agent=True): if done: print("Episode: %d, Reward: %d" % (current_episodes, ep_reward)) + self.agent.save("saves/") break # if this is not a test (evaluation), saving is enabled and we are in a multiple