From cc53ee11abeff98062f4f6946e84741af0df624c Mon Sep 17 00:00:00 2001 From: Henrique Lopes Fouquet <102550763+RickFqt@users.noreply.github.com> Date: Fri, 4 Oct 2024 18:36:54 -0300 Subject: [PATCH] Custom environment from Stable Baselines3 (#103) * feat: Added custom environment from stablebaselines3 * chore: Added typehints * test: Added unit tests for base reward * fix: Fixed error messages --- tests/units/rewards/test_reward_base.py | 37 ++++++++++++ urnai/rewards/reward_base.py | 23 +++++++ urnai/sc2/environments/sc2environment.py | 2 +- .../environments/stablebaselines3/__init__.py | 0 .../stablebaselines3/custom_env.py | 60 +++++++++++++++++++ 5 files changed, 121 insertions(+), 1 deletion(-) create mode 100644 tests/units/rewards/test_reward_base.py create mode 100644 urnai/rewards/reward_base.py create mode 100644 urnai/sc2/environments/stablebaselines3/__init__.py create mode 100644 urnai/sc2/environments/stablebaselines3/custom_env.py diff --git a/tests/units/rewards/test_reward_base.py b/tests/units/rewards/test_reward_base.py new file mode 100644 index 00000000..b77487f9 --- /dev/null +++ b/tests/units/rewards/test_reward_base.py @@ -0,0 +1,37 @@ +import unittest +from abc import ABCMeta + +from urnai.rewards.reward_base import RewardBase + + +class TestRewardBase(unittest.TestCase): + + def test_reset_method(self): + # GIVEN + RewardBase.__abstractmethods__ = set() + + class FakeReward(RewardBase): + def __init__(self): + super().__init__() + + reward = FakeReward() + + # WHEN + reset_return = reward.reset() + + # THEN + assert isinstance(RewardBase, ABCMeta) + assert reset_return is None + + def test_not_implemented_get_method(self): + # GIVEN + RewardBase.__abstractmethods__ = set() + + class FakeReward(RewardBase): + def __init__(self): + super().__init__() + + reward = FakeReward() + + # WHEN / THEN + self.assertRaises(NotImplementedError, reward.get, [[]], 0, False, False) diff --git a/urnai/rewards/reward_base.py b/urnai/rewards/reward_base.py new file mode 100644 index 00000000..2ecf107d --- /dev/null +++ b/urnai/rewards/reward_base.py @@ -0,0 +1,23 @@ +from abc import ABC, abstractmethod + + +class RewardBase(ABC): + """ + Every Agent needs to own an instance of this base class in order to calculate + its rewards. So every time we want to create a new agent, + we should either use an existing RewardBase implementation or create a new one. + """ + + @abstractmethod + def get( + self, + obs: list[list], + default_reward: int, + terminated: bool, + truncated: bool + ) -> int: + raise NotImplementedError("Get method not implemented. You should implement " + + "it in your RewardBase subclass.") + + @abstractmethod + def reset(self) -> None: ... diff --git a/urnai/sc2/environments/sc2environment.py b/urnai/sc2/environments/sc2environment.py index bcaf4f74..367ffcad 100644 --- a/urnai/sc2/environments/sc2environment.py +++ b/urnai/sc2/environments/sc2environment.py @@ -113,6 +113,6 @@ def _parse_timestep( terminated = any(o.player_result for o in self.env_instance._obs) current_steps = self.env_instance._episode_steps limit_steps = self.env_instance._episode_length - truncated = current_steps >= limit_steps + truncated = bool(current_steps >= limit_steps) return obs, reward, terminated, truncated diff --git a/urnai/sc2/environments/stablebaselines3/__init__.py b/urnai/sc2/environments/stablebaselines3/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/urnai/sc2/environments/stablebaselines3/custom_env.py b/urnai/sc2/environments/stablebaselines3/custom_env.py new file mode 100644 index 00000000..4a527f24 --- /dev/null +++ b/urnai/sc2/environments/stablebaselines3/custom_env.py @@ -0,0 +1,60 @@ +from typing import Union + +import gymnasium as gym +import numpy as np +from gymnasium import spaces +from stable_baselines3.common.type_aliases import GymResetReturn, GymStepReturn + +from urnai.actions.action_space_base import ActionSpaceBase +from urnai.environments.environment_base import EnvironmentBase +from urnai.rewards.reward_base import RewardBase +from urnai.states.state_base import StateBase + + +class CustomEnv(gym.Env): + """Custom Environment that follows gym interface.""" + + metadata = {"render_modes": ["human"], "render_fps": 30} + + def __init__(self, env: EnvironmentBase, state: StateBase, + urnai_action_space: ActionSpaceBase, reward: RewardBase, + observation_space: spaces.Space, action_space: spaces.Space): + super().__init__() + + self._env = env + self._state = state + self._action_space = urnai_action_space + self._reward = reward + self._obs = None + # space variables, used internally by the gymnasium library + self.action_space = action_space + self.observation_space = observation_space + + def step( + self, action: Union[int, np.ndarray] + ) -> GymStepReturn: + action = self._action_space.get_action(action, self._obs) + + obs, reward, terminated, truncated = self._env.step(action) + + self._obs = obs[0] + obs = self._state.update(self._obs) + reward = self._reward.get(self._obs, reward[0], terminated, truncated) + info = {} + return obs, reward, terminated, truncated, info + + def reset( + self, seed: int = None, options: dict = None + ) -> GymResetReturn: + obs = self._env.reset() + self._obs = obs[0] + obs = self._state.update(self._obs) + info = {} + return obs, info + + def render(self, mode: str) -> None: + raise NotImplementedError("Render method not implemented. If necessary, you " + + "should implement it in your CustomEnv subclass.") + + def close(self) -> None: + self._env.close() \ No newline at end of file