Skip to content

Commit da29e7e

Browse files
authored
[Format] Use line_length = 120 (#1136)
1 parent b12cbbb commit da29e7e

37 files changed

+431
-1307
lines changed

pgx/__init__.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,7 @@
11
from pgx._src.api_test import api_test
22
from pgx._src.baseline import BaselineModelId, make_baseline_model
33
from pgx._src.types import Array, PRNGKey
4-
from pgx._src.visualizer import (
5-
save_svg,
6-
save_svg_animation,
7-
set_visualization_config,
8-
)
4+
from pgx._src.visualizer import save_svg, save_svg_animation, set_visualization_config
95
from pgx.core import Env, EnvId, State, available_envs, make
106

117
__version__ = "2.0.1"

pgx/_src/api_test.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,7 @@ def api_test_single(env: Env, num: int = 100, use_key=True):
5454
for _ in range(num):
5555
rng, subkey = jax.random.split(rng)
5656
state = init(subkey)
57-
assert (
58-
state.legal_action_mask.sum() != 0
59-
), "legal_action_mask at init state cannot be zero."
57+
assert state.legal_action_mask.sum() != 0, "legal_action_mask at init state cannot be zero."
6058

6159
assert state._step_count == 0
6260
curr_steps = state._step_count
@@ -72,9 +70,7 @@ def api_test_single(env: Env, num: int = 100, use_key=True):
7270
if not use_key:
7371
subkey = None
7472
state = step(state, action, subkey)
75-
assert (
76-
state._step_count == curr_steps + 1
77-
), f"{state._step_count}, {curr_steps}"
73+
assert state._step_count == curr_steps + 1, f"{state._step_count}, {curr_steps}"
7874
curr_steps += 1
7975

8076
_validate_state(state)
@@ -135,9 +131,7 @@ def _validate_state(state: State):
135131
assert state.current_player.dtype == jnp.int32, state.current_player.dtype
136132
assert state.terminated.dtype == jnp.bool_, state.terminated.dtype
137133
assert state.rewards.dtype == jnp.float32, state.rewards.dtype
138-
assert (
139-
state.legal_action_mask.dtype == jnp.bool_
140-
), state.legal_action_mask.dtype
134+
assert state.legal_action_mask.dtype == jnp.bool_, state.legal_action_mask.dtype
141135

142136
# check public attributes
143137
public_attributes = [
@@ -158,9 +152,7 @@ def _validate_legal_actions(state: State):
158152
if state.terminated:
159153
# Agent can take any action at terminal state (but give no effect to the next state)
160154
# This is to avoid zero-division error by normalizing action probability by legal actions
161-
assert (
162-
state.legal_action_mask == jnp.ones_like(state.legal_action_mask)
163-
).all(), state.legal_action_mask
155+
assert (state.legal_action_mask == jnp.ones_like(state.legal_action_mask)).all(), state.legal_action_mask
164156
else:
165157
...
166158

pgx/_src/baseline.py

Lines changed: 14 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,7 @@
2121
]
2222

2323

24-
def make_baseline_model(
25-
model_id: BaselineModelId, download_dir: str = "baselines"
26-
):
24+
def make_baseline_model(model_id: BaselineModelId, download_dir: str = "baselines"):
2725
if model_id in (
2826
"animal_shogi_v0",
2927
"gardner_chess_v0",
@@ -44,41 +42,29 @@ def make_baseline_model(
4442
assert False
4543

4644

47-
def _make_az_baseline_model(
48-
model_id: BaselineModelId, download_dir: str = "baselines"
49-
):
45+
def _make_az_baseline_model(model_id: BaselineModelId, download_dir: str = "baselines"):
5046
import haiku as hk
5147

52-
model_args, model_params, model_state = _load_baseline_model(
53-
model_id, download_dir
54-
)
48+
model_args, model_params, model_state = _load_baseline_model(model_id, download_dir)
5549

5650
def forward_fn(x, is_eval=False):
5751
net = _create_az_model_v0(**model_args)
58-
policy_out, value_out = net(
59-
x, is_training=not is_eval, test_local_stats=False
60-
)
52+
policy_out, value_out = net(x, is_training=not is_eval, test_local_stats=False)
6153
return policy_out, value_out
6254

6355
forward = hk.without_apply_rng(hk.transform_with_state(forward_fn))
6456

6557
def apply(obs):
66-
(logits, value), _ = forward.apply(
67-
model_params, model_state, obs, is_eval=True
68-
)
58+
(logits, value), _ = forward.apply(model_params, model_state, obs, is_eval=True)
6959
return logits, value
7060

7161
return apply
7262

7363

74-
def _make_minatar_baseline_model(
75-
model_id: BaselineModelId, download_dir: str = "baselines"
76-
):
64+
def _make_minatar_baseline_model(model_id: BaselineModelId, download_dir: str = "baselines"):
7765
import haiku as hk
7866

79-
model_args, model_params, model_state = _load_baseline_model(
80-
model_id, download_dir
81-
)
67+
model_args, model_params, model_state = _load_baseline_model(model_id, download_dir)
8268
del model_state
8369

8470
class ActorCritic(hk.Module):
@@ -96,9 +82,7 @@ def __call__(self, x):
9682
activation = jax.nn.tanh
9783
x = hk.Conv2D(32, kernel_shape=2)(x)
9884
x = jax.nn.relu(x)
99-
x = hk.avg_pool(
100-
x, window_shape=(2, 2), strides=(2, 2), padding="VALID"
101-
)
85+
x = hk.avg_pool(x, window_shape=(2, 2), strides=(2, 2), padding="VALID")
10286
x = x.reshape((x.shape[0], -1)) # flatten
10387
x = hk.Linear(64)(x)
10488
x = jax.nn.relu(x)
@@ -130,9 +114,7 @@ def apply(obs):
130114
return apply
131115

132116

133-
def _load_baseline_model(
134-
baseline_model: BaselineModelId, basedir: str = "baselines"
135-
):
117+
def _load_baseline_model(baseline_model: BaselineModelId, basedir: str = "baselines"):
136118
os.makedirs(basedir, exist_ok=True)
137119

138120
# download baseline model if not exists
@@ -226,36 +208,26 @@ def __call__(self, x, is_training, test_local_stats):
226208
x = hk.Conv2D(self.num_channels, kernel_shape=3)(x)
227209

228210
if not self.resnet_v2:
229-
x = hk.BatchNorm(True, True, 0.9)(
230-
x, is_training, test_local_stats
231-
)
211+
x = hk.BatchNorm(True, True, 0.9)(x, is_training, test_local_stats)
232212
x = jax.nn.relu(x)
233213

234214
for i in range(self.num_layers):
235-
x = self.resnet_cls(self.num_channels, name=f"block_{i}")(
236-
x, is_training, test_local_stats
237-
)
215+
x = self.resnet_cls(self.num_channels, name=f"block_{i}")(x, is_training, test_local_stats)
238216

239217
if self.resnet_v2:
240-
x = hk.BatchNorm(True, True, 0.9)(
241-
x, is_training, test_local_stats
242-
)
218+
x = hk.BatchNorm(True, True, 0.9)(x, is_training, test_local_stats)
243219
x = jax.nn.relu(x)
244220

245221
# policy head
246222
logits = hk.Conv2D(output_channels=2, kernel_shape=1)(x)
247-
logits = hk.BatchNorm(True, True, 0.9)(
248-
logits, is_training, test_local_stats
249-
)
223+
logits = hk.BatchNorm(True, True, 0.9)(logits, is_training, test_local_stats)
250224
logits = jax.nn.relu(logits)
251225
logits = hk.Flatten()(logits)
252226
logits = hk.Linear(self.num_actions)(logits)
253227

254228
# value head
255229
value = hk.Conv2D(output_channels=1, kernel_shape=1)(x)
256-
value = hk.BatchNorm(True, True, 0.9)(
257-
value, is_training, test_local_stats
258-
)
230+
value = hk.BatchNorm(True, True, 0.9)(value, is_training, test_local_stats)
259231
value = jax.nn.relu(value)
260232
value = hk.Flatten()(value)
261233
value = hk.Linear(self.num_channels)(value)

pgx/_src/chess_utils.py

Lines changed: 13 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -82,14 +82,10 @@
8282
if jnp.abs(r1 - r0) == 1 and jnp.abs(c1 - c0) <= 1:
8383
legal_dst.append(to)
8484
# init move
85-
if (r0 == 1 or r0 == 6) and (
86-
jnp.abs(c1 - c0) == 0 and jnp.abs(r1 - r0) == 2
87-
):
85+
if (r0 == 1 or r0 == 6) and (jnp.abs(c1 - c0) == 0 and jnp.abs(r1 - r0) == 2):
8886
legal_dst.append(to)
8987
assert len(legal_dst) <= 8
90-
CAN_MOVE = CAN_MOVE.at[1, from_, : len(legal_dst)].set(
91-
jnp.int32(legal_dst)
92-
)
88+
CAN_MOVE = CAN_MOVE.at[1, from_, : len(legal_dst)].set(jnp.int32(legal_dst))
9389
# KNIGHT
9490
for from_ in range(64):
9591
r0, c0 = from_ % 8, from_ // 8
@@ -101,9 +97,7 @@
10197
if jnp.abs(r1 - r0) == 2 and jnp.abs(c1 - c0) == 1:
10298
legal_dst.append(to)
10399
assert len(legal_dst) <= 27
104-
CAN_MOVE = CAN_MOVE.at[2, from_, : len(legal_dst)].set(
105-
jnp.int32(legal_dst)
106-
)
100+
CAN_MOVE = CAN_MOVE.at[2, from_, : len(legal_dst)].set(jnp.int32(legal_dst))
107101
# BISHOP
108102
for from_ in range(64):
109103
r0, c0 = from_ % 8, from_ // 8
@@ -115,9 +109,7 @@
115109
if jnp.abs(r1 - r0) == jnp.abs(c1 - c0):
116110
legal_dst.append(to)
117111
assert len(legal_dst) <= 27
118-
CAN_MOVE = CAN_MOVE.at[3, from_, : len(legal_dst)].set(
119-
jnp.int32(legal_dst)
120-
)
112+
CAN_MOVE = CAN_MOVE.at[3, from_, : len(legal_dst)].set(jnp.int32(legal_dst))
121113
# ROOK
122114
for from_ in range(64):
123115
r0, c0 = from_ % 8, from_ // 8
@@ -129,9 +121,7 @@
129121
if jnp.abs(r1 - r0) == 0 or jnp.abs(c1 - c0) == 0:
130122
legal_dst.append(to)
131123
assert len(legal_dst) <= 27
132-
CAN_MOVE = CAN_MOVE.at[4, from_, : len(legal_dst)].set(
133-
jnp.int32(legal_dst)
134-
)
124+
CAN_MOVE = CAN_MOVE.at[4, from_, : len(legal_dst)].set(jnp.int32(legal_dst))
135125
# QUEEN
136126
for from_ in range(64):
137127
r0, c0 = from_ % 8, from_ // 8
@@ -145,9 +135,7 @@
145135
if jnp.abs(r1 - r0) == jnp.abs(c1 - c0):
146136
legal_dst.append(to)
147137
assert len(legal_dst) <= 27
148-
CAN_MOVE = CAN_MOVE.at[5, from_, : len(legal_dst)].set(
149-
jnp.int32(legal_dst)
150-
)
138+
CAN_MOVE = CAN_MOVE.at[5, from_, : len(legal_dst)].set(jnp.int32(legal_dst))
151139
# KING
152140
for from_ in range(64):
153141
r0, c0 = from_ % 8, from_ // 8
@@ -164,9 +152,7 @@
164152
# if from_ == 39:
165153
# legal_dst += [23, 55]
166154
assert len(legal_dst) <= 8
167-
CAN_MOVE = CAN_MOVE.at[6, from_, : len(legal_dst)].set(
168-
jnp.int32(legal_dst)
169-
)
155+
CAN_MOVE = CAN_MOVE.at[6, from_, : len(legal_dst)].set(jnp.int32(legal_dst))
170156

171157
assert (CAN_MOVE[0, :, :] == -1).all()
172158

@@ -181,9 +167,7 @@
181167
to = CAN_MOVE[2, from_, i] # KNIGHT
182168
if to >= 0:
183169
legal_dst.append(to)
184-
CAN_MOVE_ANY = CAN_MOVE_ANY.at[from_, : len(legal_dst)].set(
185-
jnp.int32(legal_dst)
186-
)
170+
CAN_MOVE_ANY = CAN_MOVE_ANY.at[from_, : len(legal_dst)].set(jnp.int32(legal_dst))
187171

188172

189173
# Between
@@ -192,10 +176,7 @@
192176
for to in range(64):
193177
r0, c0 = from_ % 8, from_ // 8
194178
r1, c1 = to % 8, to // 8
195-
if not (
196-
(jnp.abs(r1 - r0) == 0 or jnp.abs(c1 - c0) == 0)
197-
or (jnp.abs(r1 - r0) == jnp.abs(c1 - c0))
198-
):
179+
if not ((jnp.abs(r1 - r0) == 0 or jnp.abs(c1 - c0) == 0) or (jnp.abs(r1 - r0) == jnp.abs(c1 - c0))):
199180
continue
200181
dr = max(min(r1 - r0, 1), -1)
201182
dc = max(min(c1 - c0, 1), -1)
@@ -229,22 +210,14 @@
229210

230211
key = jax.random.PRNGKey(238290)
231212
key, subkey = jax.random.split(key)
232-
ZOBRIST_BOARD = jax.random.randint(
233-
subkey, shape=(64, 13, 2), minval=0, maxval=2**31 - 1, dtype=jnp.uint32
234-
)
213+
ZOBRIST_BOARD = jax.random.randint(subkey, shape=(64, 13, 2), minval=0, maxval=2**31 - 1, dtype=jnp.uint32)
235214
key, subkey = jax.random.split(key)
236-
ZOBRIST_SIDE = jax.random.randint(
237-
subkey, shape=(2,), minval=0, maxval=2**31 - 1, dtype=jnp.uint32
238-
)
215+
ZOBRIST_SIDE = jax.random.randint(subkey, shape=(2,), minval=0, maxval=2**31 - 1, dtype=jnp.uint32)
239216

240217
key, subkey = jax.random.split(key)
241-
ZOBRIST_CASTLING_QUEEN = jax.random.randint(
242-
subkey, shape=(2, 2), minval=0, maxval=2**31 - 1, dtype=jnp.uint32
243-
)
218+
ZOBRIST_CASTLING_QUEEN = jax.random.randint(subkey, shape=(2, 2), minval=0, maxval=2**31 - 1, dtype=jnp.uint32)
244219
key, subkey = jax.random.split(key)
245-
ZOBRIST_CASTLING_KING = jax.random.randint(
246-
subkey, shape=(2, 2), minval=0, maxval=2**31 - 1, dtype=jnp.uint32
247-
)
220+
ZOBRIST_CASTLING_KING = jax.random.randint(subkey, shape=(2, 2), minval=0, maxval=2**31 - 1, dtype=jnp.uint32)
248221
key, subkey = jax.random.split(key)
249222
ZOBRIST_EN_PASSANT = jax.random.randint(
250223
subkey,

pgx/_src/dwg/animalshogi.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -170,9 +170,7 @@ def _make_animalshogi_dwg(dwg, state: AnimalShogiState, config: dict):
170170
)
171171

172172
# # hand
173-
for i, piece_num, piece_type in zip(
174-
range(6), state._hand.flatten(), ["P", "R", "B", "P", "R", "B"]
175-
):
173+
for i, piece_num, piece_type in zip(range(6), state._hand.flatten(), ["P", "R", "B", "P", "R", "B"]):
176174
is_black = i < 3 if state._turn == 0 else 3 <= i # type: ignore
177175
_g = p1_pieces_g if is_black else p2_pieces_g
178176
_g.add(

pgx/_src/dwg/bridge_bidding.py

Lines changed: 6 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -64,11 +64,7 @@ def _make_bridge_dwg(dwg, state: BridgeBiddingState, config):
6464
)
6565

6666
# card
67-
card = [
68-
TO_CARD[i % NUM_CARD_TYPE]
69-
for i in hand
70-
if j * NUM_CARD_TYPE <= i < (j + 1) * NUM_CARD_TYPE
71-
][::-1]
67+
card = [TO_CARD[i % NUM_CARD_TYPE] for i in hand if j * NUM_CARD_TYPE <= i < (j + 1) * NUM_CARD_TYPE][::-1]
7268
if card != [] and card[-1] == "A":
7369
card = card[-1:] + card[:-1]
7470
card_str = " ".join(card)
@@ -81,9 +77,7 @@ def _make_bridge_dwg(dwg, state: BridgeBiddingState, config):
8177
x_offset[i] + 40,
8278
y_offset[i] + 30 * (j + 1) + newline_offset,
8379
),
84-
fill="orangered"
85-
if 0 < j < 3
86-
else color_set.text_color,
80+
fill="orangered" if 0 < j < 3 else color_set.text_color,
8781
font_size="24px",
8882
font_family="Courier",
8983
font_weight="bold",
@@ -97,9 +91,7 @@ def _make_bridge_dwg(dwg, state: BridgeBiddingState, config):
9791
x_offset[i] + 40,
9892
y_offset[i] + 30 * (j + 1) + newline_offset,
9993
),
100-
fill="orangered"
101-
if 0 < j < 3
102-
else color_set.text_color,
94+
fill="orangered" if 0 < j < 3 else color_set.text_color,
10395
font_size="24px",
10496
font_family="Courier",
10597
font_weight="bold",
@@ -131,9 +123,7 @@ def _make_bridge_dwg(dwg, state: BridgeBiddingState, config):
131123
x_offset[i] + 40,
132124
y_offset[i] + 30 * (j + 1) + newline_offset,
133125
),
134-
fill="orangered"
135-
if 0 < j < 3
136-
else color_set.text_color,
126+
fill="orangered" if 0 < j < 3 else color_set.text_color,
137127
font_size="24px",
138128
font_family="Courier",
139129
font_weight="bold",
@@ -211,20 +201,13 @@ def _make_bridge_dwg(dwg, state: BridgeBiddingState, config):
211201
if act == -1:
212202
break
213203
act_str = (
214-
str((act - BID_OFFSET_NUM) // 5 + 1)
215-
+ DENOMINATIONS[(act - BID_OFFSET_NUM) % 5]
204+
str((act - BID_OFFSET_NUM) // 5 + 1) + DENOMINATIONS[(act - BID_OFFSET_NUM) % 5]
216205
if BID_OFFSET_NUM <= act < 35 + BID_OFFSET_NUM
217206
else ACT[act]
218207
)
219208
color = (
220209
"orangered"
221-
if (
222-
(act > BID_OFFSET_NUM)
223-
and (
224-
(act - BID_OFFSET_NUM) % 5 == 1
225-
or (act - BID_OFFSET_NUM) % 5 == 2
226-
)
227-
)
210+
if ((act > BID_OFFSET_NUM) and ((act - BID_OFFSET_NUM) % 5 == 1 or (act - BID_OFFSET_NUM) % 5 == 2))
228211
or act == 1
229212
or act == 2
230213
else color_set.text_color

pgx/_src/dwg/hex.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,7 @@ def _make_hex_dwg(dwg, state: HexState, config):
8989
b_points = []
9090
w_points = []
9191
x1, y1 = (BOARD_SIZE - 1) * GRID_SIZE * r3, 0
92-
x2, y2 = ((BOARD_SIZE - 1) / 2) * GRID_SIZE * r3, (
93-
BOARD_SIZE - 1
94-
) * GRID_SIZE * 3 / 2
92+
x2, y2 = ((BOARD_SIZE - 1) / 2) * GRID_SIZE * r3, (BOARD_SIZE - 1) * GRID_SIZE * 3 / 2
9593
cx, cy = four_dig((x1 + x2) / 2), four_dig((y1 + y2) / 2)
9694
# fmt:off
9795
for i in range(BOARD_SIZE):

0 commit comments

Comments
 (0)