[torch_xla2] Wire torch_xla2.compile
d function with torch AutogradFunction
#8587
Labels
torch_xla2.compile
d function with torch AutogradFunction
#8587
🚀 Feature
Currently if we wrap with model with
torch_xla2.compile
and want to train the model using the traditional torch training loop similar to https://github.com/pytorch/xla/blob/master/experimental/torch_xla2/examples/basic_training.pyYou would notice that it doesn't work.
The reason is because the compile wrapper
JittableModule
will eventuall call ajax.jit
d callable, and torch doesn't know how to compute gradient of that callable.The solution is to create a
torch.autograd.Function
subclass on the fly, with backward defined to calljax.vjp
similar to this tutorial: https://pytorch.org/tutorials/beginner/examples_autograd/two_layer_net_custom_function.htmlThe result would be that wrapping a model with
torch_xla2.compile
it is still trainable.Motivation
Having the forward and backward compiled with jax jit is faster to run.
The text was updated successfully, but these errors were encountered: