diff --git a/urnai/actions/action_space_base.py b/urnai/actions/action_space_base.py index 26f4182e..49a7da19 100644 --- a/urnai/actions/action_space_base.py +++ b/urnai/actions/action_space_base.py @@ -1,5 +1,7 @@ from abc import ABC, abstractmethod +from urnai.actions.action_base import ActionBase + class ActionSpaceBase(ABC): """ @@ -31,12 +33,12 @@ def reset(self) -> None: ... @abstractmethod - def get_actions(self) -> list[int]: + def get_actions(self) -> list[ActionBase]: """Returns all the actions that the agent can choose from.""" ... @abstractmethod - def get_excluded_actions(self, obs) -> list[int]: + def get_excluded_actions(self, obs) -> list[ActionBase]: """Returns a subset of actions that can't be chosen by the agent.""" ... diff --git a/urnai/actions/chain_of_actions_base.py b/urnai/actions/chain_of_actions_base.py index d70131a2..8c23b64d 100644 --- a/urnai/actions/chain_of_actions_base.py +++ b/urnai/actions/chain_of_actions_base.py @@ -14,7 +14,7 @@ def __init__(self): def get_action(self, action_index) -> ActionBase: if action_index < self.length: - self.action_list[action_index] + return self.action_list[action_index] @property def length(self) -> int: