From 2ce9e2fa2eccc6d829c8c10bc83bebf64e662f89 Mon Sep 17 00:00:00 2001 From: Fenghui Zhang Date: Wed, 15 Jan 2025 18:43:38 +0000 Subject: [PATCH] Implements attn_logits_soft_cap and pass it through multi_queries_paged_attention --- test/test_pallas.py | 14 ++++++++++ torch_xla/experimental/custom_kernel.py | 34 ++++++++++++++++--------- 2 files changed, 36 insertions(+), 12 deletions(-) diff --git a/test/test_pallas.py b/test/test_pallas.py index fdc904bd98e6..3e165dc80352 100644 --- a/test/test_pallas.py +++ b/test/test_pallas.py @@ -729,8 +729,22 @@ def multi_queries_paged_attention_wrapper(q, k_pages, v_pages, kv_seq_lens, page_indices, effective_q_lens, num_kv_pages_per_compute_block, num_queries_per_compute_block, +<<<<<<< HEAD use_kernel, attn_logits_soft_cap): +======= +<<<<<<< HEAD + use_kernel, + attn_logits_soft_cap): +======= +<<<<<<< HEAD + use_kernel, + attn_logits_soft_cap): +======= + use_kernel, attn_logits_soft_cap): +>>>>>>> 0a91471da (Implements attn_logits_soft_cap and pass it through multi_queries_paged_attention) +>>>>>>> 47e8d1d00 (Implements attn_logits_soft_cap and pass it through multi_queries_paged_attention) +>>>>>>> d430cb4e9 (Implements attn_logits_soft_cap and pass it through multi_queries_paged_attention) return torch.ops.xla.multi_queries_paged_attention( q, k_pages, diff --git a/torch_xla/experimental/custom_kernel.py b/torch_xla/experimental/custom_kernel.py index d26d9f649b7e..185d2085e7f1 100644 --- a/torch_xla/experimental/custom_kernel.py +++ b/torch_xla/experimental/custom_kernel.py @@ -1117,12 +1117,17 @@ def paged_attention_non_xla(q: torch.Tensor, @impl(XLA_LIB, "multi_queries_paged_attention", "XLA") -def multi_queries_paged_attention_xla( - q: torch.Tensor, k_pages: torch.Tensor, v_pages: torch.Tensor, - lengths: torch.Tensor, page_indices: torch.Tensor, - effective_q_lens: torch.Tensor, num_kv_pages_per_compute_block: int, - num_queries_per_compute_block: int, use_kernel: bool, - attn_logits_soft_cap: float | None = None): +def multi_queries_paged_attention_xla(q: torch.Tensor, + k_pages: torch.Tensor, + v_pages: torch.Tensor, + lengths: torch.Tensor, + page_indices: torch.Tensor, + effective_q_lens: torch.Tensor, + num_kv_pages_per_compute_block: int, + num_queries_per_compute_block: int, + use_kernel: bool, + attn_logits_soft_cap: float | + None = None): return multi_queries_paged_attention(q, k_pages, v_pages, lengths, page_indices, effective_q_lens, num_kv_pages_per_compute_block, @@ -1131,12 +1136,17 @@ def multi_queries_paged_attention_xla( @impl(XLA_LIB, "multi_queries_paged_attention", "CompositeExplicitAutograd") -def multi_queries_paged_attention_non_xla( - q: torch.Tensor, k_pages: torch.Tensor, v_pages: torch.Tensor, - lengths: torch.Tensor, page_indices: torch.Tensor, - effective_q_lens: torch.Tensor, num_kv_pages_per_compute_block: int, - num_queries_per_compute_block: int, use_kernel: bool, - attn_logits_soft_cap: float | None = None): +def multi_queries_paged_attention_non_xla(q: torch.Tensor, + k_pages: torch.Tensor, + v_pages: torch.Tensor, + lengths: torch.Tensor, + page_indices: torch.Tensor, + effective_q_lens: torch.Tensor, + num_kv_pages_per_compute_block: int, + num_queries_per_compute_block: int, + use_kernel: bool, + attn_logits_soft_cap: float | + None = None): return non_xla_attetion(q, k_pages, v_pages, "paged")