Skip to content

Commit 9677ebc

Browse files
Bug fix in DirectionObsWrapper, new tests, & name change (#310)
Co-authored-by: Mark Towers <[email protected]>
1 parent f7c1750 commit 9677ebc

File tree

3 files changed

+89
-12
lines changed

3 files changed

+89
-12
lines changed

docs/api/wrappers.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,10 @@ lastpage:
5858
.. autoclass:: minigrid.wrappers.RGBImgObsWrapper
5959
```
6060

61-
# State Bonus
61+
# Position Bonus
6262

6363
```{eval-rst}
64-
.. autoclass:: minigrid.wrappers.StateBonus
64+
.. autoclass:: minigrid.wrappers.PositionBonus
6565
```
6666

6767
# Symbolic Obs

minigrid/wrappers.py

Lines changed: 42 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -127,16 +127,18 @@ def reset(self, **kwargs):
127127
return self.env.reset(**kwargs)
128128

129129

130-
# Should be named PositionBonus
131-
class StateBonus(Wrapper):
130+
class PositionBonus(Wrapper):
132131
"""
133132
Adds an exploration bonus based on which positions
134133
are visited on the grid.
135134
135+
Note:
136+
This wrapper was previously called ``StateBonus``.
137+
136138
Example:
137139
>>> import miniworld
138140
>>> import gymnasium as gym
139-
>>> from minigrid.wrappers import StateBonus
141+
>>> from minigrid.wrappers import PositionBonus
140142
>>> env = gym.make("MiniGrid-Empty-5x5-v0")
141143
>>> _, _ = env.reset(seed=0)
142144
>>> _, reward, _, _, _ = env.step(1)
@@ -145,7 +147,7 @@ class StateBonus(Wrapper):
145147
>>> _, reward, _, _, _ = env.step(1)
146148
>>> print(reward)
147149
0
148-
>>> env_bonus = StateBonus(env)
150+
>>> env_bonus = PositionBonus(env)
149151
>>> obs, _ = env_bonus.reset(seed=0)
150152
>>> obs, reward, terminated, truncated, info = env_bonus.step(1)
151153
>>> print(reward)
@@ -688,6 +690,17 @@ class DirectionObsWrapper(ObservationWrapper):
688690
"""
689691
Provides the slope/angular direction to the goal with the observations as modeled by (y2 - y2 )/( x2 - x1)
690692
type = {slope , angle}
693+
694+
Example:
695+
>>> import miniworld
696+
>>> import gymnasium as gym
697+
>>> import matplotlib.pyplot as plt
698+
>>> from minigrid.wrappers import DirectionObsWrapper
699+
>>> env = gym.make("MiniGrid-LavaCrossingS11N5-v0")
700+
>>> env_obs = DirectionObsWrapper(env, type="slope")
701+
>>> obs, _ = env_obs.reset()
702+
>>> obs['goal_direction']
703+
1.0
691704
"""
692705

693706
def __init__(self, env, type="slope"):
@@ -696,7 +709,8 @@ def __init__(self, env, type="slope"):
696709
self.type = type
697710

698711
def reset(self):
699-
obs = self.env.reset()
712+
obs, _ = self.env.reset()
713+
700714
if not self.goal_position:
701715
self.goal_position = [
702716
x for x, y in enumerate(self.grid.grid) if isinstance(y, Goal)
@@ -707,14 +721,20 @@ def reset(self):
707721
int(self.goal_position[0] / self.height),
708722
self.goal_position[0] % self.width,
709723
)
710-
return obs
724+
725+
return self.observation(obs)
711726

712727
def observation(self, obs):
713728
slope = np.divide(
714729
self.goal_position[1] - self.agent_pos[1],
715730
self.goal_position[0] - self.agent_pos[0],
716731
)
717-
obs["goal_direction"] = np.arctan(slope) if self.type == "angle" else slope
732+
733+
if self.type == "angle":
734+
obs["goal_direction"] = np.arctan(slope)
735+
else:
736+
obs["goal_direction"] = slope
737+
718738
return obs
719739

720740

@@ -723,6 +743,20 @@ class SymbolicObsWrapper(ObservationWrapper):
723743
Fully observable grid with a symbolic state representation.
724744
The symbol is a triple of (X, Y, IDX), where X and Y are
725745
the coordinates on the grid, and IDX is the id of the object.
746+
747+
Example:
748+
>>> import miniworld
749+
>>> import gymnasium as gym
750+
>>> import matplotlib.pyplot as plt
751+
>>> from minigrid.wrappers import SymbolicObsWrapper
752+
>>> env = gym.make("MiniGrid-LavaCrossingS11N5-v0")
753+
>>> obs, _ = env.reset()
754+
>>> obs['image'].shape
755+
(7, 7, 3)
756+
>>> env_obs = SymbolicObsWrapper(env)
757+
>>> obs, _ = env_obs.reset()
758+
>>> obs['image'].shape
759+
(11, 11, 3)
726760
"""
727761

728762
def __init__(self, env):
@@ -749,4 +783,5 @@ def observation(self, obs):
749783
grid = np.transpose(grid, (1, 2, 0))
750784
grid[agent_pos[0], agent_pos[1], 2] = OBJECT_TO_IDX["agent"]
751785
obs["image"] = grid
786+
752787
return obs

tests/test_wrappers.py

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,16 @@
1111
from minigrid.wrappers import (
1212
ActionBonus,
1313
DictObservationSpaceWrapper,
14+
DirectionObsWrapper,
1415
FlatObsWrapper,
1516
FullyObsWrapper,
1617
ImgObsWrapper,
1718
OneHotPartialObsWrapper,
19+
PositionBonus,
1820
ReseedWrapper,
1921
RGBImgObsWrapper,
2022
RGBImgPartialObsWrapper,
21-
StateBonus,
23+
SymbolicObsWrapper,
2224
ViewSizeWrapper,
2325
)
2426
from tests.utils import all_testing_env_specs, assert_equals, minigrid_testing_env_specs
@@ -77,9 +79,9 @@ def test_reseed_wrapper(env_spec):
7779

7880

7981
@pytest.mark.parametrize("env_id", ["MiniGrid-Empty-16x16-v0"])
80-
def test_state_bonus_wrapper(env_id):
82+
def test_position_bonus_wrapper(env_id):
8183
env = gym.make(env_id)
82-
wrapped_env = StateBonus(gym.make(env_id))
84+
wrapped_env = PositionBonus(gym.make(env_id))
8385

8486
action_forward = Actions.forward
8587
action_left = Actions.left
@@ -260,3 +262,43 @@ def test_viewsize_wrapper(view_size):
260262
obs, _, _, _, _ = env.step(0)
261263
assert obs["image"].shape == (view_size, view_size, 3)
262264
env.close()
265+
266+
267+
@pytest.mark.parametrize("env_id", ["MiniGrid-LavaCrossingS11N5-v0"])
268+
@pytest.mark.parametrize("type", ["slope", "angle"])
269+
def test_direction_obs_wrapper(env_id, type):
270+
env = gym.make(env_id)
271+
env = DirectionObsWrapper(env, type=type)
272+
obs = env.reset()
273+
274+
slope = np.divide(
275+
env.goal_position[1] - env.agent_pos[1],
276+
env.goal_position[0] - env.agent_pos[0],
277+
)
278+
if type == "slope":
279+
assert obs["goal_direction"] == slope
280+
elif type == "angle":
281+
assert obs["goal_direction"] == np.arctan(slope)
282+
283+
obs, _, _, _, _ = env.step(0)
284+
slope = np.divide(
285+
env.goal_position[1] - env.agent_pos[1],
286+
env.goal_position[0] - env.agent_pos[0],
287+
)
288+
if type == "slope":
289+
assert obs["goal_direction"] == slope
290+
elif type == "angle":
291+
assert obs["goal_direction"] == np.arctan(slope)
292+
293+
env.close()
294+
295+
296+
@pytest.mark.parametrize("env_id", ["MiniGrid-Empty-16x16-v0"])
297+
def test_symbolic_obs_wrapper(env_id):
298+
env = gym.make(env_id)
299+
env = SymbolicObsWrapper(env)
300+
obs, _ = env.reset()
301+
assert obs["image"].shape == (env.width, env.height, 3)
302+
obs, _, _, _, _ = env.step(0)
303+
assert obs["image"].shape == (env.width, env.height, 3)
304+
env.close()

0 commit comments

Comments
 (0)