Skip to content

Commit 2b901ee

Browse files
yxli2123pacman100BenjaminBossan
authored
Add LoftQ initialization method for LoRA (#1150)
--------- Co-authored-by: Sourab Mangrulkar <[email protected]> Co-authored-by: Benjamin Bossan <[email protected]>
1 parent 8298f1a commit 2b901ee

File tree

12 files changed

+1514
-12
lines changed

12 files changed

+1514
-12
lines changed

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ Supported methods:
3434
7. MultiTask Prompt Tuning: [Multitask Prompt Tuning Enables Parameter-Efficient Transfer Learning](https://arxiv.org/abs/2303.02861)
3535
8. LoHa: [FedPara: Low-Rank Hadamard Product for Communication-Efficient Federated Learning](https://arxiv.org/abs/2108.06098)
3636
9. LoKr: [KronA: Parameter Efficient Tuning with Kronecker Adapter](https://arxiv.org/abs/2212.10650) based on [Navigating Text-To-Image Customization:From LyCORIS Fine-Tuning to Model Evaluation](https://arxiv.org/abs/2309.14859) implementation
37+
10. LoftQ: [LoftQ: LoRA-Fine-Tuning-aware Quantization for Large Language Models](https://arxiv.org/abs/2310.08659)
3738

3839
## Getting started
3940

examples/loftq_finetuning/README.md

+69
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# LoftQ: LoRA-fine-tuning-aware Quantization
2+
3+
## Introduction
4+
5+
LoftQ provides better initialization for LoRA adapters A and B,
6+
and the Quantization of pre-trained weights W.
7+
8+
## Quantization
9+
We recommend to save the quantized backbone model as fp16/fp32
10+
and load it as [NormalFloat4](https://arxiv.org/abs/2305.14314).
11+
12+
We provide a simple example to show how to quantize llama-2-7b model and save/load it.
13+
14+
```sh
15+
python quantize_save_load.py \
16+
--model_name_or_path meta-llama/Llama-2-7b-hf \
17+
--token HF_TOKEN \
18+
--bits 4 --iter 5 --rank 16 \
19+
--save_dir model_zoo/loftq/
20+
```
21+
22+
- `HF_TOKEN` is the token used to access to [LLAMA models](https://huggingface.co/meta-llama).
23+
- `quantize_and_save()` function will quantize the backbone and initialize LoRA adapters.
24+
It creates 2 folders under `$save_dir`. The quantized backbone is at `Llama-2-7b-hf-4bit-16rank`,
25+
and the LoRA adapters are at the sub-folder `Llama-2-7b-hf-4bit-16rank/loftq_init`.
26+
27+
## Fine-tuning
28+
29+
Here is an example to load the quantized backbone and LoRA adapters:
30+
31+
```python
32+
import os
33+
34+
from transformers import AutoModelForCausalLM
35+
from peft import PeftModel
36+
37+
38+
base_model = AutoModelForCausalLM.from_pretrained(
39+
os.path.join(args.save_dir, "Llama-2-7b-hf-4bit-16rank"),
40+
load_in_4bit=True,
41+
)
42+
peft_model = PeftModel.from_pretrained(
43+
base_model,
44+
os.path.join(args.save_dir, "Llama-2-7b-hf-4bit-16rank", "loftq_init"),
45+
is_trainable=True,
46+
)
47+
```
48+
49+
We also provide an example to fine-tune LoftQ on GSM8K.
50+
We load the quantized backbone and LoRA adapters from the [LoftQ Huggingface hub](https://huggingface.co/LoftQ).
51+
52+
```sh
53+
python train_gsm8k_llama.py \
54+
--model_name_or_path LoftQ/Llama-2-7b-hf-4bit-64rank \
55+
--output_dir exp_results/gsm8k/llama-2-7b/bit4-rank64/lr3e-4 \
56+
--learning_rate 3e-4 \
57+
--seed 202 \
58+
--dataset_name gsm8k \
59+
--dataset_config main \
60+
--pad_to_max_length \
61+
--max_source_length 128 \
62+
--max_target_length 256 \
63+
--num_train_epochs 5 \
64+
--per_device_train_batch_size 4 \
65+
--per_device_eval_batch_size 4 \
66+
--gradient_accumulation_steps 4 \
67+
--with_tracking \
68+
--report_to tensorboard
69+
```
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,244 @@
1+
# coding=utf-8
2+
# Copyright 2023-present the HuggingFace Inc. team.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import argparse
17+
import os
18+
19+
import torch
20+
import torch.nn as nn
21+
from transformers import (
22+
AutoModelForCausalLM,
23+
AutoModelForSeq2SeqLM,
24+
AutoModelForSequenceClassification,
25+
AutoTokenizer,
26+
BitsAndBytesConfig,
27+
)
28+
29+
from peft import LoftQConfig, LoraConfig, PeftModel, TaskType, get_peft_model
30+
31+
32+
class Shell(nn.Module):
33+
def __init__(self, weight, bias=None):
34+
super().__init__()
35+
self.weight = nn.Parameter(weight, requires_grad=False)
36+
if bias is not None:
37+
self.bias = nn.Parameter(bias, requires_grad=False)
38+
39+
40+
def unwarap_model(model, sub_module_name=".base_layer"):
41+
sub_module_name_list = [k.split(sub_module_name)[0] for k in model.state_dict().keys() if sub_module_name in k]
42+
sub_module_name_set = set(sub_module_name_list)
43+
for name in sub_module_name_set:
44+
# get the parent of the submodule
45+
name_parent = ".".join(name.split(".")[:-1])
46+
name_child = name.split(".")[-1]
47+
sub_module = model.get_submodule(name_parent)
48+
print(sub_module)
49+
50+
# replace with shell
51+
child = getattr(sub_module, name_child)
52+
weight = getattr(child.base_layer, "weight", None)
53+
bias = getattr(child.base_layer, "bias", None)
54+
shell = Shell(weight, bias)
55+
56+
setattr(sub_module, name_child, shell)
57+
58+
print("You have unwrapped the model. Use it on your own risk.")
59+
60+
61+
def print_model(model, name):
62+
print("=" * 10 + name + "=" * 10)
63+
print(model)
64+
for name, param in model.named_parameters():
65+
if torch.is_tensor(param):
66+
if param.dtype in [torch.float32, torch.float16]:
67+
print(
68+
name,
69+
param.shape,
70+
param.device,
71+
param.dtype,
72+
param.requires_grad,
73+
param.mean().item(),
74+
param.max().item(),
75+
)
76+
else:
77+
print(name, param.shape, param.device, param.dtype, param.requires_grad)
78+
79+
80+
def arg_parse():
81+
parser = argparse.ArgumentParser(description="Quantize a model with LoftQ.")
82+
parser.add_argument(
83+
"--model_name_or_path",
84+
type=str,
85+
default=None,
86+
required=True,
87+
help="The name or path of the fp32/16 model.",
88+
)
89+
parser.add_argument(
90+
"--token",
91+
type=str,
92+
default=None,
93+
help="The access token to download model from HuggingFace Hub.",
94+
)
95+
parser.add_argument(
96+
"--bits",
97+
type=int,
98+
default=4,
99+
help="The quantized bits",
100+
)
101+
parser.add_argument(
102+
"--iter",
103+
type=int,
104+
default=1,
105+
help="The alternating steps in LoftQ",
106+
)
107+
parser.add_argument(
108+
"--rank",
109+
type=int,
110+
default=16,
111+
help="The rank of the LoRA adapter",
112+
)
113+
parser.add_argument(
114+
"--save_dir",
115+
type=str,
116+
default="./model_zoo/loftq/",
117+
help="The rank of the LoRA adapter",
118+
)
119+
args = parser.parse_args()
120+
return args
121+
122+
123+
def quantize_and_save():
124+
args = arg_parse()
125+
126+
# Download weights and configure LoRA
127+
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, token=args.token, trust_remote_code=True)
128+
if any(name in args.model_name_or_path.lower() for name in ["llama", "mistral", "falcon"]):
129+
model = AutoModelForCausalLM.from_pretrained(
130+
args.model_name_or_path, token=args.token, trust_remote_code=True, device_map="auto"
131+
)
132+
task_type = TaskType.CAUSAL_LM
133+
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "down_proj", "gate_proj"]
134+
135+
elif any(name in args.model_name_or_path.lower() for name in ["bart", "t5"]):
136+
model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name_or_path, token=args.token, device_map="auto")
137+
task_type = TaskType.SEQ_2_SEQ_LM
138+
target_modules = ["q_proj", "k_proj", "v_proj", "fc1", "fc2", "out_proj"]
139+
140+
elif any(name in args.model_name_or_path.lower() for name in ["deberta", "roberta", "bert"]):
141+
model = AutoModelForSequenceClassification.from_pretrained(args.model_name_or_path, token=args.token)
142+
model = model.cuda()
143+
task_type = TaskType.SEQ_CLS
144+
target_modules = ["query_proj", "key_proj", "value_proj", "dense"] # embeddings not supported by peft
145+
else:
146+
raise NotImplementedError("Other models not supported yet.")
147+
148+
# Config of LoftQ
149+
loftq_config = LoftQConfig(loftq_bits=args.bits, loftq_iter=args.iter)
150+
151+
lora_config = LoraConfig(
152+
task_type=task_type,
153+
inference_mode=True,
154+
r=args.rank,
155+
lora_alpha=16 if task_type is TaskType.CAUSAL_LM else args.rank,
156+
lora_dropout=0.1,
157+
target_modules=target_modules,
158+
init_lora_weights="loftq",
159+
loftq_config=loftq_config,
160+
)
161+
162+
# Obtain LoftQ model
163+
lora_model = get_peft_model(model, lora_config)
164+
base_model = lora_model.get_base_model()
165+
166+
# Save LoftQ model
167+
model_name = args.model_name_or_path.split("/")[-1] + f"-{args.bits}bit" + f"-{args.rank}rank"
168+
base_model_dir = os.path.join(args.save_dir, model_name)
169+
lora_model_dir = os.path.join(args.save_dir, model_name, "loft_init")
170+
171+
# save lora adapters first
172+
lora_model.base_model.peft_config[
173+
"default"
174+
].base_model_name_or_path = base_model_dir # This can be a local path or Hub model id
175+
lora_model.base_model.peft_config["default"].init_lora_weights = True # Don't apply LoftQ when loading again
176+
177+
lora_model.save_pretrained(lora_model_dir)
178+
print_model(lora_model, "lora_model")
179+
180+
# remove lora adapters and save the backbone
181+
unwarap_model(base_model)
182+
base_model.save_pretrained(base_model_dir)
183+
tokenizer.save_pretrained(base_model_dir)
184+
185+
print_model(base_model, "base_model")
186+
187+
return base_model_dir, lora_model_dir
188+
189+
190+
def load_loftq(base_model_path, lora_adapter_path):
191+
if any(name in base_model_path.lower() for name in ["llama", "mistral", "falcon"]):
192+
model = AutoModelForCausalLM.from_pretrained(
193+
base_model_path,
194+
device_map="auto",
195+
low_cpu_mem_usage=True,
196+
quantization_config=BitsAndBytesConfig(
197+
load_in_4bit=True,
198+
bnb_4bit_use_double_quant=False,
199+
bnb_4bit_quant_type="nf4",
200+
),
201+
)
202+
elif any(name in base_model_path.lower() for name in ["bart", "t5"]):
203+
model = AutoModelForSeq2SeqLM.from_pretrained(
204+
base_model_path,
205+
device_map="auto",
206+
low_cpu_mem_usage=True,
207+
load_in_4bit=True,
208+
quantization_config=BitsAndBytesConfig(
209+
load_in_4bit=True,
210+
bnb_4bit_use_double_quant=False,
211+
bnb_4bit_quant_type="nf4",
212+
),
213+
)
214+
elif any(name in base_model_path.lower() for name in ["deberta", "roberta", "bert"]):
215+
model = AutoModelForSequenceClassification.from_pretrained(
216+
base_model_path,
217+
low_cpu_mem_usage=True,
218+
load_in_4bit=True,
219+
quantization_config=BitsAndBytesConfig(
220+
load_in_4bit=True,
221+
bnb_4bit_use_double_quant=False,
222+
bnb_4bit_quant_type="nf4",
223+
),
224+
)
225+
else:
226+
raise NotImplementedError("Other models not supported yet.")
227+
228+
lora_model = PeftModel.from_pretrained(model, lora_adapter_path, is_trainable=True)
229+
230+
# Do training or inference below
231+
print_model(lora_model, "lora_model")
232+
print_model(model, "base_model")
233+
234+
235+
if __name__ == "__main__":
236+
base_dir, lora_dir = quantize_and_save()
237+
load_loftq(base_dir, lora_dir)
238+
239+
# example command:
240+
# python quantize_save_load.py \
241+
# --model_name_or_path meta-llama/Llama-2-7b-hf \
242+
# --token XXX \
243+
# --bits 4 --iter 5 --rank 16 \
244+
# --save_dir ./model_zoo/loftq/

0 commit comments

Comments
 (0)