-
Notifications
You must be signed in to change notification settings - Fork 519
@assume_pure #8962
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
@assume_pure #8962
Conversation
514c163
to
cfe2e41
Compare
This is GREAT! Looking foward to this feature! |
Another issue to think about: the naming Some decorators change the function's behavior, like, @staticmethod. Others change the behavior of the function to the caller, like torch's
It may be too hard to think through all the ways we expose lazy tensor and harmonize them in the timeframe of this PR, so consider this optional, and we may have to refactor all the ways we discuss lazy tensor and compilation at some point in the future. (e.g. mark step, xm.optimizer_step, torch_xla.compile, etc). |
cfe2e41
to
1642715
Compare
Thanks for the great work! Just confirming that when using |
Depends on how those two are combined.
|
Ack. I don't have a great immediate thought (maybe |
This is ready for another look |
aeee946
to
074eb25
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tests appear to confirm the expected behavior of the decorator.
Fixes #8805.
We introduce a decorator,
@assume_pure
, that can be placed on PyTorch/XLA functions and easily eliminate lazy tensor tracing overhead. If you have a pure function that only uses torch upstream ops, that function can be decorated with@assume_pure
and will only be traced once for each unique input tensor shape combinations.Design
@assume_pure
brings together three pieces of existing technologies:jax.vjp
, which takes a JAX function and gives you the autograd forward and backward passtorchax
, which converts a pure PyTorch function to a JAX functionxb.call_jax
, which can call any JAX function from PyTorch/XLA and integrate it into the HLO graphIt works by:
torchax.interop.jax_view
to obtain a JAX function from the input PyTorch functionjax.vjp
to get the forward and backward passtorch.autograd.Function
instance, where the forward implementation isxb.call_jax(forward_pass)
, and the backward implementation isxb.call_jax(backward_pass)
, respectively.The core logic is actually just a single line:
How is the HLO cached
xb.call_jax
caches the HLO if all the input shapes/dtypes and non-tensor arguments are the same.Therefore, subsequent
xb.call_jax
will just reuse the cached HLO instead of retracing.The same kind of caching happens in both the forward and backward pass.
Different from the jax wrapper we used in
splash_attention
, thej2t_autograd
function saves the residuals (intermediate activations) during the forward pass and reuses them during the backward by plugging those into thevjp_fun
again. This means it won't force a rematerialization (rerun the fwd) during the backward.Alternatives
Instead of
jax.vjp
we could also use AOTAutograd to get the forward and backward pass. However, AOTAutograd has a number of downsides:xp.Trace(...)
tojax.named_scope(...)
.Instead of
assume_pure
, we could also usetorch.compile
to cache the XLA executable of the compiled function and skip the lazy tensor tracing. However,torch.compile
has its own downsides:torch.compile
itself uses AOTAutograd and will suffer from the decomposition and customer operations issues etc.torch.compile
has a general perception of "either it works, or debugging will be complicated", which has been corroborated by experiments by people in the PyTorch/XLA team. See PyTorch team members' own recommendation 1. In contrast,@assume_pure
has very simple rules for determining if it will work: if your function is pure, then it works.torch.compile
will graph break when entering and leaving the compiled region. In contrast,@assume_pure
can avoid tracing overhead without even breaking the graph. The cached HLO is inlined into the overall HLO.Benchmarks
I tested tracing an example 100 layer decoder-only model:
Importantly, the
@assume_pure
does not scale with increasing complexity inside the model. That's because we only trace the model once, paying a fixed up-front cost, and then later runs will reuse the cached XLA computation object.Anecdotally, @bhavya01 reported saving >200ms tracing time in an SDXL experiment. That's very significant since each training step is sub-1 second.