Skip to content

Commit 6b81300

Browse files
Fixed the issue found here Farama-Foundation#419
1 parent df4e675 commit 6b81300

File tree

2 files changed

+22
-10
lines changed

2 files changed

+22
-10
lines changed

minigrid/minigrid_env.py

+17-9
Original file line numberDiff line numberDiff line change
@@ -649,19 +649,27 @@ def gen_obs(self):
649649

650650
return obs
651651

652-
def get_pov_render(self, tile_size):
652+
def get_pov_render(self, tile_size, agent_view_size=None):
653653
"""
654654
Render an agent's POV observation for visualization
655655
"""
656-
grid, vis_mask = self.gen_obs_grid()
657-
656+
grid, vis_mask = self.gen_obs_grid(agent_view_size)
657+
658658
# Render the whole grid
659-
img = grid.render(
660-
tile_size,
661-
agent_pos=(self.agent_view_size // 2, self.agent_view_size - 1),
662-
agent_dir=3,
663-
highlight_mask=vis_mask,
664-
)
659+
if agent_view_size is None:
660+
img = grid.render(
661+
tile_size,
662+
agent_pos=(self.agent_view_size // 2, self.agent_view_size - 1),
663+
agent_dir=3,
664+
highlight_mask=vis_mask,
665+
)
666+
else:
667+
img = grid.render(
668+
tile_size,
669+
agent_pos=(agent_view_size // 2, agent_view_size - 1),
670+
agent_dir=3,
671+
highlight_mask=vis_mask,
672+
)
665673

666674
return img
667675

minigrid/wrappers.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,7 @@ def __init__(self, env, tile_size=8):
364364
self.tile_size = tile_size
365365

366366
obs_shape = env.observation_space.spaces["image"].shape
367+
367368
new_image_space = spaces.Box(
368369
low=0,
369370
high=255,
@@ -376,7 +377,10 @@ def __init__(self, env, tile_size=8):
376377
)
377378

378379
def observation(self, obs):
379-
rgb_img_partial = self.get_frame(tile_size=self.tile_size, agent_pov=True)
380+
rgb_img_partial = self.get_pov_render(
381+
tile_size=self.tile_size,
382+
agent_view_size=self.agent_view_size
383+
)
380384

381385
return {**obs, "image": rgb_img_partial}
382386

0 commit comments

Comments
 (0)