Skip to content

fully_shard() for huggingface model: pytorch caches too much GPU memory #1126

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
mingdianliu opened this issue Apr 21, 2025 · 4 comments
Open
Labels
module: fsdp question Further information is requested

Comments

@mingdianliu
Copy link

mingdianliu commented Apr 21, 2025

Dear Community,

I'm working on fine-tuning the Qwen2-VL model using fully_shard() and wrote a script for it. However, I noticed that GPU memory usage stays high (around 50GB to 60GB) even as I scale up the number of GPUs. Besides, it will run into OOM when I try to fine tune 72B model with 128 GPUs.

I'm wondering if there might be any issues with my code or configuration. I'd really appreciate any insights or suggestions you might have. Thanks in advance!

My code:

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import Qwen2VLForConditionalGeneration, Qwen2VLProcessor, AutoModelForVision2Seq, AutoConfig
from qwen_vl_utils import process_vision_info
from peft import LoraConfig, get_peft_model
from datasets import load_dataset
import numpy as np
from PIL import Image
import io
import logging
import os

from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
from torch.distributed.device_mesh import init_device_mesh
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLDecoderLayer, Qwen2VLVisionBlock
from torch.distributed._composable.fsdp import fully_shard
from torch.distributed import init_process_group, destroy_process_group
from torch.distributed.checkpoint import DefaultLoadPlanner, DefaultSavePlanner
from torch.distributed._composable.fsdp import (
    CPUOffloadPolicy,
    fully_shard,
    MixedPrecisionPolicy,
)


# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# init dist
distributed_backend = "nccl" # gloo for cpu
dist.init_process_group(distributed_backend)

local_rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(device)


# model_name = "Qwen/Qwen2-VL-2B-Instruct"
# revision = "895c3a49bc3fa70a340399125c650a463535e71c"
model_name = "Qwen/Qwen2-VL-7B-Instruct"
revision = "a28a094eb66a9f2ac70eef346f040d8a79977472"
# model_name = "Qwen/Qwen2-VL-72B-Instruct"
# revision = "f9b556a74d58e6d9915f73227c21045c87342b42"

dataset_id = "HuggingFaceM4/ChartQA"
processor = Qwen2VLProcessor.from_pretrained(model_name, 
                                             revision=revision,
                                             )


# Configuration
class Config:
    dataset_id = "HuggingFaceM4/ChartQA"
    output_dir = "/tmp_ckpt"
    batch_size = 2
    num_epochs = 3
    learning_rate = 5e-5
    max_seq_length = 512
    lora_rank = 32
    lora_alpha = 64
    lora_dropout = 0.1
    device = "cuda" if torch.cuda.is_available() else "cpu"




system_message = """You are a Vision Language Model specialized in interpreting visual data from chart images.
Your task is to analyze the provided chart image and respond to queries with concise answers, usually a single word, number, or short phrase.
The charts include a variety of types (e.g., line charts, bar charts) and contain colors, labels, and text.
Focus on delivering accurate, succinct answers based on the visual information. Avoid additional explanation unless absolutely necessary."""

def format_data(sample):
    return [
        {
            "role": "system",
            "content": [{"type": "text", "text": system_message}],
        },
        {
            "role": "user",
            "content": [
                {
                    "type": "image",
                    "image": sample["image"],
                },
                {
                    "type": "text",
                    "text": sample["query"],
                },
            ],
        },
        {
            "role": "assistant",
            "content": [{"type": "text", "text": sample["label"][0]}],
        },
    ]

# Training function
def train_model(model, train_loader, optimizer, config):
    model.train()
    total_steps = len(train_loader) * config.num_epochs
    step = 0

    scaler = torch.amp.GradScaler("cuda", enabled=True)

    for epoch in range(config.num_epochs):
        total_loss = 0
        for batch_idx, batch in enumerate(train_loader):

            inputs, labels = batch
            inputs = inputs.to(config.device)
            labels = labels.to(config.device)

            # Mixed precision training
            loss = model(**inputs, labels=labels).loss
            loss.backward() # no scaler
            optimizer.step()
            optimizer.zero_grad()
            
            step += 1
            logger.info(f"Epoch {epoch+1}/{config.num_epochs}, Step {step}/{total_steps}, Loss: {loss.item():.4f}")

            del loss



# Create a data collator to encode text and image pairs
def collate_fn(examples):
    # Get the texts and images, and apply the chat template
    texts = [
        processor.apply_chat_template(example, tokenize=False) for example in examples
    ]  # Prepare texts for processing
    image_inputs = [process_vision_info(example)[0] for example in examples]  # Process the images to extract inputs

    # Tokenize the texts and process the images
    batch = processor(
        text=texts, images=image_inputs, return_tensors="pt", padding=True
    )  # Encode texts and images into tensors

    # The labels are the input_ids, and we mask the padding tokens in the loss computation
    labels = batch["input_ids"].clone()  # Clone input IDs for labels
    labels[labels == processor.tokenizer.pad_token_id] = -100  # Mask padding tokens in labels

    # Ignore the image token index in the loss computation (model specific)
    if isinstance(processor, Qwen2VLProcessor):  # Check if the processor is Qwen2VLProcessor
        image_tokens = [151652, 151653, 151655]  # Specific image token IDs for Qwen2VLProcessor
    else:
        image_tokens = [processor.tokenizer.convert_tokens_to_ids(processor.image_token)]  # Convert image token to ID

    # Mask image token IDs in the labels
    for image_token_id in image_tokens:
        labels[labels == image_token_id] = -100  # Mask image token IDs in labels

    return batch, labels



# Main function
def main():

    config = Config()

    # Load model and processor
    logger.info("Loading model and processor...")

    hf_config = AutoConfig.from_pretrained(
                model_name, 
                revision=revision, 
                torch_dtype=torch.bfloat16,
                attn_implementation="flash_attention_2",
                )

    with torch.device("meta"):
        model = AutoModelForVision2Seq.from_config(hf_config, torch_dtype=torch.bfloat16)

    mp_policy=MixedPrecisionPolicy(param_dtype=torch.bfloat16, 
                                   reduce_dtype=torch.bfloat16, 
                                   output_dtype=torch.bfloat16, 
                                   cast_forward_inputs=True)
    offload_policy = CPUOffloadPolicy(pin_memory=False)

    # apply FSDP2
    device_mesh = init_device_mesh("cuda", (world_size,))
    for module in model.modules():
        if isinstance(module, Qwen2VLDecoderLayer):
            fully_shard(module, 
                        mesh=device_mesh, 
                        reshard_after_forward=True,
                        mp_policy=mp_policy,
                        # offload_policy=offload_policy,
                        )
    
    model = fully_shard(model, 
                        mesh=device_mesh, 
                        reshard_after_forward=True,
                        mp_policy=mp_policy,
                        # offload_policy=offload_policy,
                        )

    model.to_empty(device='cuda')

    model_state_dict = model.state_dict()

    model_dir = "/cache/fsdp_test/72B_8_files"

    # load qwen2-vl model
    dcp.load(
        state_dict=model_state_dict,
        checkpoint_id=model_dir,
        planner=DefaultLoadPlanner(allow_partial_load=True),
    )

    model = model.to(torch.bfloat16).cuda()
    

    # Load dataset
    logger.info("Loading dataset...")

    train_dataset, eval_dataset, test_dataset = load_dataset(
        config.dataset_id, split=['train[:10%]', 'val[:10%]', 'test[:10%]'])
    train_dataset = [format_data(sample) for sample in train_dataset]

    train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=1,
        collate_fn=collate_fn,
        shuffle=True,
    )

    # Optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)

    # Create output directory
    os.makedirs(config.output_dir, exist_ok=True)

    # Train
    logger.info("Starting training...")
    train_model(model, train_dataloader, optimizer, config)


if __name__ == "__main__":
    main()
    destroy_process_group()
    logger.info("Training completed.")

Running command:
torchrun --nnodes=2 --nproc_per_node=8 qwenvl_train_fsdp.py
torchrun --nnodes=4 --nproc_per_node=8 qwenvl_train_fsdp.py
torchrun --nnodes=8 --nproc_per_node=8 qwenvl_train_fsdp.py

The following is the screenshot of the result of nvidia-smi:

16 GPU:

Image

32 GPU:

Image

64 GPU:

Image

@mingdianliu
Copy link
Author

@fegin
Copy link
Contributor

fegin commented Apr 22, 2025

@mingdianliu could it be possible that the activations dominate the memory usage under such a setting? Like a 7B model, even if we use float32, then the parameters + gradients + optimizer states is like 112 GB and with 16 GPU, each GPU will get roughly 7GB. If you freeze some modules for fine-funing, this number fewer. Same for 72B model issue, you will have to apply other techniques to reduce the memory consumption from activations, like TP or activation checkpointing.

@mingdianliu
Copy link
Author

@mingdianliu could it be possible that the activations dominate the memory usage under such a setting? Like a 7B model, even if we use float32, then the parameters + gradients + optimizer states is like 112 GB and with 16 GPU, each GPU will get roughly 7GB. If you freeze some modules for fine-funing, this number fewer. Same for 72B model issue, you will have to apply other techniques to reduce the memory consumption from activations, like TP or activation checkpointing.

Hi @fegin

Thanks for your follow-up. I found it is due to pytorch cache. The allocated and reserved GPU memory is quite small while the cached GPU memory is even higher than 50GB. I had a shoot on torch.cuda.empty_cache() after each training iteration but the GPU memory cache during each training iteration is also high (~20GB). I wonder if it is a bug of FSDP2. If not, is there any method that can mitigate this issue?

@mingdianliu mingdianliu changed the title fully_shard() with huggingface model: GPU memory doesn't drop while increasing GPU number fully_shard() for huggingface model: pytorch caches too much GPU memory Apr 22, 2025
@fegin
Copy link
Contributor

fegin commented Apr 22, 2025

Caching is not an issue because those memory will be reused for other tensor allocation. But this will not cause OOM because when new tensors are created, PyTorch will first find some empty caching memory for the tensors. And only if there is no available caching space, will PyTorch ask CUDA to give more. And if CUDA cannot give enough memory, then OOM will happen.

So, if you are not seeing OOM but only seeing high cache memory, that should not be an issue. You actually are seeing OOM, you can try to export this environment variable PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True. If this doesn't help, you need to reduce the memory usage (reducing batch size, model size, using TP, activation chekcpointing...).

@tianyu-l tianyu-l added question Further information is requested module: fsdp labels May 13, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: fsdp question Further information is requested
Projects
None yet
Development

No branches or pull requests

3 participants