Skip to content

Commit

Permalink
Fix memory leak in saved tensors (#1688)
Browse files Browse the repository at this point in the history
  • Loading branch information
t-vi authored Jan 24, 2025
1 parent c951fdb commit 52ee541
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 25 deletions.
2 changes: 1 addition & 1 deletion notebooks/writing_a_trace_transform_cpu_offloading.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,7 @@
"\n",
"# Verify that saved tensors are on CPU.\n",
"saved_tensor_devices = set()\n",
"for t in actual.grad_fn.next_functions[0][0].saved_tensors:\n",
"for t in actual.grad_fn.saved_tensors:\n",
" saved_tensor_devices.add(str(t.device))\n",
"\n",
"assert \"cpu\" in saved_tensor_devices # Verify that we actually have saved tensors on CPU\n",
Expand Down
46 changes: 26 additions & 20 deletions thunder/executors/torch_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,16 +105,18 @@ def detach_if_tensor(t):

saved_tensors = tuple(map(detach_if_tensor, saved_tensors))

# We must save tensors using ctx.save_for_backward
ctx.save_for_backward(*saved_tensors)

ctx.side_channel = side_channel
if side_channel is not None:
assert not side_channel
ctx.side_channel["fw"] = flat_output

# We must save tensors using ctx.save_for_backward but
# we want to save the tensors in the function returning the outputs to avoid memory leaks
# (basically ref-cycles via output.grad_fn.next_functions[0, 0].saved_tensors[0] == output
# PyTorch autograd handles this gracefully for output.grad_fn.saved_tensors)
ctx.side_channel["tensors_to_save"] = saved_tensors
return torch.randn(1, device="meta", requires_grad=True)
else:
ctx.save_for_backward(*saved_tensors)
return flat_output

# NOTE: If `torch.autograd.function.once_differentiable` is to be removed,
Expand All @@ -125,25 +127,26 @@ def detach_if_tensor(t):
def backward(ctx, *raw_args):
if ctx.side_channel is not None:
args = ctx.side_channel.pop("bw")
saved_tensors_list = ctx.side_channel.pop("saved_tensors")
assert not ctx.side_channel
else:
args = list(raw_args)
# ctx.saved_tensors is a tuple of tensors saved in forward. Our compiled
# backward is a really long function that takes all the tensors saved in
# forward and gradually uses them to compute the gradients of the
# inputs. Unfortunately, Python holds a reference to all arguments of a
# function until the function returns, even if we delete the variable
# "saved_tensors" inside the function, the tensors will still be held in
# memory until the function returns. Fortunately, Python passes mutable
# objects by reference, so we can just replace the saved_tensors with an
# empty list and the memory will be freed immediately. We must also
# delete the reference to the saved_tensors in the context, otherwise
# the memory will be freed only when the context is deleted.
saved_tensors_list = list(ctx.saved_tensors) # Make a copy as we will mutate it

# This is an undocumented API, but it's the only way to clear the
# reference to the saved tensors in the context
ctx.maybe_clear_saved_tensors() # Delete the reference to all saved tensors in the context
# ctx.saved_tensors is a tuple of tensors saved in forward. Our compiled
# backward is a really long function that takes all the tensors saved in
# forward and gradually uses them to compute the gradients of the
# inputs. Unfortunately, Python holds a reference to all arguments of a
# function until the function returns, even if we delete the variable
# "saved_tensors" inside the function, the tensors will still be held in
# memory until the function returns. Fortunately, Python passes mutable
# objects by reference, so we can just replace the saved_tensors with an
# empty list and the memory will be freed immediately. We must also
# delete the reference to the saved_tensors in the context, otherwise
# the memory will be freed only when the context is deleted.
saved_tensors_list = list(ctx.saved_tensors) # Make a copy as we will mutate it

# This is an undocumented API, but it's the only way to clear the
# reference to the saved tensors in the context
ctx.maybe_clear_saved_tensors() # Delete the reference to all saved tensors in the context
grads = ctx.compiled_backward([saved_tensors_list, ctx.saved_other], args)

assert not args
Expand All @@ -165,13 +168,16 @@ def forward(ctx, dummy, side_channel, *args):
ctx.side_channel = side_channel
ctx.num_args = len(args)
res = ctx.side_channel.pop("fw")
ctx.save_for_backward(*ctx.side_channel.pop("tensors_to_save"))
assert not ctx.side_channel
return res

@staticmethod
def backward(ctx, *args):
assert not ctx.side_channel
ctx.side_channel["bw"] = list(args)
ctx.side_channel["saved_tensors"] = list(ctx.saved_tensors) # see above
ctx.maybe_clear_saved_tensors() # Delete the reference to all saved tensors in the context
return torch.randn(1, device="meta"), None, *([None] * ctx.num_args)


Expand Down
23 changes: 23 additions & 0 deletions thunder/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3187,3 +3187,26 @@ def fn(x):
): # prims is unpack_sequence and any output is TensorProxy
# Verify that we print information about the unpacked TensorProxy.
assert "cpu f32[3]" in str(bsym)


def test_apply_autograd_memory():
from thunder.executors.torch_autograd import connect_to_autograd

def foo():
def backward(*args):
return None

x = torch.randn(2, 2, requires_grad=True)
o = x.sum()

connect_to_autograd(
backward_fn=backward,
flat_args=(x,),
flat_output=(o,),
saved_tensors=(o,),
saved_other=(),
return_none_instead_of_grads=True,
)
return [weakref.ref(x), weakref.ref(o)]

assert not any(wr() for wr in foo())
8 changes: 4 additions & 4 deletions thunder/tests/test_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -1752,10 +1752,10 @@ def f(x, y):

# With activation checkpointing, we are saving only the original input.
# The intermediate values are recomputed during backward pass.
assert len(out.grad_fn.next_functions[0][0].saved_tensors) == 2
assert len(out.grad_fn.saved_tensors) == 2
# We detach the saved tensors (which returns a new Python tensor backed by same storage)
# the order seems to be non-deterministic sometimes
assert {t.data_ptr() for t in out.grad_fn.next_functions[0][0].saved_tensors} == {x.data_ptr(), y.data_ptr()}
assert {t.data_ptr() for t in out.grad_fn.saved_tensors} == {x.data_ptr(), y.data_ptr()}

g = torch.ones_like(out)
out.backward(g)
Expand Down Expand Up @@ -1948,8 +1948,8 @@ def fn(a):
a = torch.randn(2, 2, device=device, requires_grad=True)
res = jfn(a)
res2 = jfn2(a)
assert len(res.grad_fn.next_functions[0][0].saved_tensors) == 3 # should be decomposed
assert len(res2.grad_fn.next_functions[0][0].saved_tensors) == 1
assert len(res.grad_fn.saved_tensors) == 3 # should be decomposed
assert len(res2.grad_fn.saved_tensors) == 1

if NVFUSER_AVAILABLE and device == "cuda":
# check everything is fused
Expand Down

0 comments on commit 52ee541

Please sign in to comment.