diff --git a/docs/source/perf/assume_pure.md b/docs/source/perf/assume_pure.md new file mode 100644 index 000000000000..912b83ea8115 --- /dev/null +++ b/docs/source/perf/assume_pure.md @@ -0,0 +1,131 @@ +# Use `@assume_pure` to speed up lazy tensor tracing + +This document explains how to use `torch_xla.experimental.assume_pure` to +eliminate lazy tensor tracing overhead. See [this blog post][lazy-tensor] for a +primer on how lazy tensor tracing (operation recording) works. + +## Background and motivation + +PyTorch/XLA's lazy tensor tracing ensures correct execution by recording an +operation graph (lazy tensor IR) when running PyTorch operations. For complex +models, this tracing overhead can exceed the execution time of the graph, +leading to performance bottlenecks. When training a model, the layers in the +model must be re-traced on every training step. That's because there's no +guarantee that the layers will do the same thing in different training steps. As +an extreme example, a layer's `forward()` function may call `math.random()` and +decide what code to run based on a pseudo random number. + +Re-tracing can introduce unnecessary overhead. In many cases, the layers in your +model will do exactly the same thing when given the same input tensor shapes. In +other words, given the same input, the function return the same output. Often, +the layers also will not will not perform side-effects such as saving the tensor +to a file or adding it to a global list. Such functions are called +"[pure functions][pure-function]". + +Any PyTorch/XLA function decorated with `@assume_pure` will only be traced once +for each unique input tensor shape and dtype combination. PyTorch/XLA will cache +the traced computation instead of repeatedly tracing the same operations. + +## How to use `@assume_pure` + +### Using `@assume_pure` with a function + +If you know your function is pure, decorate your function with `@assume_pure`: + +```py +import torch +import torch_xla +from torch_xla.experimental.assume_pure import assume_pure + +@assume_pure +def do_some_math( + # You can pass any number of XLA tensors. + a: torch.Tensor, + b: torch.Tensor, + + # Non-tensor arguments are also supported, and passing different values will + # trigger re-tracing and caching more computations. + c: int, +): + # Evaluate some pure expressions. + return a @ b + c + +# Simulate a training loop. +# Even if we run this function ten times, it will only be traced once. +for i in range(10): + v = do_some_math( + torch.tensor([1.0], device='xla'), + torch.tensor([2.0], device='xla'), + c=42, + ) + print(v) +``` + +### Using `@assume_pure` with a `nn.Module` + +If you have a pure `nn.Module` i.e. its `forward` behavior only depends on the +input arguments and the model parameters, we can use `torch.func.functional_call` +to convert the module into a pure function and pass that to `assume_pure`: + +```python +import torch +import torch.nn as nn +from torch.func import functional_call +from torch_xla.experimental.assume_pure import assume_pure + +class MyModule(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(10, 10) + def forward(self, x): + return self.linear(x) + +# Create module and move to XLA device +module = MyModule() +module = module.to('xla') + +# Convert module's forward pass into a pure function +pure_forward = lambda params, buffers, x: functional_call(module, (params, buffers), (x,)) + +# Wrap the pure function with @assume_pure +cached_forward = assume_pure(pure_forward) + +# Simulate a training loop +# Even if we run the model ten times, its forward function will only be traced once. +params = dict(module.named_parameters()) +buffers = dict(module.named_buffers()) +for i in range(10): + x = torch.randn(5, 10, device='xla') + y = cached_forward(params, buffers, x) + print(y) +``` + +## Benchmarks + +The unit tests contain a benchmark that traces an example 100 layer decoder-only +language model: + +```sh +~/pytorch/xla +❯ TESTBRIDGE_TEST_ONLY=test_trace_transformer_with_spda_attention python3 test/test_assume_pure.py --benchmark_iterations 100 +[...] +No `@assume_pure` time: 140.1342 ms +`@assume_pure` time: 24.1658 ms +``` + +The version with `@assume_pure` is much faster. + +Importantly, the `@assume_pure` running time does not scale with increasing +complexity inside the model. That's because we only trace the model once, paying +a fixed up-front cost, and then later runs will reuse the cached XLA computation. + +## Limitations + +Currently, all operations in a function wrapped with `@assume_pure` must be +PyTorch upstream operations (e.g. `torch.einsum`, `torch.sin`, ...). More +PyTorch/XLA operations (e.g. `mark_sharding`) will be supported in the future. + + + +[lazy-tensor]: https://pytorch.org/blog/understanding-lazytensor-system-performance-with-pytorch-xla-on-cloud-tpu/ +[pure-function]: https://en.wikipedia.org/wiki/Pure_function diff --git a/test/run_tests.sh b/test/run_tests.sh index 62d24f4fe720..0e096148ae62 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -211,6 +211,7 @@ function run_xla_op_tests2 { run_test "$CDIR/test_callback.py" XLA_USE_SPMD=1 run_test "$CDIR/test_callback.py" run_test "$CDIR/test_jax_interop.py" + run_test "$CDIR/test_assume_pure.py" } # All the new xla op tests should go to run_xla_op_tests3 diff --git a/test/test_assume_pure.py b/test/test_assume_pure.py new file mode 100644 index 000000000000..28e2165b9f43 --- /dev/null +++ b/test/test_assume_pure.py @@ -0,0 +1,440 @@ +from copy import deepcopy +from absl.testing import absltest +from absl import flags +import time +import unittest + +import torch +import torch.nn as nn +import torch_xla +import torch_xla.core.xla_builder as xb +import torch_xla.runtime as xr +from torch_xla.experimental.assume_pure import assume_pure +from torch_xla._internal.jax_workarounds import jax_import_guard + + +def assert_gradients_close(test_case, actual, expected): + """Checks that the gradients of the `actual` tensor is close to the gradients of the `expected` tensor.""" + + grad1 = actual.grad + grad2 = expected.grad + if grad1 is None and grad2 is None: + test_case.fail("Both gradients are None, which is unexpected") + elif grad1 is None or grad2 is None: + test_case.fail( + f"Gradient mismatch: one is None, the other is not. Grad1: {grad1}, Grad2: {grad2}" + ) + else: + torch.testing.assert_close( + grad1.detach(), + grad2.detach(), + msg=lambda s: f"Gradients do not match {s}", + check_device=False) + + +class TestAssumePure(absltest.TestCase): + + def test_assume_pure_basic(self): + # Arrange + @assume_pure + def simple_torch_function(a, b): + return torch.sin(a @ b) + + # Act + a = torch.ones((3, 3), device='xla', requires_grad=True) + actual = simple_torch_function(a, a) + actual.sum().backward() + torch_xla.sync() + + # Assert + expected = torch.sin(torch.ones(3, 3) @ torch.ones(3, 3)) + torch.testing.assert_close(actual, expected, check_device=False) + + @unittest.skipUnless(xr.global_runtime_device_count() >= 2, + "Multiple devices required") + def test_assume_pure_other_xla_devices(self): + # Preconditions: ensure we have at least two XLA devices. + assert torch.device('xla:0') != torch.device('xla:1') + + # Arrange + @assume_pure + def simple_torch_function(a, b): + return torch.sin(a @ b) + + # Act: use an XLA device with ID 1. + a = torch.ones((3, 3), device='xla:1', requires_grad=True) + actual = simple_torch_function(a, a) + actual.sum().backward() + torch_xla.sync() + + # Assert + expected = torch.sin(torch.ones(3, 3) @ torch.ones(3, 3)) + torch.testing.assert_close(actual, expected, check_device=False) + + def test_assume_pure_module(self): + # Arrange + model = nn.Linear(3, 3).to('xla') + + @assume_pure + def simple_torch_function(params, x): + return torch.func.functional_call(model, params, x) + + # Act + a = torch.ones((3, 3), device='xla', requires_grad=True) + actual = simple_torch_function(dict(model.named_parameters()), a) + actual.sum().backward() + torch_xla.sync() + + # Assert + expected = model(torch.ones(3, 3).to('xla')) + torch.testing.assert_close(actual, expected, check_device=False) + + @unittest.skipIf(xr.device_type() == 'TPU', + "Bug: https://github.com/pytorch/xla/issues/8974") + def test_assume_pure_complex_module(self): + """Test a module comprising of some linear, conv, and relu layers.""" + + # Arrange: define module and prepare inputs. + class MyModule(nn.Module): + + def __init__(self): + super().__init__() + self.linear = nn.Linear(9, 9) + self.conv = nn.Conv2d(3, 3, kernel_size=(3, 3), padding=(1, 1)) + self.layer_norm = nn.LayerNorm(9) + self.relu = nn.ReLU() + self.flatten = nn.Flatten() + self.fc = nn.Linear(9 * 9 * 3, 3) + + def forward(self, x): + x = self.linear(x) + x = self.conv(x) + x = self.layer_norm(x) + x = self.relu(x) + x = self.flatten(x) + x = self.fc(x) + return x + + orig_model = MyModule() + pure_model = deepcopy(orig_model) + orig_model = orig_model.to('xla') + pure_model = pure_model.to('xla') + orig_params = dict(orig_model.named_parameters()) + pure_params = dict(pure_model.named_parameters()) + orig_x = torch.randn((5, 3, 9, 9), device='xla', requires_grad=True) + pure_x = orig_x.clone().detach().requires_grad_(True) + torch_xla.sync() + + # Act: call module in a pure way. + orig_output = orig_model(orig_x) + pure_call = lambda params, x: torch.func.functional_call( + pure_model, params, x) + pure_output = assume_pure(pure_call)(pure_params, pure_x) + torch_xla.sync() + + # Assert + # Check that the outputs are close + torch.testing.assert_close(pure_output, orig_output, check_device=False) + + # Check that the gradients are close + orig_output.sum().backward() + pure_output.sum().backward() + torch_xla.sync() + assert_gradients_close(self, pure_x, orig_x) + for name, _ in orig_model.named_parameters(): + orig_param = orig_params[name] + pure_param = pure_params[name] + assert_gradients_close(self, orig_param, pure_param) + + def test_assume_pure_avoid_retracing_avoid_rejit(self): + """Tests that we avoid retracing and re-jitting when using assume_pure.""" + + # Arrange: first clear the cache to prevent contamination from other tests. + xb._JAX_TO_XLA_COMPUTATION_CACHE.clear() + starting_lowerings = xb._jax_to_xla_computation_cache_num_misses() + trace_counter = 0 + + @assume_pure + def simple_torch_function(a, b): + nonlocal trace_counter + trace_counter += 1 + return torch.sin(a @ b) + + # Act: simulate a training loop. + for _ in range(5): + a = torch.ones((3, 3), device='xla', requires_grad=True) + o = simple_torch_function(a, a) + o.sum().backward() + torch_xla.sync() + + # Assert + ending_lowerings = xb._jax_to_xla_computation_cache_num_misses() + + # Check that we only trace once. + self.assertEqual(trace_counter, 1) + + # Check that we only lower to HLO twice (once for forward, once for backward). + self.assertEqual(ending_lowerings - starting_lowerings, 2) + + def test_assume_pure_matmul_grads(self): + """Tests matmul with all inputs requiring gradients.""" + + # Arrange + def matmul_fn(a, b): + return a @ b + + # Prepare inputs (cloned for independent grad computation) + a_orig = torch.randn(4, 5, device='xla', requires_grad=True) + b_orig = torch.randn(5, 3, device='xla', requires_grad=True) + a_pure = a_orig.clone().detach().requires_grad_(True) + b_pure = b_orig.clone().detach().requires_grad_(True) + + # Act + # Forward pass + output_orig = matmul_fn(a_orig, b_orig) + output_pure = assume_pure(matmul_fn)(a_pure, b_pure) + + # Backward pass + loss_orig = output_orig.sum() + loss_pure = output_pure.sum() + + loss_orig.backward() + loss_pure.backward() + torch_xla.sync() + + # Assert + # Check forward pass equivalence + torch.testing.assert_close( + output_pure, + output_orig, + msg="Forward outputs do not match", + check_device=False) + + # Check gradients + assert_gradients_close(self, a_pure, a_orig) + assert_gradients_close(self, b_pure, b_orig) + + @unittest.skipIf(xr.device_type() == 'TPU', + "Bug: https://github.com/pytorch/xla/issues/8975") + def test_assume_pure_einsum_grads(self): + """Tests einsum with all inputs requiring gradients.""" + + # Arrange + def einsum_fn(x, y): + return torch.einsum('bij,bjk->bik', x, y) + + # Prepare inputs + x_orig = torch.randn(2, 3, 4, device='xla', requires_grad=True) + y_orig = torch.randn(2, 4, 5, device='xla', requires_grad=True) + x_pure = x_orig.clone().detach().requires_grad_(True) + y_pure = y_orig.clone().detach().requires_grad_(True) + + # Act + # Forward pass + output_orig = einsum_fn(x_orig, y_orig) + output_pure = assume_pure(einsum_fn)(x_pure, y_pure) + torch.testing.assert_close( + output_pure, + output_orig, + msg=lambda msg: f"Forward outputs do not match: {msg}", + check_device=False) + + # Backward pass + output_orig.sum().backward() + output_pure.sum().backward() + torch_xla.sync() + + # Assert + # Check gradients + assert_gradients_close(self, x_pure, x_orig) + assert_gradients_close(self, y_pure, y_orig) + + def test_assume_pure_partial_grads_args(self): + """Tests a function where only some positional inputs require gradients. + + In this test, tensor a, c require grad; b does not. + """ + + # Arrange + def fn(a, b, c): + return a * torch.tanh(b) + c**2 + + # Prepare inputs + torch_xla.manual_seed(42) + a_orig = torch.randn(3, 3, device='xla', requires_grad=True) + # No grad for b + b_orig = torch.randn(3, 3, device='xla', requires_grad=False) + c_orig = torch.randn(3, 3, device='xla', requires_grad=True) + + a_pure = a_orig.clone().detach().requires_grad_(True) + # No grad for b + b_pure = b_orig.clone().detach().requires_grad_(False) + c_pure = c_orig.clone().detach().requires_grad_(True) + + # Act + # Forward pass + output_orig = fn(a_orig, b_orig, c_orig) + output_pure = assume_pure(fn)(a_pure, b_pure, c_pure) + torch.testing.assert_close( + output_pure, + output_orig, + msg="Forward outputs do not match", + check_device=False) + + # Backward pass + output_orig.sum().backward() + output_pure.sum().backward() + torch_xla.sync() + + # Assert + # Check gradients + assert_gradients_close(self, a_pure, a_orig) + assert_gradients_close(self, c_pure, c_orig) + + self.assertIsNotNone(a_orig.grad, "a_orig should have grad") + self.assertIsNone(b_orig.grad, "b_orig should not have grad") + self.assertIsNone(b_pure.grad, "b_pure should not have grad") + self.assertIsNotNone(c_orig.grad, "a_orig should have grad") + + def test_assume_pure_partial_grads_kwargs(self): + """Tests a function where inputs requiring gradients are passed via kwargs.""" + + # Arrange + def fn(x, *, factor, bias): + # x, bias require grad; factor does not + # factor is a non-tensor kwarg, bias is a tensor kwarg + return x * factor + bias + + # Prepare inputs + x_orig = torch.randn(3, 3, device='xla', requires_grad=True) + bias_orig = torch.randn(3, 3, device='xla', requires_grad=True) + factor_val = 2.5 # Non-tensor kwarg + + x_pure = x_orig.clone().detach().requires_grad_(True) + bias_pure = bias_orig.clone().detach().requires_grad_(True) + + # Act + # Forward pass + output_orig = fn(x_orig, factor=factor_val, bias=bias_orig) + output_pure = assume_pure(fn)(x_pure, factor=factor_val, bias=bias_pure) + torch.testing.assert_close( + output_pure, + output_orig, + msg="Forward outputs do not match", + check_device=False) + + # Backward pass + output_orig.sum().backward() + output_pure.sum().backward() + torch_xla.sync() + + # Assert + # Check gradients + assert_gradients_close(self, x_pure, x_orig) + assert_gradients_close(self, bias_pure, bias_orig) + # Factor is not a tensor, so it won't have a .grad attribute. Nothing to check here. + + def test_assume_pure_no_grads_needed(self): + """Tests a function where no inputs require gradients.""" + + # Arrange + def original_func(a, b): + return torch.cos(a) + torch.sin(b) + + # Prepare inputs + a_orig = torch.randn(3, 3, device='xla', requires_grad=False) + b_orig = torch.randn(3, 3, device='xla', requires_grad=False) + a_pure = a_orig.clone().detach().requires_grad_(False) + b_pure = b_orig.clone().detach().requires_grad_(False) + + # Act + # Forward pass + output_orig = original_func(a_orig, b_orig) + output_pure = assume_pure(original_func)(a_pure, b_pure) + torch_xla.sync() + + # Assert + # Check outputs + torch.testing.assert_close( + output_pure, + output_orig, + msg="Forward outputs do not match", + check_device=False) + + # Check gradients + self.assertFalse(output_orig.requires_grad) + self.assertFalse(output_pure.requires_grad) + self.assertIsNone(a_orig.grad) + self.assertIsNone(b_orig.grad) + self.assertIsNone(a_pure.grad) + self.assertIsNone(b_pure.grad) + + +FLAGS = flags.FLAGS +flags.DEFINE_integer( + name='benchmark_iterations', + default=3, + help='Number of iterations to run the tracing benchmark test.') + + +class TracingBenchmark(absltest.TestCase): + + def test_trace_transformer_with_spda_attention(self): + num_iterations = FLAGS.benchmark_iterations + print(f"\nRunning benchmark with {num_iterations} iterations") + + import sys + import os + example_folder = os.path.dirname(os.path.dirname(__file__)) + "/examples" + sys.path.append(example_folder) + from decoder_only_model import DecoderOnlyConfig, DecoderOnlyModel # type:ignore + + config = DecoderOnlyConfig( + hidden_size=128, + num_hidden_layers=100, + intermediate_size=8 * 128, + vocab_size=256) + model = DecoderOnlyModel(config=config).to('xla') + batch_size = 2 + sequence_length = 8 + + # Generate random input_ids within the range of the vocabulary size + input_ids = torch.randint(0, config.vocab_size, + (batch_size, sequence_length)).to('xla') + + pure_model = deepcopy(model) + torch_xla.sync() + + # Test tracing the model normally. + model(input_ids) # Warm up + start_time = time.time() + for _ in range(num_iterations): + model(input_ids) + end_time = time.time() + model_time = (end_time - start_time) / num_iterations + print(f"No `@assume_pure` time: {model_time * 1000:.4f} ms") + + # Test tracing the model with assume_pure. + @assume_pure + def pure_call(params, x): + return torch.func.functional_call(pure_model, params, x) + + params = dict(pure_model.named_parameters()) + pure_call(params, input_ids) # Warm up + start_time = time.time() + for _ in range(num_iterations): + pure_call(params, input_ids) + end_time = time.time() + pure_model_time = (end_time - start_time) / num_iterations + print(f"`@assume_pure` time: {pure_model_time * 1000:.4f} ms") + + +if __name__ == "__main__": + torch.set_default_dtype(torch.float32) + torch.manual_seed(42) + torch_xla.manual_seed(42) + torch_xla._XLAC._xla_set_mat_mul_precision('highest') + jax_import_guard() + import torchax + torchax.enable_accuracy_mode() + absltest.main() diff --git a/test/tpu/run_tests.sh b/test/tpu/run_tests.sh index 775d76ba6f41..5cd1b4a45f9d 100755 --- a/test/tpu/run_tests.sh +++ b/test/tpu/run_tests.sh @@ -36,6 +36,7 @@ python3 "$TEST_CDIR/scan/test_scan_spmd.py" python3 "$TEST_CDIR/scan/test_scan_pallas.py" python3 "$TEST_CDIR/scan/test_scan_layers.py" python3 "$TEST_CDIR/test_gru.py" +python3 "$TEST_CDIR/test_assume_pure.py" python3 "$TEST_CDIR/test_as_stride_use_slice.py" run_xla_hlo_debug python3 "$TEST_CDIR/scan/test_scan_debug.py" python3 "$TEST_CDIR/test_pallas.py" -v diff --git a/torch_xla/core/xla_builder.py b/torch_xla/core/xla_builder.py index 97b49df75604..c060a76a2ee8 100644 --- a/torch_xla/core/xla_builder.py +++ b/torch_xla/core/xla_builder.py @@ -874,13 +874,10 @@ def jax_func_to_xla_computation(jax_func, args, kwargs, name=None): flattened_inputs, spec = jax.tree.flatten((args, kwargs)) def abstractify(a): # make a pytree leaf abstract - import jax - import torch_xla if a is None: return None if isinstance(a, torch.Tensor): - assert a.device == torch_xla.device( - ), f"Inputs must be XLA tensors. Got {a.device}" + assert a.device.type == 'xla', f"Inputs must be XLA tensors. Got {a.device}" return jax.ShapeDtypeStruct(a.shape, mappings.t2j_dtype(a.dtype)) return a diff --git a/torch_xla/experimental/assume_pure.py b/torch_xla/experimental/assume_pure.py new file mode 100644 index 000000000000..c52ba997845e --- /dev/null +++ b/torch_xla/experimental/assume_pure.py @@ -0,0 +1,36 @@ +from torch_xla._internal.jax_workarounds import requires_jax +import torch_xla.core.xla_builder as xb + + +@requires_jax +def assume_pure(fn): + """Decorates a pure PyTorch/XLA function to skip expensive re-tracing. + + Returns a new function that will only be traced once for each unique + input tensor shapes or non-tensor input argument values. This is useful + for removing Lazy Tensor tracing overhead. + + The decorated function must be pure (i.e. no side-effects, behavior + only depends on inputs). + + Limitations: + - The decorated function can only use upstream PyTorch operators e.g. + `torch.einsum`, `torch.nn.functional.layer_norm`. Custom PyTorch/XLA + operations such as `mark_sharding` are not supported. This limitation + may be lifted in the future. + """ + from torchax.interop import jax_view + return j2t_autograd(jax_view(fn)) + + +@requires_jax +def j2t_autograd(fn): + """Given a JAX function, returns a PyTorch autograd function implemented with `jax.vjp(fn)`. + + It wraps `fn` with `jax.vjp` to compute both the output and residuals (intermediate + activations). The wrapped function is then run via `call_jax` and integrated into + the PyTorch autograd framework by saving the residuals into the context object. + """ + import torchax.interop + return torchax.interop.j2t_autograd( + fn, call_jax=lambda fn, *args: xb.call_jax(fn, args)) diff --git a/torchax/test/test_interop.py b/torchax/test/test_interop.py index 32427602bc4f..e115db0fcfea 100644 --- a/torchax/test/test_interop.py +++ b/torchax/test/test_interop.py @@ -1,5 +1,6 @@ import torch import unittest +import torchax from torchax import interop class M1(torch.nn.Module): @@ -41,6 +42,38 @@ def test_mod_attr(self): self.assertEqual(m.a.weight.item(), 0) self.assertEqual(m.m1.x.item(), 0) + def test_j2t_autograd_forward(self): + with torchax.default_env(): + # Setup + def fn(x): + return x + 1 + + j2t_fn = interop.j2t_autograd(fn) + x = torch.ones(2, 2, requires_grad=True, device='jax') + + # Act + actual = j2t_fn(x) + + # Assert + expected = torch.ones(2, 2) + 1 + torch.testing.assert_close(actual, expected, check_device=False) + + def test_j2t_autograd_backward(self): + with torchax.default_env(): + # Setup + def fn(x): + return x * 2 + + j2t_fn = interop.j2t_autograd(fn) + x = torch.ones(2, 2, device='jax').requires_grad_() + + # Act + actual = j2t_fn(x) + actual.sum().backward() + + # Assert + expected = torch.ones(2, 2) * 2 + torch.testing.assert_close(x.grad, expected, check_device=False) if __name__ == '__main__': diff --git a/torchax/torchax/interop.py b/torchax/torchax/interop.py index a8c7ea5fa8cc..11ac604e8f60 100644 --- a/torchax/torchax/interop.py +++ b/torchax/torchax/interop.py @@ -1,6 +1,8 @@ import copy import functools import torch +from inspect import signature +from functools import wraps from torch.nn.utils import stateless as torch_stateless import jax import jax.numpy as jnp @@ -180,6 +182,99 @@ def call_torch(torch_func: TorchCallable, *args: JaxValue, **kwargs: JaxValue) - return jax_view(res) +def j2t_autograd(fn, call_jax=call_jax): + """Given a JAX function, returns a PyTorch autograd function implemented with `jax.vjp(fn)`. + + It wraps `fn` with `jax.vjp` to compute both the output and residuals (intermediate + activations). The wrapped function is then run via `call_jax` and integrated into + the PyTorch autograd framework by saving the residuals into the context object. + """ + + @wraps(fn) + def inner(*args, **kwargs): + from jax.tree_util import tree_flatten, tree_unflatten + from jax.util import safe_zip + + class JaxFun(torch.autograd.Function): + @staticmethod + def forward(ctx, tree_def, *flat_inputs): + # Reconstruct the original args and kwargs + args, kwargs = tree_unflatten(tree_def, flat_inputs) + + # Execute the JAX computation + # Pass the reconstructed args/kwargs tuple as the primal + y, fun_vjp = call_jax( + _jax_forward, + fn, + (args, kwargs), + ) + + # Save necessary information for backward + # Flatten the vjp function. `vjp_spec` contains a jaxpr for the backward pass. + # `residuals` contains the tensors needed for the backward pass.` + residuals, vjp_spec = tree_flatten(fun_vjp) + ctx.vjp_spec = vjp_spec + ctx.save_for_backward(*residuals) + + return y + + @staticmethod + def backward(ctx, *grad_out): + assert len(grad_out) > 0 + grad_out = grad_out if len(grad_out) > 1 else grad_out[0] + + input_grads_structured = call_jax( + _jax_backward, ctx.vjp_spec, ctx.saved_tensors, grad_out + ) + + # Flatten the gradients to match the flat inputs to forward + flat_input_grads, _ = tree_flatten(input_grads_structured) + + # Construct the gradient tuple to be returned. + # It needs to match the inputs to forward: (tree_def, *flat_inputs) + # The first gradient (for tree_def) is None. + # The subsequent gradients correspond to flat_inputs. + # We need to put a None for inputs that did not require gradients. + final_grads = [None] + for needs_grad, grad in safe_zip( + ctx.needs_input_grad[1:], flat_input_grads + ): + final_grads.append(grad if needs_grad else None) + + return tuple(final_grads) + + sig = signature(fn) + bound = sig.bind(*args, **kwargs) + bound.apply_defaults() + flat_args_kwargs, tree_def = tree_flatten((bound.args, bound.kwargs)) + return JaxFun.apply(tree_def, *flat_args_kwargs) + + return inner + + +def _jax_forward(fn, primals): + """JAX function to compute output and vjp function. + + primals should be a tuple (args, kwargs). + """ + import jax + + def fn_wrapper(a, kw): + return fn(*a, **kw) + + return jax.vjp(fn_wrapper, *primals) + + +def _jax_backward(vjp_spec, saved_tensors, grad_out): + """JAX function to compute input gradients. + + Unflattening `saved_tensors` with `vjp_spec` should restore the original vjp function. + """ + from jax.tree_util import tree_unflatten + fun_vjp = tree_unflatten(vjp_spec, saved_tensors) + return fun_vjp(grad_out) + + fori_loop = torch_view(jax.lax.fori_loop)