Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

High Peak Memory with CUDAGraphTransform #1533

Open
kshitij12345 opened this issue Dec 9, 2024 · 3 comments
Open

High Peak Memory with CUDAGraphTransform #1533

kshitij12345 opened this issue Dec 9, 2024 · 3 comments

Comments

@kshitij12345
Copy link
Collaborator

Peak Memory is very high when CUDAGraphTransform is used.

# With CUDAGraphTransform - 27517.101568
# Without CUDAGraphTransform - 11917.129728

Example -

import torch
import thunder
import litgpt
from torch.testing import make_tensor
from functools import partial
from thunder.dynamo import ThunderCompiler
from thunder.transforms.cudagraph import CUDAGraphTransform

device = torch.device("cuda")


cfg = litgpt.Config.from_name("open_llama_3b", n_layer=10)
with device:
    make = partial(make_tensor, low=0, high=255, device=device, dtype=torch.long, requires_grad=False)
    shape = (1,) + (cfg.block_size,)

    x = make(shape)
    m = litgpt.GPT(cfg)

# m = thunder.jit(m)
m = thunder.jit(m, transforms=[CUDAGraphTransform()])

o = m(x)
o.sum().backward()

# With CUDAGraphTransform - 27517.101568
# Without CUDAGraphTransform - 11917.129728
print(torch.cuda.max_memory_allocated() / 1e6)

Tested with internal image dated 20241209 on RTX 6000 Ada.

@kshitij12345
Copy link
Collaborator Author

Another Example:

Qwen2

import torch
from thunder.dynamo import ThunderCompiler
from transformers import AutoConfig, AutoModelForCausalLM
from thunder.transforms.cudagraph import CUDAGraphTransform
model_id = "Qwen/Qwen2.5-7B-Instruct"

configuration = AutoConfig.from_pretrained(
    model_id,
    num_hidden_layers=5,
)
configuration.hidden_size = configuration.num_attention_heads
with torch.device("cuda"):
    model = AutoModelForCausalLM.from_config(configuration).to(torch.bfloat16)

# backend = ThunderCompiler()
backend = ThunderCompiler(transforms=[CUDAGraphTransform()])
compiled_model = torch.compile(model, backend=backend)

input_ids = torch.randint(0, configuration.vocab_size, (1, 4096), device="cuda")

compiled_output = compiled_model(input_ids=input_ids, labels=input_ids)
compiled_output.loss.backward()

# Without CUDAGraphTransform - 13312.685568
# With CUDAGraphTransform - 26071.434752
print(torch.cuda.max_memory_allocated() / 1e6)

@t-vi
Copy link
Collaborator

t-vi commented Dec 10, 2024

This is likely an issue that where we create static input buffers for things that we would not want to do that for (e.g. maybe saved for backward tensors when the forward is computed by the cuda graph?)
The parameters seem to be correctly marked as not needing static input buffers.

@t-vi t-vi self-assigned this Dec 11, 2024
@t-vi
Copy link
Collaborator

t-vi commented Dec 17, 2024

So there are two things actually:

  • the input buffers in the backward. This will improve on that:
class MyCUDAGraphTransform(CUDAGraphTransform):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.outputs_from_forward = None

    def transform_trace_post_optimization(self, trace, **kwargs):
        if trace.siginfo().name == 'backward_fn':
            # todo: have a backward tag
            assert self.outputs_from_forward is not None, "called on backward without forward before"
            # make this more generic or have an utility?
            assert len(trace.bound_symbols[2].args) == 2 and trace.bound_symbols[2].args[0].name == 'saved_for_backward'
            assert trace.bound_symbols[8].sym.name == 'unpack_sequence' and trace.bound_symbols[8].args[0] is trace.bound_symbols[2].output[0]
            saved_for_backwards_unpacked = trace.bound_symbols[8].output
            assert len(saved_for_backwards_unpacked) == len(self.outputs_from_forward)
            for (_, is_static), p_bw in zip(self.outputs_from_forward, saved_for_backwards_unpacked):
                if is_static:
                    p_bw.tags.add(thunder.core.proxies.ProxyTag.STATIC_MEMORY_LOCATION)
            self.outputs_from_forward = None

        new_trace = super().transform_trace_post_optimization(trace, **kwargs)

        if thunder.core.trace.TraceTag.AUGMENTED_FORWARD in new_trace.tags:
            assert self.outputs_from_forward is None, "called on augmented forward twice without backward in between"
            # apparently, it is safer to go by name than assume we have the same proxies here. :(
            cudagraph_output_names = set()
            for bsym in new_trace.bound_symbols:
                if bsym.sym.name.startswith('CUDAGraph'):
                    for o in bsym.flat_proxy_outs:
                        cudagraph_output_names.add(o.name)
            saved_for_backward = thunder.core.vjp_utils.get_saved_for_backward_tensors(new_trace)
            self.outputs_from_forward = [(o.name, o.name in cudagraph_output_names or thunder.core.proxies.ProxyTag.STATIC_MEMORY_LOCATION in o.tags) for o in saved_for_backward]
        return new_trace

If we decide the sharing is the right thing, we should do that.

The other part is that the memory for the graphs are not currently shared.
I thought it would be as easy as initializing self.pool to None and then changing

graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, stream=stream):
static_outputs = fn(*static_inputs)

to

        graph = torch.cuda.CUDAGraph()
        with torch.cuda.graph(graph, pool=self.pool, stream=stream):
            static_outputs = fn(*static_inputs)
        if self.pool is None:
          self.pool = graph.pool() 

but this does not do it (and makes things worse instead?)...

I would highly appreciate any hint how this should work.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants