Skip to content

Sidestep NaN caused by div by zero where calculated scale == 0 #1529

@ywlq

Description

@ywlq

I'm trying to quantize the Qwen2-57B-A14B model using the following configuration:

pretrained_model_id = "/mnt/82_store/xj/modelzoo/qwen/Qwen/Qwen2-57B-A14B"
quantized_model_id = "/mnt/82_store/xj/GPTQModel/quantmodel/Qwen2-57B-A14B-2bit-128g-c4

 calibration_dataset = load_dataset(
     "allenai/c4",
     data_files="en/c4-train.00001-of-01024.json.gz",
     split="train"
   ).select(range(1024))["text"]

quantize_config = QuantizeConfig(
    bits=2,
    group_size=128,
    dynamic={
         r".*mlp\.shared_expert.*": { "bits": 8, "group_size": 128 },
         r".*mlp\.experts.*down_proj.*": { "bits": 8, "group_size": 128 },
     }
)

model = GPTQModel.load(pretrained_model_id, quantize_config)
model.quantize(calibration_dataset, auto_gc=False, batch_size=4)
model.save(quantized_model_id)

it dont work :(

However, the quantization fails with NaN errors. After tracing the issue, I found that it originates from this function during weight quantization:

def quantize(x, scale, zero, maxq, requires_groupwise_processing: bool):
    if maxq < 0:
        return (x > scale / 2).float() * scale + (x < zero / 2).float() * zero
    if requires_groupwise_processing:
        q = torch.clamp(torch.round(x / scale), -maxq, maxq)
        return scale * q
    else:
        q = torch.clamp(torch.round(x / scale) + zero, 0, maxq)
        return scale * (q - zero)

The problem occurs when scale == 0, which leads to division by zero and causes NaN. I temporarily fixed it with this patch:

def quantize(x, scale, zero, maxq, requires_groupwise_processing: bool):
    if maxq < 0:
        return (x > scale / 2).float() * scale + (x < zero / 2).float() * zero
    if requires_groupwise_processing:
        scale = torch.where(scale == 0, torch.tensor(1e-8, device=scale.device), scale)  # 防止除以零
        q = torch.clamp(torch.round(x / scale), -maxq, maxq)
        return scale * q
    else:
        scale = torch.where(scale == 0, torch.tensor(1e-8, device=scale.device), scale)  # 防止除以零
        q = torch.clamp(torch.round(x / scale) + zero, 0, maxq)
        return scale * (q - zero)

With this change, the quantization works well : )

Is this a valid fix, or is there a better or more correct way to handle nan loss during quantization?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions