|
| 1 | +# coding=utf-8 |
| 2 | +# Copyright 2023-present the HuggingFace Inc. team. |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | + |
| 16 | +import argparse |
| 17 | +import os |
| 18 | + |
| 19 | +import torch |
| 20 | +import torch.nn as nn |
| 21 | +from transformers import ( |
| 22 | + AutoModelForCausalLM, |
| 23 | + AutoModelForSeq2SeqLM, |
| 24 | + AutoModelForSequenceClassification, |
| 25 | + AutoTokenizer, |
| 26 | + BitsAndBytesConfig, |
| 27 | +) |
| 28 | + |
| 29 | +from peft import LoftQConfig, LoraConfig, PeftModel, TaskType, get_peft_model |
| 30 | + |
| 31 | + |
| 32 | +class Shell(nn.Module): |
| 33 | + def __init__(self, weight, bias=None): |
| 34 | + super().__init__() |
| 35 | + self.weight = nn.Parameter(weight, requires_grad=False) |
| 36 | + if bias is not None: |
| 37 | + self.bias = nn.Parameter(bias, requires_grad=False) |
| 38 | + |
| 39 | + |
| 40 | +def unwarap_model(model, sub_module_name=".base_layer"): |
| 41 | + sub_module_name_list = [k.split(sub_module_name)[0] for k in model.state_dict().keys() if sub_module_name in k] |
| 42 | + sub_module_name_set = set(sub_module_name_list) |
| 43 | + for name in sub_module_name_set: |
| 44 | + # get the parent of the submodule |
| 45 | + name_parent = ".".join(name.split(".")[:-1]) |
| 46 | + name_child = name.split(".")[-1] |
| 47 | + sub_module = model.get_submodule(name_parent) |
| 48 | + print(sub_module) |
| 49 | + |
| 50 | + # replace with shell |
| 51 | + child = getattr(sub_module, name_child) |
| 52 | + weight = getattr(child.base_layer, "weight", None) |
| 53 | + bias = getattr(child.base_layer, "bias", None) |
| 54 | + shell = Shell(weight, bias) |
| 55 | + |
| 56 | + setattr(sub_module, name_child, shell) |
| 57 | + |
| 58 | + print("You have unwrapped the model. Use it on your own risk.") |
| 59 | + |
| 60 | + |
| 61 | +def print_model(model, name): |
| 62 | + print("=" * 10 + name + "=" * 10) |
| 63 | + print(model) |
| 64 | + for name, param in model.named_parameters(): |
| 65 | + if torch.is_tensor(param): |
| 66 | + if param.dtype in [torch.float32, torch.float16]: |
| 67 | + print( |
| 68 | + name, |
| 69 | + param.shape, |
| 70 | + param.device, |
| 71 | + param.dtype, |
| 72 | + param.requires_grad, |
| 73 | + param.mean().item(), |
| 74 | + param.max().item(), |
| 75 | + ) |
| 76 | + else: |
| 77 | + print(name, param.shape, param.device, param.dtype, param.requires_grad) |
| 78 | + |
| 79 | + |
| 80 | +def arg_parse(): |
| 81 | + parser = argparse.ArgumentParser(description="Quantize a model with LoftQ.") |
| 82 | + parser.add_argument( |
| 83 | + "--model_name_or_path", |
| 84 | + type=str, |
| 85 | + default=None, |
| 86 | + required=True, |
| 87 | + help="The name or path of the fp32/16 model.", |
| 88 | + ) |
| 89 | + parser.add_argument( |
| 90 | + "--token", |
| 91 | + type=str, |
| 92 | + default=None, |
| 93 | + help="The access token to download model from HuggingFace Hub.", |
| 94 | + ) |
| 95 | + parser.add_argument( |
| 96 | + "--bits", |
| 97 | + type=int, |
| 98 | + default=4, |
| 99 | + help="The quantized bits", |
| 100 | + ) |
| 101 | + parser.add_argument( |
| 102 | + "--iter", |
| 103 | + type=int, |
| 104 | + default=1, |
| 105 | + help="The alternating steps in LoftQ", |
| 106 | + ) |
| 107 | + parser.add_argument( |
| 108 | + "--rank", |
| 109 | + type=int, |
| 110 | + default=16, |
| 111 | + help="The rank of the LoRA adapter", |
| 112 | + ) |
| 113 | + parser.add_argument( |
| 114 | + "--save_dir", |
| 115 | + type=str, |
| 116 | + default="./model_zoo/loftq/", |
| 117 | + help="The rank of the LoRA adapter", |
| 118 | + ) |
| 119 | + args = parser.parse_args() |
| 120 | + return args |
| 121 | + |
| 122 | + |
| 123 | +def quantize_and_save(): |
| 124 | + args = arg_parse() |
| 125 | + |
| 126 | + # Download weights and configure LoRA |
| 127 | + tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, token=args.token, trust_remote_code=True) |
| 128 | + if any(name in args.model_name_or_path.lower() for name in ["llama", "mistral", "falcon"]): |
| 129 | + model = AutoModelForCausalLM.from_pretrained( |
| 130 | + args.model_name_or_path, token=args.token, trust_remote_code=True, device_map="auto" |
| 131 | + ) |
| 132 | + task_type = TaskType.CAUSAL_LM |
| 133 | + target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "down_proj", "gate_proj"] |
| 134 | + |
| 135 | + elif any(name in args.model_name_or_path.lower() for name in ["bart", "t5"]): |
| 136 | + model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name_or_path, token=args.token, device_map="auto") |
| 137 | + task_type = TaskType.SEQ_2_SEQ_LM |
| 138 | + target_modules = ["q_proj", "k_proj", "v_proj", "fc1", "fc2", "out_proj"] |
| 139 | + |
| 140 | + elif any(name in args.model_name_or_path.lower() for name in ["deberta", "roberta", "bert"]): |
| 141 | + model = AutoModelForSequenceClassification.from_pretrained(args.model_name_or_path, token=args.token) |
| 142 | + model = model.cuda() |
| 143 | + task_type = TaskType.SEQ_CLS |
| 144 | + target_modules = ["query_proj", "key_proj", "value_proj", "dense"] # embeddings not supported by peft |
| 145 | + else: |
| 146 | + raise NotImplementedError("Other models not supported yet.") |
| 147 | + |
| 148 | + # Config of LoftQ |
| 149 | + loftq_config = LoftQConfig(loftq_bits=args.bits, loftq_iter=args.iter) |
| 150 | + |
| 151 | + lora_config = LoraConfig( |
| 152 | + task_type=task_type, |
| 153 | + inference_mode=True, |
| 154 | + r=args.rank, |
| 155 | + lora_alpha=16 if task_type is TaskType.CAUSAL_LM else args.rank, |
| 156 | + lora_dropout=0.1, |
| 157 | + target_modules=target_modules, |
| 158 | + init_lora_weights="loftq", |
| 159 | + loftq_config=loftq_config, |
| 160 | + ) |
| 161 | + |
| 162 | + # Obtain LoftQ model |
| 163 | + lora_model = get_peft_model(model, lora_config) |
| 164 | + base_model = lora_model.get_base_model() |
| 165 | + |
| 166 | + # Save LoftQ model |
| 167 | + model_name = args.model_name_or_path.split("/")[-1] + f"-{args.bits}bit" + f"-{args.rank}rank" |
| 168 | + base_model_dir = os.path.join(args.save_dir, model_name) |
| 169 | + lora_model_dir = os.path.join(args.save_dir, model_name, "loft_init") |
| 170 | + |
| 171 | + # save lora adapters first |
| 172 | + lora_model.base_model.peft_config[ |
| 173 | + "default" |
| 174 | + ].base_model_name_or_path = base_model_dir # This can be a local path or Hub model id |
| 175 | + lora_model.base_model.peft_config["default"].init_lora_weights = True # Don't apply LoftQ when loading again |
| 176 | + |
| 177 | + lora_model.save_pretrained(lora_model_dir) |
| 178 | + print_model(lora_model, "lora_model") |
| 179 | + |
| 180 | + # remove lora adapters and save the backbone |
| 181 | + unwarap_model(base_model) |
| 182 | + base_model.save_pretrained(base_model_dir) |
| 183 | + tokenizer.save_pretrained(base_model_dir) |
| 184 | + |
| 185 | + print_model(base_model, "base_model") |
| 186 | + |
| 187 | + return base_model_dir, lora_model_dir |
| 188 | + |
| 189 | + |
| 190 | +def load_loftq(base_model_path, lora_adapter_path): |
| 191 | + if any(name in base_model_path.lower() for name in ["llama", "mistral", "falcon"]): |
| 192 | + model = AutoModelForCausalLM.from_pretrained( |
| 193 | + base_model_path, |
| 194 | + device_map="auto", |
| 195 | + low_cpu_mem_usage=True, |
| 196 | + quantization_config=BitsAndBytesConfig( |
| 197 | + load_in_4bit=True, |
| 198 | + bnb_4bit_use_double_quant=False, |
| 199 | + bnb_4bit_quant_type="nf4", |
| 200 | + ), |
| 201 | + ) |
| 202 | + elif any(name in base_model_path.lower() for name in ["bart", "t5"]): |
| 203 | + model = AutoModelForSeq2SeqLM.from_pretrained( |
| 204 | + base_model_path, |
| 205 | + device_map="auto", |
| 206 | + low_cpu_mem_usage=True, |
| 207 | + load_in_4bit=True, |
| 208 | + quantization_config=BitsAndBytesConfig( |
| 209 | + load_in_4bit=True, |
| 210 | + bnb_4bit_use_double_quant=False, |
| 211 | + bnb_4bit_quant_type="nf4", |
| 212 | + ), |
| 213 | + ) |
| 214 | + elif any(name in base_model_path.lower() for name in ["deberta", "roberta", "bert"]): |
| 215 | + model = AutoModelForSequenceClassification.from_pretrained( |
| 216 | + base_model_path, |
| 217 | + low_cpu_mem_usage=True, |
| 218 | + load_in_4bit=True, |
| 219 | + quantization_config=BitsAndBytesConfig( |
| 220 | + load_in_4bit=True, |
| 221 | + bnb_4bit_use_double_quant=False, |
| 222 | + bnb_4bit_quant_type="nf4", |
| 223 | + ), |
| 224 | + ) |
| 225 | + else: |
| 226 | + raise NotImplementedError("Other models not supported yet.") |
| 227 | + |
| 228 | + lora_model = PeftModel.from_pretrained(model, lora_adapter_path, is_trainable=True) |
| 229 | + |
| 230 | + # Do training or inference below |
| 231 | + print_model(lora_model, "lora_model") |
| 232 | + print_model(model, "base_model") |
| 233 | + |
| 234 | + |
| 235 | +if __name__ == "__main__": |
| 236 | + base_dir, lora_dir = quantize_and_save() |
| 237 | + load_loftq(base_dir, lora_dir) |
| 238 | + |
| 239 | +# example command: |
| 240 | +# python quantize_save_load.py \ |
| 241 | +# --model_name_or_path meta-llama/Llama-2-7b-hf \ |
| 242 | +# --token XXX \ |
| 243 | +# --bits 4 --iter 5 --rank 16 \ |
| 244 | +# --save_dir ./model_zoo/loftq/ |
0 commit comments