Skip to content

Commit

Permalink
fix: Fixed reward class
Browse files Browse the repository at this point in the history
  • Loading branch information
RickFqt committed Aug 13, 2024
1 parent e8ae2bd commit 63d565a
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 1 deletion.
2 changes: 1 addition & 1 deletion urnai/agents/agent_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def reset(self, episode=0) -> None:
self.previous_state = None
self.action_space.reset()
self.model.ep_reset(episode)
# self.reward.reset()
self.reward.reset()
self.state_space.reset()

def learn(self, obs, reward, done) -> None:
Expand Down
3 changes: 3 additions & 0 deletions urnai/rewards/reward_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,6 @@ class RewardBase(ABC):

@abstractmethod
def get_reward(self, obs, reward, done) -> int: ...

@abstractmethod
def reset(self) -> None: ...
4 changes: 4 additions & 0 deletions urnai/sc2/rewards/collectables.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ def get_reward(self, obs) -> int:

self.previous_state = obs
return reward

def reset(self) -> None:
self.previous_state = None
self.old_collectable_counter = STATE_MAXIMUM_NUMBER_OF_MINERAL_SHARDS

def filter_non_mineral_shard_units(self, obs):
filtered_map = np.zeros((len(obs.feature_minimap[0]),
Expand Down

0 comments on commit 63d565a

Please sign in to comment.