Skip to content

@assume_pure #8962

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Apr 16, 2025
Merged

@assume_pure #8962

merged 14 commits into from
Apr 16, 2025

Conversation

tengyifei
Copy link
Collaborator

@tengyifei tengyifei commented Apr 11, 2025

Fixes #8805.

We introduce a decorator, @assume_pure, that can be placed on PyTorch/XLA functions and easily eliminate lazy tensor tracing overhead. If you have a pure function that only uses torch upstream ops, that function can be decorated with @assume_pure and will only be traced once for each unique input tensor shape combinations.

Design

@assume_pure brings together three pieces of existing technologies:

  • jax.vjp, which takes a JAX function and gives you the autograd forward and backward pass
  • torchax, which converts a pure PyTorch function to a JAX function
  • xb.call_jax, which can call any JAX function from PyTorch/XLA and integrate it into the HLO graph

It works by:

  • Use torchax.interop.jax_view to obtain a JAX function from the input PyTorch function
  • Use jax.vjp to get the forward and backward pass
  • Return a torch.autograd.Function instance, where the forward implementation is xb.call_jax(forward_pass), and the backward implementation is xb.call_jax(backward_pass), respectively.

The core logic is actually just a single line:

def assume_pure(fn):
  from torchax.interop import jax_view
  return j2t_autograd(jax_view(fn))

How is the HLO cached

xb.call_jax caches the HLO if all the input shapes/dtypes and non-tensor arguments are the same.

Therefore, subsequent xb.call_jax will just reuse the cached HLO instead of retracing.

The same kind of caching happens in both the forward and backward pass.

Different from the jax wrapper we used in splash_attention, the j2t_autograd function saves the residuals (intermediate activations) during the forward pass and reuses them during the backward by plugging those into the vjp_fun again. This means it won't force a rematerialization (rerun the fwd) during the backward.

Alternatives

Instead of jax.vjp we could also use AOTAutograd to get the forward and backward pass. However, AOTAutograd has a number of downsides:

  • It does more than just getting the backward. It also forcefully decomposes all operations into the "aten" op set. Decomposing operations will negatively impact performance, especially in the case of einsum.
  • There is no straightforward path to support profiler trace spans. In contrast, in the proposed approach we could translate xp.Trace(...) to jax.named_scope(...).
  • Supporting custom operations such as pallas kernels will be cumbersome. We'll need to wrap every kernel into a PyTorch custom operator in order for AOTAutograd to not crash on those functions. In contrast, in the proposed approach we could augment our pallas kernels to directly jump into JAX when the input tensor is a torchax tensor.

Instead of assume_pure, we could also use torch.compile to cache the XLA executable of the compiled function and skip the lazy tensor tracing. However, torch.compile has its own downsides:

  • torch.compile itself uses AOTAutograd and will suffer from the decomposition and customer operations issues etc.
  • torch.compile has a general perception of "either it works, or debugging will be complicated", which has been corroborated by experiments by people in the PyTorch/XLA team. See PyTorch team members' own recommendation 1. In contrast, @assume_pure has very simple rules for determining if it will work: if your function is pure, then it works.
  • torch.compile will graph break when entering and leaving the compiled region. In contrast, @assume_pure can avoid tracing overhead without even breaking the graph. The cached HLO is inlined into the overall HLO.

Benchmarks

I tested tracing an example 100 layer decoder-only model:

~/pytorch/xla
❯ TESTBRIDGE_TEST_ONLY=test_trace_transformer_with_spda_attention python3 test/test_assume_pure.py --benchmark_iterations 100
[...]
No `@assume_pure` time: 140.1342 ms
`@assume_pure` time: 24.1658 ms

Importantly, the @assume_pure does not scale with increasing complexity inside the model. That's because we only trace the model once, paying a fixed up-front cost, and then later runs will reuse the cached XLA computation object.

Anecdotally, @bhavya01 reported saving >200ms tracing time in an SDXL experiment. That's very significant since each training step is sub-1 second.

@tengyifei tengyifei marked this pull request as ready for review April 11, 2025 01:15
@tengyifei tengyifei force-pushed the yifeit/vjp-in-xla branch 5 times, most recently from 514c163 to cfe2e41 Compare April 11, 2025 06:04
@tengyifei tengyifei requested review from bhavya01, qihqi and zpcore April 11, 2025 06:05
@zpcore
Copy link
Collaborator

zpcore commented Apr 11, 2025

This is GREAT! Looking foward to this feature!

@yaoshiang
Copy link
Collaborator

yaoshiang commented Apr 14, 2025

Another issue to think about: the naming @assume_pure...

Some decorators change the function's behavior, like, @staticmethod.

Others change the behavior of the function to the caller, like torch's @torch.no_grad.

@assume_pure is closer to @no_grad... the function behavior itself isn't changing, but it's a directive to some callers to treat this function differently. no_grad is great because we immediately know the caller: the autograd system.

@assume_pure doesn't really indicate which caller is being directed. Is there a name that includes the specific caller. Maybe something like @torch_xla.compile(pure=True)? Then it's really explicit who this decorator is for.

It may be too hard to think through all the ways we expose lazy tensor and harmonize them in the timeframe of this PR, so consider this optional, and we may have to refactor all the ways we discuss lazy tensor and compilation at some point in the future. (e.g. mark step, xm.optimizer_step, torch_xla.compile, etc).

@bhavya01
Copy link
Collaborator

Thanks for the great work! Just confirming that when using assume_pure with gradient checkpointing, it will still store the residuals for the forward function wrapped by assume_pure, right? Otherwise, LGTM

@tengyifei
Copy link
Collaborator Author

when using assume_pure with gradient checkpointing, it will still store the residuals for the forward function wrapped by assume_pure

Depends on how those two are combined.

  • torch_xla.utils.checkpoint(assume_pure(fn)): then the PyTorch autograd will discard the residuals and recompute them again during the backward pass. It just works like we wanted checkpointing to work.
  • assume_pure(torch_xla.utils.checkpoint(fn)): this case doesn't work as torch_xla.utils.checkpoint will attempt to add an optimization barrier onto torchax tensors. But supporting this falls in the same bucket as supporting xs.mark_sharding etc: we just need to capture that function in torchax and handle it accordingly.

@tengyifei
Copy link
Collaborator Author

It may be too hard to think through all the ways we expose lazy tensor and harmonize them in the timeframe of this PR, so consider this optional, and we may have to refactor all the ways we discuss lazy tensor and compilation at some point in the future. (e.g. mark step, xm.optimizer_step, torch_xla.compile, etc).

Ack. I don't have a great immediate thought (maybe @pure_function_avoid_retrace? a mouthful..), but I have circulated an internal proposal about using torchax and JAX etc, that has a section with naming ideas.

@tengyifei
Copy link
Collaborator Author

This is ready for another look

Copy link
Collaborator

@yaoshiang yaoshiang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tests appear to confirm the expected behavior of the decorator.

@tengyifei tengyifei merged commit cf764ff into master Apr 16, 2025
24 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Use jax autograd from PyTorch
6 participants