Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
374d74c
fix: rgb_first_person now returns the correct observation size
epignatelli Aug 12, 2025
43a5bb3
build(release): bump version
epignatelli Aug 12, 2025
3c65704
fix: align min tile size with minigrid
epignatelli Aug 12, 2025
87d26e7
fix: align default first obs radius with minigrid
epignatelli Aug 12, 2025
49d3056
fix: align default first obs radius with minigrid
epignatelli Aug 12, 2025
7e90131
fix: align padding value with minigrid in rgb obs
epignatelli Aug 12, 2025
50a0838
fix: centre player first obs
epignatelli Aug 12, 2025
ccde1ae
fix: centre player first obs
epignatelli Aug 12, 2025
ccdb9b6
fix radius crop
epignatelli Aug 12, 2025
a1925f1
fix radius crop
epignatelli Aug 13, 2025
294926d
fix radius crop
epignatelli Aug 13, 2025
31f693c
fix: add minigrid opacity for fp rgb obs
epignatelli Aug 13, 2025
98e1d2d
fix radius crop
epignatelli Aug 13, 2025
f37c109
fix radius crop
epignatelli Aug 13, 2025
88bd163
max sure radius is correct
epignatelli Aug 13, 2025
95416cf
fix: sprites now correctly align with grid
epignatelli Aug 13, 2025
aa3970e
we should not rotate sprites at all
epignatelli Aug 13, 2025
a8997df
refactor: abstract out patch rotation
epignatelli Aug 13, 2025
98ce0db
fix: cull elements in obstructed view
epignatelli Aug 13, 2025
230eec4
fix: cull elements in obstructed view
epignatelli Aug 13, 2025
91f4625
add test for fp observations
epignatelli Aug 14, 2025
8e8dc95
basic first person observations
epignatelli Aug 15, 2025
7ce23d4
perf(obs): compile Python for loop when radius<=10,
epignatelli Aug 15, 2025
7ea18e6
feat: add a way to set the radius of the observation
epignatelli Aug 15, 2025
fc07b75
feat: add obstruction calculation to fp obs
epignatelli Aug 15, 2025
455c1d9
remove print statements
epignatelli Aug 15, 2025
5d770c3
fix: use correct img size for fp rgb obs
epignatelli Aug 15, 2025
a01a449
fix: use correct img size for fp symbolic and categorical obs
epignatelli Aug 15, 2025
fb6e5e8
feat: add obstruction calculation to fp obs
epignatelli Aug 15, 2025
30c68c8
cleanup
epignatelli Aug 15, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ jobs:
- name: Setup navix
run: |
pip install . -v
pip install -r requirements_test.txt
- name: Check code quality
run: |
pip install pylint
Expand Down
2 changes: 1 addition & 1 deletion navix/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@
# under the License.


__version__ = "0.7.1"
__version__ = "0.7.2"
__version_info__ = tuple(int(i) for i in __version__.split(".") if i.isdigit())
6 changes: 3 additions & 3 deletions navix/environments/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def _get_obs_space_from_fn(
return Discrete.create(n_elements=9, shape=(height, width))
elif observation_fn == observations.categorical_first_person:
radius = observations.RADIUS
return Discrete.create(n_elements=9, shape=(radius + 1, radius * 2 + 1))
return Discrete.create(n_elements=9, shape=(radius * 2 + 1, radius * 2 + 1))
elif observation_fn == observations.rgb:
return Discrete.create(
256,
Expand All @@ -239,7 +239,7 @@ def _get_obs_space_from_fn(
radius = observations.RADIUS
return Discrete.create(
n_elements=256,
shape=(radius * TILE_SIZE + 1, radius * TILE_SIZE * 2 + 1, 3),
shape=((radius * 2 + 1) * TILE_SIZE, (radius * 2 + 1) * TILE_SIZE, 3),
dtype=jnp.uint8,
)
elif observation_fn == observations.symbolic:
Expand All @@ -252,7 +252,7 @@ def _get_obs_space_from_fn(
radius = observations.RADIUS
return Discrete.create(
n_elements=256,
shape=(radius + 1, radius * 2 + 1, 3),
shape=(radius * 2 + 1, radius * 2 + 1, 3),
dtype=jnp.uint8,
)
else:
Expand Down
131 changes: 98 additions & 33 deletions navix/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import jax.numpy as jnp
from jax import Array
from flax import struct
from navix.rendering.registry import TILE_SIZE


Coordinates = Tuple[Array, Array]
Expand Down Expand Up @@ -209,6 +210,28 @@ def align(patch: Array, current_direction: Array, desired_direction: Array) -> A
)


def rotate_tile(patch: Array, num_times_90: Array) -> Array:
"""Rotates a patch of the grid by a given number of 90-degree rotations.

Args:
patch (Array): A patch of the grid.
num_times_90 (int): The number of 90-degree rotations to apply.

Returns:
Array: A patch of the grid rotated by the given number of 90-degree rotations.
"""
return jax.lax.switch(
num_times_90,
(
lambda x: jnp.flip(jnp.swapaxes(x, 0, 1), axis=0), # rot90
lambda x: jnp.flip(jnp.flip(x, axis=0), axis=1), # rot180
lambda x: jnp.flip(jnp.swapaxes(x, 0, 1), axis=1), # rot270
lambda x: x, # rot0
),
patch,
)


def random_positions(
key: Array, grid: Array, n: int = 1, exclude: Array = jnp.asarray((-1, -1))
) -> Array:
Expand Down Expand Up @@ -362,7 +385,7 @@ def horizontal_wall(


def crop(
grid: Array, origin: Array, direction: Array, radius: int, padding_value: int = 0
grid: Array, origin: Array, direction: Array, radius: int, padding_value: int = 100
) -> Array:
"""Crops a grid around a given origin, facing a given direction, with a given radius.

Expand All @@ -375,19 +398,11 @@ def crop(

Returns:
Array: A cropped grid."""
input_shape = grid.shape
# assert radius % 2, "Radius must be an odd number"
# mid = jnp.asarray([g // 2 for g in grid.shape[:2]])
# translated = jnp.roll(grid, mid - origin, axis=(0, 1))

# # crop such that the agent is in the centre of the grid
# cropped = translated.at[: 2 * radius + 1, : 2 * radius + 1].get(
# fill_value=padding_value
# )
diameter = radius * 2

# pad with radius
padding = [(radius, radius), (radius, radius)]
for _ in range(len(input_shape) - 2):
padding = [(diameter, diameter), (diameter, diameter)]
for _ in range(len(grid.shape) - 2):
padding.append((0, 0))

padded = jnp.pad(grid, padding, constant_values=padding_value)
Expand All @@ -396,24 +411,49 @@ def crop(
translated = jnp.roll(padded, -jnp.asarray(origin), axis=(0, 1))

# crop such that the agent is in the centre of the grid
cropped = translated[: 2 * radius + 1, : 2 * radius + 1]
cropped = translated[: 2 * diameter + 1, : 2 * diameter + 1]

# rotate such that the agent is facing north
rotated = jax.lax.switch(
direction,
(
lambda x: jnp.rot90(x, 1), # 0 = transpose, 1 = flip
lambda x: jnp.rot90(x, 2), # 0 = flip, 1 = flip
lambda x: jnp.rot90(x, 3), # 0 = flip, 1 = transpose
lambda x: x,
),
cropped,
)
rotated = rotate_tile(cropped, direction)

cropped = rotated.at[: radius + 1].get(fill_value=padding_value)
# if radius is 6
cropped = rotated.at[: diameter + 1, radius : diameter * 2 - radius + 1].get(
fill_value=padding_value
)
return jnp.asarray(cropped, dtype=grid.dtype)


def apply_minigrid_opacity(image: Array, opacity: Array = jnp.asarray(0.7)) -> Array:
"""Applies minigrid opacity to the given image, used in
`minigrid.wrappers.RGBImgPartialObsWrapper`. The default MiniGrid opacity is 0.7.

Args:
image (Array): The input image to which opacity is applied.
opacity (Array, optional): The opacity value to apply. Defaults to 0.7.

Returns:
Array: The input image with applied opacity.
"""
return jax.numpy.asarray(255 - opacity * (255 - image), dtype=jax.numpy.uint8)


def draw_grid_lines(tile: Array, luminosity: Array = jnp.asarray(100)) -> Array:
"""Draws grid lines on the given tile.

Args:
tile (Array): The input tile to which grid lines are drawn.

Returns:
Array: The tile with drawn grid lines.
"""
# Draw lines (top and left edges) at 3.1% of the tile size as per
# minigrid.core.Grid.render_tile
line_thickness = jnp.ceil(TILE_SIZE * 0.031)
tile = tile.at[:line_thickness, :].set(luminosity)
tile = tile.at[:, :line_thickness].set(luminosity)
return tile


def view_cone(transparency_map: Array, origin: Array, radius: int) -> Array:
"""Computes the view cone of a given origin in a grid with a given radius.
The view cone is a boolean map of transparent (1) and opaque (0) tiles, indicating
Expand All @@ -426,26 +466,51 @@ def view_cone(transparency_map: Array, origin: Array, radius: int) -> Array:

Returns:
Array: The view cone of the given origin in the grid with the given radius."""
# transparency_map is a boolean map of transparent (1) and opaque (0) tiles

def fin_diff(array, _):
array = jnp.roll(array, -1, axis=0) + array + jnp.roll(array, +1, axis=0)
array = jnp.roll(array, -1, axis=1) + array + jnp.roll(array, +1, axis=1)
return array * transparency_map, ()

# initialise the field to all zeros, except at the source (agent's position)
mask = jnp.zeros_like(transparency_map).at[tuple(origin)].set(1)

view = jax.lax.scan(fin_diff, mask, None, radius)[0]

# start the diffusion process using finite differences
# if radius is small, it should be fast enough to compile
MIN_SCAN_RADIUS = 10
if radius <= MIN_SCAN_RADIUS:
view = mask
for _ in range(radius):
view = fin_diff(view, None)[0]
else:
view = jax.lax.scan(fin_diff, mask, None, radius)[0]

# view has anything that is visible > 0
# we now set a hard threshold > 0, but we can also think in the future
# to use a cutoof at a different value to mimic the effect of a torch
# (or eyesight for what matters)
view = jnp.where(view > 0, 1, 0)
# to use a cutoff at a different value to mimic the effect of a torch
vis_free = view > 0

# add frontier obstacles
# frontier obstacles = opaque cells neighbouring any visible-free cell (8-neighbourhood)
opaque = transparency_map == 0
nb = (
vis_free
| jnp.roll(vis_free, +1, 0)
| jnp.roll(vis_free, -1, 0)
| jnp.roll(vis_free, +1, 1)
| jnp.roll(vis_free, -1, 1)
| jnp.roll(jnp.roll(vis_free, +1, 0), +1, 1)
| jnp.roll(jnp.roll(vis_free, +1, 0), -1, 1)
| jnp.roll(jnp.roll(vis_free, -1, 0), +1, 1)
| jnp.roll(jnp.roll(vis_free, -1, 0), -1, 1)
)
frontier = nb & opaque

# we add back the opaque tiles
view = jnp.where(transparency_map == 0, 1, view)
# final visible = transparent region plus blocking frontier
visible = vis_free | frontier
visible = visible.at[tuple(origin)].set(True)

return view
return visible.astype(transparency_map.dtype)


def from_ascii_map(ascii_map: str, mapping: Dict[str, int] = {}) -> Array:
Expand Down
66 changes: 44 additions & 22 deletions navix/observations.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,25 @@
from .rendering.cache import TILE_SIZE, unflatten_patches
from .components import DISCARD_PILE_IDX, Directional, HasColour, Openable
from .states import State
from .grid import align, idx_from_coordinates, crop, view_cone
from .grid import (
apply_minigrid_opacity,
align,
draw_grid_lines,
idx_from_coordinates,
crop,
view_cone,
)
from .entities import EntityIds


RADIUS = 3


def set_radius(radius: int):
global RADIUS
RADIUS = radius


def none(state: State) -> Array:
"""An empty observation represented as an array of shape f32[0].
Useful for testing purposes.
Expand Down Expand Up @@ -217,36 +229,46 @@ def rgb_first_person(state: State) -> Array:
Array: An RGB image of the agent's view, represented as an array of shape \
`u8[(2 * RADIUS + 1) * S, (2 * RADIUS + 1) * S, 3]`, where
`S` is the size of the tile."""
# calculate final image size
# get agent's view
# image_size = (
# state.grid.shape[0] * TILE_SIZE,
# state.grid.shape[1] * TILE_SIZE,
# )
# transparency_map = jnp.where(state.grid == 0, 1, 0)
# positions = state.get_positions()
# transparent = state.get_transparency()
# transparency_map = transparency_map.at[tuple(positions.T)].set(~transparent)
# view = view_cone(transparency_map, player.position, RADIUS)
# view = jax.image.resize(view, image_size, method="nearest")
# view = jnp.tile(view[..., None], (1, 1, 3))
# get the player
player = state.get_player()

# get sprites aligned to player's direction
sprites = state.get_sprites()
sprites = jax.vmap(lambda x: align(x, jnp.asarray(0), player.direction))(sprites)
sprites = state.get_sprites_first_person() # (n_sprites, TILE_SIZE, TILE_SIZE, 3)
# sprites = jax.vmap(lambda x: align(x, jnp.asarray(0), alignment_direction))(sprites)

# align sprites to player's direction
# draw grid lines on tiles
# sprites = jax.vmap(lambda x: draw_grid_lines(x))(sprites)

# update current patchwork
indices = idx_from_coordinates(state.grid, state.get_positions())
patches = state.cache.patches.at[indices].set(sprites)
patches = state.cache.patches.at[indices].set(
sprites
) # ( H * W + 1, TILE_SIZE, TILE_SIZE, 3)

# remove discard pile
patches = patches[:DISCARD_PILE_IDX]
patches = patches[:DISCARD_PILE_IDX] # ( H * W, TILE_SIZE, TILE_SIZE, 3)
# rearrange the sprites in a grid
patchwork = patches.reshape(*state.grid.shape, *patches.shape[1:])
patchwork = patches.reshape(
*state.grid.shape, *patches.shape[1:]
) # (H, W, TILE_SIZE, TILE_SIZE, 3)

# apply minigrid opacity
patchwork = apply_minigrid_opacity(patchwork)

# apply fov
dark_cell_colour = 0 # dark color for unseen tiles
transparency_map = jnp.where(state.grid == 0, 1, 0) # (H, W)
positions = state.get_positions()
transparent = state.get_transparency()
transparency_map = transparency_map.at[tuple(positions.T)].set(transparent)
view = view_cone(transparency_map, player.position, RADIUS) # (H, W)
view = jnp.asarray(view, dtype=jnp.bool)
patchwork = jnp.where(view[..., None, None, None], patchwork, dark_cell_colour)

# crop grid to agent's view
player = state.get_player()
patchwork = crop(patchwork, player.position, player.direction, RADIUS)
patchwork = crop(
patchwork, player.position, player.direction, RADIUS, dark_cell_colour
) # (RADIUS * 2 + 1, RADIUS * 2 + 1, TILE_SIZE, TILE_SIZE, 3)

# reconstruct image
obs = jnp.swapaxes(patchwork, 1, 2)
Expand Down
2 changes: 1 addition & 1 deletion navix/rendering/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
SPRITES_DIR = os.path.normpath(
os.path.join(__file__, "..", "..", "..", "assets", "sprites")
)
MIN_TILE_SIZE = 32
MIN_TILE_SIZE = 8
TILE_SIZE = MIN_TILE_SIZE


Expand Down
13 changes: 12 additions & 1 deletion navix/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
HasColour,
)
from .rendering.cache import RenderingCache
from .rendering.registry import PALETTE
from .rendering.registry import PALETTE, SPRITES_REGISTRY
from .entities import Entity, Entities, Goal, Wall, Ball, Lava, Key, Door, Box, Player


Expand Down Expand Up @@ -447,6 +447,17 @@ def get_sprites(self) -> Array:
"""Get the sprites of all the entities in the state."""
return jnp.concatenate([self.entities[k].sprite for k in self.entities])

def get_sprites_first_person(self) -> Array:
"""Returns the sprites with the agent aligned in the north position"""
player_sprite = SPRITES_REGISTRY[Entities.PLAYER][-1][None] # -1 is north
sprites = []
for k, v in self.entities.items():
if k is not Entities.PLAYER:
sprites.append(v.sprite)
else:
sprites.append(player_sprite)
return jnp.concatenate(sprites)

def get_transparency(self) -> Array:
"""Get the transparency of all the entities in the state."""
return jnp.concatenate([self.entities[k].transparent for k in self.entities])
1 change: 1 addition & 0 deletions requirements_test.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
minigrid
18 changes: 18 additions & 0 deletions tests/test_observations.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import pytest

import jax
import jax.numpy as jnp

Expand Down Expand Up @@ -131,6 +133,22 @@ def test_categorical_first_person():
print(obs)


def test_rgb_first_person():
import gymnasium as gym
import minigrid

navix_env_id = "Navix-Empty-8x8-v0"
gym_env_id = navix_env_id.replace("Navix", "MiniGrid")

env = nx.make(navix_env_id, observation_fn=nx.observations.rgb_first_person)
timestep = env.reset(jax.random.PRNGKey(0))

env = gym.make(gym_env_id)
env = minigrid.wrappers.RGBImgPartialObsWrapper(env)
obs, _ = env.reset()
obs = obs["image"]


if __name__ == "__main__":
test_rgb()
# test_categorical_first_person()
Expand Down
Loading