@@ -216,7 +216,9 @@ def attention_strategy(draw: Draw) -> tuple[int, int, int, int, np.dtype]:
216216 # tests.
217217 dtype = np .dtype ("float32" )
218218 else :
219- dtype = draw (hps .sampled_from ([np .dtype ("float32" ), np .dtype (jnp .bfloat16 )]))
219+ dtype = draw (
220+ hps .sampled_from ([np .dtype ("float32" ), np .dtype (jnp .bfloat16 )])
221+ )
220222 return q_seq_len , kv_seq_len , head_dim_qk , head_dim_v , dtype
221223
222224
@@ -392,9 +394,17 @@ def test_splash_attention(self, is_mqa, is_segmented, is_dynamic_mask, data):
392394 use_sinks = (False , True ),
393395 )
394396 @hp .given (hps .data ())
395- def test_splash_attention_fwd (self , is_mqa , is_segmented , is_dynamic_mask ,
396- use_base2_exp , use_max_logit_estimate ,
397- fuse_reciprocal , use_sinks , data ):
397+ def test_splash_attention_fwd (
398+ self ,
399+ is_mqa ,
400+ is_segmented ,
401+ is_dynamic_mask ,
402+ use_base2_exp ,
403+ use_max_logit_estimate ,
404+ fuse_reciprocal ,
405+ use_sinks ,
406+ data ,
407+ ):
398408 # TODO: Re-enable once dynamic masks are fixed.
399409 if is_dynamic_mask :
400410 self .skipTest ("Dynamic masks not supported." )
@@ -468,9 +478,7 @@ def test_splash_attention_fwd(self, is_mqa, is_segmented, is_dynamic_mask,
468478 elif use_max_logit_estimate == "value_2d" :
469479 max_logit_value = max_val * jnp .ones ((num_q_heads ,), dtype = jnp .bfloat16 )
470480
471- make_mask_fn = partial (
472- make_mask_fn , config = config , save_residuals = True
473- )
481+ make_mask_fn = partial (make_mask_fn , config = config , save_residuals = True )
474482 attn = make_mask_fn (mask )
475483 attn_ref = partial (
476484 splash .attention_reference ,
@@ -495,18 +503,21 @@ def test_splash_attention_fwd(self, is_mqa, is_segmented, is_dynamic_mask,
495503 res_tol = dict (atol = 1e-3 , rtol = 3e-3 )
496504 if use_sinks :
497505 o_tol = dict (atol = 1e-2 , rtol = 1e-2 )
498- elif (use_base2_exp or use_max_logit_estimate is not None
499- or not fuse_reciprocal ):
506+ elif (
507+ use_base2_exp
508+ or use_max_logit_estimate is not None
509+ or not fuse_reciprocal
510+ ):
500511 o_tol = dict (atol = 8e-3 , rtol = 3e-3 )
501512 else :
502513 o_tol = dict (atol = 4e-3 , rtol = 3e-3 )
503514
504515 self ._assert_allclose (o , o_ref , ** o_tol )
505- self ._assert_allclose (stats ["logsumexp" ],
506- stats_ref ["logsumexp" ], ** res_tol )
516+ self ._assert_allclose (stats ["logsumexp" ], stats_ref ["logsumexp" ], ** res_tol )
507517 if use_max_logit_estimate is None :
508- self ._assert_allclose (stats ["max_logits" ],
509- stats_ref ["max_logits" ], ** res_tol )
518+ self ._assert_allclose (
519+ stats ["max_logits" ], stats_ref ["max_logits" ], ** res_tol
520+ )
510521
511522 @parameterized .product (
512523 is_mqa = (False , True ),
@@ -614,8 +625,14 @@ def test_splash_attention_bwd(
614625 )
615626 attn = make_mask_fn (mask )
616627
617- o , attn_vjp = jax .vjp (partial (attn , max_logit_value = max_logit_value ),
618- q , k , v , segment_ids , sinks )
628+ o , attn_vjp = jax .vjp (
629+ partial (attn , max_logit_value = max_logit_value ),
630+ q ,
631+ k ,
632+ v ,
633+ segment_ids ,
634+ sinks ,
635+ )
619636 q32 , k32 , v32 = jax .tree .map (lambda x : x .astype (jnp .float32 ), (q , k , v ))
620637 o_ref , stats_ref = splash .attention_reference (
621638 q32 ,
@@ -630,8 +647,11 @@ def test_splash_attention_bwd(
630647 )
631648 if use_sinks :
632649 o_tol = dict (atol = 1e-2 , rtol = 1e-2 )
633- elif (use_base2_exp or use_max_logit_estimate is not None
634- or not fuse_reciprocal ):
650+ elif (
651+ use_base2_exp
652+ or use_max_logit_estimate is not None
653+ or not fuse_reciprocal
654+ ):
635655 o_tol = dict (atol = 8e-3 , rtol = 1e-2 )
636656 else :
637657 o_tol = dict (atol = 4e-3 , rtol = 3e-3 )
0 commit comments