@@ -127,16 +127,18 @@ def reset(self, **kwargs):
127127 return self .env .reset (** kwargs )
128128
129129
130- # Should be named PositionBonus
131- class StateBonus (Wrapper ):
130+ class PositionBonus (Wrapper ):
132131 """
133132 Adds an exploration bonus based on which positions
134133 are visited on the grid.
135134
135+ Note:
136+ This wrapper was previously called ``StateBonus``.
137+
136138 Example:
137139 >>> import miniworld
138140 >>> import gymnasium as gym
139- >>> from minigrid.wrappers import StateBonus
141+ >>> from minigrid.wrappers import PositionBonus
140142 >>> env = gym.make("MiniGrid-Empty-5x5-v0")
141143 >>> _, _ = env.reset(seed=0)
142144 >>> _, reward, _, _, _ = env.step(1)
@@ -145,7 +147,7 @@ class StateBonus(Wrapper):
145147 >>> _, reward, _, _, _ = env.step(1)
146148 >>> print(reward)
147149 0
148- >>> env_bonus = StateBonus (env)
150+ >>> env_bonus = PositionBonus (env)
149151 >>> obs, _ = env_bonus.reset(seed=0)
150152 >>> obs, reward, terminated, truncated, info = env_bonus.step(1)
151153 >>> print(reward)
@@ -688,6 +690,17 @@ class DirectionObsWrapper(ObservationWrapper):
688690 """
689691 Provides the slope/angular direction to the goal with the observations as modeled by (y2 - y2 )/( x2 - x1)
690692 type = {slope , angle}
693+
694+ Example:
695+ >>> import miniworld
696+ >>> import gymnasium as gym
697+ >>> import matplotlib.pyplot as plt
698+ >>> from minigrid.wrappers import DirectionObsWrapper
699+ >>> env = gym.make("MiniGrid-LavaCrossingS11N5-v0")
700+ >>> env_obs = DirectionObsWrapper(env, type="slope")
701+ >>> obs, _ = env_obs.reset()
702+ >>> obs['goal_direction']
703+ 1.0
691704 """
692705
693706 def __init__ (self , env , type = "slope" ):
@@ -696,7 +709,8 @@ def __init__(self, env, type="slope"):
696709 self .type = type
697710
698711 def reset (self ):
699- obs = self .env .reset ()
712+ obs , _ = self .env .reset ()
713+
700714 if not self .goal_position :
701715 self .goal_position = [
702716 x for x , y in enumerate (self .grid .grid ) if isinstance (y , Goal )
@@ -707,14 +721,20 @@ def reset(self):
707721 int (self .goal_position [0 ] / self .height ),
708722 self .goal_position [0 ] % self .width ,
709723 )
710- return obs
724+
725+ return self .observation (obs )
711726
712727 def observation (self , obs ):
713728 slope = np .divide (
714729 self .goal_position [1 ] - self .agent_pos [1 ],
715730 self .goal_position [0 ] - self .agent_pos [0 ],
716731 )
717- obs ["goal_direction" ] = np .arctan (slope ) if self .type == "angle" else slope
732+
733+ if self .type == "angle" :
734+ obs ["goal_direction" ] = np .arctan (slope )
735+ else :
736+ obs ["goal_direction" ] = slope
737+
718738 return obs
719739
720740
@@ -723,6 +743,20 @@ class SymbolicObsWrapper(ObservationWrapper):
723743 Fully observable grid with a symbolic state representation.
724744 The symbol is a triple of (X, Y, IDX), where X and Y are
725745 the coordinates on the grid, and IDX is the id of the object.
746+
747+ Example:
748+ >>> import miniworld
749+ >>> import gymnasium as gym
750+ >>> import matplotlib.pyplot as plt
751+ >>> from minigrid.wrappers import SymbolicObsWrapper
752+ >>> env = gym.make("MiniGrid-LavaCrossingS11N5-v0")
753+ >>> obs, _ = env.reset()
754+ >>> obs['image'].shape
755+ (7, 7, 3)
756+ >>> env_obs = SymbolicObsWrapper(env)
757+ >>> obs, _ = env_obs.reset()
758+ >>> obs['image'].shape
759+ (11, 11, 3)
726760 """
727761
728762 def __init__ (self , env ):
@@ -749,4 +783,5 @@ def observation(self, obs):
749783 grid = np .transpose (grid , (1 , 2 , 0 ))
750784 grid [agent_pos [0 ], agent_pos [1 ], 2 ] = OBJECT_TO_IDX ["agent" ]
751785 obs ["image" ] = grid
786+
752787 return obs
0 commit comments