Skip to content

Commit

Permalink
wip type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
Howuhh committed Jan 19, 2024
1 parent 658e69f commit 870cd5f
Show file tree
Hide file tree
Showing 7 changed files with 97 additions and 59 deletions.
21 changes: 17 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
name = "xminigrid"
description = "JAX-accelerated meta-reinforcement learning environments inspired by XLand and MiniGrid"
readme = "README.md"
requires-python =">=3.8"
requires-python =">=3.9"
license = {file = "LICENSE"}
authors = [
{name = "Alexander Nikulin", email = "[email protected]"},
Expand All @@ -26,7 +26,6 @@ classifiers = [
"Natural Language :: English",
"Programming Language :: Python",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
Expand All @@ -43,7 +42,8 @@ dependencies = [
[project.optional-dependencies]
dev = [
"ruff>=0.1.6",
"pre-commit>=3.3.3"
"pre-commit>=3.3.3",
"pyright>=1.1.347",
]

baselines = [
Expand Down Expand Up @@ -86,4 +86,17 @@ skip-magic-trailing-comma = false

[tool.ruff.isort]
# see https://github.com/astral-sh/ruff/issues/8571
known-third-party = ["wandb"]
known-third-party = ["wandb"]


[tool.pyright]
include = ["src/xminigrid"]
exclude = [
"**/node_modules",
"**/__pycache__",
]

reportMissingImports = true
reportMissingTypeStubs = false
pythonVersion = "3.10"
pythonPlatform = "All"
22 changes: 14 additions & 8 deletions src/xminigrid/core/actions.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
from __future__ import annotations

import jax
import jax.numpy as jnp
from typing_extensions import TypeAlias

from ..types import AgentState, GridState
from .constants import DIRECTIONS, TILES_REGISTRY, Colors, Tiles
from .grid import check_can_put, check_pickable, check_walkable, equal

ActionOutput: TypeAlias = tuple[GridState, AgentState, jax.Array]


def _move(position, direction):
def _move(position: jax.Array, direction: jax.Array) -> jax.Array:
direction = jax.lax.dynamic_index_in_dim(DIRECTIONS, direction, keepdims=False)
new_position = position + direction
return new_position


def move_forward(grid, agent):
def move_forward(grid: GridState, agent: AgentState) -> ActionOutput:
next_position = jnp.clip(
_move(agent.position, agent.direction),
a_min=jnp.array((0, 0)),
Expand All @@ -27,19 +33,19 @@ def move_forward(grid, agent):
return grid, new_agent, new_agent.position


def turn_clockwise(grid, agent):
def turn_clockwise(grid: GridState, agent: AgentState) -> ActionOutput:
new_direction = (agent.direction + 1) % 4
new_agent = agent.replace(direction=new_direction)
return grid, new_agent, agent.position


def turn_counterclockwise(grid, agent):
def turn_counterclockwise(grid: GridState, agent: AgentState) -> ActionOutput:
new_direction = (agent.direction - 1) % 4
new_agent = agent.replace(direction=new_direction)
return grid, new_agent, agent.position


def pick_up(grid, agent):
def pick_up(grid: GridState, agent: AgentState) -> ActionOutput:
next_position = _move(agent.position, agent.direction)

is_pickable = check_pickable(grid, next_position)
Expand All @@ -61,7 +67,7 @@ def pick_up(grid, agent):
return new_grid, new_agent, next_position


def put_down(grid, agent):
def put_down(grid: GridState, agent: AgentState) -> ActionOutput:
next_position = _move(agent.position, agent.direction)

can_put = check_can_put(grid, next_position)
Expand All @@ -78,7 +84,7 @@ def put_down(grid, agent):


# TODO: may be this should be open_door action? toggle is too general and box is not supported yet
def toggle(grid, agent):
def toggle(grid: GridState, agent: AgentState) -> ActionOutput:
next_position = _move(agent.position, agent.direction)
next_tile = grid[next_position[0], next_position[1]]

Expand All @@ -103,7 +109,7 @@ def toggle(grid, agent):
return new_grid, agent, next_position


def take_action(grid, agent, action):
def take_action(grid: GridState, agent: AgentState, action: int) -> ActionOutput:
# This will evaluate all actions.
# Can we fix this and choose only one function? It'll speed everything up dramatically.
actions = (
Expand Down
13 changes: 9 additions & 4 deletions src/xminigrid/core/goals.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
from __future__ import annotations

import abc

import jax
import jax.numpy as jnp
from flax import struct

from ..types import AgentState, GridState
from .grid import equal, get_neighbouring_tiles, pad_along_axis

MAX_GOAL_ENCODING_LEN = 4 + 1 # for idx


def check_goal(encoding, grid, agent, action, position):
def check_goal(
encoding: jax.Array, grid: GridState, agent: AgentState, action: int | jax.Array, position: jax.Array
) -> jax.Array:
check = jax.lax.switch(
encoding[0],
(
Expand Down Expand Up @@ -37,16 +42,16 @@ def check_goal(encoding, grid, agent, action, position):

class BaseGoal(struct.PyTreeNode):
@abc.abstractmethod
def __call__(self, grid, agent, action, position):
def __call__(self, grid: GridState, agent: AgentState, action: int | jax.Array, position: jax.Array) -> jax.Array:
...

@classmethod
@abc.abstractmethod
def decode(cls, encoding):
def decode(cls, encoding: jax.Array) -> BaseGoal:
...

@abc.abstractmethod
def encode(self):
def encode(self) -> jax.Array:
...


Expand Down
62 changes: 34 additions & 28 deletions src/xminigrid/core/grid.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,26 @@
from __future__ import annotations

from typing import Callable

import jax
import jax.numpy as jnp

from ..types import GridState, Tile
from .constants import FREE_TO_PUT_DOWN, LOS_BLOCKING, PICKABLE, TILES_REGISTRY, WALKABLE, Colors, Tiles


def empty_world(height, width):
def empty_world(height: int, width: int) -> GridState:
grid = jnp.zeros((height, width, 2), dtype=jnp.uint8)
grid = grid.at[:, :, 0:2].set(TILES_REGISTRY[Tiles.FLOOR, Colors.BLACK])
return grid


# wait, is this just a jnp.array_equal?
def equal(tile1, tile2):
def equal(tile1: Tile, tile2: Tile) -> Tile:
# wait, is this just a jnp.array_equal?
return jnp.all(jnp.equal(tile1, tile2))


def get_neighbouring_tiles(grid, y, x):
def get_neighbouring_tiles(grid: GridState, y: int | jax.Array, x: int | jax.Array) -> tuple[Tile, Tile, Tile, Tile]:
# end_of_map = TILES_REGISTRY[Tiles.END_OF_MAP, Colors.END_OF_MAP]
end_of_map = Tiles.END_OF_MAP

Expand All @@ -30,41 +35,41 @@ def get_neighbouring_tiles(grid, y, x):
return up_tile, right_tile, down_tile, left_tile


def horizontal_line(grid, x, y, length, tile):
def horizontal_line(grid: GridState, x: int, y: int, length: int, tile: Tile) -> GridState:
grid = grid.at[y, x : x + length].set(tile)
return grid


def vertical_line(grid, x, y, length, tile):
def vertical_line(grid: GridState, x: int, y: int, length: int, tile: Tile) -> GridState:
grid = grid.at[y : y + length, x].set(tile)
return grid


def rectangle(grid, x, y, h, w, tile):
def rectangle(grid: GridState, x: int, y: int, h: int, w: int, tile: Tile) -> GridState:
grid = vertical_line(grid, x, y, h, tile)
grid = vertical_line(grid, x + w - 1, y, h, tile)
grid = horizontal_line(grid, x, y, w, tile)
grid = horizontal_line(grid, x, y + h - 1, w, tile)
return grid


def room(height, width):
def room(height: int, width: int) -> GridState:
grid = empty_world(height, width)
grid = rectangle(grid, 0, 0, height, width, tile=TILES_REGISTRY[Tiles.WALL, Colors.GREY])
return grid


def two_rooms(height, width):
wall_tile = TILES_REGISTRY[Tiles.WALL, Colors.GREY]
def two_rooms(height: int, width: int) -> GridState:
wall_tile: Tile = TILES_REGISTRY[Tiles.WALL, Colors.GREY]

grid = empty_world(height, width)
grid = rectangle(grid, 0, 0, height, width, tile=wall_tile)
grid = vertical_line(grid, width // 2, 0, height, tile=wall_tile)
return grid


def four_rooms(height, width):
wall_tile = TILES_REGISTRY[Tiles.WALL, Colors.GREY]
def four_rooms(height: int, width: int) -> GridState:
wall_tile: Tile = TILES_REGISTRY[Tiles.WALL, Colors.GREY]

grid = empty_world(height, width)
grid = rectangle(grid, 0, 0, height, width, tile=wall_tile)
Expand All @@ -73,8 +78,8 @@ def four_rooms(height, width):
return grid


def nine_rooms(height, width):
wall_tile = TILES_REGISTRY[Tiles.WALL, Colors.GREY]
def nine_rooms(height: int, width: int) -> GridState:
wall_tile: Tile = TILES_REGISTRY[Tiles.WALL, Colors.GREY]

grid = empty_world(height, width)
grid = rectangle(grid, 0, 0, height, width, tile=wall_tile)
Expand All @@ -85,34 +90,34 @@ def nine_rooms(height, width):
return grid


def check_walkable(grid, position):
def check_walkable(grid: GridState, position: jax.Array) -> jax.Array:
tile_id = grid[position[0], position[1], 0]
is_walkable = jnp.isin(tile_id, WALKABLE, assume_unique=True)

return is_walkable


def check_pickable(grid, position):
def check_pickable(grid: GridState, position: jax.Array) -> jax.Array:
tile_id = grid[position[0], position[1], 0]
is_pickable = jnp.isin(tile_id, PICKABLE, assume_unique=True)
return is_pickable


def check_can_put(grid, position):
def check_can_put(grid: GridState, position: jax.Array) -> jax.Array:
tile_id = grid[position[0], position[1], 0]
can_put = jnp.isin(tile_id, FREE_TO_PUT_DOWN, assume_unique=True)

return can_put


def check_see_behind(grid, position):
def check_see_behind(grid: GridState, position: jax.Array) -> jax.Array:
tile_id = grid[position[0], position[1], 0]
is_not_blocking = jnp.isin(tile_id, LOS_BLOCKING, assume_unique=True, invert=True)

return is_not_blocking


def align_with_up(grid, direction):
def align_with_up(grid: GridState, direction: int | jax.Array) -> GridState:
aligned_grid = jax.lax.switch(
direction,
(
Expand All @@ -125,35 +130,35 @@ def align_with_up(grid, direction):
return aligned_grid


def grid_coords(grid):
def grid_coords(grid: GridState) -> jax.Array:
coords = jnp.mgrid[: grid.shape[0], : grid.shape[1]]
coords = coords.transpose(1, 2, 0).reshape(-1, 2)
return coords


def transparent_mask(grid):
def transparent_mask(grid: GridState) -> jax.Array:
coords = grid_coords(grid)
mask = jax.vmap(check_see_behind, in_axes=(None, 0))(grid, coords)
mask = mask.reshape(grid.shape[0], grid.shape[1])
return mask


def free_tiles_mask(grid):
def free_tiles_mask(grid: GridState) -> jax.Array:
coords = grid_coords(grid)
mask = jax.vmap(check_can_put, in_axes=(None, 0))(grid, coords)
mask = mask.reshape(grid.shape[0], grid.shape[1])
return mask


def coordinates_mask(grid, address, comparison_fn):
def coordinates_mask(grid: GridState, address: tuple[int, int], comparison_fn: Callable) -> jax.Array:
positions = jnp.mgrid[: grid.shape[0], : grid.shape[1]]
cond_1 = comparison_fn(positions[0], address[0])
cond_2 = comparison_fn(positions[1], address[1])
mask = jnp.logical_and(cond_1, cond_2)
return mask


def sample_coordinates(key, grid, num, mask=None):
def sample_coordinates(key: jax.Array, grid: GridState, num: int, mask: jax.Array | None = None) -> jax.Array:
if mask is None:
mask = jnp.ones((grid.shape[0], grid.shape[1]), dtype=jnp.bool_)

Expand All @@ -169,19 +174,20 @@ def sample_coordinates(key, grid, num, mask=None):
return coords


def sample_direction(key):
def sample_direction(key: jax.Array) -> jax.Array:
return jax.random.randint(key, shape=(), minval=0, maxval=4)


def pad_along_axis(arr, pad_to, axis=0, fill_value=0):
def pad_along_axis(arr: jax.Array, pad_to: int, axis: int = 0, fill_value: int = 0) -> jax.Array:
pad_size = pad_to - arr.shape[axis]
if pad_size <= 0:
return arr

npad = [(0, 0)] * arr.ndim
# manually annotate for pyright
npad: list[tuple[int, int]] = [(0, 0)] * arr.ndim
npad[axis] = (0, pad_size)
return jnp.pad(arr, pad_width=npad, mode="constant", constant_values=fill_value)


def cartesian_product_1d(a, b):
def cartesian_product_1d(a: jax.Array, b: jax.Array) -> jax.Array:
return jnp.dstack(jnp.meshgrid(a, b)).reshape(-1, 2)
Loading

0 comments on commit 870cd5f

Please sign in to comment.