From 515251d99d407b11b20b36ca54eceee274ad5285 Mon Sep 17 00:00:00 2001 From: RickFqt Date: Mon, 30 Sep 2024 17:36:21 -0300 Subject: [PATCH] chore: Added typehints --- urnai/rewards/reward_base.py | 10 +++++++-- .../stablebaselines3/custom_env.py | 22 +++++++++++++------ 2 files changed, 23 insertions(+), 9 deletions(-) diff --git a/urnai/rewards/reward_base.py b/urnai/rewards/reward_base.py index 7d6b051..34b125b 100644 --- a/urnai/rewards/reward_base.py +++ b/urnai/rewards/reward_base.py @@ -9,8 +9,14 @@ class RewardBase(ABC): """ @abstractmethod - def get_reward(self, obs: list[list], default_reward: int, terminated: bool, - truncated: bool) -> int: ... + def get( + self, + obs: list[list], + default_reward: int, + terminated: bool, + truncated: bool + ) -> int: + raise NotImplementedError(...) @abstractmethod def reset(self) -> None: ... diff --git a/urnai/sc2/environments/stablebaselines3/custom_env.py b/urnai/sc2/environments/stablebaselines3/custom_env.py index 30218c2..2e3c536 100644 --- a/urnai/sc2/environments/stablebaselines3/custom_env.py +++ b/urnai/sc2/environments/stablebaselines3/custom_env.py @@ -1,5 +1,9 @@ +from typing import Union + import gymnasium as gym +import numpy as np from gymnasium import spaces +from stable_baselines3.common.type_aliases import GymResetReturn, GymStepReturn from urnai.actions.action_space_base import ActionSpaceBase from urnai.environments.environment_base import EnvironmentBase @@ -22,30 +26,34 @@ def __init__(self, env: EnvironmentBase, state: StateBase, self._action_space = urnai_action_space self._reward = reward self._obs = None - # SB3 spaces + # space variables, used internally by the gymnasium library self.action_space = action_space self.observation_space = observation_space - def step(self, action): + def step( + self, action: Union[int, np.ndarray] + ) -> GymStepReturn: action = self._action_space.get_action(action, self._obs) obs, reward, terminated, truncated = self._env.step(action) self._obs = obs[0] obs = self._state.update(self._obs) - reward = self._reward.get_reward(self._obs, reward[0], terminated, truncated) + reward = self._reward.get(self._obs, reward[0], terminated, truncated) info = {} return obs, reward, terminated, truncated, info - def reset(self, seed=None, options=None): + def reset( + self, seed: int = None, options: dict = None + ) -> GymResetReturn: obs = self._env.reset() self._obs = obs[0] obs = self._state.update(self._obs) info = {} return obs, info - def render(self): - pass + def render(self, mode: str) -> None: + raise NotImplementedError(...) - def close(self): + def close(self) -> None: self._env.close() \ No newline at end of file