1
+ from __future__ import annotations
2
+
3
+ from typing import Callable
4
+
1
5
import jax
2
6
import jax .numpy as jnp
3
7
8
+ from ..types import GridState , Tile
4
9
from .constants import FREE_TO_PUT_DOWN , LOS_BLOCKING , PICKABLE , TILES_REGISTRY , WALKABLE , Colors , Tiles
5
10
6
11
7
- def empty_world (height , width ) :
12
+ def empty_world (height : int , width : int ) -> GridState :
8
13
grid = jnp .zeros ((height , width , 2 ), dtype = jnp .uint8 )
9
14
grid = grid .at [:, :, 0 :2 ].set (TILES_REGISTRY [Tiles .FLOOR , Colors .BLACK ])
10
15
return grid
11
16
12
17
13
- # wait, is this just a jnp.array_equal?
14
- def equal ( tile1 , tile2 ):
18
+ def equal ( tile1 : Tile , tile2 : Tile ) -> Tile :
19
+ # wait, is this just a jnp.array_equal?
15
20
return jnp .all (jnp .equal (tile1 , tile2 ))
16
21
17
22
18
- def get_neighbouring_tiles (grid , y , x ) :
23
+ def get_neighbouring_tiles (grid : GridState , y : int | jax . Array , x : int | jax . Array ) -> tuple [ Tile , Tile , Tile , Tile ] :
19
24
# end_of_map = TILES_REGISTRY[Tiles.END_OF_MAP, Colors.END_OF_MAP]
20
25
end_of_map = Tiles .END_OF_MAP
21
26
@@ -30,41 +35,41 @@ def get_neighbouring_tiles(grid, y, x):
30
35
return up_tile , right_tile , down_tile , left_tile
31
36
32
37
33
- def horizontal_line (grid , x , y , length , tile ) :
38
+ def horizontal_line (grid : GridState , x : int , y : int , length : int , tile : Tile ) -> GridState :
34
39
grid = grid .at [y , x : x + length ].set (tile )
35
40
return grid
36
41
37
42
38
- def vertical_line (grid , x , y , length , tile ) :
43
+ def vertical_line (grid : GridState , x : int , y : int , length : int , tile : Tile ) -> GridState :
39
44
grid = grid .at [y : y + length , x ].set (tile )
40
45
return grid
41
46
42
47
43
- def rectangle (grid , x , y , h , w , tile ) :
48
+ def rectangle (grid : GridState , x : int , y : int , h : int , w : int , tile : Tile ) -> GridState :
44
49
grid = vertical_line (grid , x , y , h , tile )
45
50
grid = vertical_line (grid , x + w - 1 , y , h , tile )
46
51
grid = horizontal_line (grid , x , y , w , tile )
47
52
grid = horizontal_line (grid , x , y + h - 1 , w , tile )
48
53
return grid
49
54
50
55
51
- def room (height , width ) :
56
+ def room (height : int , width : int ) -> GridState :
52
57
grid = empty_world (height , width )
53
58
grid = rectangle (grid , 0 , 0 , height , width , tile = TILES_REGISTRY [Tiles .WALL , Colors .GREY ])
54
59
return grid
55
60
56
61
57
- def two_rooms (height , width ) :
58
- wall_tile = TILES_REGISTRY [Tiles .WALL , Colors .GREY ]
62
+ def two_rooms (height : int , width : int ) -> GridState :
63
+ wall_tile : Tile = TILES_REGISTRY [Tiles .WALL , Colors .GREY ]
59
64
60
65
grid = empty_world (height , width )
61
66
grid = rectangle (grid , 0 , 0 , height , width , tile = wall_tile )
62
67
grid = vertical_line (grid , width // 2 , 0 , height , tile = wall_tile )
63
68
return grid
64
69
65
70
66
- def four_rooms (height , width ) :
67
- wall_tile = TILES_REGISTRY [Tiles .WALL , Colors .GREY ]
71
+ def four_rooms (height : int , width : int ) -> GridState :
72
+ wall_tile : Tile = TILES_REGISTRY [Tiles .WALL , Colors .GREY ]
68
73
69
74
grid = empty_world (height , width )
70
75
grid = rectangle (grid , 0 , 0 , height , width , tile = wall_tile )
@@ -73,8 +78,8 @@ def four_rooms(height, width):
73
78
return grid
74
79
75
80
76
- def nine_rooms (height , width ) :
77
- wall_tile = TILES_REGISTRY [Tiles .WALL , Colors .GREY ]
81
+ def nine_rooms (height : int , width : int ) -> GridState :
82
+ wall_tile : Tile = TILES_REGISTRY [Tiles .WALL , Colors .GREY ]
78
83
79
84
grid = empty_world (height , width )
80
85
grid = rectangle (grid , 0 , 0 , height , width , tile = wall_tile )
@@ -85,34 +90,34 @@ def nine_rooms(height, width):
85
90
return grid
86
91
87
92
88
- def check_walkable (grid , position ) :
93
+ def check_walkable (grid : GridState , position : jax . Array ) -> jax . Array :
89
94
tile_id = grid [position [0 ], position [1 ], 0 ]
90
95
is_walkable = jnp .isin (tile_id , WALKABLE , assume_unique = True )
91
96
92
97
return is_walkable
93
98
94
99
95
- def check_pickable (grid , position ) :
100
+ def check_pickable (grid : GridState , position : jax . Array ) -> jax . Array :
96
101
tile_id = grid [position [0 ], position [1 ], 0 ]
97
102
is_pickable = jnp .isin (tile_id , PICKABLE , assume_unique = True )
98
103
return is_pickable
99
104
100
105
101
- def check_can_put (grid , position ) :
106
+ def check_can_put (grid : GridState , position : jax . Array ) -> jax . Array :
102
107
tile_id = grid [position [0 ], position [1 ], 0 ]
103
108
can_put = jnp .isin (tile_id , FREE_TO_PUT_DOWN , assume_unique = True )
104
109
105
110
return can_put
106
111
107
112
108
- def check_see_behind (grid , position ) :
113
+ def check_see_behind (grid : GridState , position : jax . Array ) -> jax . Array :
109
114
tile_id = grid [position [0 ], position [1 ], 0 ]
110
115
is_not_blocking = jnp .isin (tile_id , LOS_BLOCKING , assume_unique = True , invert = True )
111
116
112
117
return is_not_blocking
113
118
114
119
115
- def align_with_up (grid , direction ) :
120
+ def align_with_up (grid : GridState , direction : int | jax . Array ) -> GridState :
116
121
aligned_grid = jax .lax .switch (
117
122
direction ,
118
123
(
@@ -125,35 +130,35 @@ def align_with_up(grid, direction):
125
130
return aligned_grid
126
131
127
132
128
- def grid_coords (grid ) :
133
+ def grid_coords (grid : GridState ) -> jax . Array :
129
134
coords = jnp .mgrid [: grid .shape [0 ], : grid .shape [1 ]]
130
135
coords = coords .transpose (1 , 2 , 0 ).reshape (- 1 , 2 )
131
136
return coords
132
137
133
138
134
- def transparent_mask (grid ) :
139
+ def transparent_mask (grid : GridState ) -> jax . Array :
135
140
coords = grid_coords (grid )
136
141
mask = jax .vmap (check_see_behind , in_axes = (None , 0 ))(grid , coords )
137
142
mask = mask .reshape (grid .shape [0 ], grid .shape [1 ])
138
143
return mask
139
144
140
145
141
- def free_tiles_mask (grid ) :
146
+ def free_tiles_mask (grid : GridState ) -> jax . Array :
142
147
coords = grid_coords (grid )
143
148
mask = jax .vmap (check_can_put , in_axes = (None , 0 ))(grid , coords )
144
149
mask = mask .reshape (grid .shape [0 ], grid .shape [1 ])
145
150
return mask
146
151
147
152
148
- def coordinates_mask (grid , address , comparison_fn ) :
153
+ def coordinates_mask (grid : GridState , address : tuple [ int , int ], comparison_fn : Callable ) -> jax . Array :
149
154
positions = jnp .mgrid [: grid .shape [0 ], : grid .shape [1 ]]
150
155
cond_1 = comparison_fn (positions [0 ], address [0 ])
151
156
cond_2 = comparison_fn (positions [1 ], address [1 ])
152
157
mask = jnp .logical_and (cond_1 , cond_2 )
153
158
return mask
154
159
155
160
156
- def sample_coordinates (key , grid , num , mask = None ):
161
+ def sample_coordinates (key : jax . Array , grid : GridState , num : int , mask : jax . Array | None = None ) -> jax . Array :
157
162
if mask is None :
158
163
mask = jnp .ones ((grid .shape [0 ], grid .shape [1 ]), dtype = jnp .bool_ )
159
164
@@ -169,19 +174,20 @@ def sample_coordinates(key, grid, num, mask=None):
169
174
return coords
170
175
171
176
172
- def sample_direction (key ) :
177
+ def sample_direction (key : jax . Array ) -> jax . Array :
173
178
return jax .random .randint (key , shape = (), minval = 0 , maxval = 4 )
174
179
175
180
176
- def pad_along_axis (arr , pad_to , axis = 0 , fill_value = 0 ) :
181
+ def pad_along_axis (arr : jax . Array , pad_to : int , axis : int = 0 , fill_value : int = 0 ) -> jax . Array :
177
182
pad_size = pad_to - arr .shape [axis ]
178
183
if pad_size <= 0 :
179
184
return arr
180
185
181
- npad = [(0 , 0 )] * arr .ndim
186
+ # manually annotate for pyright
187
+ npad : list [tuple [int , int ]] = [(0 , 0 )] * arr .ndim
182
188
npad [axis ] = (0 , pad_size )
183
189
return jnp .pad (arr , pad_width = npad , mode = "constant" , constant_values = fill_value )
184
190
185
191
186
- def cartesian_product_1d (a , b ) :
192
+ def cartesian_product_1d (a : jax . Array , b : jax . Array ) -> jax . Array :
187
193
return jnp .dstack (jnp .meshgrid (a , b )).reshape (- 1 , 2 )
0 commit comments