Skip to content

Commit

Permalink
Pipes attn_logits_soft_cap through multi_queries_paged_attention (#8583)
Browse files Browse the repository at this point in the history
  • Loading branch information
fenghuizhang authored Jan 21, 2025
1 parent a295f7d commit 6b785fc
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 44 deletions.
95 changes: 65 additions & 30 deletions test/test_pallas.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import logging
import os
import unittest

import torch
Expand Down Expand Up @@ -597,6 +596,17 @@ def test_paged_attention_multi_queries_wrapper(self):
page_indices_xla = page_indices.to("xla")
effective_q_lens_xla = effective_q_lens.to("xla")

output_no_cap = multi_queries_paged_attention(
q_xla,
k_pages_xla,
v_pages_xla,
kv_seq_lens_xla,
page_indices_xla,
effective_q_lens_xla,
num_kv_pages_per_compute_block=block_kv_size // page_size,
num_queries_per_compute_block=num_queries_per_compute_block,
)

output = multi_queries_paged_attention(
q_xla,
k_pages_xla,
Expand All @@ -606,6 +616,7 @@ def test_paged_attention_multi_queries_wrapper(self):
effective_q_lens_xla,
num_kv_pages_per_compute_block=block_kv_size // page_size,
num_queries_per_compute_block=num_queries_per_compute_block,
attn_logits_soft_cap=1.0,
)

nonkernel_output = multi_queries_paged_attention(
Expand All @@ -627,6 +638,19 @@ def test_paged_attention_multi_queries_wrapper(self):
page_indices_jax = jnp.array(page_indices.numpy(), dtype=jnp.int32)
effective_q_lens_jax = jnp.array(effective_q_lens.numpy(), dtype=jnp.int32)
expected_output = torch.from_numpy(
np.array(
jax_multi_queries_paged_attention(
q_jax,
k_pages_jax,
v_pages_jax,
kv_seq_lens_jax,
page_indices_jax,
effective_q_lens_jax,
num_kv_pages_per_compute_block=block_kv_size // page_size,
num_queries_per_compute_block=num_queries_per_compute_block,
attn_logits_soft_cap=1.0,
)))
expected_output_no_cap = torch.from_numpy(
np.array(
jax_multi_queries_paged_attention(
q_jax,
Expand All @@ -642,9 +666,18 @@ def test_paged_attention_multi_queries_wrapper(self):
self.assertTrue(
torch.allclose(
output.cpu(), expected_output.cpu(), atol=1e-5, rtol=1e-5))
self.assertFalse(
torch.allclose(
output.cpu(), expected_output_no_cap.cpu(), atol=1e-5, rtol=1e-5))
self.assertTrue(
torch.allclose(
output.cpu(), nonkernel_output.cpu(), atol=1e-2, rtol=1e-2))
output_no_cap.cpu(),
expected_output_no_cap.cpu(),
atol=1e-5,
rtol=1e-5))
self.assertTrue(
torch.allclose(
output_no_cap.cpu(), nonkernel_output.cpu(), atol=1e-2, rtol=1e-2))

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4,
"This test only works on TPUv4+.")
Expand Down Expand Up @@ -696,7 +729,7 @@ def multi_queries_paged_attention_wrapper(q, k_pages, v_pages, kv_seq_lens,
page_indices, effective_q_lens,
num_kv_pages_per_compute_block,
num_queries_per_compute_block,
use_kernel):
use_kernel, attn_logits_soft_cap):
return torch.ops.xla.multi_queries_paged_attention(
q,
k_pages,
Expand All @@ -707,38 +740,42 @@ def multi_queries_paged_attention_wrapper(q, k_pages, v_pages, kv_seq_lens,
num_kv_pages_per_compute_block,
num_queries_per_compute_block,
use_kernel=use_kernel,
attn_logits_soft_cap=attn_logits_soft_cap,
)

compiled_paged_attention = torch.compile(
multi_queries_paged_attention_wrapper, backend="openxla")

output = compiled_paged_attention(
q_xla,
k_pages_xla,
v_pages_xla,
kv_seq_lens_xla,
page_indices_xla,
effective_q_lens_xla,
num_kv_pages_per_compute_block=block_kv_size // page_size,
num_queries_per_compute_block=num_queries_per_compute_block,
use_kernel=True,
)
for attn_logits_soft_cap in (1.0, None):
output = compiled_paged_attention(
q_xla,
k_pages_xla,
v_pages_xla,
kv_seq_lens_xla,
page_indices_xla,
effective_q_lens_xla,
num_kv_pages_per_compute_block=block_kv_size // page_size,
num_queries_per_compute_block=num_queries_per_compute_block,
use_kernel=True,
attn_logits_soft_cap=attn_logits_soft_cap,
)

nonkernel_output = compiled_paged_attention(
q_xla,
k_pages_xla,
v_pages_xla,
kv_seq_lens_xla,
page_indices_xla,
effective_q_lens_xla,
num_kv_pages_per_compute_block=block_kv_size // page_size,
num_queries_per_compute_block=num_queries_per_compute_block,
use_kernel=False,
)
nonkernel_output = compiled_paged_attention(
q_xla,
k_pages_xla,
v_pages_xla,
kv_seq_lens_xla,
page_indices_xla,
effective_q_lens_xla,
num_kv_pages_per_compute_block=block_kv_size // page_size,
num_queries_per_compute_block=num_queries_per_compute_block,
use_kernel=False,
attn_logits_soft_cap=attn_logits_soft_cap,
)

self.assertTrue(
torch.allclose(
output.cpu(), nonkernel_output.cpu(), atol=1e-2, rtol=1e-2))
self.assertTrue(
torch.allclose(
output.cpu(), nonkernel_output.cpu(), atol=1e-2, rtol=1e-2))

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() != 4,
"This test only works on TPUv4 and TPUv5p.")
Expand Down Expand Up @@ -822,7 +859,6 @@ def test_paged_attention_wrapper_with_dynamo(self):
num_kv_heads = 8
q_kv_head_ratio = 8
head_dim = 256
dtype = torch.float32
seq_lens = torch.tensor([0, 3, 256, 513, 1023, 2048], dtype=torch.int32)

q, k_pages, v_pages, page_indices = self._pagedattention_generate_qkv(
Expand Down Expand Up @@ -899,7 +935,6 @@ def test_paged_attention_wrapper_with_attn_logits_soft_cap(self):
num_kv_heads = 8
q_kv_head_ratio = 8
head_dim = 256
dtype = torch.float32
seq_lens = torch.tensor([0, 3, 256, 513, 1023, 2048], dtype=torch.int32)

q, k_pages, v_pages, page_indices = self._pagedattention_generate_qkv(
Expand Down
8 changes: 8 additions & 0 deletions test/test_tpu_paged_attention_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def _ref_jax_extended_paged_attention(
lengths, # [batch_size], the effective kv_length.
page_indices, # [batch_size, pages_per_sequence]
effective_q_lens, # [batch_size] the effective q_length
attn_logits_soft_cap: float | None = None,
):
batch_size, query_len, num_query_heads, head_size = q.shape
num_kv_heads, total_num_pages, page_size, _ = k_pages.shape
Expand All @@ -71,6 +72,9 @@ def _ref_jax_extended_paged_attention(
v = jnp.repeat(v, num_query_per_kv, axis=1)

attn = jnp.einsum("qhd,khd->hqk", q[i], k)
if attn_logits_soft_cap is not None:
capped_attn = jnp.tanh(attn / attn_logits_soft_cap)
attn = capped_attn * attn_logits_soft_cap
attn = attn.astype('float32')
effective_q_len = effective_q_lens[i]
q_span = (kv_len - effective_q_len) + jax.lax.broadcasted_iota(
Expand Down Expand Up @@ -111,6 +115,7 @@ def setUp(self):
head_dim=(128, 256),
num_queries_per_compute_block=(16, 32),
block_kv_size=(128, 256),
attn_logits_soft_cap=(1.0, None),
)
def test_paged_attention_without_query_padding(
self,
Expand All @@ -121,6 +126,7 @@ def test_paged_attention_without_query_padding(
head_dim,
num_queries_per_compute_block,
block_kv_size,
attn_logits_soft_cap,
):

max_kv_len = 2048
Expand Down Expand Up @@ -160,6 +166,7 @@ def test_paged_attention_without_query_padding(
effective_q_lens,
num_kv_pages_per_compute_block=num_kv_pages_per_compute_block,
num_queries_per_compute_block=num_queries_per_compute_block,
attn_logits_soft_cap=attn_logits_soft_cap,
)
# Note kernel execution is async. Without blocking, if an error happens in the kernel, the error may point to some irrelevant and confusing places. See https://github.com/pytorch/xla/pull/8356#issuecomment-2486861631
actual_output = jax.block_until_ready(actual_output)
Expand All @@ -172,6 +179,7 @@ def test_paged_attention_without_query_padding(
kv_seq_lens,
page_indices,
effective_q_lens,
attn_logits_soft_cap=attn_logits_soft_cap,
)

self.assertEqual(actual_output.shape, expected_output.shape)
Expand Down
47 changes: 34 additions & 13 deletions torch_xla/experimental/custom_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,7 @@ def _multi_queries_paged_attention_nonkernel(
lengths, # seq_lengths, [batch_size]. nb batch_size = len(seq_lens), the effective kv_length.
page_indices, # [batch_size, pages_per_sequence]
effective_q_lens, # [batch_size], the effective q_length
attn_logits_soft_cap: float | None = None,
) -> torch.Tensor: # [batch_size, query_len, num_heads, head_dim]
batch_size, query_len, num_query_heads, head_size = q.shape
num_kv_heads, total_num_pages, page_size, _ = k_pages.shape
Expand Down Expand Up @@ -543,6 +544,9 @@ def _multi_queries_paged_attention_nonkernel(
# For example, it can use bfloat16 instead of float32 or vice versa for performance or simplicity.
attn = torch.einsum("qhd,khd->hqk", q[i],
k) # [num_query_heads, query_len, kv_len]
if attn_logits_soft_cap is not None:
capped_attn = torch.tanh(attn / attn_logits_soft_cap)
attn = capped_attn * attn_logits_soft_cap
attn = attn.float()
empty_mask = torch.ones(query_len, kv_len, device=attn.device)
effective_q_len = effective_q_lens[i]
Expand All @@ -569,6 +573,7 @@ def multi_queries_paged_attention(
num_kv_pages_per_compute_block,
num_queries_per_compute_block,
use_kernel=True,
attn_logits_soft_cap: float | None = None,
): # [batch_size, query_len, num_heads, head_dim]:
assert len(q.shape) == 4, "q should have 4 dimensions."
if not use_kernel:
Expand All @@ -579,6 +584,7 @@ def multi_queries_paged_attention(
lengths,
page_indices,
effective_q_lens,
attn_logits_soft_cap=attn_logits_soft_cap,
)

# Import JAX within the function such that we don't need to call the jax_import_guard()
Expand All @@ -595,9 +601,11 @@ def multi_queries_paged_attention(
effective_q_lens,
num_kv_pages_per_compute_block=num_kv_pages_per_compute_block,
num_queries_per_compute_block=num_queries_per_compute_block,
attn_logits_soft_cap=attn_logits_soft_cap,
static_argnames=[
"num_kv_pages_per_compute_block",
"num_queries_per_compute_block",
"attn_logits_soft_cap",
],
)

Expand Down Expand Up @@ -1103,29 +1111,42 @@ def paged_attention_non_xla(q: torch.Tensor,


XLA_LIB.define(
"multi_queries_paged_attention(Tensor q, Tensor k_pages, Tensor v_pages, Tensor lengths, Tensor page_indices, Tensor effective_q_lens, int num_kv_pages_per_compute_block, int num_queries_per_compute_block, bool use_kernel) -> Tensor",
)
"multi_queries_paged_attention(Tensor q, Tensor k_pages, Tensor v_pages, Tensor lengths, Tensor page_indices,"
" Tensor effective_q_lens, int num_kv_pages_per_compute_block, int num_queries_per_compute_block,"
" bool use_kernel, float? attn_logits_soft_cap=None) -> Tensor",)


@impl(XLA_LIB, "multi_queries_paged_attention", "XLA")
def multi_queries_paged_attention_xla(
q: torch.Tensor, k_pages: torch.Tensor, v_pages: torch.Tensor,
lengths: torch.Tensor, page_indices: torch.Tensor,
effective_q_lens: torch.Tensor, num_kv_pages_per_compute_block: int,
num_queries_per_compute_block: int, use_kernel: bool):
def multi_queries_paged_attention_xla(q: torch.Tensor,
k_pages: torch.Tensor,
v_pages: torch.Tensor,
lengths: torch.Tensor,
page_indices: torch.Tensor,
effective_q_lens: torch.Tensor,
num_kv_pages_per_compute_block: int,
num_queries_per_compute_block: int,
use_kernel: bool,
attn_logits_soft_cap: float |
None = None):
return multi_queries_paged_attention(q, k_pages, v_pages, lengths,
page_indices, effective_q_lens,
num_kv_pages_per_compute_block,
num_queries_per_compute_block,
use_kernel)
use_kernel, attn_logits_soft_cap)


@impl(XLA_LIB, "multi_queries_paged_attention", "CompositeExplicitAutograd")
def multi_queries_paged_attention_non_xla(
q: torch.Tensor, k_pages: torch.Tensor, v_pages: torch.Tensor,
lengths: torch.Tensor, page_indices: torch.Tensor,
effective_q_lens: torch.Tensor, num_kv_pages_per_compute_block: int,
num_queries_per_compute_block: int, use_kernel: bool):
def multi_queries_paged_attention_non_xla(q: torch.Tensor,
k_pages: torch.Tensor,
v_pages: torch.Tensor,
lengths: torch.Tensor,
page_indices: torch.Tensor,
effective_q_lens: torch.Tensor,
num_kv_pages_per_compute_block: int,
num_queries_per_compute_block: int,
use_kernel: bool,
attn_logits_soft_cap: float |
None = None):
return non_xla_attetion(q, k_pages, v_pages, "paged")


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def _flash_attention(
query_len: int,
page_size: int,
head_dim: int,
attn_logits_soft_cap: float | None,
):
b, kv_head_idx, q_blk_idx, kv_blk_idx = (
pl.program_id(0),
Expand Down Expand Up @@ -143,6 +144,10 @@ def start_new_sequence():
s = jnp.einsum(
'qd,td->qt', q, k,
preferred_element_type=jnp.float32) # [block_q, block_k]
if attn_logits_soft_cap is not None:
capped_s = jnp.tanh(s / attn_logits_soft_cap)
s = capped_s * attn_logits_soft_cap

assert s.shape == (num_queries_per_compute_block,
kv_seq_len_per_kv_compute_blk)

Expand Down Expand Up @@ -266,6 +271,7 @@ def paged_flash_attention_kernel(
num_kv_pages_per_compute_block: int,
mask_value: float,
query_len: int,
attn_logits_soft_cap: float | None,
):
"""Pallas kernel for paged attention."""
b, kv_head_idx, q_blk_idx, kv_blk_idx = (
Expand Down Expand Up @@ -411,6 +417,7 @@ def prefetch_next_block(): # pylint: disable=unused-variable
query_len=query_len,
page_size=page_size,
head_dim=head_dim,
attn_logits_soft_cap=attn_logits_soft_cap,
)
# o_ref.shape=[num_q_heads_per_kv_head, num_queries_per_compute_block, head_dim]
step_ref[0] = step + 1
Expand All @@ -428,6 +435,7 @@ def prefetch_next_block(): # pylint: disable=unused-variable
"num_kv_pages_per_compute_block",
"num_queries_per_compute_block",
"mask_value",
"attn_logits_soft_cap",
],
)
def paged_attention(
Expand All @@ -441,6 +449,7 @@ def paged_attention(
mask_value: float = DEFAULT_MASK_VALUE,
num_kv_pages_per_compute_block: int,
num_queries_per_compute_block: int = 4,
attn_logits_soft_cap: float | None = None,
) -> jax.Array:
"""Paged grouped query attention.
Expand Down Expand Up @@ -620,7 +629,9 @@ def lm_index_map(batch_index, kv_head_index, q_seq_blk_idx, *_):
batch_size=batch_size,
num_kv_pages_per_compute_block=num_kv_pages_per_compute_block,
mask_value=mask_value,
query_len=query_len),
query_len=query_len,
attn_logits_soft_cap=attn_logits_soft_cap,
),
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=5,
in_specs=in_specs,
Expand Down

0 comments on commit 6b785fc

Please sign in to comment.