Skip to content

Commit 55b339f

Browse files
wrap xla ops in try catch
1 parent 0f4b40e commit 55b339f

File tree

1 file changed

+44
-37
lines changed

1 file changed

+44
-37
lines changed

torchax/torchax/ops/jtorch.py

Lines changed: 44 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -513,43 +513,50 @@ def functional_linear(self, weights, bias=None):
513513
res += bias
514514
return res
515515

516-
@register_function(torch.ops.xla.dynamo_set_buffer_donor_)
517-
def _dynamo_set_buffer_donor(self, donor):
516+
try:
517+
# TODO: Currently the following ops are wrapped in the try
518+
# catch block because torch.ops.xla is not in the torch ops
519+
# registry. Either we import torch_xla in the upper level,
520+
# or modify the the register_function to support this.
521+
@register_function(torch.ops.xla.dynamo_set_buffer_donor_)
522+
def _dynamo_set_buffer_donor(self, donor):
523+
pass
524+
525+
@register_function(torch.ops.xla.ragged_paged_attention)
526+
def _ragged_paged_attention(
527+
q: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim]
528+
kv_pages: jax.Array, # [total_num_pages, page_size, num_combined_kv_heads, head_dim]
529+
kv_lens: jax.Array, # i32[max_num_seqs]
530+
page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq]
531+
cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
532+
num_seqs: jax.Array, # i32[1]
533+
use_kernel: bool = True,
534+
sm_scale: float = 1.0,
535+
sliding_window: int | None = None,
536+
soft_cap: float | None = None,
537+
mask_value: float | None = None,
538+
num_kv_pages_per_block: int | None = None,
539+
num_queries_per_block: int | None = None,
540+
vmem_limit_bytes: int | None = None,
541+
):
542+
543+
from torch_xla.experimental.pallas_kernels.ragged_paged_attention_v2 import ragged_paged_attention as ragged_paged_attention_kernel
544+
return ragged_paged_attention_kernel(
545+
q = q,
546+
kv_pages = kv_pages,
547+
kv_lens = kv_lens,
548+
page_indices = page_indices,
549+
cu_q_lens = cu_q_lens,
550+
num_seqs = num_seqs,
551+
sm_scale = sm_scale,
552+
sliding_window = sliding_window,
553+
soft_cap = soft_cap,
554+
mask_value = mask_value,
555+
num_kv_pages_per_block = num_kv_pages_per_block,
556+
num_queries_per_block = num_queries_per_block,
557+
vmem_limit_bytes = vmem_limit_bytes,
558+
)
559+
except Exception as e:
518560
pass
519561

520-
@register_function(torch.ops.xla.ragged_paged_attention)
521-
def _ragged_paged_attention(
522-
q: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim]
523-
kv_pages: jax.Array, # [total_num_pages, page_size, num_combined_kv_heads, head_dim]
524-
kv_lens: jax.Array, # i32[max_num_seqs]
525-
page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq]
526-
cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
527-
num_seqs: jax.Array, # i32[1]
528-
use_kernel: bool = True,
529-
sm_scale: float = 1.0,
530-
sliding_window: int | None = None,
531-
soft_cap: float | None = None,
532-
mask_value: float | None = None,
533-
num_kv_pages_per_block: int | None = None,
534-
num_queries_per_block: int | None = None,
535-
vmem_limit_bytes: int | None = None,
536-
):
537-
538-
from torch_xla.experimental.pallas_kernels.ragged_paged_attention_v2 import ragged_paged_attention as ragged_paged_attention_kernel
539-
return ragged_paged_attention_kernel(
540-
q = q,
541-
kv_pages = kv_pages,
542-
kv_lens = kv_lens,
543-
page_indices = page_indices,
544-
cu_q_lens = cu_q_lens,
545-
num_seqs = num_seqs,
546-
sm_scale = sm_scale,
547-
sliding_window = sliding_window,
548-
soft_cap = soft_cap,
549-
mask_value = mask_value,
550-
num_kv_pages_per_block = num_kv_pages_per_block,
551-
num_queries_per_block = num_queries_per_block,
552-
vmem_limit_bytes = vmem_limit_bytes,
553-
)
554-
555562

0 commit comments

Comments
 (0)