diff --git a/test/test_pallas.py b/test/test_pallas.py index d49df491dc0..106b917528d 100644 --- a/test/test_pallas.py +++ b/test/test_pallas.py @@ -1,5 +1,4 @@ import logging -import os import unittest import torch @@ -597,6 +596,17 @@ def test_paged_attention_multi_queries_wrapper(self): page_indices_xla = page_indices.to("xla") effective_q_lens_xla = effective_q_lens.to("xla") + output_no_cap = multi_queries_paged_attention( + q_xla, + k_pages_xla, + v_pages_xla, + kv_seq_lens_xla, + page_indices_xla, + effective_q_lens_xla, + num_kv_pages_per_compute_block=block_kv_size // page_size, + num_queries_per_compute_block=num_queries_per_compute_block, + ) + output = multi_queries_paged_attention( q_xla, k_pages_xla, @@ -606,6 +616,7 @@ def test_paged_attention_multi_queries_wrapper(self): effective_q_lens_xla, num_kv_pages_per_compute_block=block_kv_size // page_size, num_queries_per_compute_block=num_queries_per_compute_block, + attn_logits_soft_cap=1.0, ) nonkernel_output = multi_queries_paged_attention( @@ -627,6 +638,19 @@ def test_paged_attention_multi_queries_wrapper(self): page_indices_jax = jnp.array(page_indices.numpy(), dtype=jnp.int32) effective_q_lens_jax = jnp.array(effective_q_lens.numpy(), dtype=jnp.int32) expected_output = torch.from_numpy( + np.array( + jax_multi_queries_paged_attention( + q_jax, + k_pages_jax, + v_pages_jax, + kv_seq_lens_jax, + page_indices_jax, + effective_q_lens_jax, + num_kv_pages_per_compute_block=block_kv_size // page_size, + num_queries_per_compute_block=num_queries_per_compute_block, + attn_logits_soft_cap=1.0, + ))) + expected_output_no_cap = torch.from_numpy( np.array( jax_multi_queries_paged_attention( q_jax, @@ -642,9 +666,18 @@ def test_paged_attention_multi_queries_wrapper(self): self.assertTrue( torch.allclose( output.cpu(), expected_output.cpu(), atol=1e-5, rtol=1e-5)) + self.assertFalse( + torch.allclose( + output.cpu(), expected_output_no_cap.cpu(), atol=1e-5, rtol=1e-5)) self.assertTrue( torch.allclose( - output.cpu(), nonkernel_output.cpu(), atol=1e-2, rtol=1e-2)) + output_no_cap.cpu(), + expected_output_no_cap.cpu(), + atol=1e-5, + rtol=1e-5)) + self.assertTrue( + torch.allclose( + output_no_cap.cpu(), nonkernel_output.cpu(), atol=1e-2, rtol=1e-2)) @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4, "This test only works on TPUv4+.") @@ -696,7 +729,7 @@ def multi_queries_paged_attention_wrapper(q, k_pages, v_pages, kv_seq_lens, page_indices, effective_q_lens, num_kv_pages_per_compute_block, num_queries_per_compute_block, - use_kernel): + use_kernel, attn_logits_soft_cap): return torch.ops.xla.multi_queries_paged_attention( q, k_pages, @@ -707,38 +740,42 @@ def multi_queries_paged_attention_wrapper(q, k_pages, v_pages, kv_seq_lens, num_kv_pages_per_compute_block, num_queries_per_compute_block, use_kernel=use_kernel, + attn_logits_soft_cap=attn_logits_soft_cap, ) compiled_paged_attention = torch.compile( multi_queries_paged_attention_wrapper, backend="openxla") - output = compiled_paged_attention( - q_xla, - k_pages_xla, - v_pages_xla, - kv_seq_lens_xla, - page_indices_xla, - effective_q_lens_xla, - num_kv_pages_per_compute_block=block_kv_size // page_size, - num_queries_per_compute_block=num_queries_per_compute_block, - use_kernel=True, - ) + for attn_logits_soft_cap in (1.0, None): + output = compiled_paged_attention( + q_xla, + k_pages_xla, + v_pages_xla, + kv_seq_lens_xla, + page_indices_xla, + effective_q_lens_xla, + num_kv_pages_per_compute_block=block_kv_size // page_size, + num_queries_per_compute_block=num_queries_per_compute_block, + use_kernel=True, + attn_logits_soft_cap=attn_logits_soft_cap, + ) - nonkernel_output = compiled_paged_attention( - q_xla, - k_pages_xla, - v_pages_xla, - kv_seq_lens_xla, - page_indices_xla, - effective_q_lens_xla, - num_kv_pages_per_compute_block=block_kv_size // page_size, - num_queries_per_compute_block=num_queries_per_compute_block, - use_kernel=False, - ) + nonkernel_output = compiled_paged_attention( + q_xla, + k_pages_xla, + v_pages_xla, + kv_seq_lens_xla, + page_indices_xla, + effective_q_lens_xla, + num_kv_pages_per_compute_block=block_kv_size // page_size, + num_queries_per_compute_block=num_queries_per_compute_block, + use_kernel=False, + attn_logits_soft_cap=attn_logits_soft_cap, + ) - self.assertTrue( - torch.allclose( - output.cpu(), nonkernel_output.cpu(), atol=1e-2, rtol=1e-2)) + self.assertTrue( + torch.allclose( + output.cpu(), nonkernel_output.cpu(), atol=1e-2, rtol=1e-2)) @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() != 4, "This test only works on TPUv4 and TPUv5p.") @@ -822,7 +859,6 @@ def test_paged_attention_wrapper_with_dynamo(self): num_kv_heads = 8 q_kv_head_ratio = 8 head_dim = 256 - dtype = torch.float32 seq_lens = torch.tensor([0, 3, 256, 513, 1023, 2048], dtype=torch.int32) q, k_pages, v_pages, page_indices = self._pagedattention_generate_qkv( @@ -899,7 +935,6 @@ def test_paged_attention_wrapper_with_attn_logits_soft_cap(self): num_kv_heads = 8 q_kv_head_ratio = 8 head_dim = 256 - dtype = torch.float32 seq_lens = torch.tensor([0, 3, 256, 513, 1023, 2048], dtype=torch.int32) q, k_pages, v_pages, page_indices = self._pagedattention_generate_qkv( diff --git a/test/test_tpu_paged_attention_kernel.py b/test/test_tpu_paged_attention_kernel.py index 746439ba4d0..8749334b711 100644 --- a/test/test_tpu_paged_attention_kernel.py +++ b/test/test_tpu_paged_attention_kernel.py @@ -45,6 +45,7 @@ def _ref_jax_extended_paged_attention( lengths, # [batch_size], the effective kv_length. page_indices, # [batch_size, pages_per_sequence] effective_q_lens, # [batch_size] the effective q_length + attn_logits_soft_cap: float | None = None, ): batch_size, query_len, num_query_heads, head_size = q.shape num_kv_heads, total_num_pages, page_size, _ = k_pages.shape @@ -71,6 +72,9 @@ def _ref_jax_extended_paged_attention( v = jnp.repeat(v, num_query_per_kv, axis=1) attn = jnp.einsum("qhd,khd->hqk", q[i], k) + if attn_logits_soft_cap is not None: + capped_attn = jnp.tanh(attn / attn_logits_soft_cap) + attn = capped_attn * attn_logits_soft_cap attn = attn.astype('float32') effective_q_len = effective_q_lens[i] q_span = (kv_len - effective_q_len) + jax.lax.broadcasted_iota( @@ -111,6 +115,7 @@ def setUp(self): head_dim=(128, 256), num_queries_per_compute_block=(16, 32), block_kv_size=(128, 256), + attn_logits_soft_cap=(1.0, None), ) def test_paged_attention_without_query_padding( self, @@ -121,6 +126,7 @@ def test_paged_attention_without_query_padding( head_dim, num_queries_per_compute_block, block_kv_size, + attn_logits_soft_cap, ): max_kv_len = 2048 @@ -160,6 +166,7 @@ def test_paged_attention_without_query_padding( effective_q_lens, num_kv_pages_per_compute_block=num_kv_pages_per_compute_block, num_queries_per_compute_block=num_queries_per_compute_block, + attn_logits_soft_cap=attn_logits_soft_cap, ) # Note kernel execution is async. Without blocking, if an error happens in the kernel, the error may point to some irrelevant and confusing places. See https://github.com/pytorch/xla/pull/8356#issuecomment-2486861631 actual_output = jax.block_until_ready(actual_output) @@ -172,6 +179,7 @@ def test_paged_attention_without_query_padding( kv_seq_lens, page_indices, effective_q_lens, + attn_logits_soft_cap=attn_logits_soft_cap, ) self.assertEqual(actual_output.shape, expected_output.shape) diff --git a/torch_xla/experimental/custom_kernel.py b/torch_xla/experimental/custom_kernel.py index 8ccc4ddbc59..185d2085e7f 100644 --- a/torch_xla/experimental/custom_kernel.py +++ b/torch_xla/experimental/custom_kernel.py @@ -506,6 +506,7 @@ def _multi_queries_paged_attention_nonkernel( lengths, # seq_lengths, [batch_size]. nb batch_size = len(seq_lens), the effective kv_length. page_indices, # [batch_size, pages_per_sequence] effective_q_lens, # [batch_size], the effective q_length + attn_logits_soft_cap: float | None = None, ) -> torch.Tensor: # [batch_size, query_len, num_heads, head_dim] batch_size, query_len, num_query_heads, head_size = q.shape num_kv_heads, total_num_pages, page_size, _ = k_pages.shape @@ -543,6 +544,9 @@ def _multi_queries_paged_attention_nonkernel( # For example, it can use bfloat16 instead of float32 or vice versa for performance or simplicity. attn = torch.einsum("qhd,khd->hqk", q[i], k) # [num_query_heads, query_len, kv_len] + if attn_logits_soft_cap is not None: + capped_attn = torch.tanh(attn / attn_logits_soft_cap) + attn = capped_attn * attn_logits_soft_cap attn = attn.float() empty_mask = torch.ones(query_len, kv_len, device=attn.device) effective_q_len = effective_q_lens[i] @@ -569,6 +573,7 @@ def multi_queries_paged_attention( num_kv_pages_per_compute_block, num_queries_per_compute_block, use_kernel=True, + attn_logits_soft_cap: float | None = None, ): # [batch_size, query_len, num_heads, head_dim]: assert len(q.shape) == 4, "q should have 4 dimensions." if not use_kernel: @@ -579,6 +584,7 @@ def multi_queries_paged_attention( lengths, page_indices, effective_q_lens, + attn_logits_soft_cap=attn_logits_soft_cap, ) # Import JAX within the function such that we don't need to call the jax_import_guard() @@ -595,9 +601,11 @@ def multi_queries_paged_attention( effective_q_lens, num_kv_pages_per_compute_block=num_kv_pages_per_compute_block, num_queries_per_compute_block=num_queries_per_compute_block, + attn_logits_soft_cap=attn_logits_soft_cap, static_argnames=[ "num_kv_pages_per_compute_block", "num_queries_per_compute_block", + "attn_logits_soft_cap", ], ) @@ -1103,29 +1111,42 @@ def paged_attention_non_xla(q: torch.Tensor, XLA_LIB.define( - "multi_queries_paged_attention(Tensor q, Tensor k_pages, Tensor v_pages, Tensor lengths, Tensor page_indices, Tensor effective_q_lens, int num_kv_pages_per_compute_block, int num_queries_per_compute_block, bool use_kernel) -> Tensor", -) + "multi_queries_paged_attention(Tensor q, Tensor k_pages, Tensor v_pages, Tensor lengths, Tensor page_indices," + " Tensor effective_q_lens, int num_kv_pages_per_compute_block, int num_queries_per_compute_block," + " bool use_kernel, float? attn_logits_soft_cap=None) -> Tensor",) @impl(XLA_LIB, "multi_queries_paged_attention", "XLA") -def multi_queries_paged_attention_xla( - q: torch.Tensor, k_pages: torch.Tensor, v_pages: torch.Tensor, - lengths: torch.Tensor, page_indices: torch.Tensor, - effective_q_lens: torch.Tensor, num_kv_pages_per_compute_block: int, - num_queries_per_compute_block: int, use_kernel: bool): +def multi_queries_paged_attention_xla(q: torch.Tensor, + k_pages: torch.Tensor, + v_pages: torch.Tensor, + lengths: torch.Tensor, + page_indices: torch.Tensor, + effective_q_lens: torch.Tensor, + num_kv_pages_per_compute_block: int, + num_queries_per_compute_block: int, + use_kernel: bool, + attn_logits_soft_cap: float | + None = None): return multi_queries_paged_attention(q, k_pages, v_pages, lengths, page_indices, effective_q_lens, num_kv_pages_per_compute_block, num_queries_per_compute_block, - use_kernel) + use_kernel, attn_logits_soft_cap) @impl(XLA_LIB, "multi_queries_paged_attention", "CompositeExplicitAutograd") -def multi_queries_paged_attention_non_xla( - q: torch.Tensor, k_pages: torch.Tensor, v_pages: torch.Tensor, - lengths: torch.Tensor, page_indices: torch.Tensor, - effective_q_lens: torch.Tensor, num_kv_pages_per_compute_block: int, - num_queries_per_compute_block: int, use_kernel: bool): +def multi_queries_paged_attention_non_xla(q: torch.Tensor, + k_pages: torch.Tensor, + v_pages: torch.Tensor, + lengths: torch.Tensor, + page_indices: torch.Tensor, + effective_q_lens: torch.Tensor, + num_kv_pages_per_compute_block: int, + num_queries_per_compute_block: int, + use_kernel: bool, + attn_logits_soft_cap: float | + None = None): return non_xla_attetion(q, k_pages, v_pages, "paged") diff --git a/torch_xla/experimental/pallas_kernels/multi_queries_paged_attention_kernel.py b/torch_xla/experimental/pallas_kernels/multi_queries_paged_attention_kernel.py index 84d6ad530e5..dc03d7bca85 100644 --- a/torch_xla/experimental/pallas_kernels/multi_queries_paged_attention_kernel.py +++ b/torch_xla/experimental/pallas_kernels/multi_queries_paged_attention_kernel.py @@ -116,6 +116,7 @@ def _flash_attention( query_len: int, page_size: int, head_dim: int, + attn_logits_soft_cap: float | None, ): b, kv_head_idx, q_blk_idx, kv_blk_idx = ( pl.program_id(0), @@ -143,6 +144,10 @@ def start_new_sequence(): s = jnp.einsum( 'qd,td->qt', q, k, preferred_element_type=jnp.float32) # [block_q, block_k] + if attn_logits_soft_cap is not None: + capped_s = jnp.tanh(s / attn_logits_soft_cap) + s = capped_s * attn_logits_soft_cap + assert s.shape == (num_queries_per_compute_block, kv_seq_len_per_kv_compute_blk) @@ -266,6 +271,7 @@ def paged_flash_attention_kernel( num_kv_pages_per_compute_block: int, mask_value: float, query_len: int, + attn_logits_soft_cap: float | None, ): """Pallas kernel for paged attention.""" b, kv_head_idx, q_blk_idx, kv_blk_idx = ( @@ -411,6 +417,7 @@ def prefetch_next_block(): # pylint: disable=unused-variable query_len=query_len, page_size=page_size, head_dim=head_dim, + attn_logits_soft_cap=attn_logits_soft_cap, ) # o_ref.shape=[num_q_heads_per_kv_head, num_queries_per_compute_block, head_dim] step_ref[0] = step + 1 @@ -428,6 +435,7 @@ def prefetch_next_block(): # pylint: disable=unused-variable "num_kv_pages_per_compute_block", "num_queries_per_compute_block", "mask_value", + "attn_logits_soft_cap", ], ) def paged_attention( @@ -441,6 +449,7 @@ def paged_attention( mask_value: float = DEFAULT_MASK_VALUE, num_kv_pages_per_compute_block: int, num_queries_per_compute_block: int = 4, + attn_logits_soft_cap: float | None = None, ) -> jax.Array: """Paged grouped query attention. @@ -620,7 +629,9 @@ def lm_index_map(batch_index, kv_head_index, q_seq_blk_idx, *_): batch_size=batch_size, num_kv_pages_per_compute_block=num_kv_pages_per_compute_block, mask_value=mask_value, - query_len=query_len), + query_len=query_len, + attn_logits_soft_cap=attn_logits_soft_cap, + ), grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=5, in_specs=in_specs,