Skip to content

Commit

Permalink
Resolves #96 (#97)
Browse files Browse the repository at this point in the history
* fix(issue-96): Small corrections

Add return in get_action
Correct return type in get_actions and get_excluded_actions

* refactor(issue-96): Formatted import block
  • Loading branch information
CinquilCinquil authored Jul 19, 2024
1 parent 42e5ce7 commit 616f29b
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
6 changes: 4 additions & 2 deletions urnai/actions/action_space_base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from abc import ABC, abstractmethod

from urnai.actions.action_base import ActionBase


class ActionSpaceBase(ABC):
"""
Expand Down Expand Up @@ -31,12 +33,12 @@ def reset(self) -> None:
...

@abstractmethod
def get_actions(self) -> list[int]:
def get_actions(self) -> list[ActionBase]:
"""Returns all the actions that the agent can choose from."""
...

@abstractmethod
def get_excluded_actions(self, obs) -> list[int]:
def get_excluded_actions(self, obs) -> list[ActionBase]:
"""Returns a subset of actions that can't be chosen by the agent."""
...

Expand Down
2 changes: 1 addition & 1 deletion urnai/actions/chain_of_actions_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def __init__(self):

def get_action(self, action_index) -> ActionBase:
if action_index < self.length:
self.action_list[action_index]
return self.action_list[action_index]

@property
def length(self) -> int:
Expand Down

0 comments on commit 616f29b

Please sign in to comment.