Skip to content

Commit 870cd5f

Browse files
committed
wip type hints
1 parent 658e69f commit 870cd5f

File tree

7 files changed

+97
-59
lines changed

7 files changed

+97
-59
lines changed

pyproject.toml

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
name = "xminigrid"
33
description = "JAX-accelerated meta-reinforcement learning environments inspired by XLand and MiniGrid"
44
readme = "README.md"
5-
requires-python =">=3.8"
5+
requires-python =">=3.9"
66
license = {file = "LICENSE"}
77
authors = [
88
{name = "Alexander Nikulin", email = "[email protected]"},
@@ -26,7 +26,6 @@ classifiers = [
2626
"Natural Language :: English",
2727
"Programming Language :: Python",
2828
"Programming Language :: Python :: 3",
29-
"Programming Language :: Python :: 3.8",
3029
"Programming Language :: Python :: 3.9",
3130
"Programming Language :: Python :: 3.10",
3231
"Topic :: Scientific/Engineering :: Artificial Intelligence",
@@ -43,7 +42,8 @@ dependencies = [
4342
[project.optional-dependencies]
4443
dev = [
4544
"ruff>=0.1.6",
46-
"pre-commit>=3.3.3"
45+
"pre-commit>=3.3.3",
46+
"pyright>=1.1.347",
4747
]
4848

4949
baselines = [
@@ -86,4 +86,17 @@ skip-magic-trailing-comma = false
8686

8787
[tool.ruff.isort]
8888
# see https://github.com/astral-sh/ruff/issues/8571
89-
known-third-party = ["wandb"]
89+
known-third-party = ["wandb"]
90+
91+
92+
[tool.pyright]
93+
include = ["src/xminigrid"]
94+
exclude = [
95+
"**/node_modules",
96+
"**/__pycache__",
97+
]
98+
99+
reportMissingImports = true
100+
reportMissingTypeStubs = false
101+
pythonVersion = "3.10"
102+
pythonPlatform = "All"

src/xminigrid/core/actions.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,23 @@
1+
from __future__ import annotations
2+
13
import jax
24
import jax.numpy as jnp
5+
from typing_extensions import TypeAlias
36

7+
from ..types import AgentState, GridState
48
from .constants import DIRECTIONS, TILES_REGISTRY, Colors, Tiles
59
from .grid import check_can_put, check_pickable, check_walkable, equal
610

11+
ActionOutput: TypeAlias = tuple[GridState, AgentState, jax.Array]
12+
713

8-
def _move(position, direction):
14+
def _move(position: jax.Array, direction: jax.Array) -> jax.Array:
915
direction = jax.lax.dynamic_index_in_dim(DIRECTIONS, direction, keepdims=False)
1016
new_position = position + direction
1117
return new_position
1218

1319

14-
def move_forward(grid, agent):
20+
def move_forward(grid: GridState, agent: AgentState) -> ActionOutput:
1521
next_position = jnp.clip(
1622
_move(agent.position, agent.direction),
1723
a_min=jnp.array((0, 0)),
@@ -27,19 +33,19 @@ def move_forward(grid, agent):
2733
return grid, new_agent, new_agent.position
2834

2935

30-
def turn_clockwise(grid, agent):
36+
def turn_clockwise(grid: GridState, agent: AgentState) -> ActionOutput:
3137
new_direction = (agent.direction + 1) % 4
3238
new_agent = agent.replace(direction=new_direction)
3339
return grid, new_agent, agent.position
3440

3541

36-
def turn_counterclockwise(grid, agent):
42+
def turn_counterclockwise(grid: GridState, agent: AgentState) -> ActionOutput:
3743
new_direction = (agent.direction - 1) % 4
3844
new_agent = agent.replace(direction=new_direction)
3945
return grid, new_agent, agent.position
4046

4147

42-
def pick_up(grid, agent):
48+
def pick_up(grid: GridState, agent: AgentState) -> ActionOutput:
4349
next_position = _move(agent.position, agent.direction)
4450

4551
is_pickable = check_pickable(grid, next_position)
@@ -61,7 +67,7 @@ def pick_up(grid, agent):
6167
return new_grid, new_agent, next_position
6268

6369

64-
def put_down(grid, agent):
70+
def put_down(grid: GridState, agent: AgentState) -> ActionOutput:
6571
next_position = _move(agent.position, agent.direction)
6672

6773
can_put = check_can_put(grid, next_position)
@@ -78,7 +84,7 @@ def put_down(grid, agent):
7884

7985

8086
# TODO: may be this should be open_door action? toggle is too general and box is not supported yet
81-
def toggle(grid, agent):
87+
def toggle(grid: GridState, agent: AgentState) -> ActionOutput:
8288
next_position = _move(agent.position, agent.direction)
8389
next_tile = grid[next_position[0], next_position[1]]
8490

@@ -103,7 +109,7 @@ def toggle(grid, agent):
103109
return new_grid, agent, next_position
104110

105111

106-
def take_action(grid, agent, action):
112+
def take_action(grid: GridState, agent: AgentState, action: int) -> ActionOutput:
107113
# This will evaluate all actions.
108114
# Can we fix this and choose only one function? It'll speed everything up dramatically.
109115
actions = (

src/xminigrid/core/goals.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,20 @@
1+
from __future__ import annotations
2+
13
import abc
24

35
import jax
46
import jax.numpy as jnp
57
from flax import struct
68

9+
from ..types import AgentState, GridState
710
from .grid import equal, get_neighbouring_tiles, pad_along_axis
811

912
MAX_GOAL_ENCODING_LEN = 4 + 1 # for idx
1013

1114

12-
def check_goal(encoding, grid, agent, action, position):
15+
def check_goal(
16+
encoding: jax.Array, grid: GridState, agent: AgentState, action: int | jax.Array, position: jax.Array
17+
) -> jax.Array:
1318
check = jax.lax.switch(
1419
encoding[0],
1520
(
@@ -37,16 +42,16 @@ def check_goal(encoding, grid, agent, action, position):
3742

3843
class BaseGoal(struct.PyTreeNode):
3944
@abc.abstractmethod
40-
def __call__(self, grid, agent, action, position):
45+
def __call__(self, grid: GridState, agent: AgentState, action: int | jax.Array, position: jax.Array) -> jax.Array:
4146
...
4247

4348
@classmethod
4449
@abc.abstractmethod
45-
def decode(cls, encoding):
50+
def decode(cls, encoding: jax.Array) -> BaseGoal:
4651
...
4752

4853
@abc.abstractmethod
49-
def encode(self):
54+
def encode(self) -> jax.Array:
5055
...
5156

5257

src/xminigrid/core/grid.py

Lines changed: 34 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,26 @@
1+
from __future__ import annotations
2+
3+
from typing import Callable
4+
15
import jax
26
import jax.numpy as jnp
37

8+
from ..types import GridState, Tile
49
from .constants import FREE_TO_PUT_DOWN, LOS_BLOCKING, PICKABLE, TILES_REGISTRY, WALKABLE, Colors, Tiles
510

611

7-
def empty_world(height, width):
12+
def empty_world(height: int, width: int) -> GridState:
813
grid = jnp.zeros((height, width, 2), dtype=jnp.uint8)
914
grid = grid.at[:, :, 0:2].set(TILES_REGISTRY[Tiles.FLOOR, Colors.BLACK])
1015
return grid
1116

1217

13-
# wait, is this just a jnp.array_equal?
14-
def equal(tile1, tile2):
18+
def equal(tile1: Tile, tile2: Tile) -> Tile:
19+
# wait, is this just a jnp.array_equal?
1520
return jnp.all(jnp.equal(tile1, tile2))
1621

1722

18-
def get_neighbouring_tiles(grid, y, x):
23+
def get_neighbouring_tiles(grid: GridState, y: int | jax.Array, x: int | jax.Array) -> tuple[Tile, Tile, Tile, Tile]:
1924
# end_of_map = TILES_REGISTRY[Tiles.END_OF_MAP, Colors.END_OF_MAP]
2025
end_of_map = Tiles.END_OF_MAP
2126

@@ -30,41 +35,41 @@ def get_neighbouring_tiles(grid, y, x):
3035
return up_tile, right_tile, down_tile, left_tile
3136

3237

33-
def horizontal_line(grid, x, y, length, tile):
38+
def horizontal_line(grid: GridState, x: int, y: int, length: int, tile: Tile) -> GridState:
3439
grid = grid.at[y, x : x + length].set(tile)
3540
return grid
3641

3742

38-
def vertical_line(grid, x, y, length, tile):
43+
def vertical_line(grid: GridState, x: int, y: int, length: int, tile: Tile) -> GridState:
3944
grid = grid.at[y : y + length, x].set(tile)
4045
return grid
4146

4247

43-
def rectangle(grid, x, y, h, w, tile):
48+
def rectangle(grid: GridState, x: int, y: int, h: int, w: int, tile: Tile) -> GridState:
4449
grid = vertical_line(grid, x, y, h, tile)
4550
grid = vertical_line(grid, x + w - 1, y, h, tile)
4651
grid = horizontal_line(grid, x, y, w, tile)
4752
grid = horizontal_line(grid, x, y + h - 1, w, tile)
4853
return grid
4954

5055

51-
def room(height, width):
56+
def room(height: int, width: int) -> GridState:
5257
grid = empty_world(height, width)
5358
grid = rectangle(grid, 0, 0, height, width, tile=TILES_REGISTRY[Tiles.WALL, Colors.GREY])
5459
return grid
5560

5661

57-
def two_rooms(height, width):
58-
wall_tile = TILES_REGISTRY[Tiles.WALL, Colors.GREY]
62+
def two_rooms(height: int, width: int) -> GridState:
63+
wall_tile: Tile = TILES_REGISTRY[Tiles.WALL, Colors.GREY]
5964

6065
grid = empty_world(height, width)
6166
grid = rectangle(grid, 0, 0, height, width, tile=wall_tile)
6267
grid = vertical_line(grid, width // 2, 0, height, tile=wall_tile)
6368
return grid
6469

6570

66-
def four_rooms(height, width):
67-
wall_tile = TILES_REGISTRY[Tiles.WALL, Colors.GREY]
71+
def four_rooms(height: int, width: int) -> GridState:
72+
wall_tile: Tile = TILES_REGISTRY[Tiles.WALL, Colors.GREY]
6873

6974
grid = empty_world(height, width)
7075
grid = rectangle(grid, 0, 0, height, width, tile=wall_tile)
@@ -73,8 +78,8 @@ def four_rooms(height, width):
7378
return grid
7479

7580

76-
def nine_rooms(height, width):
77-
wall_tile = TILES_REGISTRY[Tiles.WALL, Colors.GREY]
81+
def nine_rooms(height: int, width: int) -> GridState:
82+
wall_tile: Tile = TILES_REGISTRY[Tiles.WALL, Colors.GREY]
7883

7984
grid = empty_world(height, width)
8085
grid = rectangle(grid, 0, 0, height, width, tile=wall_tile)
@@ -85,34 +90,34 @@ def nine_rooms(height, width):
8590
return grid
8691

8792

88-
def check_walkable(grid, position):
93+
def check_walkable(grid: GridState, position: jax.Array) -> jax.Array:
8994
tile_id = grid[position[0], position[1], 0]
9095
is_walkable = jnp.isin(tile_id, WALKABLE, assume_unique=True)
9196

9297
return is_walkable
9398

9499

95-
def check_pickable(grid, position):
100+
def check_pickable(grid: GridState, position: jax.Array) -> jax.Array:
96101
tile_id = grid[position[0], position[1], 0]
97102
is_pickable = jnp.isin(tile_id, PICKABLE, assume_unique=True)
98103
return is_pickable
99104

100105

101-
def check_can_put(grid, position):
106+
def check_can_put(grid: GridState, position: jax.Array) -> jax.Array:
102107
tile_id = grid[position[0], position[1], 0]
103108
can_put = jnp.isin(tile_id, FREE_TO_PUT_DOWN, assume_unique=True)
104109

105110
return can_put
106111

107112

108-
def check_see_behind(grid, position):
113+
def check_see_behind(grid: GridState, position: jax.Array) -> jax.Array:
109114
tile_id = grid[position[0], position[1], 0]
110115
is_not_blocking = jnp.isin(tile_id, LOS_BLOCKING, assume_unique=True, invert=True)
111116

112117
return is_not_blocking
113118

114119

115-
def align_with_up(grid, direction):
120+
def align_with_up(grid: GridState, direction: int | jax.Array) -> GridState:
116121
aligned_grid = jax.lax.switch(
117122
direction,
118123
(
@@ -125,35 +130,35 @@ def align_with_up(grid, direction):
125130
return aligned_grid
126131

127132

128-
def grid_coords(grid):
133+
def grid_coords(grid: GridState) -> jax.Array:
129134
coords = jnp.mgrid[: grid.shape[0], : grid.shape[1]]
130135
coords = coords.transpose(1, 2, 0).reshape(-1, 2)
131136
return coords
132137

133138

134-
def transparent_mask(grid):
139+
def transparent_mask(grid: GridState) -> jax.Array:
135140
coords = grid_coords(grid)
136141
mask = jax.vmap(check_see_behind, in_axes=(None, 0))(grid, coords)
137142
mask = mask.reshape(grid.shape[0], grid.shape[1])
138143
return mask
139144

140145

141-
def free_tiles_mask(grid):
146+
def free_tiles_mask(grid: GridState) -> jax.Array:
142147
coords = grid_coords(grid)
143148
mask = jax.vmap(check_can_put, in_axes=(None, 0))(grid, coords)
144149
mask = mask.reshape(grid.shape[0], grid.shape[1])
145150
return mask
146151

147152

148-
def coordinates_mask(grid, address, comparison_fn):
153+
def coordinates_mask(grid: GridState, address: tuple[int, int], comparison_fn: Callable) -> jax.Array:
149154
positions = jnp.mgrid[: grid.shape[0], : grid.shape[1]]
150155
cond_1 = comparison_fn(positions[0], address[0])
151156
cond_2 = comparison_fn(positions[1], address[1])
152157
mask = jnp.logical_and(cond_1, cond_2)
153158
return mask
154159

155160

156-
def sample_coordinates(key, grid, num, mask=None):
161+
def sample_coordinates(key: jax.Array, grid: GridState, num: int, mask: jax.Array | None = None) -> jax.Array:
157162
if mask is None:
158163
mask = jnp.ones((grid.shape[0], grid.shape[1]), dtype=jnp.bool_)
159164

@@ -169,19 +174,20 @@ def sample_coordinates(key, grid, num, mask=None):
169174
return coords
170175

171176

172-
def sample_direction(key):
177+
def sample_direction(key: jax.Array) -> jax.Array:
173178
return jax.random.randint(key, shape=(), minval=0, maxval=4)
174179

175180

176-
def pad_along_axis(arr, pad_to, axis=0, fill_value=0):
181+
def pad_along_axis(arr: jax.Array, pad_to: int, axis: int = 0, fill_value: int = 0) -> jax.Array:
177182
pad_size = pad_to - arr.shape[axis]
178183
if pad_size <= 0:
179184
return arr
180185

181-
npad = [(0, 0)] * arr.ndim
186+
# manually annotate for pyright
187+
npad: list[tuple[int, int]] = [(0, 0)] * arr.ndim
182188
npad[axis] = (0, pad_size)
183189
return jnp.pad(arr, pad_width=npad, mode="constant", constant_values=fill_value)
184190

185191

186-
def cartesian_product_1d(a, b):
192+
def cartesian_product_1d(a: jax.Array, b: jax.Array) -> jax.Array:
187193
return jnp.dstack(jnp.meshgrid(a, b)).reshape(-1, 2)

0 commit comments

Comments
 (0)