Skip to content

[Benchmark] Reproduce GPTQv2 results #1545

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
eldarkurtic opened this issue Apr 16, 2025 · 7 comments
Open

[Benchmark] Reproduce GPTQv2 results #1545

eldarkurtic opened this issue Apr 16, 2025 · 7 comments
Labels
bug Something isn't working

Comments

@eldarkurtic
Copy link

Hi, I would like to reproduce GPTQv2 W4g128 evals shown in the README.

Image

Could you help me by:

  1. releasing the model or providing the exact command to recreate the model
  2. providing lm-evaluation-harness command to reproduce the evaluation process
@eldarkurtic eldarkurtic added the bug Something isn't working label Apr 16, 2025
@Qubitium
Copy link
Collaborator

Qubitium commented Apr 16, 2025

Hi @eldarkurtic Use tests/models/test_llama3_2.py as base and then make following changes:

  • Change model to Llama 3.1-8B-Instruct
  • Override V2 =True in test and QUANTIZE_CONFIG_BITS = 3 for 3/v2.
  • Add/override TASK_NAME = [EVAL.LM_EVAL.ARC_CHALLENGE, EVAL.LM_EVAL.GSM8K_PLATINUM_COT] it uses GPTQModel (marlin) kernel to run the tests. 3bit tests will auto use the only kernel that is usable (Torch) so it will be very slow for eval.
  • Make sure you install the main version of lm-eval as the GSM8K_PLATINUM Pr was merged but not released.

That's the script I used with above changes to create the test result in the readme bench.

I commented the 4 lines I changed. Everything else is same. Tests were performed on A100.

class TestLlama3_2(ModelTest):
    NATIVE_MODEL_ID = "meta/Llama-3.1-8B-Instruct"  # change
    NATIVE_ARC_CHALLENGE_ACC = 0.3567 
    NATIVE_ARC_CHALLENGE_ACC_NORM = 0.3805
    QUANT_ARC_MAX_DELTA_FLOOR_PERCENT = 0.36
    APPLY_CHAT_TEMPLATE = True
    V2 = True # change
    QUANTIZE_CONFIG_BITS = 4 # change
    TASK_NAME = [EVAL.LM_EVAL.ARC_CHALLENGE, EVAL.LM_EVAL.GSM8K_PLATINUM_COT] # change

    def test_llama3_2(self):
        self.quant_lm_eval()

Let me know if you run into issues. Above is 100% of the script used and you should be able to replicate.

@eldarkurtic
Copy link
Author

Thanks a lot, will give it a try. Do you by any chance have the models already available somewhere? Are GPTQv2 models runnable in vLLM?

@Qubitium Qubitium changed the title [BUG] How to reproduce GPTQv2 results from the README.md ? [Benchmark] Reproduce GPTQv2 results Apr 16, 2025
@Qubitium
Copy link
Collaborator

Qubitium commented Apr 16, 2025

@eldarkurtic I did not store the quantized models, I should. Let me add this to my to-do. GPTQ v2 is only different in the quantization process and output is 100% gptq compliant so it will run on all kernels/inference engines that supports gptq (v1) including vllm and sglang.

@eldarkurtic
Copy link
Author

Any chance you could share this GPTQv2 W4g128 model?

@Qubitium
Copy link
Collaborator

Any chance you could share this GPTQv2 W4g128 model?

Yes. I will requant and push the 4 models (3/4 + v2/v1) models to HF later today.

@Qubitium
Copy link
Collaborator

Qubitium commented Apr 17, 2025

Env:

HW: A100
Driver: 570
Cuda 12.8

Name: torch
Version: 2.8.0.dev20250415+cu128

Name: transformers
Version: 4.51.2

Reproduced models, quant script, and eval script:

https://huggingface.co/ModelCloud/GPTQ-v1-Llama-3.1-8B-Instruct
https://huggingface.co/ModelCloud/GPTQ-v2-Llama-3.1-8B-Instruct

I included updated benchmark results in the HF repos since this quantized model has slightly different results but the monster difference between v1 v2 difference for GSM8K_PLATINUM is still there.

Quant is use aforementioned code with c4/en, 256 samples, goup size 128.

I posted the exact code used to evaluate the mode. I use GPTQModel's close integration with lm-eval with Marlin kernel for inference on A100. GPTQModel main + lm-eval main branch.

# eval
from lm_eval.tasks import TaskManager
from lm_eval.utils import make_table

with tempfile.TemporaryDirectory() as tmp_dir:
    results = GPTQModel.eval(
        QUANT_SAVE_PATH,
        tasks=[EVAL.LM_EVAL.ARC_CHALLENGE, EVAL.LM_EVAL.GSM8K_PLATINUM_COT],
        apply_chat_template=True,
        random_seed=898,
        output_path= tmp_dir,
    )

    print(make_table(results))
    if "groups" in results:
        print(make_table(results, "groups"))

v1:

Tasks Version Filter n-shot Metric Value Stderr
arc_challenge 1 none 0 acc 0.5000 ± 0.0146
none 0 acc_norm 0.5128 ± 0.0146
gsm8k_platinum_cot 3 flexible-extract 8 exact_match 0.3995 ± 0.0141
strict-match 8 exact_match 0.2548 ± 0.0125

v2:

Tasks Version Filter n-shot Metric Value Stderr
arc_challenge 1 none 0 acc 0.5034 ± 0.0146
none 0 acc_norm 0.5068 ± 0.0146
gsm8k_platinum_cot 3 flexible-extract 8 exact_match 0.7601 ± 0.0123
strict-match 8 exact_match 0.5211 ± 0.0144

So v2 offers slightly higher accuracy when only measuring by logic probability like PPL/ARC but appears to offer substantial improvement for generation based benchmarks like GS8K_Platinum.

Full quantization code below:

import tempfile

from datasets import load_dataset
from gptqmodel import GPTQModel, QuantizeConfig
from gptqmodel.quantization import FORMAT
from gptqmodel.utils.eval import EVAL
from logbar import LogBar

log = LogBar.shared()

MODEL_ID = "meta-llama/Llama-3.1-8B-Instruct"
CFG_BITS = 4
CFG_GROUPSIZE = 128
CFG_V2 = True
INPUTS_MAX_LENGTH = 2048 # in tokens
QUANT_SAVE_PATH = f"/your_path/gptq_v2_{CFG_V2}_bit_{CFG_BITS}_gpsize_{CFG_GROUPSIZE}_llama_3.1_8B_Instruct"

def get_calib_data(tokenizer, rows: int):
    # calibration_dataset = load_dataset(
    #     "allenai/c4",
    #     data_files="en/c4-train.00000-of-01024.json.gz",
    #     split="train"
    # )

    calibration_dataset = load_dataset(
        "json",
        data_files="/your_path/dataset/c4-train.00000-of-01024.json.gz",
        split="train")

    datas = []
    for index, sample in enumerate(calibration_dataset):
        tokenized = tokenizer(sample["text"])
        if len(tokenized.data['input_ids']) <= INPUTS_MAX_LENGTH:
            datas.append(tokenized)
            if len(datas) >= rows:
                break

    return datas

quant_config = QuantizeConfig(
    bits=CFG_BITS,
    group_size=CFG_GROUPSIZE,
    format=FORMAT.GPTQ,
    desc_act=True,
    sym=True,
    v2=CFG_V2,
)

log.info(f"QuantConfig: {quant_config}")
log.info(f"Save Path: {QUANT_SAVE_PATH}")

# load un-quantized native model
model = GPTQModel.load(MODEL_ID, quant_config)

# load calibration data
calibration_dataset = get_calib_data(tokenizer=model.tokenizer, rows=256)

model.quantize(calibration_dataset, batch_size=1)

model.save(QUANT_SAVE_PATH)
log.info(f"Quant Model Saved to: {QUANT_SAVE_PATH}")

Note: Both v1 and v2 quant and eval are produced via the above.

@Stonesjtu
Copy link

I found the losses of v2 are much higher than the losses of v1. Does it make sense to directly compare them?

And them samples counts are different while I only change the quantconfig(v2=True)

loss v1

Image

loss v2

Image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants