Skip to content

Commit

Permalink
[Format] Use line_length = 120 (#1136)
Browse files Browse the repository at this point in the history
  • Loading branch information
sotetsuk authored Dec 28, 2023
1 parent b12cbbb commit da29e7e
Show file tree
Hide file tree
Showing 37 changed files with 431 additions and 1,307 deletions.
6 changes: 1 addition & 5 deletions pgx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
from pgx._src.api_test import api_test
from pgx._src.baseline import BaselineModelId, make_baseline_model
from pgx._src.types import Array, PRNGKey
from pgx._src.visualizer import (
save_svg,
save_svg_animation,
set_visualization_config,
)
from pgx._src.visualizer import save_svg, save_svg_animation, set_visualization_config
from pgx.core import Env, EnvId, State, available_envs, make

__version__ = "2.0.1"
Expand Down
16 changes: 4 additions & 12 deletions pgx/_src/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,7 @@ def api_test_single(env: Env, num: int = 100, use_key=True):
for _ in range(num):
rng, subkey = jax.random.split(rng)
state = init(subkey)
assert (
state.legal_action_mask.sum() != 0
), "legal_action_mask at init state cannot be zero."
assert state.legal_action_mask.sum() != 0, "legal_action_mask at init state cannot be zero."

assert state._step_count == 0
curr_steps = state._step_count
Expand All @@ -72,9 +70,7 @@ def api_test_single(env: Env, num: int = 100, use_key=True):
if not use_key:
subkey = None
state = step(state, action, subkey)
assert (
state._step_count == curr_steps + 1
), f"{state._step_count}, {curr_steps}"
assert state._step_count == curr_steps + 1, f"{state._step_count}, {curr_steps}"
curr_steps += 1

_validate_state(state)
Expand Down Expand Up @@ -135,9 +131,7 @@ def _validate_state(state: State):
assert state.current_player.dtype == jnp.int32, state.current_player.dtype
assert state.terminated.dtype == jnp.bool_, state.terminated.dtype
assert state.rewards.dtype == jnp.float32, state.rewards.dtype
assert (
state.legal_action_mask.dtype == jnp.bool_
), state.legal_action_mask.dtype
assert state.legal_action_mask.dtype == jnp.bool_, state.legal_action_mask.dtype

# check public attributes
public_attributes = [
Expand All @@ -158,9 +152,7 @@ def _validate_legal_actions(state: State):
if state.terminated:
# Agent can take any action at terminal state (but give no effect to the next state)
# This is to avoid zero-division error by normalizing action probability by legal actions
assert (
state.legal_action_mask == jnp.ones_like(state.legal_action_mask)
).all(), state.legal_action_mask
assert (state.legal_action_mask == jnp.ones_like(state.legal_action_mask)).all(), state.legal_action_mask
else:
...

Expand Down
56 changes: 14 additions & 42 deletions pgx/_src/baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@
]


def make_baseline_model(
model_id: BaselineModelId, download_dir: str = "baselines"
):
def make_baseline_model(model_id: BaselineModelId, download_dir: str = "baselines"):
if model_id in (
"animal_shogi_v0",
"gardner_chess_v0",
Expand All @@ -44,41 +42,29 @@ def make_baseline_model(
assert False


def _make_az_baseline_model(
model_id: BaselineModelId, download_dir: str = "baselines"
):
def _make_az_baseline_model(model_id: BaselineModelId, download_dir: str = "baselines"):
import haiku as hk

model_args, model_params, model_state = _load_baseline_model(
model_id, download_dir
)
model_args, model_params, model_state = _load_baseline_model(model_id, download_dir)

def forward_fn(x, is_eval=False):
net = _create_az_model_v0(**model_args)
policy_out, value_out = net(
x, is_training=not is_eval, test_local_stats=False
)
policy_out, value_out = net(x, is_training=not is_eval, test_local_stats=False)
return policy_out, value_out

forward = hk.without_apply_rng(hk.transform_with_state(forward_fn))

def apply(obs):
(logits, value), _ = forward.apply(
model_params, model_state, obs, is_eval=True
)
(logits, value), _ = forward.apply(model_params, model_state, obs, is_eval=True)
return logits, value

return apply


def _make_minatar_baseline_model(
model_id: BaselineModelId, download_dir: str = "baselines"
):
def _make_minatar_baseline_model(model_id: BaselineModelId, download_dir: str = "baselines"):
import haiku as hk

model_args, model_params, model_state = _load_baseline_model(
model_id, download_dir
)
model_args, model_params, model_state = _load_baseline_model(model_id, download_dir)
del model_state

class ActorCritic(hk.Module):
Expand All @@ -96,9 +82,7 @@ def __call__(self, x):
activation = jax.nn.tanh
x = hk.Conv2D(32, kernel_shape=2)(x)
x = jax.nn.relu(x)
x = hk.avg_pool(
x, window_shape=(2, 2), strides=(2, 2), padding="VALID"
)
x = hk.avg_pool(x, window_shape=(2, 2), strides=(2, 2), padding="VALID")
x = x.reshape((x.shape[0], -1)) # flatten
x = hk.Linear(64)(x)
x = jax.nn.relu(x)
Expand Down Expand Up @@ -130,9 +114,7 @@ def apply(obs):
return apply


def _load_baseline_model(
baseline_model: BaselineModelId, basedir: str = "baselines"
):
def _load_baseline_model(baseline_model: BaselineModelId, basedir: str = "baselines"):
os.makedirs(basedir, exist_ok=True)

# download baseline model if not exists
Expand Down Expand Up @@ -226,36 +208,26 @@ def __call__(self, x, is_training, test_local_stats):
x = hk.Conv2D(self.num_channels, kernel_shape=3)(x)

if not self.resnet_v2:
x = hk.BatchNorm(True, True, 0.9)(
x, is_training, test_local_stats
)
x = hk.BatchNorm(True, True, 0.9)(x, is_training, test_local_stats)
x = jax.nn.relu(x)

for i in range(self.num_layers):
x = self.resnet_cls(self.num_channels, name=f"block_{i}")(
x, is_training, test_local_stats
)
x = self.resnet_cls(self.num_channels, name=f"block_{i}")(x, is_training, test_local_stats)

if self.resnet_v2:
x = hk.BatchNorm(True, True, 0.9)(
x, is_training, test_local_stats
)
x = hk.BatchNorm(True, True, 0.9)(x, is_training, test_local_stats)
x = jax.nn.relu(x)

# policy head
logits = hk.Conv2D(output_channels=2, kernel_shape=1)(x)
logits = hk.BatchNorm(True, True, 0.9)(
logits, is_training, test_local_stats
)
logits = hk.BatchNorm(True, True, 0.9)(logits, is_training, test_local_stats)
logits = jax.nn.relu(logits)
logits = hk.Flatten()(logits)
logits = hk.Linear(self.num_actions)(logits)

# value head
value = hk.Conv2D(output_channels=1, kernel_shape=1)(x)
value = hk.BatchNorm(True, True, 0.9)(
value, is_training, test_local_stats
)
value = hk.BatchNorm(True, True, 0.9)(value, is_training, test_local_stats)
value = jax.nn.relu(value)
value = hk.Flatten()(value)
value = hk.Linear(self.num_channels)(value)
Expand Down
53 changes: 13 additions & 40 deletions pgx/_src/chess_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,14 +82,10 @@
if jnp.abs(r1 - r0) == 1 and jnp.abs(c1 - c0) <= 1:
legal_dst.append(to)
# init move
if (r0 == 1 or r0 == 6) and (
jnp.abs(c1 - c0) == 0 and jnp.abs(r1 - r0) == 2
):
if (r0 == 1 or r0 == 6) and (jnp.abs(c1 - c0) == 0 and jnp.abs(r1 - r0) == 2):
legal_dst.append(to)
assert len(legal_dst) <= 8
CAN_MOVE = CAN_MOVE.at[1, from_, : len(legal_dst)].set(
jnp.int32(legal_dst)
)
CAN_MOVE = CAN_MOVE.at[1, from_, : len(legal_dst)].set(jnp.int32(legal_dst))
# KNIGHT
for from_ in range(64):
r0, c0 = from_ % 8, from_ // 8
Expand All @@ -101,9 +97,7 @@
if jnp.abs(r1 - r0) == 2 and jnp.abs(c1 - c0) == 1:
legal_dst.append(to)
assert len(legal_dst) <= 27
CAN_MOVE = CAN_MOVE.at[2, from_, : len(legal_dst)].set(
jnp.int32(legal_dst)
)
CAN_MOVE = CAN_MOVE.at[2, from_, : len(legal_dst)].set(jnp.int32(legal_dst))
# BISHOP
for from_ in range(64):
r0, c0 = from_ % 8, from_ // 8
Expand All @@ -115,9 +109,7 @@
if jnp.abs(r1 - r0) == jnp.abs(c1 - c0):
legal_dst.append(to)
assert len(legal_dst) <= 27
CAN_MOVE = CAN_MOVE.at[3, from_, : len(legal_dst)].set(
jnp.int32(legal_dst)
)
CAN_MOVE = CAN_MOVE.at[3, from_, : len(legal_dst)].set(jnp.int32(legal_dst))
# ROOK
for from_ in range(64):
r0, c0 = from_ % 8, from_ // 8
Expand All @@ -129,9 +121,7 @@
if jnp.abs(r1 - r0) == 0 or jnp.abs(c1 - c0) == 0:
legal_dst.append(to)
assert len(legal_dst) <= 27
CAN_MOVE = CAN_MOVE.at[4, from_, : len(legal_dst)].set(
jnp.int32(legal_dst)
)
CAN_MOVE = CAN_MOVE.at[4, from_, : len(legal_dst)].set(jnp.int32(legal_dst))
# QUEEN
for from_ in range(64):
r0, c0 = from_ % 8, from_ // 8
Expand All @@ -145,9 +135,7 @@
if jnp.abs(r1 - r0) == jnp.abs(c1 - c0):
legal_dst.append(to)
assert len(legal_dst) <= 27
CAN_MOVE = CAN_MOVE.at[5, from_, : len(legal_dst)].set(
jnp.int32(legal_dst)
)
CAN_MOVE = CAN_MOVE.at[5, from_, : len(legal_dst)].set(jnp.int32(legal_dst))
# KING
for from_ in range(64):
r0, c0 = from_ % 8, from_ // 8
Expand All @@ -164,9 +152,7 @@
# if from_ == 39:
# legal_dst += [23, 55]
assert len(legal_dst) <= 8
CAN_MOVE = CAN_MOVE.at[6, from_, : len(legal_dst)].set(
jnp.int32(legal_dst)
)
CAN_MOVE = CAN_MOVE.at[6, from_, : len(legal_dst)].set(jnp.int32(legal_dst))

assert (CAN_MOVE[0, :, :] == -1).all()

Expand All @@ -181,9 +167,7 @@
to = CAN_MOVE[2, from_, i] # KNIGHT
if to >= 0:
legal_dst.append(to)
CAN_MOVE_ANY = CAN_MOVE_ANY.at[from_, : len(legal_dst)].set(
jnp.int32(legal_dst)
)
CAN_MOVE_ANY = CAN_MOVE_ANY.at[from_, : len(legal_dst)].set(jnp.int32(legal_dst))


# Between
Expand All @@ -192,10 +176,7 @@
for to in range(64):
r0, c0 = from_ % 8, from_ // 8
r1, c1 = to % 8, to // 8
if not (
(jnp.abs(r1 - r0) == 0 or jnp.abs(c1 - c0) == 0)
or (jnp.abs(r1 - r0) == jnp.abs(c1 - c0))
):
if not ((jnp.abs(r1 - r0) == 0 or jnp.abs(c1 - c0) == 0) or (jnp.abs(r1 - r0) == jnp.abs(c1 - c0))):
continue
dr = max(min(r1 - r0, 1), -1)
dc = max(min(c1 - c0, 1), -1)
Expand Down Expand Up @@ -229,22 +210,14 @@

key = jax.random.PRNGKey(238290)
key, subkey = jax.random.split(key)
ZOBRIST_BOARD = jax.random.randint(
subkey, shape=(64, 13, 2), minval=0, maxval=2**31 - 1, dtype=jnp.uint32
)
ZOBRIST_BOARD = jax.random.randint(subkey, shape=(64, 13, 2), minval=0, maxval=2**31 - 1, dtype=jnp.uint32)
key, subkey = jax.random.split(key)
ZOBRIST_SIDE = jax.random.randint(
subkey, shape=(2,), minval=0, maxval=2**31 - 1, dtype=jnp.uint32
)
ZOBRIST_SIDE = jax.random.randint(subkey, shape=(2,), minval=0, maxval=2**31 - 1, dtype=jnp.uint32)

key, subkey = jax.random.split(key)
ZOBRIST_CASTLING_QUEEN = jax.random.randint(
subkey, shape=(2, 2), minval=0, maxval=2**31 - 1, dtype=jnp.uint32
)
ZOBRIST_CASTLING_QUEEN = jax.random.randint(subkey, shape=(2, 2), minval=0, maxval=2**31 - 1, dtype=jnp.uint32)
key, subkey = jax.random.split(key)
ZOBRIST_CASTLING_KING = jax.random.randint(
subkey, shape=(2, 2), minval=0, maxval=2**31 - 1, dtype=jnp.uint32
)
ZOBRIST_CASTLING_KING = jax.random.randint(subkey, shape=(2, 2), minval=0, maxval=2**31 - 1, dtype=jnp.uint32)
key, subkey = jax.random.split(key)
ZOBRIST_EN_PASSANT = jax.random.randint(
subkey,
Expand Down
4 changes: 1 addition & 3 deletions pgx/_src/dwg/animalshogi.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,7 @@ def _make_animalshogi_dwg(dwg, state: AnimalShogiState, config: dict):
)

# # hand
for i, piece_num, piece_type in zip(
range(6), state._hand.flatten(), ["P", "R", "B", "P", "R", "B"]
):
for i, piece_num, piece_type in zip(range(6), state._hand.flatten(), ["P", "R", "B", "P", "R", "B"]):
is_black = i < 3 if state._turn == 0 else 3 <= i # type: ignore
_g = p1_pieces_g if is_black else p2_pieces_g
_g.add(
Expand Down
29 changes: 6 additions & 23 deletions pgx/_src/dwg/bridge_bidding.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,7 @@ def _make_bridge_dwg(dwg, state: BridgeBiddingState, config):
)

# card
card = [
TO_CARD[i % NUM_CARD_TYPE]
for i in hand
if j * NUM_CARD_TYPE <= i < (j + 1) * NUM_CARD_TYPE
][::-1]
card = [TO_CARD[i % NUM_CARD_TYPE] for i in hand if j * NUM_CARD_TYPE <= i < (j + 1) * NUM_CARD_TYPE][::-1]
if card != [] and card[-1] == "A":
card = card[-1:] + card[:-1]
card_str = " ".join(card)
Expand All @@ -81,9 +77,7 @@ def _make_bridge_dwg(dwg, state: BridgeBiddingState, config):
x_offset[i] + 40,
y_offset[i] + 30 * (j + 1) + newline_offset,
),
fill="orangered"
if 0 < j < 3
else color_set.text_color,
fill="orangered" if 0 < j < 3 else color_set.text_color,
font_size="24px",
font_family="Courier",
font_weight="bold",
Expand All @@ -97,9 +91,7 @@ def _make_bridge_dwg(dwg, state: BridgeBiddingState, config):
x_offset[i] + 40,
y_offset[i] + 30 * (j + 1) + newline_offset,
),
fill="orangered"
if 0 < j < 3
else color_set.text_color,
fill="orangered" if 0 < j < 3 else color_set.text_color,
font_size="24px",
font_family="Courier",
font_weight="bold",
Expand Down Expand Up @@ -131,9 +123,7 @@ def _make_bridge_dwg(dwg, state: BridgeBiddingState, config):
x_offset[i] + 40,
y_offset[i] + 30 * (j + 1) + newline_offset,
),
fill="orangered"
if 0 < j < 3
else color_set.text_color,
fill="orangered" if 0 < j < 3 else color_set.text_color,
font_size="24px",
font_family="Courier",
font_weight="bold",
Expand Down Expand Up @@ -211,20 +201,13 @@ def _make_bridge_dwg(dwg, state: BridgeBiddingState, config):
if act == -1:
break
act_str = (
str((act - BID_OFFSET_NUM) // 5 + 1)
+ DENOMINATIONS[(act - BID_OFFSET_NUM) % 5]
str((act - BID_OFFSET_NUM) // 5 + 1) + DENOMINATIONS[(act - BID_OFFSET_NUM) % 5]
if BID_OFFSET_NUM <= act < 35 + BID_OFFSET_NUM
else ACT[act]
)
color = (
"orangered"
if (
(act > BID_OFFSET_NUM)
and (
(act - BID_OFFSET_NUM) % 5 == 1
or (act - BID_OFFSET_NUM) % 5 == 2
)
)
if ((act > BID_OFFSET_NUM) and ((act - BID_OFFSET_NUM) % 5 == 1 or (act - BID_OFFSET_NUM) % 5 == 2))
or act == 1
or act == 2
else color_set.text_color
Expand Down
4 changes: 1 addition & 3 deletions pgx/_src/dwg/hex.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,7 @@ def _make_hex_dwg(dwg, state: HexState, config):
b_points = []
w_points = []
x1, y1 = (BOARD_SIZE - 1) * GRID_SIZE * r3, 0
x2, y2 = ((BOARD_SIZE - 1) / 2) * GRID_SIZE * r3, (
BOARD_SIZE - 1
) * GRID_SIZE * 3 / 2
x2, y2 = ((BOARD_SIZE - 1) / 2) * GRID_SIZE * r3, (BOARD_SIZE - 1) * GRID_SIZE * 3 / 2
cx, cy = four_dig((x1 + x2) / 2), four_dig((y1 + y2) / 2)
# fmt:off
for i in range(BOARD_SIZE):
Expand Down
Loading

0 comments on commit da29e7e

Please sign in to comment.