diff --git a/pyproject.toml b/pyproject.toml index 3b3dd03..75c9021 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,7 +2,7 @@ name = "xminigrid" description = "JAX-accelerated meta-reinforcement learning environments inspired by XLand and MiniGrid" readme = "README.md" -requires-python =">=3.8" +requires-python =">=3.9" license = {file = "LICENSE"} authors = [ {name = "Alexander Nikulin", email = "a.p.nikulin@tinkoff.ai"}, @@ -26,7 +26,6 @@ classifiers = [ "Natural Language :: English", "Programming Language :: Python", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Topic :: Scientific/Engineering :: Artificial Intelligence", @@ -43,7 +42,8 @@ dependencies = [ [project.optional-dependencies] dev = [ "ruff>=0.1.6", - "pre-commit>=3.3.3" + "pre-commit>=3.3.3", + "pyright>=1.1.347", ] baselines = [ @@ -86,4 +86,17 @@ skip-magic-trailing-comma = false [tool.ruff.isort] # see https://github.com/astral-sh/ruff/issues/8571 -known-third-party = ["wandb"] \ No newline at end of file +known-third-party = ["wandb"] + + +[tool.pyright] +include = ["src/xminigrid"] +exclude = [ + "**/node_modules", + "**/__pycache__", +] + +reportMissingImports = true +reportMissingTypeStubs = false +pythonVersion = "3.10" +pythonPlatform = "All" \ No newline at end of file diff --git a/src/xminigrid/core/actions.py b/src/xminigrid/core/actions.py index c54169c..4cc0c3a 100644 --- a/src/xminigrid/core/actions.py +++ b/src/xminigrid/core/actions.py @@ -1,17 +1,23 @@ +from __future__ import annotations + import jax import jax.numpy as jnp +from typing_extensions import TypeAlias +from ..types import AgentState, GridState from .constants import DIRECTIONS, TILES_REGISTRY, Colors, Tiles from .grid import check_can_put, check_pickable, check_walkable, equal +ActionOutput: TypeAlias = tuple[GridState, AgentState, jax.Array] + -def _move(position, direction): +def _move(position: jax.Array, direction: jax.Array) -> jax.Array: direction = jax.lax.dynamic_index_in_dim(DIRECTIONS, direction, keepdims=False) new_position = position + direction return new_position -def move_forward(grid, agent): +def move_forward(grid: GridState, agent: AgentState) -> ActionOutput: next_position = jnp.clip( _move(agent.position, agent.direction), a_min=jnp.array((0, 0)), @@ -27,19 +33,19 @@ def move_forward(grid, agent): return grid, new_agent, new_agent.position -def turn_clockwise(grid, agent): +def turn_clockwise(grid: GridState, agent: AgentState) -> ActionOutput: new_direction = (agent.direction + 1) % 4 new_agent = agent.replace(direction=new_direction) return grid, new_agent, agent.position -def turn_counterclockwise(grid, agent): +def turn_counterclockwise(grid: GridState, agent: AgentState) -> ActionOutput: new_direction = (agent.direction - 1) % 4 new_agent = agent.replace(direction=new_direction) return grid, new_agent, agent.position -def pick_up(grid, agent): +def pick_up(grid: GridState, agent: AgentState) -> ActionOutput: next_position = _move(agent.position, agent.direction) is_pickable = check_pickable(grid, next_position) @@ -61,7 +67,7 @@ def pick_up(grid, agent): return new_grid, new_agent, next_position -def put_down(grid, agent): +def put_down(grid: GridState, agent: AgentState) -> ActionOutput: next_position = _move(agent.position, agent.direction) can_put = check_can_put(grid, next_position) @@ -78,7 +84,7 @@ def put_down(grid, agent): # TODO: may be this should be open_door action? toggle is too general and box is not supported yet -def toggle(grid, agent): +def toggle(grid: GridState, agent: AgentState) -> ActionOutput: next_position = _move(agent.position, agent.direction) next_tile = grid[next_position[0], next_position[1]] @@ -103,7 +109,7 @@ def toggle(grid, agent): return new_grid, agent, next_position -def take_action(grid, agent, action): +def take_action(grid: GridState, agent: AgentState, action: int) -> ActionOutput: # This will evaluate all actions. # Can we fix this and choose only one function? It'll speed everything up dramatically. actions = ( diff --git a/src/xminigrid/core/goals.py b/src/xminigrid/core/goals.py index f45a6f7..6985b18 100644 --- a/src/xminigrid/core/goals.py +++ b/src/xminigrid/core/goals.py @@ -1,15 +1,20 @@ +from __future__ import annotations + import abc import jax import jax.numpy as jnp from flax import struct +from ..types import AgentState, GridState from .grid import equal, get_neighbouring_tiles, pad_along_axis MAX_GOAL_ENCODING_LEN = 4 + 1 # for idx -def check_goal(encoding, grid, agent, action, position): +def check_goal( + encoding: jax.Array, grid: GridState, agent: AgentState, action: int | jax.Array, position: jax.Array +) -> jax.Array: check = jax.lax.switch( encoding[0], ( @@ -37,16 +42,16 @@ def check_goal(encoding, grid, agent, action, position): class BaseGoal(struct.PyTreeNode): @abc.abstractmethod - def __call__(self, grid, agent, action, position): + def __call__(self, grid: GridState, agent: AgentState, action: int | jax.Array, position: jax.Array) -> jax.Array: ... @classmethod @abc.abstractmethod - def decode(cls, encoding): + def decode(cls, encoding: jax.Array) -> BaseGoal: ... @abc.abstractmethod - def encode(self): + def encode(self) -> jax.Array: ... diff --git a/src/xminigrid/core/grid.py b/src/xminigrid/core/grid.py index 2d14890..eff1c13 100644 --- a/src/xminigrid/core/grid.py +++ b/src/xminigrid/core/grid.py @@ -1,21 +1,26 @@ +from __future__ import annotations + +from typing import Callable + import jax import jax.numpy as jnp +from ..types import GridState, Tile from .constants import FREE_TO_PUT_DOWN, LOS_BLOCKING, PICKABLE, TILES_REGISTRY, WALKABLE, Colors, Tiles -def empty_world(height, width): +def empty_world(height: int, width: int) -> GridState: grid = jnp.zeros((height, width, 2), dtype=jnp.uint8) grid = grid.at[:, :, 0:2].set(TILES_REGISTRY[Tiles.FLOOR, Colors.BLACK]) return grid -# wait, is this just a jnp.array_equal? -def equal(tile1, tile2): +def equal(tile1: Tile, tile2: Tile) -> Tile: + # wait, is this just a jnp.array_equal? return jnp.all(jnp.equal(tile1, tile2)) -def get_neighbouring_tiles(grid, y, x): +def get_neighbouring_tiles(grid: GridState, y: int | jax.Array, x: int | jax.Array) -> tuple[Tile, Tile, Tile, Tile]: # end_of_map = TILES_REGISTRY[Tiles.END_OF_MAP, Colors.END_OF_MAP] end_of_map = Tiles.END_OF_MAP @@ -30,17 +35,17 @@ def get_neighbouring_tiles(grid, y, x): return up_tile, right_tile, down_tile, left_tile -def horizontal_line(grid, x, y, length, tile): +def horizontal_line(grid: GridState, x: int, y: int, length: int, tile: Tile) -> GridState: grid = grid.at[y, x : x + length].set(tile) return grid -def vertical_line(grid, x, y, length, tile): +def vertical_line(grid: GridState, x: int, y: int, length: int, tile: Tile) -> GridState: grid = grid.at[y : y + length, x].set(tile) return grid -def rectangle(grid, x, y, h, w, tile): +def rectangle(grid: GridState, x: int, y: int, h: int, w: int, tile: Tile) -> GridState: grid = vertical_line(grid, x, y, h, tile) grid = vertical_line(grid, x + w - 1, y, h, tile) grid = horizontal_line(grid, x, y, w, tile) @@ -48,14 +53,14 @@ def rectangle(grid, x, y, h, w, tile): return grid -def room(height, width): +def room(height: int, width: int) -> GridState: grid = empty_world(height, width) grid = rectangle(grid, 0, 0, height, width, tile=TILES_REGISTRY[Tiles.WALL, Colors.GREY]) return grid -def two_rooms(height, width): - wall_tile = TILES_REGISTRY[Tiles.WALL, Colors.GREY] +def two_rooms(height: int, width: int) -> GridState: + wall_tile: Tile = TILES_REGISTRY[Tiles.WALL, Colors.GREY] grid = empty_world(height, width) grid = rectangle(grid, 0, 0, height, width, tile=wall_tile) @@ -63,8 +68,8 @@ def two_rooms(height, width): return grid -def four_rooms(height, width): - wall_tile = TILES_REGISTRY[Tiles.WALL, Colors.GREY] +def four_rooms(height: int, width: int) -> GridState: + wall_tile: Tile = TILES_REGISTRY[Tiles.WALL, Colors.GREY] grid = empty_world(height, width) grid = rectangle(grid, 0, 0, height, width, tile=wall_tile) @@ -73,8 +78,8 @@ def four_rooms(height, width): return grid -def nine_rooms(height, width): - wall_tile = TILES_REGISTRY[Tiles.WALL, Colors.GREY] +def nine_rooms(height: int, width: int) -> GridState: + wall_tile: Tile = TILES_REGISTRY[Tiles.WALL, Colors.GREY] grid = empty_world(height, width) grid = rectangle(grid, 0, 0, height, width, tile=wall_tile) @@ -85,34 +90,34 @@ def nine_rooms(height, width): return grid -def check_walkable(grid, position): +def check_walkable(grid: GridState, position: jax.Array) -> jax.Array: tile_id = grid[position[0], position[1], 0] is_walkable = jnp.isin(tile_id, WALKABLE, assume_unique=True) return is_walkable -def check_pickable(grid, position): +def check_pickable(grid: GridState, position: jax.Array) -> jax.Array: tile_id = grid[position[0], position[1], 0] is_pickable = jnp.isin(tile_id, PICKABLE, assume_unique=True) return is_pickable -def check_can_put(grid, position): +def check_can_put(grid: GridState, position: jax.Array) -> jax.Array: tile_id = grid[position[0], position[1], 0] can_put = jnp.isin(tile_id, FREE_TO_PUT_DOWN, assume_unique=True) return can_put -def check_see_behind(grid, position): +def check_see_behind(grid: GridState, position: jax.Array) -> jax.Array: tile_id = grid[position[0], position[1], 0] is_not_blocking = jnp.isin(tile_id, LOS_BLOCKING, assume_unique=True, invert=True) return is_not_blocking -def align_with_up(grid, direction): +def align_with_up(grid: GridState, direction: int | jax.Array) -> GridState: aligned_grid = jax.lax.switch( direction, ( @@ -125,27 +130,27 @@ def align_with_up(grid, direction): return aligned_grid -def grid_coords(grid): +def grid_coords(grid: GridState) -> jax.Array: coords = jnp.mgrid[: grid.shape[0], : grid.shape[1]] coords = coords.transpose(1, 2, 0).reshape(-1, 2) return coords -def transparent_mask(grid): +def transparent_mask(grid: GridState) -> jax.Array: coords = grid_coords(grid) mask = jax.vmap(check_see_behind, in_axes=(None, 0))(grid, coords) mask = mask.reshape(grid.shape[0], grid.shape[1]) return mask -def free_tiles_mask(grid): +def free_tiles_mask(grid: GridState) -> jax.Array: coords = grid_coords(grid) mask = jax.vmap(check_can_put, in_axes=(None, 0))(grid, coords) mask = mask.reshape(grid.shape[0], grid.shape[1]) return mask -def coordinates_mask(grid, address, comparison_fn): +def coordinates_mask(grid: GridState, address: tuple[int, int], comparison_fn: Callable) -> jax.Array: positions = jnp.mgrid[: grid.shape[0], : grid.shape[1]] cond_1 = comparison_fn(positions[0], address[0]) cond_2 = comparison_fn(positions[1], address[1]) @@ -153,7 +158,7 @@ def coordinates_mask(grid, address, comparison_fn): return mask -def sample_coordinates(key, grid, num, mask=None): +def sample_coordinates(key: jax.Array, grid: GridState, num: int, mask: jax.Array | None = None) -> jax.Array: if mask is None: mask = jnp.ones((grid.shape[0], grid.shape[1]), dtype=jnp.bool_) @@ -169,19 +174,20 @@ def sample_coordinates(key, grid, num, mask=None): return coords -def sample_direction(key): +def sample_direction(key: jax.Array) -> jax.Array: return jax.random.randint(key, shape=(), minval=0, maxval=4) -def pad_along_axis(arr, pad_to, axis=0, fill_value=0): +def pad_along_axis(arr: jax.Array, pad_to: int, axis: int = 0, fill_value: int = 0) -> jax.Array: pad_size = pad_to - arr.shape[axis] if pad_size <= 0: return arr - npad = [(0, 0)] * arr.ndim + # manually annotate for pyright + npad: list[tuple[int, int]] = [(0, 0)] * arr.ndim npad[axis] = (0, pad_size) return jnp.pad(arr, pad_width=npad, mode="constant", constant_values=fill_value) -def cartesian_product_1d(a, b): +def cartesian_product_1d(a: jax.Array, b: jax.Array) -> jax.Array: return jnp.dstack(jnp.meshgrid(a, b)).reshape(-1, 2) diff --git a/src/xminigrid/core/observation.py b/src/xminigrid/core/observation.py index bf6a9c5..e2c4ab8 100644 --- a/src/xminigrid/core/observation.py +++ b/src/xminigrid/core/observation.py @@ -1,11 +1,12 @@ import jax import jax.numpy as jnp +from ..types import AgentState, GridState from .constants import Tiles from .grid import align_with_up, check_see_behind -def crop_field_of_view(grid, agent, height, width): +def crop_field_of_view(grid: GridState, agent: AgentState, height: int, width: int) -> jax.Array: # TODO: assert height and width are odd and >= 3 # TODO: in theory we don't need padding from all 4 sides, only for out of bounds sides grid = jnp.pad( @@ -32,7 +33,7 @@ def crop_field_of_view(grid, agent, height, width): return fov_crop -def transparent_field_of_view(grid, agent, height, width): +def transparent_field_of_view(grid: GridState, agent: AgentState, height: int, width: int) -> jax.Array: fov_grid = crop_field_of_view(grid, agent, height, width) fov_grid = align_with_up(fov_grid, agent.direction) @@ -51,7 +52,7 @@ def transparent_field_of_view(grid, agent, height, width): # https://github.com/Farama-Foundation/Minigrid/blob/e6f34bee70c5eb45ca9bfa2ea061cf06dd03e7b3/minigrid/core/grid.py#L291C9-L291C20 # noqa # but adapted to jax and transposed grid # WARN: only works for field of view crop aligned with UP direction, use align_with_up before! -def generate_viz_mask_minigrid(grid): +def generate_viz_mask_minigrid(grid: GridState) -> jax.Array: H, W = grid.shape[0], grid.shape[1] viz_mask = jnp.zeros((H, W), dtype=jnp.bool_) # agent position with UP alignment, always visible @@ -104,8 +105,8 @@ def _main_body(viz_mask, y): return viz_mask -# TODO: works well with unroll=16 and random actions, but very slow with PPO even with high unroll. Fix! -def minigrid_field_of_view(grid, agent, height, width): +# TODO: works well with unroll=16 and random actions, but very slow with PPO even with high unroll! +def minigrid_field_of_view(grid: GridState, agent: AgentState, height: int, width: int) -> jax.Array: fov_grid = crop_field_of_view(grid, agent, height, width) fov_grid = align_with_up(fov_grid, agent.direction) mask = generate_viz_mask_minigrid(fov_grid) diff --git a/src/xminigrid/core/rules.py b/src/xminigrid/core/rules.py index 9e43ee6..5f903b0 100644 --- a/src/xminigrid/core/rules.py +++ b/src/xminigrid/core/rules.py @@ -1,9 +1,12 @@ +from __future__ import annotations + import abc import jax import jax.numpy as jnp from flax import struct +from ..types import AgentState, GridState from .constants import TILES_REGISTRY, Colors, Tiles from .grid import equal, get_neighbouring_tiles, pad_along_axis @@ -13,7 +16,9 @@ # this is very costly, will evaluate all rules under vmap. Submit a PR if you know how to do it better! # In general, we need a way to select specific function/class based on ID number. # We can not just decode without evaluation, as then return type will be different between branches -def check_rule(encodings, grid, agent, action, position): +def check_rule( + encodings: jax.Array, grid: GridState, agent: AgentState, action: int | jax.Array, position: jax.Array +) -> tuple[GridState, AgentState]: def _check(carry, encoding): grid, agent = carry grid, agent = jax.lax.switch( @@ -44,16 +49,18 @@ def _check(carry, encoding): class BaseRule(struct.PyTreeNode): @abc.abstractmethod - def __call__(self, grid, agent, action, position): + def __call__( + self, grid: GridState, agent: AgentState, action: int | jax.Array, position: jax.Array + ) -> tuple[GridState, AgentState]: ... @classmethod @abc.abstractmethod - def decode(cls, encoding): + def decode(cls, encoding: jax.Array) -> BaseRule: ... @abc.abstractmethod - def encode(self): + def encode(self) -> jax.Array: ... diff --git a/src/xminigrid/types.py b/src/xminigrid/types.py index 418b3d6..326437b 100644 --- a/src/xminigrid/types.py +++ b/src/xminigrid/types.py @@ -1,8 +1,7 @@ -from typing import TypeAlias - import jax import jax.numpy as jnp from flax import struct +from typing_extensions import TypeAlias from .core.constants import TILES_REGISTRY, Colors, Tiles @@ -14,6 +13,7 @@ class RuleSet(struct.PyTreeNode): GridState: TypeAlias = jax.Array +Tile: TypeAlias = jax.Array class AgentState(struct.PyTreeNode): @@ -30,7 +30,7 @@ class State(struct.PyTreeNode): key: jax.random.PRNGKey step_num: jax.Array - grid: jax.Array + grid: GridState agent: AgentState goal_encoding: jax.Array rule_encoding: jax.Array @@ -39,9 +39,9 @@ class State(struct.PyTreeNode): class StepType(jnp.uint8): - FIRST: int = jnp.asarray(0, dtype=jnp.uint8) - MID: int = jnp.asarray(1, dtype=jnp.uint8) - LAST: int = jnp.asarray(2, dtype=jnp.uint8) + FIRST: jax.Array = jnp.asarray(0, dtype=jnp.uint8) + MID: jax.Array = jnp.asarray(1, dtype=jnp.uint8) + LAST: jax.Array = jnp.asarray(2, dtype=jnp.uint8) class TimeStep(struct.PyTreeNode):