Skip to content

Commit 515251d

Browse files
committed
chore: Added typehints
1 parent 1e7b2ba commit 515251d

File tree

2 files changed

+23
-9
lines changed

2 files changed

+23
-9
lines changed

urnai/rewards/reward_base.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,14 @@ class RewardBase(ABC):
99
"""
1010

1111
@abstractmethod
12-
def get_reward(self, obs: list[list], default_reward: int, terminated: bool,
13-
truncated: bool) -> int: ...
12+
def get(
13+
self,
14+
obs: list[list],
15+
default_reward: int,
16+
terminated: bool,
17+
truncated: bool
18+
) -> int:
19+
raise NotImplementedError(...)
1420

1521
@abstractmethod
1622
def reset(self) -> None: ...

urnai/sc2/environments/stablebaselines3/custom_env.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
1+
from typing import Union
2+
13
import gymnasium as gym
4+
import numpy as np
25
from gymnasium import spaces
6+
from stable_baselines3.common.type_aliases import GymResetReturn, GymStepReturn
37

48
from urnai.actions.action_space_base import ActionSpaceBase
59
from urnai.environments.environment_base import EnvironmentBase
@@ -22,30 +26,34 @@ def __init__(self, env: EnvironmentBase, state: StateBase,
2226
self._action_space = urnai_action_space
2327
self._reward = reward
2428
self._obs = None
25-
# SB3 spaces
29+
# space variables, used internally by the gymnasium library
2630
self.action_space = action_space
2731
self.observation_space = observation_space
2832

29-
def step(self, action):
33+
def step(
34+
self, action: Union[int, np.ndarray]
35+
) -> GymStepReturn:
3036
action = self._action_space.get_action(action, self._obs)
3137

3238
obs, reward, terminated, truncated = self._env.step(action)
3339

3440
self._obs = obs[0]
3541
obs = self._state.update(self._obs)
36-
reward = self._reward.get_reward(self._obs, reward[0], terminated, truncated)
42+
reward = self._reward.get(self._obs, reward[0], terminated, truncated)
3743
info = {}
3844
return obs, reward, terminated, truncated, info
3945

40-
def reset(self, seed=None, options=None):
46+
def reset(
47+
self, seed: int = None, options: dict = None
48+
) -> GymResetReturn:
4149
obs = self._env.reset()
4250
self._obs = obs[0]
4351
obs = self._state.update(self._obs)
4452
info = {}
4553
return obs, info
4654

47-
def render(self):
48-
pass
55+
def render(self, mode: str) -> None:
56+
raise NotImplementedError(...)
4957

50-
def close(self):
58+
def close(self) -> None:
5159
self._env.close()

0 commit comments

Comments
 (0)