From 6299558a0202d414ee60ca5562fb7e027a4a22b5 Mon Sep 17 00:00:00 2001 From: CinquilCinquil <106356391+CinquilCinquil@users.noreply.github.com> Date: Fri, 19 Jul 2024 13:11:49 -0300 Subject: [PATCH 1/2] fix(issue-96): Small corrections Add return in get_action Correct return type in get_actions and get_excluded_actions --- urnai/actions/action_space_base.py | 5 +++-- urnai/actions/chain_of_actions_base.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/urnai/actions/action_space_base.py b/urnai/actions/action_space_base.py index 26f4182e..c45fcfec 100644 --- a/urnai/actions/action_space_base.py +++ b/urnai/actions/action_space_base.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod +from urnai.actions.action_base import ActionBase class ActionSpaceBase(ABC): """ @@ -31,12 +32,12 @@ def reset(self) -> None: ... @abstractmethod - def get_actions(self) -> list[int]: + def get_actions(self) -> list[ActionBase]: """Returns all the actions that the agent can choose from.""" ... @abstractmethod - def get_excluded_actions(self, obs) -> list[int]: + def get_excluded_actions(self, obs) -> list[ActionBase]: """Returns a subset of actions that can't be chosen by the agent.""" ... diff --git a/urnai/actions/chain_of_actions_base.py b/urnai/actions/chain_of_actions_base.py index d70131a2..8c23b64d 100644 --- a/urnai/actions/chain_of_actions_base.py +++ b/urnai/actions/chain_of_actions_base.py @@ -14,7 +14,7 @@ def __init__(self): def get_action(self, action_index) -> ActionBase: if action_index < self.length: - self.action_list[action_index] + return self.action_list[action_index] @property def length(self) -> int: From 39485cdd97fa78ccb3a663b0d0420d3c139fdd7c Mon Sep 17 00:00:00 2001 From: CinquilCinquil <106356391+CinquilCinquil@users.noreply.github.com> Date: Fri, 19 Jul 2024 13:15:34 -0300 Subject: [PATCH 2/2] refactor(issue-96): Formatted import block --- urnai/actions/action_space_base.py | 1 + 1 file changed, 1 insertion(+) diff --git a/urnai/actions/action_space_base.py b/urnai/actions/action_space_base.py index c45fcfec..49a7da19 100644 --- a/urnai/actions/action_space_base.py +++ b/urnai/actions/action_space_base.py @@ -2,6 +2,7 @@ from urnai.actions.action_base import ActionBase + class ActionSpaceBase(ABC): """ ActionSpace works as an extra abstraction layer used by the agent