Skip to content

Commit

Permalink
move adhoc executor first for priority in autograd (#1569)
Browse files Browse the repository at this point in the history
  • Loading branch information
t-vi authored Dec 18, 2024
1 parent eff3594 commit e37bec2
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 1 deletion.
2 changes: 1 addition & 1 deletion thunder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ def jit(
# Resolve names of executors
executors = resolve_executors(executors)
ad_hoc_executor = extend.AdHocExecutor()
executors = (*executors, ad_hoc_executor)
executors = (ad_hoc_executor, *executors)

# TODO: verify that tutorials don't have false positives and enable warning by default
# # Make sharp_edges == warn default if not supplied and if in the general jit
Expand Down
33 changes: 33 additions & 0 deletions thunder/tests/test_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -1853,3 +1853,36 @@ def forward(x):
actual_grad = torch.autograd.grad(actual, x, grad_o)
expected_grad = torch.autograd.grad(expected, x, grad_o)
torch.testing.assert_close(actual_grad, expected_grad)


@instantiate(
dtypes=NOTHING,
)
def test_adhoc_executor_grad(executor, device, _):
import torch
import thunder

x = torch.ones(2, device=device, requires_grad=True)

class Sin(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return torch.sin(x)

@staticmethod
def backward(ctx, g):
(x,) = ctx.saved_tensors
return g * torch.cos(x) * 200

def func(x):
return Sin.apply(x)

cfunc = thunder.jit(func)
actual = cfunc(x)
(actual_gr,) = torch.autograd.grad(actual.sum(), x)
expected = func(x)
(expected_gr,) = torch.autograd.grad(expected.sum(), x)

torch.testing.assert_close(actual, expected)
torch.testing.assert_close(actual_gr, expected_gr)

0 comments on commit e37bec2

Please sign in to comment.