1
+ from typing import Union
2
+
1
3
import gymnasium as gym
4
+ import numpy as np
2
5
from gymnasium import spaces
6
+ from stable_baselines3 .common .type_aliases import GymResetReturn , GymStepReturn
3
7
4
8
from urnai .actions .action_space_base import ActionSpaceBase
5
9
from urnai .environments .environment_base import EnvironmentBase
@@ -22,30 +26,34 @@ def __init__(self, env: EnvironmentBase, state: StateBase,
22
26
self ._action_space = urnai_action_space
23
27
self ._reward = reward
24
28
self ._obs = None
25
- # SB3 spaces
29
+ # space variables, used internally by the gymnasium library
26
30
self .action_space = action_space
27
31
self .observation_space = observation_space
28
32
29
- def step (self , action ):
33
+ def step (
34
+ self , action : Union [int , np .ndarray ]
35
+ ) -> GymStepReturn :
30
36
action = self ._action_space .get_action (action , self ._obs )
31
37
32
38
obs , reward , terminated , truncated = self ._env .step (action )
33
39
34
40
self ._obs = obs [0 ]
35
41
obs = self ._state .update (self ._obs )
36
- reward = self ._reward .get_reward (self ._obs , reward [0 ], terminated , truncated )
42
+ reward = self ._reward .get (self ._obs , reward [0 ], terminated , truncated )
37
43
info = {}
38
44
return obs , reward , terminated , truncated , info
39
45
40
- def reset (self , seed = None , options = None ):
46
+ def reset (
47
+ self , seed : int = None , options : dict = None
48
+ ) -> GymResetReturn :
41
49
obs = self ._env .reset ()
42
50
self ._obs = obs [0 ]
43
51
obs = self ._state .update (self ._obs )
44
52
info = {}
45
53
return obs , info
46
54
47
- def render (self ) :
48
- pass
55
+ def render (self , mode : str ) -> None :
56
+ raise NotImplementedError (...)
49
57
50
- def close (self ):
58
+ def close (self ) -> None :
51
59
self ._env .close ()
0 commit comments