From 6b813004d948fe77f6f90e82fcd3b4d2c40d923a Mon Sep 17 00:00:00 2001 From: vishwassathish Date: Sun, 25 Feb 2024 01:56:08 -0800 Subject: [PATCH] Fixed the issue found here https://github.com/Farama-Foundation/Minigrid/issues/419 --- minigrid/minigrid_env.py | 26 +++++++++++++++++--------- minigrid/wrappers.py | 6 +++++- 2 files changed, 22 insertions(+), 10 deletions(-) diff --git a/minigrid/minigrid_env.py b/minigrid/minigrid_env.py index b50499b40..037f89c00 100755 --- a/minigrid/minigrid_env.py +++ b/minigrid/minigrid_env.py @@ -649,19 +649,27 @@ def gen_obs(self): return obs - def get_pov_render(self, tile_size): + def get_pov_render(self, tile_size, agent_view_size=None): """ Render an agent's POV observation for visualization """ - grid, vis_mask = self.gen_obs_grid() - + grid, vis_mask = self.gen_obs_grid(agent_view_size) + # Render the whole grid - img = grid.render( - tile_size, - agent_pos=(self.agent_view_size // 2, self.agent_view_size - 1), - agent_dir=3, - highlight_mask=vis_mask, - ) + if agent_view_size is None: + img = grid.render( + tile_size, + agent_pos=(self.agent_view_size // 2, self.agent_view_size - 1), + agent_dir=3, + highlight_mask=vis_mask, + ) + else: + img = grid.render( + tile_size, + agent_pos=(agent_view_size // 2, agent_view_size - 1), + agent_dir=3, + highlight_mask=vis_mask, + ) return img diff --git a/minigrid/wrappers.py b/minigrid/wrappers.py index 569fa11d0..5f4bf774b 100644 --- a/minigrid/wrappers.py +++ b/minigrid/wrappers.py @@ -364,6 +364,7 @@ def __init__(self, env, tile_size=8): self.tile_size = tile_size obs_shape = env.observation_space.spaces["image"].shape + new_image_space = spaces.Box( low=0, high=255, @@ -376,7 +377,10 @@ def __init__(self, env, tile_size=8): ) def observation(self, obs): - rgb_img_partial = self.get_frame(tile_size=self.tile_size, agent_pov=True) + rgb_img_partial = self.get_pov_render( + tile_size=self.tile_size, + agent_view_size=self.agent_view_size + ) return {**obs, "image": rgb_img_partial}