Skip to content

Commit

Permalink
chore: Added typehints
Browse files Browse the repository at this point in the history
  • Loading branch information
RickFqt committed Sep 30, 2024
1 parent 1e7b2ba commit 515251d
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 9 deletions.
10 changes: 8 additions & 2 deletions urnai/rewards/reward_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,14 @@ class RewardBase(ABC):
"""

@abstractmethod
def get_reward(self, obs: list[list], default_reward: int, terminated: bool,
truncated: bool) -> int: ...
def get(
self,
obs: list[list],
default_reward: int,
terminated: bool,
truncated: bool
) -> int:
raise NotImplementedError(...)

@abstractmethod
def reset(self) -> None: ...
22 changes: 15 additions & 7 deletions urnai/sc2/environments/stablebaselines3/custom_env.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from typing import Union

import gymnasium as gym
import numpy as np
from gymnasium import spaces
from stable_baselines3.common.type_aliases import GymResetReturn, GymStepReturn

from urnai.actions.action_space_base import ActionSpaceBase
from urnai.environments.environment_base import EnvironmentBase
Expand All @@ -22,30 +26,34 @@ def __init__(self, env: EnvironmentBase, state: StateBase,
self._action_space = urnai_action_space
self._reward = reward
self._obs = None
# SB3 spaces
# space variables, used internally by the gymnasium library
self.action_space = action_space
self.observation_space = observation_space

def step(self, action):
def step(
self, action: Union[int, np.ndarray]
) -> GymStepReturn:
action = self._action_space.get_action(action, self._obs)

obs, reward, terminated, truncated = self._env.step(action)

self._obs = obs[0]
obs = self._state.update(self._obs)
reward = self._reward.get_reward(self._obs, reward[0], terminated, truncated)
reward = self._reward.get(self._obs, reward[0], terminated, truncated)
info = {}
return obs, reward, terminated, truncated, info

def reset(self, seed=None, options=None):
def reset(
self, seed: int = None, options: dict = None
) -> GymResetReturn:
obs = self._env.reset()
self._obs = obs[0]
obs = self._state.update(self._obs)
info = {}
return obs, info

def render(self):
pass
def render(self, mode: str) -> None:
raise NotImplementedError(...)

def close(self):
def close(self) -> None:
self._env.close()

0 comments on commit 515251d

Please sign in to comment.