@@ -513,43 +513,50 @@ def functional_linear(self, weights, bias=None):
513
513
res += bias
514
514
return res
515
515
516
- @register_function (torch .ops .xla .dynamo_set_buffer_donor_ )
517
- def _dynamo_set_buffer_donor (self , donor ):
516
+ try :
517
+ # TODO: Currently the following ops are wrapped in the try
518
+ # catch block because torch.ops.xla is not in the torch ops
519
+ # registry. Either we import torch_xla in the upper level,
520
+ # or modify the the register_function to support this.
521
+ @register_function (torch .ops .xla .dynamo_set_buffer_donor_ )
522
+ def _dynamo_set_buffer_donor (self , donor ):
523
+ pass
524
+
525
+ @register_function (torch .ops .xla .ragged_paged_attention )
526
+ def _ragged_paged_attention (
527
+ q : jax .Array , # [max_num_batched_tokens, num_q_heads, head_dim]
528
+ kv_pages : jax .Array , # [total_num_pages, page_size, num_combined_kv_heads, head_dim]
529
+ kv_lens : jax .Array , # i32[max_num_seqs]
530
+ page_indices : jax .Array , # i32[max_num_seqs, pages_per_seq]
531
+ cu_q_lens : jax .Array , # i32[max_num_seqs + 1]
532
+ num_seqs : jax .Array , # i32[1]
533
+ use_kernel : bool = True ,
534
+ sm_scale : float = 1.0 ,
535
+ sliding_window : int | None = None ,
536
+ soft_cap : float | None = None ,
537
+ mask_value : float | None = None ,
538
+ num_kv_pages_per_block : int | None = None ,
539
+ num_queries_per_block : int | None = None ,
540
+ vmem_limit_bytes : int | None = None ,
541
+ ):
542
+
543
+ from torch_xla .experimental .pallas_kernels .ragged_paged_attention_v2 import ragged_paged_attention as ragged_paged_attention_kernel
544
+ return ragged_paged_attention_kernel (
545
+ q = q ,
546
+ kv_pages = kv_pages ,
547
+ kv_lens = kv_lens ,
548
+ page_indices = page_indices ,
549
+ cu_q_lens = cu_q_lens ,
550
+ num_seqs = num_seqs ,
551
+ sm_scale = sm_scale ,
552
+ sliding_window = sliding_window ,
553
+ soft_cap = soft_cap ,
554
+ mask_value = mask_value ,
555
+ num_kv_pages_per_block = num_kv_pages_per_block ,
556
+ num_queries_per_block = num_queries_per_block ,
557
+ vmem_limit_bytes = vmem_limit_bytes ,
558
+ )
559
+ except Exception as e :
518
560
pass
519
561
520
- @register_function (torch .ops .xla .ragged_paged_attention )
521
- def _ragged_paged_attention (
522
- q : jax .Array , # [max_num_batched_tokens, num_q_heads, head_dim]
523
- kv_pages : jax .Array , # [total_num_pages, page_size, num_combined_kv_heads, head_dim]
524
- kv_lens : jax .Array , # i32[max_num_seqs]
525
- page_indices : jax .Array , # i32[max_num_seqs, pages_per_seq]
526
- cu_q_lens : jax .Array , # i32[max_num_seqs + 1]
527
- num_seqs : jax .Array , # i32[1]
528
- use_kernel : bool = True ,
529
- sm_scale : float = 1.0 ,
530
- sliding_window : int | None = None ,
531
- soft_cap : float | None = None ,
532
- mask_value : float | None = None ,
533
- num_kv_pages_per_block : int | None = None ,
534
- num_queries_per_block : int | None = None ,
535
- vmem_limit_bytes : int | None = None ,
536
- ):
537
-
538
- from torch_xla .experimental .pallas_kernels .ragged_paged_attention_v2 import ragged_paged_attention as ragged_paged_attention_kernel
539
- return ragged_paged_attention_kernel (
540
- q = q ,
541
- kv_pages = kv_pages ,
542
- kv_lens = kv_lens ,
543
- page_indices = page_indices ,
544
- cu_q_lens = cu_q_lens ,
545
- num_seqs = num_seqs ,
546
- sm_scale = sm_scale ,
547
- sliding_window = sliding_window ,
548
- soft_cap = soft_cap ,
549
- mask_value = mask_value ,
550
- num_kv_pages_per_block = num_kv_pages_per_block ,
551
- num_queries_per_block = num_queries_per_block ,
552
- vmem_limit_bytes = vmem_limit_bytes ,
553
- )
554
-
555
562
0 commit comments