Skip to content

TE: Fix redundant compute for PEFT using transform #2138

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 20 commits into
base: main
Choose a base branch
from

Conversation

kshitij12345
Copy link
Collaborator

@kshitij12345 kshitij12345 commented May 26, 2025

Fixes: #2076

TODO

  • Tested with RTX6000, verify for sanity on H100 and B200.

Use a pass based on the backward trace to determine if wgrad and bgrad are computed and update forward trace accordingly.

Have added a test for the same (and verified existing tests single-GPU and distributed).

NOTE:

  • Changes only for v1 executor, similar changes are needed in v2 executor.

Example Program:

    with torch.device("cuda"):
        model = torch.nn.Sequential(*(torch.nn.Linear(32, 32, bias=False) for _ in range(4)))
        x = torch.randn(32, 32, requires_grad=True)

    for idx, parameters in enumerate(model.parameters()):
        # Every even linear layer's weight is frozen.
        if idx % 2 == 0:
            parameters.requires_grad = False

Forward Trace

@transformer_engine.fp8_autocast(fp8_recipe=te_fp8_recipe)
@torch.no_grad()
@no_autocast
def computation(input, t_0_weight, t_1_weight, t_2_weight, t_3_weight):
  # input: "cuda:0 f32[32, 32]"
  # t_0_weight: "cuda:0 f32[32, 32]"
  # t_1_weight: "cuda:0 f32[32, 32]"
  # t_2_weight: "cuda:0 f32[32, 32]"
  # t_3_weight: "cuda:0 f32[32, 32]"

  # /usr/local/lib/python3.12/dist-packages/torch/nn/modules/linear.py:125:             return F.linear(input, self.weight, self.bias)
  (t27, (t19, t20, t21, t22, t23, t24), ctx_te_1418) = te_linear_13(input, t_0_weight, None, input_requires_grad=True, weight_requires_grad=False, bias_requires_grad=False)
  (t42, (t34, t35, t36, t37, t38, t39, t40, t41), ctx_te_1533) = te_linear_14(t27, t_1_weight, None, input_requires_grad=True, weight_requires_grad=True, bias_requires_grad=False)
  del t27

  # /usr/local/lib/python3.12/dist-packages/torch/nn/modules/linear.py:125:             return F.linear(input, self.weight, self.bias)
  (t57, (t49, t50, t51, t52, t53, t54), ctx_te_1648) = te_linear_15(t42, t_2_weight, None, input_requires_grad=True, weight_requires_grad=False, bias_requires_grad=False)
  del t42

  # /usr/local/lib/python3.12/dist-packages/torch/nn/modules/linear.py:125:             return F.linear(input, self.weight, self.bias)
  (t72, (t64, t65, t66, t67, t68, t69, t70, t71), ctx_te_1763) = te_linear_16(t57, t_3_weight, None, input_requires_grad=True, weight_requires_grad=True, bias_requires_grad=False)
  del t57
  return {'output': (t72,), 'flat_args': [input, t_0_weight, t_1_weight, t_2_weight, t_3_weight], 'flat_output': (t72,)}, ((t19, t20, t21, t22, t23, t24, t34, t35, t36, t37, t38, t39, t40, t41, t49, t50, t51, t52, t53, t54, t64, t65, t66, t67, t68, t69, t70, t71), (ctx_te_1418, ctx_te_1533, ctx_te_1648, ctx_te_1763))

Backward Trace

def backward_fn(saved_for_backward, cotangents):
  # saved_for_backward: "Collection"
  # cotangents: "Collection"
  C0, C1, = saved_for_backward
  # C0: "Collection"
  # C1: "Collection"
  clear_mutable_collection(saved_for_backward)
  del saved_for_backward
  t73, = cotangents
  # t73: "cuda:0 f32[32, 32]"
  clear_mutable_collection(cotangents)
  del cotangents
  t19, t20, t21, t22, t23, t24, t34, t35, t36, t37, t38, t39, t40, t41, t49, t50, \
  t51, t52, t53, t54, t64, t65, t66, t67, t68, t69, t70, t71, = C0
  clear_mutable_collection(C0)
  del C0
  ctx_te_1418, ctx_te_1533, ctx_te_1648, ctx_te_1763, = C1
  clear_mutable_collection(C1)
  del C1
  (bw_t74, grad_for_t_3_weight, _) = te_functional_linear_backward((32, 32), (32, 32), None, ctx_te_1763, (t64, t65, t66, t67, t68, t69, t70, t71), t73, input_requires_grad=True, weight_requires_grad=True, bias_requires_grad=False)
  del ctx_te_1763, t64, t65, t66, t67, t68, t69, t70, t71, t73
  (bw_t59, _, _) = te_functional_linear_backward((32, 32), (32, 32), None, ctx_te_1648, (t49, t50, t51, t52, t53, t54), bw_t74, input_requires_grad=True, weight_requires_grad=False, bias_requires_grad=False)
  del ctx_te_1648, t49, t50, t51, t52, t53, t54, bw_t74
  (bw_t44, grad_for_t_1_weight, _) = te_functional_linear_backward((32, 32), (32, 32), None, ctx_te_1533, (t34, t35, t36, t37, t38, t39, t40, t41), bw_t59, input_requires_grad=True, weight_requires_grad=True, bias_requires_grad=False)
  del ctx_te_1533, t34, t35, t36, t37, t38, t39, t40, t41, bw_t59
  (grad_for_input, _, _) = te_functional_linear_backward((32, 32), (32, 32), None, ctx_te_1418, (t19, t20, t21, t22, t23, t24), bw_t44, input_requires_grad=True, weight_requires_grad=False, bias_requires_grad=False)
  del ctx_te_1418, t19, t20, t21, t22, t23, t24, bw_t44
  te_sync_fp8_meta_bwd()
  return (grad_for_input, None, grad_for_t_1_weight, None, grad_for_t_3_weight)

@kshitij12345 kshitij12345 marked this pull request as draft May 26, 2025 13:25
Copy link
Collaborator

@riccardofelluga riccardofelluga left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a nice fix, however we cannot rely on the assumption that requires_grad is always propagated throughout the trace and I think we should move away from that assumption unless we make sure that propagation is always guaranteed.

A more involved alternative would be to pickup on the runtime proxy idea

Comment on lines 599 to 601
dgrad, wgrad, bgrad = bsym.output
w_requires_grad = True if wgrad is not None else False
b_requires_grad = True if bgrad is not None else False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting hack for requires_grad propagation, tho if the symbol before the one captured by TE executor did not propagate requires_grad this might not work as intended

@kshitij12345
Copy link
Collaborator Author

This is a nice fix, however we cannot rely on the assumption that requires_grad is always propagated throughout the trace and I think we should move away from that assumption unless we make sure that propagation is always guaranteed.

This fix doesn't rely on requires_grad being propagated correctly but on whether or not the gradient was returned from backward trace.

# Update the backward trace to only compute gradients for the
# inputs that require gradients
assert bw_trace.bound_symbols[-1].sym.id == PrimIDs.RETURN
filtered_grads = tuple(
(arg_grad if requires_grad else None)
for arg_grad, requires_grad in utils.safe_zip(bw_trace.bound_symbols[-1].args[0], requires_grad_mask)
)
# autograd.Function.backward expects a flat tuple of gradients
bw_trace.bound_symbols[-1] = replace(bw_trace.bound_symbols[-1], args=(filtered_grads,))

If the gradient is not returned from the backward trace, then we just update both forward and backward trace so that we don't save FP8 copy for backward and wgrad is not computed respectively.

A more involved alternative would be to pickup on the runtime proxy idea

As far as I can tell RuntimeProxy idea #1599, will just ban us to fetch requires_grad from intermediate TensorProxy. However, it won't fix the problem of correctly propagating it #1768. I could be wrong though cc: @IvanYashchuk as author of #1599 to clarify.

@kshitij12345 kshitij12345 marked this pull request as ready for review June 4, 2025 23:29
@kshitij12345
Copy link
Collaborator Author

TODO: Understand the interaction of this PR with #2102

@nvMelissa
Copy link
Collaborator

@kshitij12345 - this PR is ready, yes? Who needs to approve this please?

@kshitij12345
Copy link
Collaborator Author

Need to update this PR to work correctly with the changes from #2102.

Also, TE integration is broken post #2102, so we need #2222 to be merged first and then merge this PR.

Moving it back to draft to avoid confusion.

@kshitij12345 kshitij12345 marked this pull request as draft June 18, 2025 08:49
@kshitij12345 kshitij12345 marked this pull request as ready for review June 30, 2025 15:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

TE: Redundant backward computation in PEFT setting.
3 participants