Skip to content

Conversation

ysiraichi
Copy link
Collaborator

@ysiraichi ysiraichi commented Oct 18, 2025

This PR refactors the custom_call implementation by improving its error message, and returning a status type value.

Key Changes:

  • Make tensor_methods::{tpu,}custom_call return StatusOr<vector<XLATensorPtr>>
  • Improve error messages and error handling
    • Create new CheckCustomCallNonEmptyInputs function for checking if there' at least 1 input tensor
    • Create new CheckCustomCallOutputPropertiesSize function for checking if the given shapes and dtypes match in size, i.e. they agree on the number of outputs
    • Create new CustomCallImpl function for implementing both custom_call and tpu_custom_call, since they did mostly the same thing
    • Deleted CheckIntList function (should have been deleted in #9648)
    • Deleted TpuCustomCall function, inlining it in its corresponding Python binding, which is how custom_call binding is currently implemented

Example 1: no inputs

output_shapes = [[1]]
output_dtypes = [torch.int8]
torch_xla._XLAC._xla_custom_call([], "custom_op_target", output_shapes, output_dtypes, False, "", 0, {})
Comparison

Before:

Traceback (most recent call last):
  File "examples/test.py", line 8, in <module>
    torch_xla._XLAC._xla_custom_call([], "custom_op_target", output_shapes, output_dtypes, False, "", 0, {})
RuntimeError: Check failed: inputs.size() > 0: inputs are empty (at torch_xla/csrc/tensor_methods.cpp:895)

Exception raised from operator& at torch_xla/csrc/runtime/tf_logging.cpp:26 (most recent call first):

After:

Traceback (most recent call last):
  File "examples/test.py", line 8, in <module>
    torch_xla._XLAC._xla_custom_call([], "custom_op_target", output_shapes, output_dtypes, False, "", 0, {})
RuntimeError: custom_call(custom_op_target): expected at least 1 input tensor.

Status Propagation Trace:
    From: CheckCustomCallNonEmptyInputs at torch_xla/csrc/tensor_methods.cpp:659 (error: custom_call(custom_op_target): expected at least 1 input tensor.)
    From: CustomCallImpl at torch_xla/csrc/tensor_methods.cpp:695
    From: custom_call at torch_xla/csrc/tensor_methods.cpp:961
    From: operator() at torch_xla/csrc/init_python_bindings.cpp:3111

Exception raised from ThrowStatusError at torch_xla/csrc/status.cpp:128 (most recent call first):

Example 2: output properties size mismatch

input = torch.rand(10, device=torch_xla.device())
output_shapes = [[1], [1]]
output_dtypes = [torch.int8]
torch_xla._XLAC._xla_custom_call([input], "custom_op_target", output_shapes, output_dtypes, False, "", 0, {})
Comparison

Before:

Traceback (most recent call last):
  File "examples/test.py", line 17, in <module>
    torch_xla._XLAC._xla_custom_call([input], "custom_op_target", output_shapes, output_dtypes, False, "", 0, {})
RuntimeError: Check failed: output_shapes.size() == output_dtypes.size() (2 vs. 1) (at torch_xla/csrc/tensor_methods.cpp:903)

Exception raised from operator& at torch_xla/csrc/runtime/tf_logging.cpp:26 (most recent call first):

After:

Traceback (most recent call last):
  File "examples/test.py", line 17, in <module>
    torch_xla._XLAC._xla_custom_call([input], "custom_op_target", output_shapes, output_dtypes, False, "", 0, {})
RuntimeError: custom_call(custom_op_target): expected the given output shapes (size=2) to be of the same size as the given output dtypes (size=1).

Status Propagation Trace:
    From: CheckCustomCallOutputPropertiesSize at torch_xla/csrc/tensor_methods.cpp:677 (error: custom_call(custom_op_target): expected the given output shapes (size=2) to be of the same size as the given output dtypes (size=1).)
    From: CustomCallImpl at torch_xla/csrc/tensor_methods.cpp:696
    From: custom_call at torch_xla/csrc/tensor_methods.cpp:961
    From: operator() at torch_xla/csrc/init_python_bindings.cpp:3111

Exception raised from ThrowStatusError at torch_xla/csrc/status.cpp:128 (most recent call first):

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant