Skip to content

Issue Using Qwen2-VL HuggingFace Model with Transformer Engine FP8 Training #1889

Open
@mingdianliu

Description

@mingdianliu

Dear Sir or Madam,

I am currently developing a demo that integrates Transformer Engine FP8 training with the Qwen2-VL model from HuggingFace. However, I have encountered an error during execution:

RuntimeError: Unable to find suitable cuBLAS GEMM algorithm

I have reviewed the documentation and attempted several troubleshooting steps, but I haven't been able to resolve the issue.

Could you kindly provide guidance or suggestions to help address this problem? Any assistance would be greatly appreciated.

Thank you very much for your support.

te_qwen2_vl.py

# te_qwen2_vl.py
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import Qwen2VLForConditionalGeneration, Qwen2VLProcessor, AutoModelForVision2Seq, AutoConfig
from qwen_vl_utils import process_vision_info
from peft import LoraConfig, get_peft_model
from datasets import load_dataset
import numpy as np
from PIL import Image
import io
import logging

from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
from torch.distributed.device_mesh import init_device_mesh
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLDecoderLayer, Qwen2VLVisionBlock
from torch.distributed._composable.fsdp import fully_shard
from torch.distributed import init_process_group, destroy_process_group
from torch.distributed.checkpoint import DefaultLoadPlanner, DefaultSavePlanner
from torch.distributed._composable.fsdp import (
    CPUOffloadPolicy,
    fully_shard,
    MixedPrecisionPolicy,
)
from torch.distributed.checkpoint.state_dict import (
    StateDictOptions, 
    get_model_state_dict, 
    get_optimizer_state_dict, 
    set_model_state_dict, 
    set_optimizer_state_dict,
)
import subprocess
from transformers.models.qwen2_vl.modeling_qwen2_vl import rotate_half


from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
    apply_activation_checkpointing, checkpoint_wrapper, CheckpointImpl,
)

# transformer engine
import transformer_engine.pytorch as te
import torch.nn as nn
from transformer_engine.common.recipe import Format, DelayedScaling
from transformer_engine.common import recipe
import transformer_engine.pytorch as te

from torch.nn import Module
import bitsandbytes as bnb


# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# init dist
distributed_backend = "nccl" # gloo for cpu
dist.init_process_group(distributed_backend)

local_rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
print("local_rank", local_rank)
print("world_size", world_size)
device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(device)


model_name = "Qwen/Qwen2-VL-2B-Instruct"
revision = "895c3a49bc3fa70a340399125c650a463535e71c"
# model_name = "Qwen/Qwen2-VL-7B-Instruct"
# revision = "a28a094eb66a9f2ac70eef346f040d8a79977472"
# model_name = "Qwen/Qwen2-VL-72B-Instruct"
# revision = "f9b556a74d58e6d9915f73227c21045c87342b42"

dataset_id = "HuggingFaceM4/ChartQA"
processor = Qwen2VLProcessor.from_pretrained(model_name, 
                                             revision=revision, 
                                            #  cache_dir="/cache/hf_cache",
                                             cache_dir="/cache/hf_cache"
                                             )


# Configuration
class Config:
    dataset_id = "HuggingFaceM4/ChartQA"
    output_dir = "/tmp_ckpt"
    batch_size = 1
    num_epochs = 3
    learning_rate = 5e-6
    max_seq_length = 512
    lora_rank = 32
    lora_alpha = 64
    lora_dropout = 0.1
    device = "cuda" if torch.cuda.is_available() else "cpu"




# system_message = """You are a Vision Language Model specialized in interpreting visual data from chart images.
# Your task is to analyze the provided chart image and respond to queries with concise answers, usually a single word, number, or short phrase.
# The charts include a variety of types (e.g., line charts, bar charts) and contain colors, labels, and text.
# Focus on delivering accurate, succinct answers based on the visual information. Avoid additional explanation unless absolutely necessary."""

system_message = (
    "You are a Vision Language Model specialized in interpreting visual data "
    "from chart images. Answer concisely."
)

def format_data(sample):

    return [
        {
            "role": "system",
            "content": [{"type": "text", "text": system_message}],
        },
        {
            "role": "user",
            "content": [
                {
                    "type": "image",
                    "image": sample["image"],
                },
                {
                    "type": "text",
                    "text": sample["query"],
                },
            ],
        },
        {
            "role": "assistant",
            "content": [{"type": "text", "text": sample["label"][0]}],
        },
    ]

# Training function
def train_model(model, train_loader, optimizer, config, fp8_recipe):
    print("strat training model")
    model.train()
    total_steps = len(train_loader) * config.num_epochs
    step = 0

    scaler = torch.amp.GradScaler("cuda", enabled=True)

    for epoch in range(config.num_epochs):
        total_loss = 0
        for batch_idx, batch in enumerate(train_loader):
            inputs, labels = batch
            inputs = inputs.to(config.device) 
            labels = labels.to(config.device)
            # dict_keys(['input_ids', 'attention_mask', 'pixel_values', 'image_grid_thw'])
            # print all shape of inputs
            # input_ids shape: torch.Size([1, 171])
            # attention_mask shape: torch.Size([1, 171])
            # pixel_values shape: torch.Size([440, 1176])
            # image_grid_thw shape: torch.Size([1, 3])

            print(f"input shape: {inputs['input_ids'].shape} labels shape: {labels.shape}") 

            # Mixed precision training
            # loss = model(**inputs, labels=labels).loss
            # with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
            with te.fp8_autocast(enabled=True):
                outputs = model(**inputs, labels=labels)
            loss = outputs.loss
            loss.backward() 
            optimizer.step()
            # print("after optimizer.step() before empty alloc/​reserved:", 
            #     torch.cuda.memory_allocated()/1e9,
            #     torch.cuda.memory_reserved()/1e9)
            # torch.cuda.empty_cache()

            optimizer.zero_grad()
    
            step += 1

            print(f"Epoch {epoch+1}/{config.num_epochs}, Step {step}/{total_steps}, Loss: {loss.item():.4f}")
            del loss


# Create a data collator to encode text and image pairs
def collate_fn(examples):
    # Get the texts and images, and apply the chat template
    texts = [
        processor.apply_chat_template(example, tokenize=False) for example in examples
    ]  # Prepare texts for processing
    image_inputs = [process_vision_info(example)[0] for example in examples]  # Process the images to extract inputs

    # Tokenize the texts and process the images
    batch = processor(
        text=texts, images=image_inputs, return_tensors="pt", padding=True
    )  # Encode texts and images into tensors

    # The labels are the input_ids, and we mask the padding tokens in the loss computation
    labels = batch["input_ids"].clone()  # Clone input IDs for labels
    labels[labels == processor.tokenizer.pad_token_id] = -100  # Mask padding tokens in labels

    # Ignore the image token index in the loss computation (model specific)
    if isinstance(processor, Qwen2VLProcessor):  # Check if the processor is Qwen2VLProcessor
        image_tokens = [151652, 151653, 151655]  # Specific image token IDs for Qwen2VLProcessor
    else:
        image_tokens = [processor.tokenizer.convert_tokens_to_ids(processor.image_token)]  # Convert image token to ID

    # Mask image token IDs in the labels
    for image_token_id in image_tokens:
        labels[labels == image_token_id] = -100  # Mask image token IDs in labels
        
    return batch, labels

def _to_te(module: nn.Module):
    for name, child in list(module.named_children()):
        # 1) Linear → te.Linear
        if isinstance(child, nn.Linear) and child.in_features % 16 == 0 and child.out_features % 16 == 0:
            te_linear = te.Linear(
                child.in_features,
                child.out_features,
                bias=child.bias is not None,
                params_dtype=torch.bfloat16,
            )
            te_linear.weight.data.copy_(child.weight.data)
            if child.bias is not None:
                te_linear.bias.data.copy_(child.bias.data)
            setattr(module, name, te_linear)

        # # 2) LayerNorm → te.LayerNorm
        # elif isinstance(child, nn.LayerNorm):
        #     te_ln = te.LayerNorm(
        #         normalized_shape=child.normalized_shape,
        #         eps=child.eps,
        #         params_dtype=torch.bfloat16,
        #     )
        #     te_ln.weight.data.copy_(child.weight.data)
        #     te_ln.bias.data.copy_(child.bias.data)
        #     setattr(module, name, te_ln)
        else:
            _to_te(child)

# Main function
def main():
    os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
    # os.environ["HF_HOME"] = "/cache/hf_cache"
    os.environ["HF_HOME"] = "cache/hf_cache"
    config = Config()

    # Load model and processor
    logger.info("Loading model and processor...")

    model = Qwen2VLForConditionalGeneration.from_pretrained(
                model_name, 
                revision=revision, 
                torch_dtype=torch.bfloat16,
                attn_implementation="flash_attention_2",
                cache_dir="/cache/hf_cache",
                # device_map=torch.device("cpu"),
                )
    model = model.to(torch.bfloat16).cuda() 

    # apply transformer engine
    _to_te(model)

    # fp8 training recipe
    fp8_recipe = DelayedScaling(
        fp8_format=Format.HYBRID,    # E4M3 during forward pass, E5M2 during backward pass
        # fp8_format=Format.E4M3,    # E4M3 used everywhere
        amax_history_len=16,
    )

    # Load dataset
    logger.info("Loading dataset...")
    os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
    # os.environ["HF_HOME"] = "/cache/hf_cache"
    os.environ["HF_HOME"] = "/cache/hf_cache"
    print("print HF_HOME")
    subprocess.run("echo $HF_HOME", shell=True)

    first_row = load_dataset(
            config.dataset_id,
            split="train[:10]",
            cache_dir="/cache/hf_cache",
        )[3]
    
    train_dataset = [first_row] * 5096
    
    train_dataset = [format_data(sample) for sample in train_dataset]

    train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=1,
        collate_fn=collate_fn,
        shuffle=True,
    )

    # Optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
    # optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=config.learning_rate)

    # Create output directory
    os.makedirs(config.output_dir, exist_ok=True)

    # Train
    logger.info("Starting training...")
    torch.cuda.empty_cache()
    train_model(model, train_dataloader, optimizer, config, fp8_recipe)

    # # Save final model
    # model.save_pretrained(config.output_dir)
    # processor.save_pretrained(config.output_dir)
    # logger.info(f"Final model saved to {config.output_dir}")

if __name__ == "__main__":
    main()
    destroy_process_group()
    logger.info("Training completed.")

run_main.sh

# run_main.sh
# To run samples:
# bash run_example.sh {file_to_run.py} {num_gpus}
# where file_to_run = example to launch.  Default = 'fsdp_tp_example.py'
# num_gpus = num local gpus to use (must be at least 2). Default = 4


pip install qwen-vl-utils

# pip uninstall transformer-engine
# pip3 install --no-build-isolation transformer_engine[pytorch] -i https://pypi.tuna.tsinghua.edu.cn/simple

echo "Launching ${1:-fsdp_tp_example.py} with ${2:-8} gpus"
torchrun --nnodes=1 --nproc_per_node=${2:-8} --rdzv_id=101 --rdzv_endpoint="localhost:5972" ${1:-fsdp_tp_example.py}

Running command:

bash run_main.sh te_qwen2_vl.py 8

The error is as follows:

[rank0]: Traceback (most recent call last):
[rank0]:   File "/workspace/fp8_dev/te_qwen2_vl.py", line 312, in <module>
[rank0]:     main()
[rank0]:   File "/workspace/fp8_dev/te_qwen2_vl.py", line 304, in main
[rank0]:     train_model(model, train_dataloader, optimizer, config, fp8_recipe)
[rank0]:   File "/workspace/fp8_dev/te_qwen2_vl.py", line 166, in train_model
[rank0]:     loss.backward() 
[rank0]:     ^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/_tensor.py", line 648, in backward
[rank0]:     torch.autograd.backward(
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/autograd/__init__.py", line 347, in backward
[rank0]:     _engine_run_backward(
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/autograd/graph.py", line 823, in _engine_run_backward
[rank0]:     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/autograd/function.py", line 307, in apply
[rank0]:     return user_fn(self, *args)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/transformer_engine/pytorch/module/linear.py", line 518, in backward
[rank0]:     wgrad, grad_bias_, _, rs_out = general_gemm(
[rank0]:                                    ^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/transformer_engine/pytorch/cpp_extensions/gemm.py", line 141, in general_gemm
[rank0]:     out, bias_grad, gelu_input, extra_output = tex.generic_gemm(*args, **kwargs)
[rank0]:                                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: RuntimeError: /workspace/transformerengine/transformer_engine/common/gemm/cublaslt_gemm.cu:395 in function cublas_gemm: Assertion failed: status != CUBLAS_STATUS_NOT_SUPPORTED. Unable to find suitable cuBLAS GEMM algorithm

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions