Skip to content

Commit

Permalink
tmp debug
Browse files Browse the repository at this point in the history
  • Loading branch information
kiya00 committed Dec 17, 2024
1 parent f3d358f commit d201c8c
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 6 deletions.
2 changes: 2 additions & 0 deletions thunder/core/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2495,6 +2495,8 @@ def is_constant_for_vjp(symbol: prims.Symbol) -> bool:
Returns:
bool: True if the symbol is constant, False otherwise.
"""
if isinstance(symbol.sym.id, str) and symbol.sym.id.startswith("higher_order_autograd_function_apply"):
return False
are_all_args_non_differentiable = not any(isinstance(arg, (FloatProxy, TensorProxy)) for arg in symbol.flat_args)
# Symbol's tag their output in `torch.no_grad` regions with `DETACHED_AUTOGRAD_GRAPH`.
# These are treated as constant for VJP.
Expand Down
19 changes: 13 additions & 6 deletions thunder/tests/test_dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,27 +280,34 @@ def backward(ctx, g):
return g * torch.cos(x)

def func(x):
y = torch.cos(x) + Sin.apply(x)
return torch.matmul(x, y)
# y = torch.cos(x) + Sin.apply(x)
# return torch.matmul(x, y)
return Sin.apply(x)

expected = torch.compile(func, dynamic=dynamic)(x)

backend = ThunderCompiler()
cfunc = torch.compile(func, backend=backend, dynamic=dynamic)
actual = cfunc(x)

targets = (node.target for node in backend.subgraph_infos[0].split_graph_module.graph.nodes)
assert any(target.startswith("thunder_") for target in targets)
assert any(target.startswith("inductor_") for target in targets)
# targets = (node.target for node in backend.subgraph_infos[0].split_graph_module.graph.nodes)
# assert any(target.startswith("thunder_") for target in targets)
# assert any(target.startswith("inductor_") for target in targets)

# Verify forward pass
torch.testing.assert_close(actual, expected)

# Verify backward pass
g = torch.rand_like(actual)
actual_grad = torch.autograd.grad(actual, x, g)
actual_grad = torch.autograd.grad(actual, x, g, allow_unused=True)
expected_grad = torch.autograd.grad(expected, x, g)
torch.testing.assert_close(actual_grad, expected_grad)
from thunder import last_backward_traces

assert len(backend.subgraph_infos) == 1
assert len(backend.subgraph_infos[0].thunder_compiled_fns) == 1
print(last_traces(backend.subgraph_infos[0].thunder_compiled_fns[0]))
print(last_backward_traces(backend.subgraph_infos[0].thunder_compiled_fns[0]))


@instantiate(
Expand Down

0 comments on commit d201c8c

Please sign in to comment.