Skip to content

Commit 6b785fc

Browse files
authored
Pipes attn_logits_soft_cap through multi_queries_paged_attention (#8583)
1 parent a295f7d commit 6b785fc

File tree

4 files changed

+119
-44
lines changed

4 files changed

+119
-44
lines changed

test/test_pallas.py

Lines changed: 65 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import logging
2-
import os
32
import unittest
43

54
import torch
@@ -597,6 +596,17 @@ def test_paged_attention_multi_queries_wrapper(self):
597596
page_indices_xla = page_indices.to("xla")
598597
effective_q_lens_xla = effective_q_lens.to("xla")
599598

599+
output_no_cap = multi_queries_paged_attention(
600+
q_xla,
601+
k_pages_xla,
602+
v_pages_xla,
603+
kv_seq_lens_xla,
604+
page_indices_xla,
605+
effective_q_lens_xla,
606+
num_kv_pages_per_compute_block=block_kv_size // page_size,
607+
num_queries_per_compute_block=num_queries_per_compute_block,
608+
)
609+
600610
output = multi_queries_paged_attention(
601611
q_xla,
602612
k_pages_xla,
@@ -606,6 +616,7 @@ def test_paged_attention_multi_queries_wrapper(self):
606616
effective_q_lens_xla,
607617
num_kv_pages_per_compute_block=block_kv_size // page_size,
608618
num_queries_per_compute_block=num_queries_per_compute_block,
619+
attn_logits_soft_cap=1.0,
609620
)
610621

611622
nonkernel_output = multi_queries_paged_attention(
@@ -627,6 +638,19 @@ def test_paged_attention_multi_queries_wrapper(self):
627638
page_indices_jax = jnp.array(page_indices.numpy(), dtype=jnp.int32)
628639
effective_q_lens_jax = jnp.array(effective_q_lens.numpy(), dtype=jnp.int32)
629640
expected_output = torch.from_numpy(
641+
np.array(
642+
jax_multi_queries_paged_attention(
643+
q_jax,
644+
k_pages_jax,
645+
v_pages_jax,
646+
kv_seq_lens_jax,
647+
page_indices_jax,
648+
effective_q_lens_jax,
649+
num_kv_pages_per_compute_block=block_kv_size // page_size,
650+
num_queries_per_compute_block=num_queries_per_compute_block,
651+
attn_logits_soft_cap=1.0,
652+
)))
653+
expected_output_no_cap = torch.from_numpy(
630654
np.array(
631655
jax_multi_queries_paged_attention(
632656
q_jax,
@@ -642,9 +666,18 @@ def test_paged_attention_multi_queries_wrapper(self):
642666
self.assertTrue(
643667
torch.allclose(
644668
output.cpu(), expected_output.cpu(), atol=1e-5, rtol=1e-5))
669+
self.assertFalse(
670+
torch.allclose(
671+
output.cpu(), expected_output_no_cap.cpu(), atol=1e-5, rtol=1e-5))
645672
self.assertTrue(
646673
torch.allclose(
647-
output.cpu(), nonkernel_output.cpu(), atol=1e-2, rtol=1e-2))
674+
output_no_cap.cpu(),
675+
expected_output_no_cap.cpu(),
676+
atol=1e-5,
677+
rtol=1e-5))
678+
self.assertTrue(
679+
torch.allclose(
680+
output_no_cap.cpu(), nonkernel_output.cpu(), atol=1e-2, rtol=1e-2))
648681

649682
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4,
650683
"This test only works on TPUv4+.")
@@ -696,7 +729,7 @@ def multi_queries_paged_attention_wrapper(q, k_pages, v_pages, kv_seq_lens,
696729
page_indices, effective_q_lens,
697730
num_kv_pages_per_compute_block,
698731
num_queries_per_compute_block,
699-
use_kernel):
732+
use_kernel, attn_logits_soft_cap):
700733
return torch.ops.xla.multi_queries_paged_attention(
701734
q,
702735
k_pages,
@@ -707,38 +740,42 @@ def multi_queries_paged_attention_wrapper(q, k_pages, v_pages, kv_seq_lens,
707740
num_kv_pages_per_compute_block,
708741
num_queries_per_compute_block,
709742
use_kernel=use_kernel,
743+
attn_logits_soft_cap=attn_logits_soft_cap,
710744
)
711745

712746
compiled_paged_attention = torch.compile(
713747
multi_queries_paged_attention_wrapper, backend="openxla")
714748

715-
output = compiled_paged_attention(
716-
q_xla,
717-
k_pages_xla,
718-
v_pages_xla,
719-
kv_seq_lens_xla,
720-
page_indices_xla,
721-
effective_q_lens_xla,
722-
num_kv_pages_per_compute_block=block_kv_size // page_size,
723-
num_queries_per_compute_block=num_queries_per_compute_block,
724-
use_kernel=True,
725-
)
749+
for attn_logits_soft_cap in (1.0, None):
750+
output = compiled_paged_attention(
751+
q_xla,
752+
k_pages_xla,
753+
v_pages_xla,
754+
kv_seq_lens_xla,
755+
page_indices_xla,
756+
effective_q_lens_xla,
757+
num_kv_pages_per_compute_block=block_kv_size // page_size,
758+
num_queries_per_compute_block=num_queries_per_compute_block,
759+
use_kernel=True,
760+
attn_logits_soft_cap=attn_logits_soft_cap,
761+
)
726762

727-
nonkernel_output = compiled_paged_attention(
728-
q_xla,
729-
k_pages_xla,
730-
v_pages_xla,
731-
kv_seq_lens_xla,
732-
page_indices_xla,
733-
effective_q_lens_xla,
734-
num_kv_pages_per_compute_block=block_kv_size // page_size,
735-
num_queries_per_compute_block=num_queries_per_compute_block,
736-
use_kernel=False,
737-
)
763+
nonkernel_output = compiled_paged_attention(
764+
q_xla,
765+
k_pages_xla,
766+
v_pages_xla,
767+
kv_seq_lens_xla,
768+
page_indices_xla,
769+
effective_q_lens_xla,
770+
num_kv_pages_per_compute_block=block_kv_size // page_size,
771+
num_queries_per_compute_block=num_queries_per_compute_block,
772+
use_kernel=False,
773+
attn_logits_soft_cap=attn_logits_soft_cap,
774+
)
738775

739-
self.assertTrue(
740-
torch.allclose(
741-
output.cpu(), nonkernel_output.cpu(), atol=1e-2, rtol=1e-2))
776+
self.assertTrue(
777+
torch.allclose(
778+
output.cpu(), nonkernel_output.cpu(), atol=1e-2, rtol=1e-2))
742779

743780
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() != 4,
744781
"This test only works on TPUv4 and TPUv5p.")
@@ -822,7 +859,6 @@ def test_paged_attention_wrapper_with_dynamo(self):
822859
num_kv_heads = 8
823860
q_kv_head_ratio = 8
824861
head_dim = 256
825-
dtype = torch.float32
826862
seq_lens = torch.tensor([0, 3, 256, 513, 1023, 2048], dtype=torch.int32)
827863

828864
q, k_pages, v_pages, page_indices = self._pagedattention_generate_qkv(
@@ -899,7 +935,6 @@ def test_paged_attention_wrapper_with_attn_logits_soft_cap(self):
899935
num_kv_heads = 8
900936
q_kv_head_ratio = 8
901937
head_dim = 256
902-
dtype = torch.float32
903938
seq_lens = torch.tensor([0, 3, 256, 513, 1023, 2048], dtype=torch.int32)
904939

905940
q, k_pages, v_pages, page_indices = self._pagedattention_generate_qkv(

test/test_tpu_paged_attention_kernel.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def _ref_jax_extended_paged_attention(
4545
lengths, # [batch_size], the effective kv_length.
4646
page_indices, # [batch_size, pages_per_sequence]
4747
effective_q_lens, # [batch_size] the effective q_length
48+
attn_logits_soft_cap: float | None = None,
4849
):
4950
batch_size, query_len, num_query_heads, head_size = q.shape
5051
num_kv_heads, total_num_pages, page_size, _ = k_pages.shape
@@ -71,6 +72,9 @@ def _ref_jax_extended_paged_attention(
7172
v = jnp.repeat(v, num_query_per_kv, axis=1)
7273

7374
attn = jnp.einsum("qhd,khd->hqk", q[i], k)
75+
if attn_logits_soft_cap is not None:
76+
capped_attn = jnp.tanh(attn / attn_logits_soft_cap)
77+
attn = capped_attn * attn_logits_soft_cap
7478
attn = attn.astype('float32')
7579
effective_q_len = effective_q_lens[i]
7680
q_span = (kv_len - effective_q_len) + jax.lax.broadcasted_iota(
@@ -111,6 +115,7 @@ def setUp(self):
111115
head_dim=(128, 256),
112116
num_queries_per_compute_block=(16, 32),
113117
block_kv_size=(128, 256),
118+
attn_logits_soft_cap=(1.0, None),
114119
)
115120
def test_paged_attention_without_query_padding(
116121
self,
@@ -121,6 +126,7 @@ def test_paged_attention_without_query_padding(
121126
head_dim,
122127
num_queries_per_compute_block,
123128
block_kv_size,
129+
attn_logits_soft_cap,
124130
):
125131

126132
max_kv_len = 2048
@@ -160,6 +166,7 @@ def test_paged_attention_without_query_padding(
160166
effective_q_lens,
161167
num_kv_pages_per_compute_block=num_kv_pages_per_compute_block,
162168
num_queries_per_compute_block=num_queries_per_compute_block,
169+
attn_logits_soft_cap=attn_logits_soft_cap,
163170
)
164171
# Note kernel execution is async. Without blocking, if an error happens in the kernel, the error may point to some irrelevant and confusing places. See https://github.com/pytorch/xla/pull/8356#issuecomment-2486861631
165172
actual_output = jax.block_until_ready(actual_output)
@@ -172,6 +179,7 @@ def test_paged_attention_without_query_padding(
172179
kv_seq_lens,
173180
page_indices,
174181
effective_q_lens,
182+
attn_logits_soft_cap=attn_logits_soft_cap,
175183
)
176184

177185
self.assertEqual(actual_output.shape, expected_output.shape)

torch_xla/experimental/custom_kernel.py

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -506,6 +506,7 @@ def _multi_queries_paged_attention_nonkernel(
506506
lengths, # seq_lengths, [batch_size]. nb batch_size = len(seq_lens), the effective kv_length.
507507
page_indices, # [batch_size, pages_per_sequence]
508508
effective_q_lens, # [batch_size], the effective q_length
509+
attn_logits_soft_cap: float | None = None,
509510
) -> torch.Tensor: # [batch_size, query_len, num_heads, head_dim]
510511
batch_size, query_len, num_query_heads, head_size = q.shape
511512
num_kv_heads, total_num_pages, page_size, _ = k_pages.shape
@@ -543,6 +544,9 @@ def _multi_queries_paged_attention_nonkernel(
543544
# For example, it can use bfloat16 instead of float32 or vice versa for performance or simplicity.
544545
attn = torch.einsum("qhd,khd->hqk", q[i],
545546
k) # [num_query_heads, query_len, kv_len]
547+
if attn_logits_soft_cap is not None:
548+
capped_attn = torch.tanh(attn / attn_logits_soft_cap)
549+
attn = capped_attn * attn_logits_soft_cap
546550
attn = attn.float()
547551
empty_mask = torch.ones(query_len, kv_len, device=attn.device)
548552
effective_q_len = effective_q_lens[i]
@@ -569,6 +573,7 @@ def multi_queries_paged_attention(
569573
num_kv_pages_per_compute_block,
570574
num_queries_per_compute_block,
571575
use_kernel=True,
576+
attn_logits_soft_cap: float | None = None,
572577
): # [batch_size, query_len, num_heads, head_dim]:
573578
assert len(q.shape) == 4, "q should have 4 dimensions."
574579
if not use_kernel:
@@ -579,6 +584,7 @@ def multi_queries_paged_attention(
579584
lengths,
580585
page_indices,
581586
effective_q_lens,
587+
attn_logits_soft_cap=attn_logits_soft_cap,
582588
)
583589

584590
# Import JAX within the function such that we don't need to call the jax_import_guard()
@@ -595,9 +601,11 @@ def multi_queries_paged_attention(
595601
effective_q_lens,
596602
num_kv_pages_per_compute_block=num_kv_pages_per_compute_block,
597603
num_queries_per_compute_block=num_queries_per_compute_block,
604+
attn_logits_soft_cap=attn_logits_soft_cap,
598605
static_argnames=[
599606
"num_kv_pages_per_compute_block",
600607
"num_queries_per_compute_block",
608+
"attn_logits_soft_cap",
601609
],
602610
)
603611

@@ -1103,29 +1111,42 @@ def paged_attention_non_xla(q: torch.Tensor,
11031111

11041112

11051113
XLA_LIB.define(
1106-
"multi_queries_paged_attention(Tensor q, Tensor k_pages, Tensor v_pages, Tensor lengths, Tensor page_indices, Tensor effective_q_lens, int num_kv_pages_per_compute_block, int num_queries_per_compute_block, bool use_kernel) -> Tensor",
1107-
)
1114+
"multi_queries_paged_attention(Tensor q, Tensor k_pages, Tensor v_pages, Tensor lengths, Tensor page_indices,"
1115+
" Tensor effective_q_lens, int num_kv_pages_per_compute_block, int num_queries_per_compute_block,"
1116+
" bool use_kernel, float? attn_logits_soft_cap=None) -> Tensor",)
11081117

11091118

11101119
@impl(XLA_LIB, "multi_queries_paged_attention", "XLA")
1111-
def multi_queries_paged_attention_xla(
1112-
q: torch.Tensor, k_pages: torch.Tensor, v_pages: torch.Tensor,
1113-
lengths: torch.Tensor, page_indices: torch.Tensor,
1114-
effective_q_lens: torch.Tensor, num_kv_pages_per_compute_block: int,
1115-
num_queries_per_compute_block: int, use_kernel: bool):
1120+
def multi_queries_paged_attention_xla(q: torch.Tensor,
1121+
k_pages: torch.Tensor,
1122+
v_pages: torch.Tensor,
1123+
lengths: torch.Tensor,
1124+
page_indices: torch.Tensor,
1125+
effective_q_lens: torch.Tensor,
1126+
num_kv_pages_per_compute_block: int,
1127+
num_queries_per_compute_block: int,
1128+
use_kernel: bool,
1129+
attn_logits_soft_cap: float |
1130+
None = None):
11161131
return multi_queries_paged_attention(q, k_pages, v_pages, lengths,
11171132
page_indices, effective_q_lens,
11181133
num_kv_pages_per_compute_block,
11191134
num_queries_per_compute_block,
1120-
use_kernel)
1135+
use_kernel, attn_logits_soft_cap)
11211136

11221137

11231138
@impl(XLA_LIB, "multi_queries_paged_attention", "CompositeExplicitAutograd")
1124-
def multi_queries_paged_attention_non_xla(
1125-
q: torch.Tensor, k_pages: torch.Tensor, v_pages: torch.Tensor,
1126-
lengths: torch.Tensor, page_indices: torch.Tensor,
1127-
effective_q_lens: torch.Tensor, num_kv_pages_per_compute_block: int,
1128-
num_queries_per_compute_block: int, use_kernel: bool):
1139+
def multi_queries_paged_attention_non_xla(q: torch.Tensor,
1140+
k_pages: torch.Tensor,
1141+
v_pages: torch.Tensor,
1142+
lengths: torch.Tensor,
1143+
page_indices: torch.Tensor,
1144+
effective_q_lens: torch.Tensor,
1145+
num_kv_pages_per_compute_block: int,
1146+
num_queries_per_compute_block: int,
1147+
use_kernel: bool,
1148+
attn_logits_soft_cap: float |
1149+
None = None):
11291150
return non_xla_attetion(q, k_pages, v_pages, "paged")
11301151

11311152

torch_xla/experimental/pallas_kernels/multi_queries_paged_attention_kernel.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ def _flash_attention(
116116
query_len: int,
117117
page_size: int,
118118
head_dim: int,
119+
attn_logits_soft_cap: float | None,
119120
):
120121
b, kv_head_idx, q_blk_idx, kv_blk_idx = (
121122
pl.program_id(0),
@@ -143,6 +144,10 @@ def start_new_sequence():
143144
s = jnp.einsum(
144145
'qd,td->qt', q, k,
145146
preferred_element_type=jnp.float32) # [block_q, block_k]
147+
if attn_logits_soft_cap is not None:
148+
capped_s = jnp.tanh(s / attn_logits_soft_cap)
149+
s = capped_s * attn_logits_soft_cap
150+
146151
assert s.shape == (num_queries_per_compute_block,
147152
kv_seq_len_per_kv_compute_blk)
148153

@@ -266,6 +271,7 @@ def paged_flash_attention_kernel(
266271
num_kv_pages_per_compute_block: int,
267272
mask_value: float,
268273
query_len: int,
274+
attn_logits_soft_cap: float | None,
269275
):
270276
"""Pallas kernel for paged attention."""
271277
b, kv_head_idx, q_blk_idx, kv_blk_idx = (
@@ -411,6 +417,7 @@ def prefetch_next_block(): # pylint: disable=unused-variable
411417
query_len=query_len,
412418
page_size=page_size,
413419
head_dim=head_dim,
420+
attn_logits_soft_cap=attn_logits_soft_cap,
414421
)
415422
# o_ref.shape=[num_q_heads_per_kv_head, num_queries_per_compute_block, head_dim]
416423
step_ref[0] = step + 1
@@ -428,6 +435,7 @@ def prefetch_next_block(): # pylint: disable=unused-variable
428435
"num_kv_pages_per_compute_block",
429436
"num_queries_per_compute_block",
430437
"mask_value",
438+
"attn_logits_soft_cap",
431439
],
432440
)
433441
def paged_attention(
@@ -441,6 +449,7 @@ def paged_attention(
441449
mask_value: float = DEFAULT_MASK_VALUE,
442450
num_kv_pages_per_compute_block: int,
443451
num_queries_per_compute_block: int = 4,
452+
attn_logits_soft_cap: float | None = None,
444453
) -> jax.Array:
445454
"""Paged grouped query attention.
446455
@@ -620,7 +629,9 @@ def lm_index_map(batch_index, kv_head_index, q_seq_blk_idx, *_):
620629
batch_size=batch_size,
621630
num_kv_pages_per_compute_block=num_kv_pages_per_compute_block,
622631
mask_value=mask_value,
623-
query_len=query_len),
632+
query_len=query_len,
633+
attn_logits_soft_cap=attn_logits_soft_cap,
634+
),
624635
grid_spec=pltpu.PrefetchScalarGridSpec(
625636
num_scalar_prefetch=5,
626637
in_specs=in_specs,

0 commit comments

Comments
 (0)