Skip to content

Commit

Permalink
[Chess] Reduce _is_attacked call (#1260)
Browse files Browse the repository at this point in the history
  • Loading branch information
sotetsuk authored Oct 23, 2024
1 parent 744cb2b commit 5ae5f01
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions pgx/_src/games/chess.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@
INIT_LEGAL_ACTION_MASK[ixs] = True

LEGAL_DEST = -np.ones((7, 64, 27), np.int32) # LEGAL_DEST[0, :, :] == -1
LEGAL_DEST_NEAR = -np.ones((64, 16), np.int32)
LEGAL_DEST_FAR = -np.ones((64, 19), np.int32)
LEGAL_DEST_NEAR = -np.ones((64, 16), np.int32) # king and knight moves
LEGAL_DEST_FAR = -np.ones((64, 19), np.int32) # queen moves except king moves
CAN_MOVE = np.zeros((7, 64, 64), dtype=np.bool_)
for from_ in range(64):
legal_dest = {p: [] for p in range(7)}
Expand Down Expand Up @@ -351,6 +351,10 @@ def legal_labels(label):
a1 = jax.vmap(legal_normal_moves)(possible_piece_positions).flatten()
a2 = legal_en_passants()
actions = jnp.hstack((a1, a2)) # include -1
# filter out -1. 200 is big enough for normal play.
ixs = jnp.nonzero(actions >= 0, size=200, fill_value=0)[0]
actions = actions[ixs] # size: 19 * 27 -> 200
# filter ignoring checks and suicides
actions = jnp.where(jax.vmap(is_not_checked)(actions), actions, -1)
mask = jnp.zeros(64 * 73 + 1, dtype=jnp.bool_) # +1 for sentinel
mask = mask.at[actions].set(True)
Expand Down

0 comments on commit 5ae5f01

Please sign in to comment.