Skip to content

Update augmented_forward+backward implementation to attach residuals to all outputs #1834

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions thunder/core/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
OrderedSet,
ProxyDict,
)
from thunder.core.codeutils import is_literal
import thunder.clang as clang
from thunder.clang import (
empty,
Expand Down Expand Up @@ -2329,6 +2330,8 @@ def iter_bound_symbols(bound_symbols):
for symbol in bound_symbols:
if symbol.sym.id in trace_interpreter_skip_list:
continue
elif all(is_literal(sym_out) for sym_out in symbol.flat_outs):
continue
elif symbol.output is None:
continue
else:
Expand Down Expand Up @@ -2626,12 +2629,7 @@ def vjp_impl_const(symbol, *args, **kwargs):
def _vjp_impl(*args, **kwargs):
primals, kwargs = tree_map(lambda x: x.primal if isinstance(x, VJPDual) else x, (args, kwargs))
out_primal, out_residuals = vjp_impl(*primals, **kwargs)
# We are saving the residuals and pullback only in the first output
# backward_pass then retrieves the residuals and pullback from the first output
if isinstance(out_primal, Sequence):
return (VJPDual(out_primal[0], out_residuals), *(VJPDual(o, tuple()) for o in out_primal[1:]))

return (VJPDual(out_primal, out_residuals),)
return tree_map(lambda x: VJPDual(x, out_residuals), sequencify(out_primal))

return _vjp_impl

Expand Down
20 changes: 20 additions & 0 deletions thunder/tests/test_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -1939,6 +1939,26 @@ def func(x):
torch.testing.assert_close(actual_gr, expected_gr)


def test_unused_first_output():
def forward(x):
_, x_2 = torch.split(x, 2)
return x_2

jforward = thunder.jit(forward)

x = make_tensor([4, 2], dtype=torch.bfloat16, device="cpu", requires_grad=True)

actual = jforward(x)
expected = forward(x)
torch.testing.assert_close(actual, expected)

grad_o = torch.randn_like(actual)

actual_grad = torch.autograd.grad(actual, x, grad_o)
expected_grad = torch.autograd.grad(expected, x, grad_o)
torch.testing.assert_close(actual_grad, expected_grad)


@pytest.mark.parametrize("device", ("cuda", "cpu"))
def test_backward_recomputation_decomposed_ops(device):
if device == "cuda" and not torch.cuda.is_available():
Expand Down
Loading