Skip to content

Commit 2ce9e2f

Browse files
committed
Implements attn_logits_soft_cap and pass it through multi_queries_paged_attention
1 parent 351de89 commit 2ce9e2f

File tree

2 files changed

+36
-12
lines changed

2 files changed

+36
-12
lines changed

test/test_pallas.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -729,8 +729,22 @@ def multi_queries_paged_attention_wrapper(q, k_pages, v_pages, kv_seq_lens,
729729
page_indices, effective_q_lens,
730730
num_kv_pages_per_compute_block,
731731
num_queries_per_compute_block,
732+
<<<<<<< HEAD
732733
use_kernel,
733734
attn_logits_soft_cap):
735+
=======
736+
<<<<<<< HEAD
737+
use_kernel,
738+
attn_logits_soft_cap):
739+
=======
740+
<<<<<<< HEAD
741+
use_kernel,
742+
attn_logits_soft_cap):
743+
=======
744+
use_kernel, attn_logits_soft_cap):
745+
>>>>>>> 0a91471da (Implements attn_logits_soft_cap and pass it through multi_queries_paged_attention)
746+
>>>>>>> 47e8d1d00 (Implements attn_logits_soft_cap and pass it through multi_queries_paged_attention)
747+
>>>>>>> d430cb4e9 (Implements attn_logits_soft_cap and pass it through multi_queries_paged_attention)
734748
return torch.ops.xla.multi_queries_paged_attention(
735749
q,
736750
k_pages,

torch_xla/experimental/custom_kernel.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1117,12 +1117,17 @@ def paged_attention_non_xla(q: torch.Tensor,
11171117

11181118

11191119
@impl(XLA_LIB, "multi_queries_paged_attention", "XLA")
1120-
def multi_queries_paged_attention_xla(
1121-
q: torch.Tensor, k_pages: torch.Tensor, v_pages: torch.Tensor,
1122-
lengths: torch.Tensor, page_indices: torch.Tensor,
1123-
effective_q_lens: torch.Tensor, num_kv_pages_per_compute_block: int,
1124-
num_queries_per_compute_block: int, use_kernel: bool,
1125-
attn_logits_soft_cap: float | None = None):
1120+
def multi_queries_paged_attention_xla(q: torch.Tensor,
1121+
k_pages: torch.Tensor,
1122+
v_pages: torch.Tensor,
1123+
lengths: torch.Tensor,
1124+
page_indices: torch.Tensor,
1125+
effective_q_lens: torch.Tensor,
1126+
num_kv_pages_per_compute_block: int,
1127+
num_queries_per_compute_block: int,
1128+
use_kernel: bool,
1129+
attn_logits_soft_cap: float |
1130+
None = None):
11261131
return multi_queries_paged_attention(q, k_pages, v_pages, lengths,
11271132
page_indices, effective_q_lens,
11281133
num_kv_pages_per_compute_block,
@@ -1131,12 +1136,17 @@ def multi_queries_paged_attention_xla(
11311136

11321137

11331138
@impl(XLA_LIB, "multi_queries_paged_attention", "CompositeExplicitAutograd")
1134-
def multi_queries_paged_attention_non_xla(
1135-
q: torch.Tensor, k_pages: torch.Tensor, v_pages: torch.Tensor,
1136-
lengths: torch.Tensor, page_indices: torch.Tensor,
1137-
effective_q_lens: torch.Tensor, num_kv_pages_per_compute_block: int,
1138-
num_queries_per_compute_block: int, use_kernel: bool,
1139-
attn_logits_soft_cap: float | None = None):
1139+
def multi_queries_paged_attention_non_xla(q: torch.Tensor,
1140+
k_pages: torch.Tensor,
1141+
v_pages: torch.Tensor,
1142+
lengths: torch.Tensor,
1143+
page_indices: torch.Tensor,
1144+
effective_q_lens: torch.Tensor,
1145+
num_kv_pages_per_compute_block: int,
1146+
num_queries_per_compute_block: int,
1147+
use_kernel: bool,
1148+
attn_logits_soft_cap: float |
1149+
None = None):
11401150
return non_xla_attetion(q, k_pages, v_pages, "paged")
11411151

11421152

0 commit comments

Comments
 (0)