Skip to content

Commit 003cf20

Browse files
authored
FEAT Add LoRA INC support (#2499)
Add LoRA Adds Intel Neural Compressor. --------- Signed-off-by: Daniel Socek <[email protected]>
1 parent 453a6ff commit 003cf20

File tree

6 files changed

+195
-1
lines changed

6 files changed

+195
-1
lines changed

docs/source/developer_guides/quantization.md

+40-1
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ model = get_peft_model(model, config)
192192

193193
## HQQ quantization
194194

195-
The models that is quantized using Half-Quadratic Quantization of Large Machine Learning Models ([HQQ](https://mobiusml.github.io/hqq_blog/)) support LoRA adapter tuning. To tune the quantized model, you'll need to install the `hqq` library with: `pip install hqq`.
195+
The models that are quantized using Half-Quadratic Quantization of Large Machine Learning Models ([HQQ](https://mobiusml.github.io/hqq_blog/)) support LoRA adapter tuning. To tune the quantized model, you'll need to install the `hqq` library with: `pip install hqq`.
196196

197197
```python
198198
from hqq.engine.hf import HQQModelForCausalLM
@@ -237,6 +237,45 @@ model = get_peft_model(base_model, peft_config)
237237
- DoRA only works with `quant_type = "int8_weight_only"` at the moment.
238238
- There is explicit support for torchao when used with LoRA. However, when torchao quantizes a layer, its class does not change, only the type of the underlying tensor. For this reason, PEFT methods other than LoRA will generally also work with torchao, even if not explicitly supported. Be aware, however, that **merging only works correctly with LoRA and with `quant_type = "int8_weight_only"`**. If you use a different PEFT method or dtype, merging will likely result in an error, and even it doesn't, the results will still be incorrect.
239239

240+
## INC quantization
241+
242+
Intel Neural Compressor ([INC](https://github.com/intel/neural-compressor)) enables model quantization for various devices,
243+
including Intel Gaudi accelerators (also known as HPU devices). You can perform LoRA fine-tuning on models that have been
244+
quantized using INC. To use INC with PyTorch models, install the library with: `pip install neural-compressor[pt]`.
245+
Quantizing a model to FP8 precision for HPU devices can be done with the following single-step quantization workflow:
246+
247+
```python
248+
import torch
249+
from neural_compressor.torch.quantization import FP8Config, convert, finalize_calibration, prepare
250+
quant_configs = {
251+
...
252+
}
253+
config = FP8Config(**quant_configs)
254+
```
255+
256+
Pass the config to the `prepare` method, run inference to gather calibration stats, and call `finalize_calibration`
257+
and `convert` methods to quantize model to FP8 precision:
258+
259+
```python
260+
model = prepare(model, config)
261+
# Run inference to collect calibration statistics
262+
...
263+
# Finalize calibration and convert the model to FP8 precision
264+
finalize_calibration(model)
265+
model = convert(model)
266+
# Load PEFT LoRA adapter as usual
267+
...
268+
```
269+
270+
An example demonstrating how to load a PEFT LoRA adapter into an INC-quantized FLUX text-to-image model for HPU
271+
devices is provided [here](https://github.com/huggingface/peft/blob/main/examples/stable_diffusion/inc_flux_lora_hpu.py).
272+
273+
274+
### Caveats:
275+
276+
- `merge()` and `unmerge()` methods are currently not supported for INC-quantized models.
277+
- Currently, only **Linear** INC-quantized layers are supported when loading PEFT adapters.
278+
240279
## Other Supported PEFT Methods
241280

242281
Besides LoRA, the following PEFT methods also support quantization:
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
"""
2+
This exampe demonstrates loading of LoRA adapter (via PEFT) into an FP8 INC-quantized FLUX model.
3+
4+
More info on Intel Neural Compressor (INC) FP8 quantization is available at:
5+
https://github.com/intel/neural-compressor/tree/master/examples/helloworld/fp8_example
6+
7+
Requirements:
8+
pip install optimum-habana sentencepiece neural-compressor[pt] peft
9+
"""
10+
11+
import importlib
12+
13+
import torch
14+
from neural_compressor.torch.quantization import FP8Config, convert, finalize_calibration, prepare
15+
16+
17+
# Checks if HPU device is available
18+
# Adapted from https://github.com/huggingface/accelerate/blob/b451956fd69a135efc283aadaa478f0d33fcbe6a/src/accelerate/utils/imports.py#L435
19+
def is_hpu_available():
20+
if (
21+
importlib.util.find_spec("habana_frameworks") is None
22+
or importlib.util.find_spec("habana_frameworks.torch") is None
23+
):
24+
return False
25+
26+
import habana_frameworks.torch # noqa: F401
27+
28+
return hasattr(torch, "hpu") and torch.hpu.is_available()
29+
30+
31+
# Ensure HPU device is available before proceeding
32+
if is_hpu_available():
33+
from optimum.habana.diffusers import GaudiFluxPipeline
34+
else:
35+
raise RuntimeError("HPU device not found. This code requires Intel Gaudi device to run.")
36+
37+
# Example: FLUX model inference on HPU via optimum-habana pipeline
38+
hpu_configs = {
39+
"use_habana": True,
40+
"use_hpu_graphs": True,
41+
"sdp_on_bf16": True,
42+
"gaudi_config": "Habana/stable-diffusion",
43+
}
44+
pipe = GaudiFluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, **hpu_configs)
45+
prompt = "A picture of sks dog in a bucket"
46+
47+
# Quantize FLUX transformer to FP8 using INC (Intel Neural Compressor)
48+
quant_configs = {
49+
"mode": "AUTO",
50+
"observer": "maxabs",
51+
"scale_method": "maxabs_hw",
52+
"allowlist": {"types": [], "names": []},
53+
"blocklist": {"types": [], "names": []},
54+
"dump_stats_path": "/tmp/hqt_output/measure",
55+
}
56+
config = FP8Config(**quant_configs)
57+
pipe.transformer = prepare(pipe.transformer, config)
58+
pipe(prompt)
59+
finalize_calibration(pipe.transformer)
60+
pipe.transformer = convert(pipe.transformer)
61+
62+
# Load LoRA weights with PEFT
63+
pipe.load_lora_weights("dsocek/lora-flux-dog", adapter_name="user_lora")
64+
65+
# Run inference
66+
image = pipe(prompt).images[0]
67+
image.save("dog.png")

src/peft/import_utils.py

+5
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,11 @@ def is_hqq_available():
118118
return importlib.util.find_spec("hqq") is not None
119119

120120

121+
@lru_cache
122+
def is_inc_available():
123+
return importlib.util.find_spec("neural_compressor") is not None
124+
125+
121126
@lru_cache
122127
def is_torchao_available():
123128
if importlib.util.find_spec("torchao") is None:

src/peft/tuners/lora/inc.py

+78
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Copyright 2025-present the HuggingFace Inc. team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# NOTE: PEFT tests related to INC are handled under Optimum-Habana repository:
16+
# - LLMs: https://github.com/huggingface/optimum-habana/blob/main/tests/test_peft_inference.py
17+
# - Diffusers: https://github.com/huggingface/optimum-habana/blob/main/tests/test_diffusers.py
18+
19+
from typing import Optional
20+
21+
import torch
22+
23+
from peft.import_utils import is_inc_available
24+
from peft.tuners.tuners_utils import BaseTunerLayer
25+
26+
from .layer import Linear
27+
28+
29+
if is_inc_available():
30+
31+
class IncLoraLinear(Linear):
32+
def __init__(
33+
self,
34+
base_layer: torch.nn.Module,
35+
adapter_name: str,
36+
**kwargs,
37+
):
38+
super().__init__(base_layer, adapter_name, **kwargs)
39+
40+
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
41+
"""
42+
Merge the active adapter weights into the base weights
43+
44+
Args:
45+
safe_merge (`bool`, *optional*):
46+
If True, the merge operation will be performed in a copy of the original weights and check for NaNs
47+
before merging the weights. This is useful if you want to check if the merge operation will produce
48+
NaNs. Defaults to `False`.
49+
adapter_names (`list[str]`, *optional*):
50+
The list of adapter names that should be merged. If None, all active adapters will be merged.
51+
Defaults to `None`.
52+
"""
53+
raise NotImplementedError("Merging LoRA with INC layers is not yet implemented")
54+
55+
def unmerge(self) -> None:
56+
"""
57+
This method unmerges all merged adapter layers from the base weights.
58+
"""
59+
raise NotImplementedError("Unmerging LoRA from INC layers is not yet implemented")
60+
61+
62+
def dispatch_inc(target: torch.nn.Module, adapter_name: str, **kwargs):
63+
new_module = None
64+
65+
if isinstance(target, BaseTunerLayer):
66+
target_base_layer = target.get_base_layer()
67+
else:
68+
target_base_layer = target
69+
70+
if is_inc_available():
71+
from neural_compressor.torch.algorithms.fp8_quant._quant_common.helper_modules import (
72+
PatchedLinear,
73+
)
74+
75+
if isinstance(target_base_layer, PatchedLinear):
76+
new_module = IncLoraLinear(target, adapter_name, **kwargs)
77+
78+
return new_module

src/peft/tuners/lora/layer.py

+3
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,9 @@ def __init__(self, base_layer: nn.Module, ephemeral_gpu_offload: bool = False, *
145145
elif hasattr(base_layer, "W_q") and base_layer.__class__.__name__ == "HQQLinear":
146146
# HQQ layers
147147
in_features, out_features = base_layer.in_features, base_layer.out_features
148+
elif base_layer.__class__.__name__ == "PatchedLinear":
149+
# INC layers
150+
in_features, out_features = base_layer.in_features, base_layer.out_features
148151
else:
149152
# possibly support user provided custom layer types using dynamic dispatch
150153
if hasattr(base_layer, "in_features") and hasattr(base_layer, "out_features"):

src/peft/tuners/lora/model.py

+2
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
from .eetq import dispatch_eetq
5353
from .gptq import dispatch_gptq
5454
from .hqq import dispatch_hqq
55+
from .inc import dispatch_inc
5556
from .layer import Conv2d, LoraLayer, dispatch_default
5657
from .torchao import dispatch_torchao
5758
from .tp_layer import dispatch_megatron
@@ -331,6 +332,7 @@ def dynamic_dispatch_func(target, adapter_name, lora_config, **kwargs):
331332
dispatch_awq,
332333
dispatch_gptq,
333334
dispatch_hqq,
335+
dispatch_inc,
334336
dispatch_torchao,
335337
dispatch_megatron,
336338
dispatch_default,

0 commit comments

Comments
 (0)