Skip to content

How to run AWQ-W4Afp8 quantization? #1368

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wanzhenchn opened this issue Apr 22, 2025 · 2 comments
Open

How to run AWQ-W4Afp8 quantization? #1368

wanzhenchn opened this issue Apr 22, 2025 · 2 comments
Assignees

Comments

@wanzhenchn
Copy link

wanzhenchn commented Apr 22, 2025

How to run AWQ-W4Afp8 quantization on MoE models?

I have run awq-w4afp8 quantization on Qwen1.5-MoE-A2.7B, however, the ValueError occurred below

Image

# llmcompressor and compressed-tensors are both installed from source code of main barch.

from compressed_tensors.quantization import (
    QuantizationArgs,
    QuantizationScheme,
    QuantizationStrategy,
    QuantizationType,
)

from llmcompressor.modifiers.awq import AWQModifier
from llmcompressor.modifiers.quantization import QuantizationModifier

model_path="Qwen/Qwen1.5-MoE-A2.7B"

tokenizer = AutoTokenizer.from_pretrained(
model_path, trust_remote_code=True)

model = AutoModelForCausalLM.from_pretrained(
model_path, device_map="auto", torch_dtype="auto",
trust_remote_code=True)

ignore_layers = ["lm_head", "re:.*mlp.gate", "re:.*mlp.shared_expert_gate"]

recipe = [
    AWQModifier(bits=4, symmetric=False),
    QuantizationModifier(
        ignore=ignore_layers,
        config_groups={
            "group_0": QuantizationScheme(
                targets=["Linear"],
                weights=QuantizationArgs(
                    num_bits=4,
                    type=QuantizationType.INT,
                    dynamic=False,
                    symmetric=False,
                    strategy=QuantizationStrategy.GROUP,
                    group_size=128,
                ),
                input_activations=QuantizationArgs(
                    num_bits=8,
                    type=QuantizationType.FLOAT,
                    strategy=QuantizationStrategy.TENSOR,
                    dynamic=False,
                    symmetric=True,
                ),
            ),
        },
    ),
]


oneshot(
    model=model,
    tokenizer=tokenizer,
    dataset="open_playpus",
    recipe=self.recipe,
    max_seq_length=2948,
    num_calibration_samples=2,
    save_compressed=True,
    trust_remote_code_model=True,
    output_dir=self.saved_path,
)

tokenizer.save_pretrained
@brian-dellabetta
Copy link
Collaborator

Hi @wanzhenchn , thanks for taking an interest in our AWQ feature! We have merged most of the AWQ logic but we have a few TODOs related to the issues you are hitting (here and here). We wanted to add these in a separate PR so that the initial PR is largely a port of the code in AutoAWQ, and so we have an example for how additional mappings can be added for other architectures.

We will wrap this up in the next couple weeks and make a release and more public announcement that AWQ is ready for consumption.

@brian-dellabetta brian-dellabetta self-assigned this Apr 22, 2025
@wanzhenchn
Copy link
Author

Hi @wanzhenchn , thanks for taking an interest in our AWQ feature! We have merged most of the AWQ logic but we have a few TODOs related to the issues you are hitting (here and here). We wanted to add these in a separate PR so that the initial PR is largely a port of the code in AutoAWQ, and so we have an example for how additional mappings can be added for other architectures.

We will wrap this up in the next couple weeks and make a release and more public announcement that AWQ is ready for consumption.

Thanks for your feedback, looking forward to AWQ supporting more models.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants