Skip to content

Commit

Permalink
Implements attn_logits_soft_cap and pass it through multi_queries_pag…
Browse files Browse the repository at this point in the history
…ed_attention
  • Loading branch information
fenghuizhang committed Jan 21, 2025
1 parent 351de89 commit 2ce9e2f
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 12 deletions.
14 changes: 14 additions & 0 deletions test/test_pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -729,8 +729,22 @@ def multi_queries_paged_attention_wrapper(q, k_pages, v_pages, kv_seq_lens,
page_indices, effective_q_lens,
num_kv_pages_per_compute_block,
num_queries_per_compute_block,
<<<<<<< HEAD
use_kernel,
attn_logits_soft_cap):
=======
<<<<<<< HEAD
use_kernel,
attn_logits_soft_cap):
=======
<<<<<<< HEAD
use_kernel,
attn_logits_soft_cap):
=======
use_kernel, attn_logits_soft_cap):
>>>>>>> 0a91471da (Implements attn_logits_soft_cap and pass it through multi_queries_paged_attention)
>>>>>>> 47e8d1d00 (Implements attn_logits_soft_cap and pass it through multi_queries_paged_attention)
>>>>>>> d430cb4e9 (Implements attn_logits_soft_cap and pass it through multi_queries_paged_attention)
return torch.ops.xla.multi_queries_paged_attention(
q,
k_pages,
Expand Down
34 changes: 22 additions & 12 deletions torch_xla/experimental/custom_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1117,12 +1117,17 @@ def paged_attention_non_xla(q: torch.Tensor,


@impl(XLA_LIB, "multi_queries_paged_attention", "XLA")
def multi_queries_paged_attention_xla(
q: torch.Tensor, k_pages: torch.Tensor, v_pages: torch.Tensor,
lengths: torch.Tensor, page_indices: torch.Tensor,
effective_q_lens: torch.Tensor, num_kv_pages_per_compute_block: int,
num_queries_per_compute_block: int, use_kernel: bool,
attn_logits_soft_cap: float | None = None):
def multi_queries_paged_attention_xla(q: torch.Tensor,
k_pages: torch.Tensor,
v_pages: torch.Tensor,
lengths: torch.Tensor,
page_indices: torch.Tensor,
effective_q_lens: torch.Tensor,
num_kv_pages_per_compute_block: int,
num_queries_per_compute_block: int,
use_kernel: bool,
attn_logits_soft_cap: float |
None = None):
return multi_queries_paged_attention(q, k_pages, v_pages, lengths,
page_indices, effective_q_lens,
num_kv_pages_per_compute_block,
Expand All @@ -1131,12 +1136,17 @@ def multi_queries_paged_attention_xla(


@impl(XLA_LIB, "multi_queries_paged_attention", "CompositeExplicitAutograd")
def multi_queries_paged_attention_non_xla(
q: torch.Tensor, k_pages: torch.Tensor, v_pages: torch.Tensor,
lengths: torch.Tensor, page_indices: torch.Tensor,
effective_q_lens: torch.Tensor, num_kv_pages_per_compute_block: int,
num_queries_per_compute_block: int, use_kernel: bool,
attn_logits_soft_cap: float | None = None):
def multi_queries_paged_attention_non_xla(q: torch.Tensor,
k_pages: torch.Tensor,
v_pages: torch.Tensor,
lengths: torch.Tensor,
page_indices: torch.Tensor,
effective_q_lens: torch.Tensor,
num_kv_pages_per_compute_block: int,
num_queries_per_compute_block: int,
use_kernel: bool,
attn_logits_soft_cap: float |
None = None):
return non_xla_attetion(q, k_pages, v_pages, "paged")


Expand Down

0 comments on commit 2ce9e2f

Please sign in to comment.