diff --git a/_static/img/onnx/custom_addandround_function.png b/_static/img/onnx/custom_addandround_function.png deleted file mode 100644 index a0c7000161e..00000000000 Binary files a/_static/img/onnx/custom_addandround_function.png and /dev/null differ diff --git a/_static/img/onnx/custom_addandround_model.png b/_static/img/onnx/custom_addandround_model.png deleted file mode 100644 index 793d8cfbb5d..00000000000 Binary files a/_static/img/onnx/custom_addandround_model.png and /dev/null differ diff --git a/_static/img/onnx/custom_aten_add_function.png b/_static/img/onnx/custom_aten_add_function.png deleted file mode 100644 index 8ef05a747a0..00000000000 Binary files a/_static/img/onnx/custom_aten_add_function.png and /dev/null differ diff --git a/_static/img/onnx/custom_aten_add_model.png b/_static/img/onnx/custom_aten_add_model.png deleted file mode 100644 index e5ef1c71742..00000000000 Binary files a/_static/img/onnx/custom_aten_add_model.png and /dev/null differ diff --git a/_static/img/onnx/custom_aten_gelu_model.png b/_static/img/onnx/custom_aten_gelu_model.png deleted file mode 100644 index 5b326690eb7..00000000000 Binary files a/_static/img/onnx/custom_aten_gelu_model.png and /dev/null differ diff --git a/beginner_source/onnx/onnx_registry_tutorial.py b/beginner_source/onnx/onnx_registry_tutorial.py index 0f64ba9c4d4..8266f912a13 100644 --- a/beginner_source/onnx/onnx_registry_tutorial.py +++ b/beginner_source/onnx/onnx_registry_tutorial.py @@ -1,14 +1,12 @@ -# -*- coding: utf-8 -*- - """ `Introduction to ONNX `_ || `Exporting a PyTorch model to ONNX `_ || -**Extending the ONNX Registry** +**Extending the ONNX Exporter Operator Support** -Extending the ONNX Registry -=========================== +Extending the ONNX Exporter Operator Support +============================================ -**Authors:** Ti-Tai Wang (titaiwang@microsoft.com) +**Authors:** Ti-Tai Wang (titaiwang@microsoft.com), Justin Chu (justinchu@microsoft.com) """ @@ -16,430 +14,327 @@ # Overview # -------- # -# This tutorial is an introduction to ONNX registry, which empowers users to implement new ONNX operators -# or even replace existing operators with a new implementation. -# -# During the model export to ONNX, the PyTorch model is lowered to an intermediate -# representation composed of `ATen operators `_. -# While ATen operators are maintained by PyTorch core team, it is the responsibility of the ONNX exporter team -# to independently implement each of these operators to ONNX through `ONNX Script `_. -# The users can also replace the behavior implemented by the ONNX exporter team with their own implementation -# to fix bugs or improve performance for a specific ONNX runtime. -# -# The ONNX Registry manages the mapping between PyTorch operators and the ONNX operators counterparts and provides -# APIs to extend the registry. +# This tutorial describes how you can create ONNX implementation for unsupported PyTorch operators +# or replace existing implementation with your own. # -# In this tutorial, we will cover three scenarios that require extending the ONNX registry with custom operators: +# We will cover three scenarios that require extending the ONNX exporter's operator support: # -# * Unsupported ATen operators -# * Custom operators with existing ONNX Runtime support -# * Custom operators without ONNX Runtime support +# * Overriding the implementation of an existing PyTorch operator +# * Using custom ONNX operators +# * Supporting a custom PyTorch operator # -# Unsupported ATen operators -# -------------------------- +# Overriding the implementation of an existing PyTorch operator +# ------------------------------------------------------------- # -# Although the ONNX exporter team does their best efforts to support all ATen operators, some of them +# Although the ONNX exporter team does their best efforts to support all PyTorch operators, some of them # might not be supported yet. In this section, we will demonstrate how you can add -# unsupported ATen operators to the ONNX Registry. +# unsupported PyTorch operators to the ONNX Registry. # # .. note:: -# The steps to implement unsupported ATen operators are the same to replace the implementation of an existing -# ATen operator with a custom implementation. -# Because we don't actually have an unsupported ATen operator to use in this tutorial, we are going to leverage -# this and replace the implementation of ``aten::add.Tensor`` with a custom implementation the same way we would -# if the operator was not present in the ONNX Registry. +# The steps to implement unsupported PyTorch operators are the same to replace the implementation of an existing +# PyTorch operator with a custom implementation. +# Because we don't actually have an unsupported PyTorch operator to use in this tutorial, we are going to leverage +# this and replace the implementation of ``torch.ops.aten.add.Tensor`` with a custom implementation the same way we would +# if the operator was not implemented by the ONNX exporter. # # When a model cannot be exported to ONNX due to an unsupported operator, the ONNX exporter will show an error message # similar to: # # .. code-block:: python # -# RuntimeErrorWithDiagnostic: Unsupported FX nodes: {'call_function': ['aten.add.Tensor']}. +# No decompositions registered for [...] # -# The error message indicates that the fully qualified name of unsupported ATen operator is ``aten::add.Tensor``. -# The fully qualified name of an operator is composed of the namespace, operator name, and overload following -# the format ``namespace::operator_name.overload``. +# The error message indicates that the unsupported PyTorch operator is ``torch.ops.aten.add.Tensor``. +# The operator is of type ````, and this operator is what we will use as the +# target to register our custom implementation. # -# To add support for an unsupported ATen operator or to replace the implementation for an existing one, we need: +# To add support for an unsupported PyTorch operator or to replace the implementation for an existing one, we need: # -# * The fully qualified name of the ATen operator (e.g. ``aten::add.Tensor``). -# This information is always present in the error message as show above. +# * The target PyTorch operator. # * The implementation of the operator using `ONNX Script `__. # ONNX Script is a prerequisite for this tutorial. Please make sure you have read the # `ONNX Script tutorial `_ # before proceeding. -# -# Because ``aten::add.Tensor`` is already supported by the ONNX Registry, we will demonstrate how to replace it with a -# custom implementation, but keep in mind that the same steps apply to support new unsupported ATen operators. -# -# This is possible because the :class:`OnnxRegistry` allows users to override an operator registration. -# We will override the registration of ``aten::add.Tensor`` with our custom implementation and verify it exists. -# import torch import onnxruntime import onnxscript -from onnxscript import opset18 # opset 18 is the latest (and only) supported version for now -class Model(torch.nn.Module): - def forward(self, input_x, input_y): - return torch.ops.aten.add(input_x, input_y) # generates a aten::add.Tensor node +# Opset 18 is the standard supported version as of PyTorch 2.6 +from onnxscript import opset18 as op -input_add_x = torch.randn(3, 4) -input_add_y = torch.randn(3, 4) -aten_add_model = Model() +# Create a model that uses the operator torch.ops.aten.add.Tensor +class Model(torch.nn.Module): + def forward(self, input_x, input_y): + return torch.ops.aten.add.Tensor(input_x, input_y) -# Now we create a ONNX Script function that implements ``aten::add.Tensor``. -# The function name (e.g. ``custom_aten_add``) is displayed in the ONNX graph, so we recommend to use intuitive names. -custom_aten = onnxscript.values.Opset(domain="custom.aten", version=1) -# NOTE: The function signature must match the signature of the unsupported ATen operator. +# NOTE: The function signature (including param names) must match the signature of the unsupported PyTorch operator. # https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml # NOTE: All attributes must be annotated with type hints. -@onnxscript.script(custom_aten) -def custom_aten_add(input_x, input_y, alpha: float = 1.0): - input_y = opset18.Mul(input_y, alpha) - return opset18.Add(input_x, input_y) - - -# Now we have everything we need to support unsupported ATen operators. -# Let's register the ``custom_aten_add`` function to ONNX registry, and export the model to ONNX again. -onnx_registry = torch.onnx.OnnxRegistry() -onnx_registry.register_op( - namespace="aten", op_name="add", overload="Tensor", function=custom_aten_add - ) -print(f"aten::add.Tensor is supported by ONNX registry: \ - {onnx_registry.is_registered_op(namespace='aten', op_name='add', overload='Tensor')}" - ) -export_options = torch.onnx.ExportOptions(onnx_registry=onnx_registry) -onnx_program = torch.onnx.dynamo_export( - aten_add_model, input_add_x, input_add_y, export_options=export_options - ) +def custom_aten_add(self, other, alpha: float = 1.0): + if alpha != 1.0: + alpha = op.CastLike(alpha, other) + other = op.Mul(other, alpha) + # To distinguish the custom implementation from the builtin one, we switch the order of the inputs + return op.Add(other, self) + + +x = torch.tensor([1.0]) +y = torch.tensor([2.0]) + +# Then we provide the custom implementation to the ONNX exporter as a ``custom_translation_table``. +onnx_program = torch.onnx.export( + Model().eval(), + (x, y), + dynamo=True, + custom_translation_table={ + torch.ops.aten.add.Tensor: custom_aten_add, + }, +) +# Optimize the ONNX graph to remove redundant nodes +onnx_program.optimize() ###################################################################### -# Now let's inspect the model and verify the model has a ``custom_aten_add`` instead of ``aten::add.Tensor``. -# The graph has one graph node for ``custom_aten_add``, and inside of it there are four function nodes, one for each -# operator, and one for constant attribute. -# - -# graph node domain is the custom domain we registered -assert onnx_program.model_proto.graph.node[0].domain == "custom.aten" -assert len(onnx_program.model_proto.graph.node) == 1 -# graph node name is the function name -assert onnx_program.model_proto.graph.node[0].op_type == "custom_aten_add" -# function node domain is empty because we use standard ONNX operators -assert {node.domain for node in onnx_program.model_proto.functions[0].node} == {""} -# function node name is the standard ONNX operator name -assert {node.op_type for node in onnx_program.model_proto.functions[0].node} == {"Add", "Mul", "Constant"} +# Now let's inspect the model and verify the model is using the custom implementation. +print(onnx_program.model) ###################################################################### -# This is how ``custom_aten_add_model`` looks in the ONNX graph using Netron: -# -# .. image:: /_static/img/onnx/custom_aten_add_model.png -# :width: 70% -# :align: center -# -# Inside the ``custom_aten_add`` function, we can see the three ONNX nodes we -# used in the function (``CastLike``, ``Add``, and ``Mul``), and one ``Constant`` attribute: -# -# .. image:: /_static/img/onnx/custom_aten_add_function.png -# :width: 70% -# :align: center +# We get # -# This was all that we needed to register the new ATen operator into the ONNX Registry. -# As an additional step, we can use ONNX Runtime to run the model, and compare the results with PyTorch. -# - - -# Use ONNX Runtime to run the model, and compare the results with PyTorch -onnx_program.save("./custom_add_model.onnx") -ort_session = onnxruntime.InferenceSession( - "./custom_add_model.onnx", providers=['CPUExecutionProvider'] - ) - -def to_numpy(tensor): - return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() - -onnx_input = onnx_program.adapt_torch_inputs_to_onnx(input_add_x, input_add_y) -onnxruntime_input = {k.name: to_numpy(v) for k, v in zip(ort_session.get_inputs(), onnx_input)} -onnxruntime_outputs = ort_session.run(None, onnxruntime_input) - -torch_outputs = aten_add_model(input_add_x, input_add_y) -torch_outputs = onnx_program.adapt_torch_outputs_to_onnx(torch_outputs) - -assert len(torch_outputs) == len(onnxruntime_outputs) -for torch_output, onnxruntime_output in zip(torch_outputs, onnxruntime_outputs): - torch.testing.assert_close(torch_output, torch.tensor(onnxruntime_output)) +# .. code-block:: python +# < +# ir_version=10, +# opset_imports={'pkg.onnxscript.torch_lib.common': 1, '': 18}, +# producer_name='pytorch', +# producer_version='2.7.0.dev20250124+cu124', +# domain=None, +# model_version=None, +# > +# graph( +# name=main_graph, +# inputs=( +# %"input_x", +# %"input_y" +# ), +# outputs=( +# %"add" +# ), +# ) { +# 0 | # node_Add_0 +# %"add" ⬅️ ::Add(%"input_y", %"input_x") +# return %"add" +# } +# +# The translation is using our custom implementation: In node ``node_Add_0``, ``input_y`` now +# comes first, and ``input_x`` comes second. +# +# We can use ONNX Runtime to run the model and verify the results by calling +# the ONNXProgram directly on the input tensors. + +result = onnx_program(x, y)[0] +torch.testing.assert_close(result, torch.tensor([3.0])) ###################################################################### -# Custom operators with existing ONNX Runtime support -# --------------------------------------------------- +# Using custom ONNX operators +# --------------------------- # -# In this case, the user creates a model with standard PyTorch operators, but the ONNX runtime +# In this case, we create a model with standard PyTorch operators, but the runtime # (e.g. Microsoft's ONNX Runtime) can provide a custom implementation for that kernel, effectively replacing the -# existing implementation in the ONNX Registry. Another use case is when the user wants to use a custom implementation -# of an existing ONNX operator to fix a bug or improve performance of a specific operator. -# To achieve this, we only need to register the new implementation with the existing ATen fully qualified name. +# existing implementation. # -# In the following example, we use the ``com.microsoft.Gelu`` from ONNX Runtime, -# which is not the same ``Gelu`` from ONNX spec. Thus, we register the Gelu with -# the namespace ``com.microsoft`` and operator name ``Gelu``. -# -# Before we begin, let's check whether ``aten::gelu.default`` is really supported by the ONNX registry. - -onnx_registry = torch.onnx.OnnxRegistry() -print(f"aten::gelu.default is supported by ONNX registry: \ - {onnx_registry.is_registered_op(namespace='aten', op_name='gelu', overload='default')}") +# In the following example, we use the ``com.microsoft.Gelu`` operator provided by ONNX Runtime, +# which is not the same ``Gelu`` from ONNX spec. -###################################################################### -# In our example, ``aten::gelu.default`` operator is supported by the ONNX registry, -# so :meth:`onnx_registry.is_registered_op` returns ``True``. - -class CustomGelu(torch.nn.Module): +class GeluModel(torch.nn.Module): def forward(self, input_x): return torch.ops.aten.gelu(input_x) -# com.microsoft is an official ONNX Runtime namspace -custom_ort = onnxscript.values.Opset(domain="com.microsoft", version=1) -# NOTE: The function signature must match the signature of the unsupported ATen operator. +# Create a namespace for the custom operator using ONNX Script +# com.microsoft is an official ONNX Runtime namespace +microsoft_op = onnxscript.values.Opset(domain="com.microsoft", version=1) + +# NOTE: The function signature (including param names) must match the signature of the unsupported PyTorch operator. # https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml # NOTE: All attributes must be annotated with type hints. -@onnxscript.script(custom_ort) -def custom_aten_gelu(input_x, approximate: str = "none"): - # We know com.microsoft::Gelu is supported by ONNX Runtime - # It's only not supported by ONNX - return custom_ort.Gelu(input_x) +# The function must be scripted using the ``@onnxscript.script()`` decorator when +# using operators from custom domains. This may be improved in future versions. +from onnxscript import FLOAT -onnx_registry = torch.onnx.OnnxRegistry() -onnx_registry.register_op( - namespace="aten", op_name="gelu", overload="default", function=custom_aten_gelu) -export_options = torch.onnx.ExportOptions(onnx_registry=onnx_registry) +@onnxscript.script(microsoft_op) +def custom_aten_gelu(self: FLOAT, approximate: str = "none") -> FLOAT: + return microsoft_op.Gelu(self) -aten_gelu_model = CustomGelu() -input_gelu_x = torch.randn(3, 3) -onnx_program = torch.onnx.dynamo_export( - aten_gelu_model, input_gelu_x, export_options=export_options - ) +onnx_program = torch.onnx.export( + GeluModel().eval(), + (x,), + dynamo=True, + custom_translation_table={ + torch.ops.aten.gelu.default: custom_aten_gelu, + }, +) + +# Optimize the ONNX graph to remove redundant nodes +onnx_program.optimize() ###################################################################### # Let's inspect the model and verify the model uses op_type ``Gelu`` # from namespace ``com.microsoft``. # -# .. note:: -# :func:`custom_aten_gelu` does not exist in the graph because -# functions with fewer than three operators are inlined automatically. -# - -# graph node domain is the custom domain we registered -assert onnx_program.model_proto.graph.node[0].domain == "com.microsoft" -# graph node name is the function name -assert onnx_program.model_proto.graph.node[0].op_type == "Gelu" +print(onnx_program.model) ###################################################################### -# The following diagram shows ``custom_aten_gelu_model`` ONNX graph using Netron, -# we can see the ``Gelu`` node from module ``com.microsoft`` used in the function: -# -# .. image:: /_static/img/onnx/custom_aten_gelu_model.png +# We get # -# That is all we need to do. As an additional step, we can use ONNX Runtime to run the model, -# and compare the results with PyTorch. -# - -onnx_program.save("./custom_gelu_model.onnx") -ort_session = onnxruntime.InferenceSession( - "./custom_gelu_model.onnx", providers=['CPUExecutionProvider'] - ) +# .. code-block:: python +# < +# ir_version=10, +# opset_imports={'pkg.onnxscript.torch_lib.common': 1, 'com.microsoft': 1, '': 18}, +# producer_name='pytorch', +# producer_version='2.7.0.dev20250124+cu124', +# domain=None, +# model_version=None, +# > +# graph( +# name=main_graph, +# inputs=( +# %"input_x" +# ), +# outputs=( +# %"gelu" +# ), +# ) { +# 0 | # n0 +# %"gelu" ⬅️ com.microsoft::Gelu(%"input_x") +# return %"gelu" +# } -def to_numpy(tensor): - return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() -onnx_input = onnx_program.adapt_torch_inputs_to_onnx(input_gelu_x) -onnxruntime_input = {k.name: to_numpy(v) for k, v in zip(ort_session.get_inputs(), onnx_input)} -onnxruntime_outputs = ort_session.run(None, onnxruntime_input) +###################################################################### +# Similar to the previous example, we can use ONNX Runtime to run the model and verify the results. -torch_outputs = aten_gelu_model(input_gelu_x) -torch_outputs = onnx_program.adapt_torch_outputs_to_onnx(torch_outputs) +result = onnx_program(x)[0] +torch.testing.assert_close(result, torch.ops.aten.gelu(x)) -assert len(torch_outputs) == len(onnxruntime_outputs) -for torch_output, onnxruntime_output in zip(torch_outputs, onnxruntime_outputs): - torch.testing.assert_close(torch_output, torch.tensor(onnxruntime_output)) ###################################################################### -# Custom operators without ONNX Runtime support -# --------------------------------------------- +# Supporting a custom PyTorch operator +# ------------------------------------ # -# In this case, the operator is not supported by any ONNX runtime, but we -# would like to use it as custom operator in ONNX graph. Therefore, we need to implement -# the operator in three places: -# -# 1. PyTorch FX graph -# 2. ONNX Registry -# 3. ONNX Runtime +# In this case, the operator is an operator that is user implemented and registered to PyTorch. # # In the following example, we would like to use a custom operator # that takes one tensor input, and returns one output. The operator adds # the input to itself, and returns the rounded result. # -# -# Custom Ops Registration in PyTorch FX Graph (Beta) -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -# -# Firstly, we need to implement the operator in PyTorch FX graph. -# This can be done by using ``torch._custom_op``. -# - -# NOTE: This is a beta feature in PyTorch, and is subject to change. -from torch._custom_op import impl as custom_op - -@custom_op.custom_op("mylibrary::addandround_op") -def addandround_op(tensor_x: torch.Tensor) -> torch.Tensor: - ... - -@addandround_op.impl_abstract() -def addandround_op_impl_abstract(tensor_x): - return torch.empty_like(tensor_x) - -@addandround_op.impl("cpu") -def addandround_op_impl(tensor_x): - return torch.round(tensor_x + tensor_x) # add x to itself, and round the result - -torch._dynamo.allow_in_graph(addandround_op) +# Firstly, we assume the custom operator is implemented and registered with ``torch.library.custom_op()``. +# You can refer to `Creating new custom ops in Python `_ +# for a detailed guide on how to create custom operators. -class CustomFoo(torch.nn.Module): - def forward(self, tensor_x): - return addandround_op(tensor_x) -input_addandround_x = torch.randn(3) -custom_addandround_model = CustomFoo() +# Define and use the operator in PyTorch +@torch.library.custom_op("mylibrary::add_and_round_op", mutates_args=()) +def add_and_round_op(input: torch.Tensor) -> torch.Tensor: + return torch.round(input + input) -###################################################################### -# -# Custom Ops Registration in ONNX Registry -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -# -# For the step 2 and 3, we need to implement the operator in ONNX registry. -# In this example, we will implement the operator in ONNX registry -# with the namespace ``test.customop`` and operator name ``CustomOpOne``, -# and ``CustomOpTwo``. These two ops are registered and built in -# `cpu_ops.cc `__. -# - - -custom_opset = onnxscript.values.Opset(domain="test.customop", version=1) - -# NOTE: The function signature must match the signature of the unsupported ATen operator. -# https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml -# NOTE: All attributes must be annotated with type hints. -@onnxscript.script(custom_opset) -def custom_addandround(input_x): - # The same as opset18.Add(x, x) - add_x = custom_opset.CustomOpOne(input_x, input_x) - # The same as opset18.Round(x, x) - round_x = custom_opset.CustomOpTwo(add_x) - # Cast to FLOAT to match the ONNX type - return opset18.Cast(round_x, to=1) +@add_and_round_op.register_fake +def _add_and_round_op_fake(tensor_x): + return torch.empty_like(tensor_x) -onnx_registry = torch.onnx.OnnxRegistry() -onnx_registry.register_op( - namespace="mylibrary", op_name="addandround_op", overload="default", function=custom_addandround - ) +class AddAndRoundModel(torch.nn.Module): + def forward(self, input): + return add_and_round_op(input) -export_options = torch.onnx.ExportOptions(onnx_registry=onnx_registry) -onnx_program = torch.onnx.dynamo_export( - custom_addandround_model, input_addandround_x, export_options=export_options - ) -onnx_program.save("./custom_addandround_model.onnx") +# Implement the custom operator in ONNX using ONNX Script +def onnx_add_and_round(input): + return op.Round(op.Add(input, input)) -###################################################################### -# The ``onnx_program`` exposes the exported model as protobuf through ``onnx_program.model_proto``. -# The graph has one graph nodes for ``custom_addandround``, and inside ``custom_addandround``, -# there are two function nodes, one for each operator. -# -assert onnx_program.model_proto.graph.node[0].domain == "test.customop" -assert onnx_program.model_proto.graph.node[0].op_type == "custom_addandround" -assert onnx_program.model_proto.functions[0].node[0].domain == "test.customop" -assert onnx_program.model_proto.functions[0].node[0].op_type == "CustomOpOne" -assert onnx_program.model_proto.functions[0].node[1].domain == "test.customop" -assert onnx_program.model_proto.functions[0].node[1].op_type == "CustomOpTwo" +onnx_program = torch.onnx.export( + AddAndRoundModel().eval(), + (x,), + dynamo=True, + custom_translation_table={ + torch.ops.mylibrary.add_and_round_op.default: onnx_add_and_round, + }, +) +# Optimize the ONNX graph to remove redundant nodes +onnx_program.optimize() +print(onnx_program) ###################################################################### -# This is how ``custom_addandround_model`` ONNX graph looks using Netron: -# -# .. image:: /_static/img/onnx/custom_addandround_model.png -# :width: 70% -# :align: center -# -# Inside the ``custom_addandround`` function, we can see the two custom operators we -# used in the function (``CustomOpOne``, and ``CustomOpTwo``), and they are from module -# ``test.customop``: -# -# .. image:: /_static/img/onnx/custom_addandround_function.png -# -# Custom Ops Registration in ONNX Runtime -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -# -# To link your custom op library to ONNX Runtime, you need to -# compile your C++ code into a shared library and link it to ONNX Runtime. -# Follow the instructions below: -# -# 1. Implement your custom op in C++ by following -# `ONNX Runtime instructions <`https://github.com/microsoft/onnxruntime/blob/gh-pages/docs/reference/operators/add-custom-op.md>`__. -# 2. Download ONNX Runtime source distribution from -# `ONNX Runtime releases `__. -# 3. Compile and link your custom op library to ONNX Runtime, for example: -# -# .. code-block:: bash -# -# $ gcc -shared -o libcustom_op_library.so custom_op_library.cc -L /path/to/downloaded/ort/lib/ -lonnxruntime -fPIC -# -# 4. Run the model with ONNX Runtime Python API and compare the results with PyTorch. +# We get # # .. code-block:: python +# < +# ir_version=10, +# opset_imports={'pkg.onnxscript.torch_lib.common': 1, '': 18}, +# producer_name='pytorch', +# producer_version='2.7.0.dev20250124+cu124', +# domain=None, +# model_version=None, +# > +# graph( +# name=main_graph, +# inputs=( +# %"input" +# ), +# outputs=( +# %"add_and_round_op" +# ), +# ) { +# 0 | # node_Add_0 +# %"val_0" ⬅️ ::Add(%"input", %"input") +# 1 | # node_Round_1 +# %"add_and_round_op" ⬅️ ::Round(%"val_0") +# return %"add_and_round_op" +# } +# +# And exported program # -# ort_session_options = onnxruntime.SessionOptions() -# -# # NOTE: Link the custom op library to ONNX Runtime and replace the path -# # with the path to your custom op library -# ort_session_options.register_custom_ops_library( -# "/path/to/libcustom_op_library.so" -# ) -# ort_session = onnxruntime.InferenceSession( -# "./custom_addandround_model.onnx", providers=['CPUExecutionProvider'], sess_options=ort_session_options) -# -# def to_numpy(tensor): -# return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() +# .. code-block:: python +# ExportedProgram: +# class GraphModule(torch.nn.Module): +# def forward(self, input: "f32[1]"): +# input_1 = input # -# onnx_input = onnx_program.adapt_torch_inputs_to_onnx(input_addandround_x) -# onnxruntime_input = {k.name: to_numpy(v) for k, v in zip(ort_session.get_inputs(), onnx_input)} -# onnxruntime_outputs = ort_session.run(None, onnxruntime_input) +# add_and_round_op: "f32[1]" = torch.ops.mylibrary.add_and_round_op.default(input_1); input_1 = None +# return (add_and_round_op,) # -# torch_outputs = custom_addandround_model(input_addandround_x) -# torch_outputs = onnx_program.adapt_torch_outputs_to_onnx(torch_outputs) +# Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=, arg=TensorArgument(name='input'), target=None, persistent=None)], output_specs=[OutputSpec(kind=, arg=TensorArgument(name='add_and_round_op'), target=None)]) +# Range constraints: {} # -# assert len(torch_outputs) == len(onnxruntime_outputs) -# for torch_output, onnxruntime_output in zip(torch_outputs, onnxruntime_outputs): -# torch.testing.assert_close(torch_output, torch.tensor(onnxruntime_output)) +# The translation is using our custom implementation to translate the ``torch.ops.mylibrary.add_and_round_op.default`` +# operator in the ExportedProgram to the ONNX operator ``Add`` and ``Round``. # + +###################################################################### +# Finally we verify the results. + +result = onnx_program(x)[0] +torch.testing.assert_close(result, add_and_round_op(x)) + +###################################################################### # Conclusion # ---------- # -# Congratulations! In this tutorial, we explored the :class:`ONNXRegistry` API and -# discovered how to create custom implementations for unsupported or existing ATen operators +# Congratulations! In this tutorial, we explored the ``custom_translation_table`` option and +# discovered how to create custom implementations for unsupported or existing PyTorch operators # using ONNX Script. +# # Finally, we leveraged ONNX Runtime to execute the model and compare the results with PyTorch, # providing us with a comprehensive understanding of handling unsupported # operators in the ONNX ecosystem.