diff --git a/advanced_source/cpp_custom_ops.rst b/advanced_source/cpp_custom_ops.rst index 435ff088bc..ffabd6eff7 100644 --- a/advanced_source/cpp_custom_ops.rst +++ b/advanced_source/cpp_custom_ops.rst @@ -174,6 +174,8 @@ To add ``torch.compile`` support for an operator, we must add a FakeTensor kerne known as a "meta kernel" or "abstract impl"). FakeTensors are Tensors that have metadata (such as shape, dtype, device) but no data: the FakeTensor kernel for an operator specifies how to compute the metadata of output tensors given the metadata of input tensors. +The FakeTensor kernel should return dummy Tensors of your choice with +the correct Tensor metadata (shape/strides/``dtype``/device). We recommend that this be done from Python via the `torch.library.register_fake` API, though it is possible to do this from C++ as well (see diff --git a/advanced_source/python_custom_ops.py b/advanced_source/python_custom_ops.py index 1e429b76b3..0b3bf6e474 100644 --- a/advanced_source/python_custom_ops.py +++ b/advanced_source/python_custom_ops.py @@ -66,7 +66,7 @@ def display(img): ###################################################################### # ``crop`` is not handled effectively out-of-the-box by # ``torch.compile``: ``torch.compile`` induces a -# `"graph break" `_ +# `"graph break" `_ # on functions it is unable to handle and graph breaks are bad for performance. # The following code demonstrates this by raising an error # (``torch.compile`` with ``fullgraph=True`` raises an error if a @@ -85,9 +85,9 @@ def f(img): # # 1. wrap the function into a PyTorch custom operator. # 2. add a "``FakeTensor`` kernel" (aka "meta kernel") to the operator. -# Given the metadata (e.g. shapes) -# of the input Tensors, this function says how to compute the metadata -# of the output Tensor(s). +# Given some ``FakeTensors`` inputs (dummy Tensors that don't have storage), +# this function should return dummy Tensors of your choice with the correct +# Tensor metadata (shape/strides/``dtype``/device). from typing import Sequence @@ -130,6 +130,11 @@ def f(img): # ``autograd.Function`` with PyTorch operator registration APIs can lead to (and # has led to) silent incorrectness when composed with ``torch.compile``. # +# If you don't need training support, there is no need to use +# ``torch.library.register_autograd``. +# If you end up training with a ``custom_op`` that doesn't have an autograd +# registration, we'll raise an error message. +# # The gradient formula for ``crop`` is essentially ``PIL.paste`` (we'll leave the # derivation as an exercise to the reader). Let's first wrap ``paste`` into a # custom operator: @@ -203,7 +208,7 @@ def setup_context(ctx, inputs, output): ###################################################################### # Mutable Python Custom operators # ------------------------------- -# You can also wrap a Python function that mutates its inputs into a custom +# You can also wrap a Python function that mutates its inputs into a custom # operator. # Functions that mutate inputs are common because that is how many low-level # kernels are written; for example, a kernel that computes ``sin`` may take in