@@ -1117,12 +1117,17 @@ def paged_attention_non_xla(q: torch.Tensor,
1117
1117
1118
1118
1119
1119
@impl (XLA_LIB , "multi_queries_paged_attention" , "XLA" )
1120
- def multi_queries_paged_attention_xla (
1121
- q : torch .Tensor , k_pages : torch .Tensor , v_pages : torch .Tensor ,
1122
- lengths : torch .Tensor , page_indices : torch .Tensor ,
1123
- effective_q_lens : torch .Tensor , num_kv_pages_per_compute_block : int ,
1124
- num_queries_per_compute_block : int , use_kernel : bool ,
1125
- attn_logits_soft_cap : float | None = None ):
1120
+ def multi_queries_paged_attention_xla (q : torch .Tensor ,
1121
+ k_pages : torch .Tensor ,
1122
+ v_pages : torch .Tensor ,
1123
+ lengths : torch .Tensor ,
1124
+ page_indices : torch .Tensor ,
1125
+ effective_q_lens : torch .Tensor ,
1126
+ num_kv_pages_per_compute_block : int ,
1127
+ num_queries_per_compute_block : int ,
1128
+ use_kernel : bool ,
1129
+ attn_logits_soft_cap : float |
1130
+ None = None ):
1126
1131
return multi_queries_paged_attention (q , k_pages , v_pages , lengths ,
1127
1132
page_indices , effective_q_lens ,
1128
1133
num_kv_pages_per_compute_block ,
@@ -1131,12 +1136,17 @@ def multi_queries_paged_attention_xla(
1131
1136
1132
1137
1133
1138
@impl (XLA_LIB , "multi_queries_paged_attention" , "CompositeExplicitAutograd" )
1134
- def multi_queries_paged_attention_non_xla (
1135
- q : torch .Tensor , k_pages : torch .Tensor , v_pages : torch .Tensor ,
1136
- lengths : torch .Tensor , page_indices : torch .Tensor ,
1137
- effective_q_lens : torch .Tensor , num_kv_pages_per_compute_block : int ,
1138
- num_queries_per_compute_block : int , use_kernel : bool ,
1139
- attn_logits_soft_cap : float | None = None ):
1139
+ def multi_queries_paged_attention_non_xla (q : torch .Tensor ,
1140
+ k_pages : torch .Tensor ,
1141
+ v_pages : torch .Tensor ,
1142
+ lengths : torch .Tensor ,
1143
+ page_indices : torch .Tensor ,
1144
+ effective_q_lens : torch .Tensor ,
1145
+ num_kv_pages_per_compute_block : int ,
1146
+ num_queries_per_compute_block : int ,
1147
+ use_kernel : bool ,
1148
+ attn_logits_soft_cap : float |
1149
+ None = None ):
1140
1150
return non_xla_attetion (q , k_pages , v_pages , "paged" )
1141
1151
1142
1152
0 commit comments