Skip to content

Quantizing whisper-large-v3-turbo doesn't improve inference latency on H100/A100 #1551

Open
@gautijha37

Description

@gautijha37

Describe the bug
I created a quantized W8A8 checkpoint of openai/whisper-large-v3-turbo using GPTQModifier with the following script (following this llmcompressor example):

import torch
from datasets import load_dataset
from transformers import WhisperProcessor, WhisperForConditionalGeneration

from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import GPTQModifier

MODEL_ID = "openai/whisper-large-v3-turbo"
device = torch.device("cuda")

model = WhisperForConditionalGeneration.from_pretrained(MODEL_ID).to(device)
processor = WhisperProcessor.from_pretrained(MODEL_ID)
processor.tokenizer.set_prefix_tokens(task="transcribe")

DATASET_ID = "MLCommons/peoples_speech"
DATASET_SUBSET = "test"
DATASET_SPLIT = "test"

NUM_CALIBRATION_SAMPLES = 512
MAX_SEQUENCE_LENGTH = 2048

# Load dataset and preprocess.
ds = load_dataset(
    DATASET_ID,
    DATASET_SUBSET,
    split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]",
    trust_remote_code=True,
)

def preprocess(example):
    return {
        "array": example["audio"]["array"],
        "sampling_rate": example["audio"]["sampling_rate"],
        "text": " " + example["text"].capitalize(),
    }

ds = ds.map(preprocess, remove_columns=ds.column_names)

def process(sample):
    inputs = processor(
        audio=sample["array"],
        sampling_rate=sample["sampling_rate"],
        text=sample["text"],
        add_special_tokens=True,
        return_tensors="pt",
    )

    inputs["input_features"] = inputs["input_features"].to(dtype=model.dtype)
    inputs["decoder_input_ids"] = inputs["labels"]
    del inputs["labels"]

    return inputs


ds = ds.map(process, remove_columns=ds.column_names)

def data_collator(batch):
    assert len(batch) == 1
    return {key: torch.tensor(value) for key, value in batch[0].items()}

recipe = GPTQModifier(targets="Linear", scheme="W8A8", ignore=["lm_head"])

# Apply algorithms.
oneshot(
    model=model,
    dataset=ds,
    recipe=recipe,
    max_seq_length=MAX_SEQUENCE_LENGTH,
    num_calibration_samples=NUM_CALIBRATION_SAMPLES,
    data_collator=data_collator,
)

SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-W8A8-G128"
model.save_pretrained(SAVE_DIR, save_compressed=True)
processor.save_pretrained(SAVE_DIR)

I see the downloaded model.safetensors is 908M in size so it seems the quantization to int8 did succeed. This is my script to check the inference latency on 100 audio samples I have on disk:

from vllm import LLM, SamplingParams
import torch

quantized_model_path = "/home/backend/whisper-large-v3-turbo-W4A16-G128"

stt_model = LLM(
    model=quantized_model_path,
    max_model_len=256,
    max_num_seqs=64,
    limit_mm_per_prompt={"audio": 1},
    enforce_eager=False,
    gpu_memory_utilization=1,
)

stt_sampling_params = SamplingParams(
    temperature=0,
    top_p=1.0,
    max_tokens=64,
)

def get_stt_prompt(audio):
    return [
        {
            "prompt": "<|startoftranscript|>",
            "multi_modal_data": {
                "audio": audio,
            },
        }
    ]

import time
import numpy as np

def get_stt_response(audio_pcm_bytes):
    audio_np = np.frombuffer(audio_pcm_bytes, dtype=np.int16).astype(np.float32) / 32768.0
    while audio_np.ndim != 1:
        audio_np = audio_np.squeeze(0)

    start_time = time.perf_counter()
    output = stt_model.generate(get_stt_prompt((audio_np, 16_000)), stt_sampling_params)[0].outputs[0]
    
    stt_time = time.perf_counter() - start_time
    rtf = stt_time / (len(audio_np) / 16_000)
    time_per_output_token = stt_time / len(output.token_ids)
    
    return output, stt_time * 1000, rtf, time_per_output_token * 1000

complete_stt_times = []
rtf_ratios = []
time_per_output_tokens = []
transcripts = []

from utils import audio_file_to_pcm
import os
from tqdm import tqdm

audio_dir = "/home/audiofiles/"

for input_file in tqdm(sorted(os.listdir(audio_dir))):
    audio_pcm_bytes = audio_file_to_pcm(os.path.join(audio_dir, input_file))
    output, stt_time, rtf, time_per_output_token = get_stt_response(audio_pcm_bytes)
    
    transcripts.append(output.text)
    complete_stt_times.append(stt_time)
    rtf_ratios.append(rtf)
    time_per_output_tokens.append(time_per_output_token)

def display_stats(name, data):
    if len(data) == 0:
        print(f"{name}: No data available.")
        return
    arr = np.array(data)
    print(f"{name}:")
    print(f"  Mean:     {arr.mean():.4f}")
    print(f"  Min:      {arr.min():.4f}")
    print(f"  Max:      {arr.max():.4f}")
    print(f"  Std.dev.: {arr.std():.4f}")

display_stats("complete_stt_times(ms)", complete_stt_times)
display_stats("rtf_ratios", rtf_ratios)
display_stats("time_per_output_tokens(ms)", time_per_output_tokens)

In addition to quantizing using llmcompressor oneshot (which I call llmcompressor quantization), I also tried inference with option quantization="fp8" in LLM constructor above (which I call vllm internal quantization). These are the inference results for mean complete_stt_times(ms):

GPU No quantization llmcompressor quantization (W8A8) llmcompressor quantization (W4A16) vllm internal quantization
A100 124 148 156 144
H100 99 102 102 104

Expected behavior
I expected inference time to llm compressor quantization on A100 (since W8A8 quantizes to int8 which A100 supports natively), and inference time to reduce for both llm compressor and vllm internal quantization on H100 which is not what I observe.

Environment
Include all relevant environment information:
H100 -

  1. OS [Ubuntu 22.04]:
  2. Python version [3.10.12]:
  3. LLM Compressor version or commit hash [e.g. 0.5.1]:
  4. ML framework version(s) [torch 2.5.1]:
  5. Other Python package versions [vLLM (0.7.3), compressed-tensors(0.9.1), numpy(1.26.4)]:
  6. Other relevant environment information [CUDA version 12.2]:

A100-

  1. OS [Ubuntu 24.04]:
  2. Python version [3.12.3]:
  3. LLM Compressor version or commit hash [e.g. 0.5.1]:
  4. ML framework version(s) [torch 2.5.1]:
  5. Other Python package versions [vLLM (0.7.3), compressed-tensors(0.9.1), numpy(1.26.4)]:
  6. Other relevant environment information [CUDA version 12.8]:

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingvllmUsing vLLM

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions