Skip to content

Commit 6e30991

Browse files
jiqing-fengLRL-ModelCloudQubitiumZX-ModelCloudstevhliu
authored
FEAT Add gptqmodel support (#2247)
Add support for gptqmodel quantization. This is a replacement for auto-gptq. For now, both packages are supported, but since auto-gptq is no longer being developed, it will be deprecated and removed at some point in the future. --------- Signed-off-by: jiqing-feng <[email protected]> Co-authored-by: LRL-ModelCloud <[email protected]> Co-authored-by: Qubitium-ModelCloud <[email protected]> Co-authored-by: ZX-ModelCloud <[email protected]> Co-authored-by: LRL <[email protected]> Co-authored-by: Steven Liu <[email protected]>
1 parent 1b9bcb2 commit 6e30991

File tree

11 files changed

+508
-37
lines changed

11 files changed

+508
-37
lines changed

Makefile

+1
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ tests_core_single_gpu:
3434
tests_common_gpu:
3535
python -m pytest tests/test_decoder_models.py $(if $(IS_GITHUB_CI),--report-log "common_decoder.log",)
3636
python -m pytest tests/test_encoder_decoder_models.py $(if $(IS_GITHUB_CI),--report-log "common_encoder_decoder.log",)
37+
python -m pytest tests/test_gptqmodel.py $(if $(IS_GITHUB_CI),--report-log "gptqmodel_gpu.log",)
3738

3839
tests_examples_multi_gpu_bnb:
3940
python -m pytest -m "multi_gpu_tests and bitsandbytes" tests/test_gpu_examples.py $(if $(IS_GITHUB_CI),--report-log "multi_gpu_examples.log",)

docs/source/developer_guides/quantization.md

+26
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,32 @@ QLoRA adds trainable weights to all the linear layers in the transformer archite
107107
config = LoraConfig(target_modules="all-linear", ...)
108108
```
109109

110+
## GPTQ quantization
111+
112+
You can learn more about gptq based `[2, 3, 4, 8]` bits quantization at [GPTQModel](https://github.com/ModelCloud/GPTQModel) and the Transformers [GPTQ](https://huggingface.co/docs/transformers/quantization/gptq) doc. Post-quant training, PEFT can use both [GPTQModel](https://github.com/ModelCloud/GPTQModel) or [AutoGPTQ](https://github.com/autogptq/autogptq) libraries, but we recommend GPTQModel because AutoGPTQ will be deprecated in a future release.
113+
114+
```bash
115+
# gptqmodel install
116+
pip install gptqmodel --no-build-isolation
117+
```
118+
119+
```py
120+
from transformers import AutoModelForCausalLM, AutoTokenizer, GPTQConfig
121+
122+
model_id = "facebook/opt-125m"
123+
tokenizer = AutoTokenizer.from_pretrained(model_id)
124+
125+
gptq_config = GPTQConfig(bits=4, group_size=128, dataset="wikitext2", tokenizer=tokenizer)
126+
127+
quantized_model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", quantization_config=gptq_config)
128+
129+
# save quantized model
130+
quantized_model.save_pretrained("./opt-125m-gptq")
131+
tokenizer.save_pretrained("./opt-125m-gptq")
132+
```
133+
134+
Once quantized, you can post-train GPTQ models with PEFT APIs.
135+
110136
## AQLM quantization
111137

112138
Additive Quantization of Language Models ([AQLM](https://arxiv.org/abs/2401.06118)) is a Large Language Models compression method. It quantizes multiple weights together and takes advantage of interdependencies between them. AQLM represents groups of 8-16 weights as a sum of multiple vector codes. This allows it to compress models down to as low as 2-bit with considerably low accuracy losses.

src/peft/import_utils.py

+27
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,33 @@ def is_auto_gptq_available():
4949
)
5050

5151

52+
@lru_cache
53+
def is_gptqmodel_available():
54+
if importlib.util.find_spec("gptqmodel") is not None:
55+
GPTQMODEL_MINIMUM_VERSION = packaging.version.parse("1.7.0")
56+
OPTIMUM_MINIMUM_VERSION = packaging.version.parse("1.23.99")
57+
version_gptqmodel = packaging.version.parse(importlib_metadata.version("gptqmodel"))
58+
if GPTQMODEL_MINIMUM_VERSION <= version_gptqmodel:
59+
if is_optimum_available():
60+
version_optimum = packaging.version.parse(importlib_metadata.version("optimum"))
61+
if OPTIMUM_MINIMUM_VERSION <= version_optimum:
62+
return True
63+
else:
64+
raise ImportError(
65+
f"gptqmodel requires optimum version {OPTIMUM_MINIMUM_VERSION} or higher. Found version {version_optimum}, "
66+
f"but only versions above {OPTIMUM_MINIMUM_VERSION} are supported"
67+
)
68+
else:
69+
raise ImportError(
70+
f"gptqmodel requires optimum version {OPTIMUM_MINIMUM_VERSION} or higher to be installed."
71+
)
72+
else:
73+
raise ImportError(
74+
f"Found an incompatible version of gptqmodel. Found version {version_gptqmodel}, "
75+
f"but only versions above {GPTQMODEL_MINIMUM_VERSION} are supported"
76+
)
77+
78+
5279
@lru_cache
5380
def is_optimum_available() -> bool:
5481
return importlib.util.find_spec("optimum") is not None

src/peft/tuners/adalora/model.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,15 @@
1717
import torch
1818
from transformers.pytorch_utils import Conv1D
1919

20-
from peft.import_utils import is_bnb_4bit_available, is_bnb_available
20+
from peft.import_utils import is_bnb_4bit_available, is_bnb_available, is_gptqmodel_available
2121
from peft.tuners.lora import LoraConfig, LoraModel
2222
from peft.tuners.tuners_utils import BaseTunerLayer
2323
from peft.utils import (
2424
TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING,
2525
_freeze_adapter,
2626
_get_submodules,
2727
get_auto_gptq_quant_linear,
28+
get_gptqmodel_quant_linear,
2829
get_quantization_config,
2930
)
3031
from peft.utils.integrations import gather_params_ctx
@@ -135,7 +136,8 @@ def _create_and_replace(
135136

136137
# If it is not an AdaLoraLayer, create a new module, else update it with new adapters
137138
if not isinstance(target, AdaLoraLayer):
138-
new_module = self._create_new_module(lora_config, adapter_name, target, **kwargs)
139+
device_map = self.model.hf_device_map if hasattr(self.model, "hf_device_map") else None
140+
new_module = self._create_new_module(lora_config, adapter_name, target, device_map=device_map, **kwargs)
139141
if adapter_name not in self.active_adapters:
140142
# adding an additional adapter: it is not automatically trainable
141143
new_module.requires_grad_(False)
@@ -150,7 +152,7 @@ def _create_and_replace(
150152
)
151153

152154
@staticmethod
153-
def _create_new_module(lora_config, adapter_name, target, **kwargs):
155+
def _create_new_module(lora_config, adapter_name, target, device_map=None, **kwargs):
154156
# avoid eager bnb import
155157
if is_bnb_available():
156158
import bitsandbytes as bnb
@@ -160,7 +162,11 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs):
160162
from .bnb import SVDLinear4bit
161163

162164
gptq_quantization_config = kwargs.get("gptq_quantization_config", None)
163-
AutoGPTQQuantLinear = get_auto_gptq_quant_linear(gptq_quantization_config)
165+
166+
if is_gptqmodel_available():
167+
QuantLinear = get_gptqmodel_quant_linear(gptq_quantization_config, device_map=device_map)
168+
else:
169+
QuantLinear = get_auto_gptq_quant_linear(gptq_quantization_config)
164170

165171
loaded_in_8bit = kwargs.pop("loaded_in_8bit", False)
166172
loaded_in_4bit = kwargs.pop("loaded_in_4bit", False)
@@ -189,7 +195,7 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs):
189195
}
190196
)
191197
new_module = SVDLinear4bit(target, adapter_name, **fourbit_kwargs)
192-
elif AutoGPTQQuantLinear is not None and isinstance(target, AutoGPTQQuantLinear):
198+
elif QuantLinear is not None and isinstance(target, QuantLinear):
193199
new_module = SVDQuantLinear(target, adapter_name, **kwargs)
194200
else:
195201
if isinstance(target_base_layer, torch.nn.Linear):

src/peft/tuners/lora/gptq.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,10 @@
1616

1717
import torch
1818

19+
from peft.import_utils import is_gptqmodel_available
1920
from peft.tuners.lora.layer import LoraLayer
2021
from peft.tuners.tuners_utils import BaseTunerLayer
21-
from peft.utils import get_auto_gptq_quant_linear
22+
from peft.utils import get_auto_gptq_quant_linear, get_gptqmodel_quant_linear
2223

2324

2425
class QuantLinear(torch.nn.Module, LoraLayer):
@@ -106,10 +107,15 @@ def dispatch_gptq(
106107
else:
107108
target_base_layer = target
108109

109-
gptq_quantization_config = kwargs.get("gptq_quantization_config", None)
110-
AutoGPTQQuantLinear = get_auto_gptq_quant_linear(gptq_quantization_config)
110+
cfg = kwargs.get("gptq_quantization_config", None)
111111

112-
if AutoGPTQQuantLinear is not None and isinstance(target_base_layer, AutoGPTQQuantLinear):
112+
if is_gptqmodel_available():
113+
device_map = kwargs.get("device_map", None)
114+
quant_linear = get_gptqmodel_quant_linear(cfg, device_map=device_map)
115+
else:
116+
quant_linear = get_auto_gptq_quant_linear(cfg)
117+
118+
if quant_linear is not None and isinstance(target_base_layer, quant_linear):
113119
new_module = QuantLinear(target, adapter_name, **kwargs)
114120
target.qweight = target_base_layer.qweight
115121

src/peft/tuners/lora/model.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,8 @@ def _create_and_replace(
232232
lora_bias=lora_config.lora_bias,
233233
)
234234
else:
235-
new_module = self._create_new_module(lora_config, adapter_name, target, **kwargs)
235+
device_map = self.model.hf_device_map if hasattr(self.model, "hf_device_map") else None
236+
new_module = self._create_new_module(lora_config, adapter_name, target, device_map=device_map, **kwargs)
236237
if adapter_name not in self.active_adapters:
237238
# adding an additional adapter: it is not automatically trainable
238239
new_module.requires_grad_(False)

src/peft/utils/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
bloom_model_postprocess_past_key_value,
4040
cast_mixed_precision_params,
4141
get_auto_gptq_quant_linear,
42+
get_gptqmodel_quant_linear,
4243
get_quantization_config,
4344
id_tensor_storage,
4445
infer_device,
@@ -77,6 +78,7 @@
7778
"bloom_model_postprocess_past_key_value",
7879
"cast_mixed_precision_params",
7980
"get_auto_gptq_quant_linear",
81+
"get_gptqmodel_quant_linear",
8082
"get_peft_model_state_dict",
8183
"get_quantization_config",
8284
"id_tensor_storage",

src/peft/utils/other.py

+66-23
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from packaging import version
3131
from safetensors.torch import storage_ptr, storage_size
3232

33-
from ..import_utils import is_auto_gptq_available, is_torch_tpu_available
33+
from ..import_utils import is_auto_gptq_available, is_gptqmodel_available, is_torch_tpu_available
3434
from .constants import (
3535
CONFIG_NAME,
3636
EMBEDDING_LAYER_NAMES,
@@ -610,30 +610,73 @@ def get_auto_gptq_quant_linear(gptq_quantization_config):
610610
"""
611611
Get the right AutoGPTQQuantLinear class based on the quantization config file
612612
"""
613-
if gptq_quantization_config is not None and is_auto_gptq_available():
613+
if gptq_quantization_config is None:
614+
return None
615+
616+
if is_auto_gptq_available():
614617
from auto_gptq.utils.import_utils import dynamically_import_QuantLinear
618+
else:
619+
return None
615620

616-
desc_act = gptq_quantization_config.desc_act
617-
group_size = gptq_quantization_config.group_size
618-
bits = gptq_quantization_config.bits
619-
if hasattr(gptq_quantization_config, "use_exllama"):
620-
use_exllama = gptq_quantization_config.use_exllama
621-
else:
622-
use_exllama = not gptq_quantization_config.disable_exllama
623-
if hasattr(gptq_quantization_config, "exllama_config"):
624-
exllama_version = gptq_quantization_config.exllama_config["version"]
625-
else:
626-
exllama_version = 1
627-
AutoGPTQQuantLinear = dynamically_import_QuantLinear(
628-
use_triton=False,
629-
desc_act=desc_act,
630-
group_size=group_size,
631-
bits=bits,
632-
disable_exllama=not (use_exllama and exllama_version == 1),
633-
disable_exllamav2=not (use_exllama and exllama_version == 2),
634-
)
635-
return AutoGPTQQuantLinear
636-
return None
621+
desc_act = gptq_quantization_config.desc_act
622+
group_size = gptq_quantization_config.group_size
623+
bits = gptq_quantization_config.bits
624+
if hasattr(gptq_quantization_config, "use_exllama"):
625+
use_exllama = gptq_quantization_config.use_exllama
626+
else:
627+
use_exllama = not gptq_quantization_config.disable_exllama
628+
if hasattr(gptq_quantization_config, "exllama_config"):
629+
exllama_version = gptq_quantization_config.exllama_config["version"]
630+
else:
631+
exllama_version = 1
632+
633+
QuantLinear = dynamically_import_QuantLinear(
634+
use_triton=False,
635+
desc_act=desc_act,
636+
group_size=group_size,
637+
bits=bits,
638+
disable_exllama=not (use_exllama and exllama_version == 1),
639+
disable_exllamav2=not (use_exllama and exllama_version == 2),
640+
)
641+
642+
return QuantLinear
643+
644+
645+
def get_gptqmodel_quant_linear(gptq_quantization_config, device_map=None):
646+
"""
647+
Get the right GPTQQuantLinear class based on the quantization config file
648+
"""
649+
if gptq_quantization_config is None:
650+
return None
651+
652+
if not is_gptqmodel_available():
653+
return None
654+
655+
from gptqmodel.utils.importer import hf_select_quant_linear
656+
657+
desc_act = gptq_quantization_config.desc_act
658+
group_size = gptq_quantization_config.group_size
659+
bits = gptq_quantization_config.bits
660+
checkpoint_format = (
661+
gptq_quantization_config.checkpoint_format
662+
if hasattr(gptq_quantization_config, "checkpoint_format")
663+
else "gptq"
664+
)
665+
sym = gptq_quantization_config.sym
666+
meta = gptq_quantization_config.meta if hasattr(gptq_quantization_config, "meta") else None
667+
668+
QuantLinear = hf_select_quant_linear(
669+
bits=bits,
670+
group_size=group_size,
671+
desc_act=desc_act,
672+
sym=sym,
673+
device_map=device_map,
674+
checkpoint_format=checkpoint_format,
675+
meta=meta,
676+
backend="auto_trainable",
677+
)
678+
679+
return QuantLinear
637680

638681

639682
def id_tensor_storage(tensor: torch.Tensor) -> tuple[torch.device, int, int]:

tests/test_common_gpu.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -406,19 +406,19 @@ def test_lora_gptq_quantization_from_pretrained_safetensors(self):
406406

407407
config = LoraConfig(task_type="CAUSAL_LM")
408408
peft_model = get_peft_model(model, config)
409-
peft_model.generate(input_ids=torch.LongTensor([[0, 2, 3, 1]]).to(0))
409+
peft_model.generate(input_ids=torch.LongTensor([[0, 2, 3, 1]]).to(peft_model.device))
410410

411411
with tempfile.TemporaryDirectory() as tmp_dir:
412412
peft_model.save_pretrained(tmp_dir)
413413
model = AutoModelForCausalLM.from_pretrained(**kwargs)
414414
model = PeftModel.from_pretrained(model, tmp_dir)
415415
model = prepare_model_for_kbit_training(model)
416-
model.generate(input_ids=torch.LongTensor([[0, 2, 3, 1]]).to(0))
416+
model.generate(input_ids=torch.LongTensor([[0, 2, 3, 1]]).to(peft_model.device))
417417

418418
# loading a 2nd adapter works, #1239
419419
model.load_adapter(tmp_dir, "adapter2")
420420
model.set_adapter("adapter2")
421-
model.generate(input_ids=torch.LongTensor([[0, 2, 3, 1]]).to(0))
421+
model.generate(input_ids=torch.LongTensor([[0, 2, 3, 1]]).to(peft_model.device))
422422

423423
# check that both adapters are in the same layer
424424
assert "default" in model.base_model.model.model.decoder.layers[0].self_attn.q_proj.lora_A

0 commit comments

Comments
 (0)