Skip to content

Commit

Permalink
add descriptions with language
Browse files Browse the repository at this point in the history
  • Loading branch information
BartekCupial committed Oct 31, 2024
1 parent a3b796e commit 870926b
Showing 1 changed file with 187 additions and 3 deletions.
190 changes: 187 additions & 3 deletions minigrid/minigrid_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from gymnasium.core import ActType, ObsType

from minigrid.core.actions import Actions
from minigrid.core.constants import COLOR_NAMES, DIR_TO_VEC, TILE_PIXELS
from minigrid.core.constants import COLOR_NAMES, DIR_TO_VEC, TILE_PIXELS, COLOR_TO_IDX, OBJECT_TO_IDX
from minigrid.core.grid import Grid
from minigrid.core.mission import MissionSpace
from minigrid.core.world_object import Point, WorldObj
Expand Down Expand Up @@ -45,6 +45,7 @@ def __init__(
highlight: bool = True,
tile_size: int = TILE_PIXELS,
agent_pov: bool = False,
language='english'
):
# Initialize mission
self.mission = mission_space.sample()
Expand Down Expand Up @@ -102,6 +103,9 @@ def __init__(

self.see_through_walls = see_through_walls

# Language of the descriptions
self.language = language

# Current position and direction of the agent
self.agent_pos: np.ndarray | tuple[int, int] = None
self.agent_dir: int = None
Expand Down Expand Up @@ -154,7 +158,10 @@ def reset(
# Return first observation
obs = self.gen_obs()

return obs, {}
# add info Episodic Knowledge to minigrid
info = self.gen_graph(move_forward=None)

return obs, info

def hash(self, size=16):
"""Compute a hash that uniquely identifies the current state of the environment.
Expand Down Expand Up @@ -591,8 +598,16 @@ def step(
self.render()

obs = self.gen_obs()

# add info Episodic Knowledge to minigrid
move_forward = None
if action == self.actions.forward:
move_forward = False
if np.all(self.agent_pos == fwd_pos):
move_forward = True
info = self.gen_graph(move_forward=move_forward)

return obs, reward, terminated, truncated, {}
return obs, reward, terminated, truncated, info

def gen_obs_grid(self, agent_view_size=None):
"""
Expand Down Expand Up @@ -787,3 +802,172 @@ def render(self):
def close(self):
if self.window:
pygame.quit()

def gen_graph(self, move_forward=None):
grid, vis_mask = self.gen_obs_grid()

# Encode the partially observable view into a numpy array
image = grid.encode(vis_mask)
# (OBJECT_TO_IDX[self.type], COLOR_TO_IDX[self.color], state)
# State, 0: open, 1: closed, 2: locked
if self.language == 'english':
IDX_TO_STATE = {0: 'open', 1: 'closed', 2: 'locked'}
IDX_TO_COLOR = dict(zip(COLOR_TO_IDX.values(), COLOR_TO_IDX.keys()))
IDX_TO_OBJECT = dict(zip(OBJECT_TO_IDX.values(), OBJECT_TO_IDX.keys()))

elif self.language == 'french':
IDX_TO_STATE = {0: 'ouverte', 1: 'fermée', 2: 'fermée à clef'}
IDX_TO_COLOR = {0: 'rouge', 1: 'verte', 2: 'bleue', 3: 'violette', 4: 'jaune', 5: 'grise'}
IDX_TO_OBJECT = {0: 'non visible', 1: 'vide', 2: 'mur', 3: 'sol', 4: 'porte', 5: 'clef',
6: 'balle', 7: 'boîte', 8: 'but', 9: 'lave', 10: 'agent'}

list_textual_descriptions = []

if self.carrying is not None:
# print('carrying')
if self.language == 'english':
list_textual_descriptions.append("You carry a {} {}".format(self.carrying.color, self.carrying.type))
elif self.language == 'french':
list_textual_descriptions.append("Tu portes une {} {}".format(self.carrying.type, self.carrying.color))

# print('A agent position i: {}, j: {}'.format(self.agent_pos[0], self.agent_pos[1]))
agent_pos_vx, agent_pos_vy = self.get_view_coords(self.agent_pos[0], self.agent_pos[1])
# print('B agent position i: {}, j: {}'.format(agent_pos_vx, agent_pos_vy))

view_field_dictionary = dict()

for i in range(image.shape[0]):
for j in range(image.shape[1]):
if image[i][j][0] != 0 and image[i][j][0] != 1 and image[i][j][0] != 2:
if i not in view_field_dictionary.keys():
view_field_dictionary[i] = dict()
view_field_dictionary[i][j] = image[i][j]
else:
view_field_dictionary[i][j] = image[i][j]

# Find the wall if any
# We describe a wall only if there is no objects between the agent and the wall in straight line

# Find wall in front
j = agent_pos_vy - 1
object_seen = False
while j >= 0 and not object_seen:
if image[agent_pos_vx][j][0] != 0 and image[agent_pos_vx][j][0] != 1:
if image[agent_pos_vx][j][0] == 2:
if self.language == 'english':
list_textual_descriptions.append(
f"You see a wall {agent_pos_vy - j} step{'s' if agent_pos_vy - j > 1 else ''} forward")
elif self.language == 'french':
list_textual_descriptions.append("Tu vois un mur à {} pas devant".format(agent_pos_vy - j))
object_seen = True
else:
object_seen = True
j -= 1
# Find wall left
i = agent_pos_vx - 1
object_seen = False
while i >= 0 and not object_seen:
if image[i][agent_pos_vy][0] != 0 and image[i][agent_pos_vy][0] != 1:
if image[i][agent_pos_vy][0] == 2:
if self.language == 'english':
list_textual_descriptions.append(
f"You see a wall {agent_pos_vx - i} step{'s' if agent_pos_vx - i > 1 else ''} left")
elif self.language == 'french':
list_textual_descriptions.append("Tu vois un mur à {} pas à gauche".format(agent_pos_vx - i))
object_seen = True
else:
object_seen = True
i -= 1
# Find wall right
i = agent_pos_vx + 1
object_seen = False
while i < image.shape[0] and not object_seen:
if image[i][agent_pos_vy][0] != 0 and image[i][agent_pos_vy][0] != 1:
if image[i][agent_pos_vy][0] == 2:
if self.language == 'english':
list_textual_descriptions.append(
f"You see a wall {i - agent_pos_vx} step{'s' if i - agent_pos_vx > 1 else ''} right")
elif self.language == 'french':
list_textual_descriptions.append("Tu vois un mur à {} pas à droite".format(i - agent_pos_vx))
object_seen = True
else:
object_seen = True
i += 1

# returns the position of seen objects relative to you
for i in view_field_dictionary.keys():
for j in view_field_dictionary[i].keys():
if i != agent_pos_vx or j != agent_pos_vy:
object = view_field_dictionary[i][j]
relative_position = dict()

if i - agent_pos_vx > 0:
if self.language == 'english':
relative_position["x_axis"] = ("right", i - agent_pos_vx)
elif self.language == 'french':
relative_position["x_axis"] = ("à droite", i - agent_pos_vx)
elif i - agent_pos_vx == 0:
if self.language == 'english':
relative_position["x_axis"] = ("face", 0)
elif self.language == 'french':
relative_position["x_axis"] = ("en face", 0)
else:
if self.language == 'english':
relative_position["x_axis"] = ("left", agent_pos_vx - i)
elif self.language == 'french':
relative_position["x_axis"] = ("à gauche", agent_pos_vx - i)
if agent_pos_vy - j > 0:
if self.language == 'english':
relative_position["y_axis"] = ("forward", agent_pos_vy - j)
elif self.language == 'french':
relative_position["y_axis"] = ("devant", agent_pos_vy - j)
elif agent_pos_vy - j == 0:
if self.language == 'english':
relative_position["y_axis"] = ("forward", 0)
elif self.language == 'french':
relative_position["y_axis"] = ("devant", 0)

distances = []
if relative_position["x_axis"][0] in ["face", "en face"]:
distances.append((relative_position["y_axis"][1], relative_position["y_axis"][0]))
elif relative_position["y_axis"][1] == 0:
distances.append((relative_position["x_axis"][1], relative_position["x_axis"][0]))
else:
distances.append((relative_position["x_axis"][1], relative_position["x_axis"][0]))
distances.append((relative_position["y_axis"][1], relative_position["y_axis"][0]))

description = ""
if object[0] != 4: # if it is not a door
if self.language == 'english':
description = f"You see a {IDX_TO_COLOR[object[1]]} {IDX_TO_OBJECT[object[0]]} "
elif self.language == 'french':
description = f"Tu vois une {IDX_TO_OBJECT[object[0]]} {IDX_TO_COLOR[object[1]]} "

else:
if IDX_TO_STATE[object[2]] != 0: # if it is not open
if self.language == 'english':
description = f"You see a {IDX_TO_STATE[object[2]]} {IDX_TO_COLOR[object[1]]} {IDX_TO_OBJECT[object[0]]} "
elif self.language == 'french':
description = f"Tu vois une {IDX_TO_OBJECT[object[0]]} {IDX_TO_COLOR[object[1]]} {IDX_TO_STATE[object[2]]} "

else:
if self.language == 'english':
description = f"You see an {IDX_TO_STATE[object[2]]} {IDX_TO_COLOR[object[1]]} {IDX_TO_OBJECT[object[0]]} "
elif self.language == 'french':
description = f"Tu vois une {IDX_TO_OBJECT[object[0]]} {IDX_TO_COLOR[object[1]]} {IDX_TO_STATE[object[2]]} "

for _i, _distance in enumerate(distances):
if _i > 0:
if self.language == 'english':
description += " and "
elif self.language == 'french':
description += " et "

if self.language == 'english':
description += f"{_distance[0]} step{'s' if _distance[0] > 1 else ''} {_distance[1]}"
elif self.language == 'french':
description += f"{_distance[0]} pas {_distance[1]}"

list_textual_descriptions.append(description)

return {'descriptions': list_textual_descriptions}

0 comments on commit 870926b

Please sign in to comment.