Skip to content

Commit 954007c

Browse files
sbodensteincopybara-github
authored andcommitted
Format splash_attention_kernel_test.py.
PiperOrigin-RevId: 834807012
1 parent 35eb863 commit 954007c

File tree

1 file changed

+37
-17
lines changed

1 file changed

+37
-17
lines changed

tokamax/_src/ops/experimental/tpu/splash_attention/splash_attention_kernel_test.py

Lines changed: 37 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,9 @@ def attention_strategy(draw: Draw) -> tuple[int, int, int, int, np.dtype]:
216216
# tests.
217217
dtype = np.dtype("float32")
218218
else:
219-
dtype = draw(hps.sampled_from([np.dtype("float32"), np.dtype(jnp.bfloat16)]))
219+
dtype = draw(
220+
hps.sampled_from([np.dtype("float32"), np.dtype(jnp.bfloat16)])
221+
)
220222
return q_seq_len, kv_seq_len, head_dim_qk, head_dim_v, dtype
221223

222224

@@ -392,9 +394,17 @@ def test_splash_attention(self, is_mqa, is_segmented, is_dynamic_mask, data):
392394
use_sinks=(False, True),
393395
)
394396
@hp.given(hps.data())
395-
def test_splash_attention_fwd(self, is_mqa, is_segmented, is_dynamic_mask,
396-
use_base2_exp, use_max_logit_estimate,
397-
fuse_reciprocal, use_sinks, data):
397+
def test_splash_attention_fwd(
398+
self,
399+
is_mqa,
400+
is_segmented,
401+
is_dynamic_mask,
402+
use_base2_exp,
403+
use_max_logit_estimate,
404+
fuse_reciprocal,
405+
use_sinks,
406+
data,
407+
):
398408
# TODO: Re-enable once dynamic masks are fixed.
399409
if is_dynamic_mask:
400410
self.skipTest("Dynamic masks not supported.")
@@ -468,9 +478,7 @@ def test_splash_attention_fwd(self, is_mqa, is_segmented, is_dynamic_mask,
468478
elif use_max_logit_estimate == "value_2d":
469479
max_logit_value = max_val * jnp.ones((num_q_heads,), dtype=jnp.bfloat16)
470480

471-
make_mask_fn = partial(
472-
make_mask_fn, config=config, save_residuals=True
473-
)
481+
make_mask_fn = partial(make_mask_fn, config=config, save_residuals=True)
474482
attn = make_mask_fn(mask)
475483
attn_ref = partial(
476484
splash.attention_reference,
@@ -495,18 +503,21 @@ def test_splash_attention_fwd(self, is_mqa, is_segmented, is_dynamic_mask,
495503
res_tol = dict(atol=1e-3, rtol=3e-3)
496504
if use_sinks:
497505
o_tol = dict(atol=1e-2, rtol=1e-2)
498-
elif (use_base2_exp or use_max_logit_estimate is not None
499-
or not fuse_reciprocal):
506+
elif (
507+
use_base2_exp
508+
or use_max_logit_estimate is not None
509+
or not fuse_reciprocal
510+
):
500511
o_tol = dict(atol=8e-3, rtol=3e-3)
501512
else:
502513
o_tol = dict(atol=4e-3, rtol=3e-3)
503514

504515
self._assert_allclose(o, o_ref, **o_tol)
505-
self._assert_allclose(stats["logsumexp"],
506-
stats_ref["logsumexp"], **res_tol)
516+
self._assert_allclose(stats["logsumexp"], stats_ref["logsumexp"], **res_tol)
507517
if use_max_logit_estimate is None:
508-
self._assert_allclose(stats["max_logits"],
509-
stats_ref["max_logits"], **res_tol)
518+
self._assert_allclose(
519+
stats["max_logits"], stats_ref["max_logits"], **res_tol
520+
)
510521

511522
@parameterized.product(
512523
is_mqa=(False, True),
@@ -614,8 +625,14 @@ def test_splash_attention_bwd(
614625
)
615626
attn = make_mask_fn(mask)
616627

617-
o, attn_vjp = jax.vjp(partial(attn, max_logit_value=max_logit_value),
618-
q, k, v, segment_ids, sinks)
628+
o, attn_vjp = jax.vjp(
629+
partial(attn, max_logit_value=max_logit_value),
630+
q,
631+
k,
632+
v,
633+
segment_ids,
634+
sinks,
635+
)
619636
q32, k32, v32 = jax.tree.map(lambda x: x.astype(jnp.float32), (q, k, v))
620637
o_ref, stats_ref = splash.attention_reference(
621638
q32,
@@ -630,8 +647,11 @@ def test_splash_attention_bwd(
630647
)
631648
if use_sinks:
632649
o_tol = dict(atol=1e-2, rtol=1e-2)
633-
elif (use_base2_exp or use_max_logit_estimate is not None
634-
or not fuse_reciprocal):
650+
elif (
651+
use_base2_exp
652+
or use_max_logit_estimate is not None
653+
or not fuse_reciprocal
654+
):
635655
o_tol = dict(atol=8e-3, rtol=1e-2)
636656
else:
637657
o_tol = dict(atol=4e-3, rtol=3e-3)

0 commit comments

Comments
 (0)