-
Notifications
You must be signed in to change notification settings - Fork 84
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
High Peak Memory with CUDAGraphTransform #1533
Comments
Another Example: Qwen2 import torch
from thunder.dynamo import ThunderCompiler
from transformers import AutoConfig, AutoModelForCausalLM
from thunder.transforms.cudagraph import CUDAGraphTransform
model_id = "Qwen/Qwen2.5-7B-Instruct"
configuration = AutoConfig.from_pretrained(
model_id,
num_hidden_layers=5,
)
configuration.hidden_size = configuration.num_attention_heads
with torch.device("cuda"):
model = AutoModelForCausalLM.from_config(configuration).to(torch.bfloat16)
# backend = ThunderCompiler()
backend = ThunderCompiler(transforms=[CUDAGraphTransform()])
compiled_model = torch.compile(model, backend=backend)
input_ids = torch.randint(0, configuration.vocab_size, (1, 4096), device="cuda")
compiled_output = compiled_model(input_ids=input_ids, labels=input_ids)
compiled_output.loss.backward()
# Without CUDAGraphTransform - 13312.685568
# With CUDAGraphTransform - 26071.434752
print(torch.cuda.max_memory_allocated() / 1e6) |
This is likely an issue that where we create static input buffers for things that we would not want to do that for (e.g. maybe saved for backward tensors when the forward is computed by the cuda graph?) |
So there are two things actually:
class MyCUDAGraphTransform(CUDAGraphTransform):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.outputs_from_forward = None
def transform_trace_post_optimization(self, trace, **kwargs):
if trace.siginfo().name == 'backward_fn':
# todo: have a backward tag
assert self.outputs_from_forward is not None, "called on backward without forward before"
# make this more generic or have an utility?
assert len(trace.bound_symbols[2].args) == 2 and trace.bound_symbols[2].args[0].name == 'saved_for_backward'
assert trace.bound_symbols[8].sym.name == 'unpack_sequence' and trace.bound_symbols[8].args[0] is trace.bound_symbols[2].output[0]
saved_for_backwards_unpacked = trace.bound_symbols[8].output
assert len(saved_for_backwards_unpacked) == len(self.outputs_from_forward)
for (_, is_static), p_bw in zip(self.outputs_from_forward, saved_for_backwards_unpacked):
if is_static:
p_bw.tags.add(thunder.core.proxies.ProxyTag.STATIC_MEMORY_LOCATION)
self.outputs_from_forward = None
new_trace = super().transform_trace_post_optimization(trace, **kwargs)
if thunder.core.trace.TraceTag.AUGMENTED_FORWARD in new_trace.tags:
assert self.outputs_from_forward is None, "called on augmented forward twice without backward in between"
# apparently, it is safer to go by name than assume we have the same proxies here. :(
cudagraph_output_names = set()
for bsym in new_trace.bound_symbols:
if bsym.sym.name.startswith('CUDAGraph'):
for o in bsym.flat_proxy_outs:
cudagraph_output_names.add(o.name)
saved_for_backward = thunder.core.vjp_utils.get_saved_for_backward_tensors(new_trace)
self.outputs_from_forward = [(o.name, o.name in cudagraph_output_names or thunder.core.proxies.ProxyTag.STATIC_MEMORY_LOCATION in o.tags) for o in saved_for_backward]
return new_trace If we decide the sharing is the right thing, we should do that. The other part is that the memory for the graphs are not currently shared. lightning-thunder/thunder/transforms/cudagraph.py Lines 98 to 100 in d3b2276
to graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, pool=self.pool, stream=stream):
static_outputs = fn(*static_inputs)
if self.pool is None:
self.pool = graph.pool() but this does not do it (and makes things worse instead?)... I would highly appreciate any hint how this should work. |
Peak Memory is very high when CUDAGraphTransform is used.
Example -
Tested with internal image dated
20241209
on RTX 6000 Ada.The text was updated successfully, but these errors were encountered: