-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implements the method: "Block Affine Transformation as Parameter Efficient Fine-tuning Methods for Large Language Models" described in https://arxiv.org/abs/2409.15371.
- Loading branch information
Showing
21 changed files
with
1,127 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
<!--Copyright 2024 The HuggingFace Team. All rights reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | ||
the License. You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | ||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | ||
specific language governing permissions and limitations under the License. | ||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be | ||
rendered properly in your Markdown viewer. | ||
--> | ||
|
||
# Bone | ||
|
||
Block Affine ([Bone](https://huggingface.co/papers/2409.15371)) is a new PEFT technology different from the LoRA series, reducing training resources and requiring no complex initialization. Bone not only enhances the utilization of original weight information but also emphasizes the internal connections between weights, leading to faster convergence and better data fitting. | ||
|
||
The abstract from the paper is: | ||
|
||
Low-Rank Adaptation (LoRA) has achieved remarkable training results by freezing the original weights and training only low-rank matrices, establishing itself as the predominant fine-tuning method for LLMs. In pursuit of performance closer to full-parameter training, a series of LoRA variants have emerged, such as LoRA+, PISSA, Olora, and LoRA-GA. However, these improvements complicate the initial setup of model training and increase initialization time. More importantly, they overlook the internal interactions of the original weight information. To address these issues, we introduce a novel theory, ``Weight Guide'' aimed at continuously guiding trainable matrices through the original weights during training to enhance the utilization of weight information. Based on this theory, we designed a new PEFT technique called Bone (Block Affine), which not only enhances the utilization of original weight information but also emphasizes the internal connections between weights, leading to faster convergence and better data fitting. Experimental comparisons across two different LLM architectures (LLaMA2, RWKV6) and various parameter scales demonstrate that the Bone structure can achieve rapid convergence and superior data fitting without the need for complex initialization. For example, when fine-tuning LLaMA2-7B on the MetaMathQA dataset and validating on GSM8k and math benchmarks, Bone achieved fine-tuning scores of 49.36 and 8.8, respectively, outperforming PISSA by 5.84\% and 1.96\%. | ||
|
||
## BoneConfig | ||
|
||
[[autodoc]] tuners.bone.config.BoneConfig | ||
|
||
## BoneModel | ||
|
||
[[autodoc]] tuners.bone.model.BoneModel |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
# BONE: BLOCK AFFINE TRANSFORMATION AS PARAMETER EFFICIENT FINE-TUNING METHODS FOR LARGE LANGUAGE MODELS | ||
## Introduction ([Paper](https://arxiv.org/pdf/2409.15371), [code](https://github.com/JL-er/Bone)) | ||
Low-Rank Adaptation (LoRA) has achieved remarkable training results by freezing the original weights and training only low-rank matrices, establishing itself as the predominant fine-tuning method for LLMs. In pursuit of performance closer to full-parameter training, a series of LoRA variants have emerged, such as LoRA+, PISSA, Olora, and LoRA-GA. However, these improvements complicate the initial setup of model training and increase initialization time. More importantly, they overlook the internal interactions of the original weight information. To address these issues, we introduce a novel theory, ``Weight Guide'' aimed at continuously guiding trainable matrices through the original weights during training to enhance the utilization of weight information. Based on this theory, we designed a new PEFT technique called Bone Block Affine, which not only enhances the utilization of original weight information but also emphasizes the internal connections between weights, leading to faster convergence and better data fitting. Experimental comparisons across two different LLM architectures (LLaMA2, RWKV6) and various parameter scales demonstrate that the Bone structure can achieve rapid convergence and superior data fitting without the need for complex initialization. For example, when fine-tuning LLaMA2-7B on the MetaMathQA dataset and validating on GSM8k and math benchmarks, Bone achieved fine-tuning scores of 49.36 and 8.8, respectively, outperforming PISSA by 5.84% and 1.96%. | ||
|
||
## Quick Start | ||
```python | ||
import torch | ||
from peft import LoraConfig, get_peft_model | ||
from transformers import AutoTokenizer, AutoModelForCausalLM | ||
from trl import SFTConfig, SFTTrainer | ||
from datasets import load_dataset | ||
|
||
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", torch_dtype=torch.bfloat16, device_map="auto") | ||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") | ||
tokenizer.pad_token_id = tokenizer.eos_token_id | ||
bone_config = BoneConfig( | ||
r = 64 | ||
) | ||
peft_model = get_peft_model(model, bone_config) | ||
|
||
peft_model.print_trainable_parameters() | ||
|
||
dataset = load_dataset("imdb", split="train[:1%]") | ||
|
||
training_args = SFTConfig(dataset_text_field="text", max_seq_length=128) | ||
trainer = SFTTrainer( | ||
model=peft_model, | ||
args=training_args, | ||
train_dataset=dataset, | ||
tokenizer=tokenizer, | ||
) | ||
trainer.train() | ||
peft_model.save_pretrained("bone-llama-2-7b") | ||
``` | ||
|
||
|
||
To utilize the fine-tuned Bone modules, simply run the following command: | ||
```python | ||
import torch | ||
from peft import PeftModel | ||
from transformers import AutoModelForCausalLM | ||
|
||
model = AutoModelForCausalLM.from_pretrained( | ||
"meta-llama/Llama-2-7b-hf", torch_dtype=torch.bfloat16, device_map="auto" | ||
) | ||
peft_model = PeftModel.from_pretrained(model, "bone-llama-2-7b") | ||
``` | ||
|
||
## Advanced Usage | ||
|
||
### Fine-tune | ||
```shell | ||
python bone_finetuning.py \ | ||
--base_model_name_or_path meta-llama/Llama-2-7b-hf \ | ||
--output_dir output/bone-llama-2-7b-metamath-10k \ | ||
--bits bf16 \ | ||
--data_path meta-math/MetaMathQA \ | ||
--dataset_split train[:100000] \ | ||
--dataset_field query response \ | ||
--bf16 True \ | ||
--num_train_epochs 1 \ | ||
--per_device_train_batch_size 2 \ | ||
--gradient_accumulation_steps 8 \ | ||
--save_strategy "steps" \ | ||
--save_steps 1000 \ | ||
--save_total_limit 1 \ | ||
--logging_steps 1 \ | ||
--learning_rate 2e-5 \ | ||
--weight_decay 0. \ | ||
--warmup_ratio 0.03 \ | ||
--tf32 True \ | ||
--report_to none | ||
``` | ||
|
||
|
||
|
||
# Citation | ||
```bib | ||
@misc{kang2024boneblockaffinetransformation, | ||
title={Bone: Block Affine Transformation as Parameter Efficient Fine-tuning Methods for Large Language Models}, | ||
author={Jiale Kang}, | ||
year={2024}, | ||
eprint={2409.15371}, | ||
archivePrefix={arXiv}, | ||
primaryClass={cs.CL}, | ||
url={https://arxiv.org/abs/2409.15371}, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
# Copyright 2023-present the HuggingFace Inc. team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import os | ||
from dataclasses import dataclass, field | ||
from typing import Optional | ||
|
||
import torch | ||
from datasets import load_dataset | ||
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser | ||
from trl import SFTConfig, SFTTrainer | ||
|
||
from peft import BoneConfig, get_peft_model | ||
|
||
|
||
@dataclass | ||
class ScriptArguments(SFTConfig): | ||
# model configs | ||
base_model_name_or_path: Optional[str] = field( | ||
default=None, metadata={"help": "The name or path of the fp32/16 base model."} | ||
) | ||
bits: str = field(default="bf16", metadata={"help": "(`['bf16', 'fp16', fp32]`)"}) | ||
init_bone_weights: str = field(default="False") | ||
bone_r: int = field(default=16) | ||
merge_and_save: bool = field(default=False) | ||
# dataset configs | ||
data_path: str = field(default="imdb", metadata={"help": "Path to the training data."}) | ||
dataset_split: str = field(default="train[:1%]", metadata={"help": "(`['train', 'test', 'eval']`):"}) | ||
dataset_field: list[str] = field(default=None, metadata={"help": "Fields of dataset input and output."}) | ||
|
||
|
||
parser = HfArgumentParser(ScriptArguments) | ||
script_args = parser.parse_args_into_dataclasses()[0] | ||
print(script_args) | ||
|
||
print(f"Load pre-processed residual model in {script_args.bits} bits.") | ||
if script_args.bits in ["nf4", "fp4", "int8"]: | ||
print("Bone currently does not support quantization.") | ||
|
||
elif script_args.base_model_name_or_path is not None: | ||
print(f"No available pre-processed model, manually initialize a Bone using {script_args.base_model_name_or_path}.") | ||
model = AutoModelForCausalLM.from_pretrained( | ||
script_args.base_model_name_or_path, | ||
torch_dtype=( | ||
torch.float16 | ||
if script_args.bits == "fp16" | ||
else (torch.bfloat16 if script_args.bits == "bf16" else torch.float32) | ||
), | ||
device_map="auto", | ||
) | ||
tokenizer = AutoTokenizer.from_pretrained(script_args.base_model_name_or_path) | ||
tokenizer.pad_token_id = tokenizer.eos_token_id | ||
lora_config = BoneConfig( | ||
r=script_args.bone_r, | ||
target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"], | ||
bias="none", | ||
task_type="CAUSAL_LM", | ||
) | ||
peft_model = get_peft_model(model, lora_config) | ||
|
||
print(peft_model) | ||
peft_model.print_trainable_parameters() | ||
|
||
print(f"Training Bone with trl on the {script_args.data_path}[{script_args.dataset_split}] dataset.") | ||
dataset = load_dataset(script_args.data_path, split=script_args.dataset_split) | ||
dataset = dataset.map( | ||
lambda example: { | ||
"text": f"### USER: {example[script_args.dataset_field[0]]}\n### ASSISTANT: {example[script_args.dataset_field[1]]}" | ||
} | ||
) | ||
|
||
trainer = SFTTrainer( | ||
model=peft_model, | ||
args=script_args, | ||
train_dataset=dataset, | ||
tokenizer=tokenizer, | ||
) | ||
trainer.train() | ||
trainer.save_state() | ||
|
||
peft_model.save_pretrained( | ||
os.path.join(script_args.output_dir, "bone_ft"), | ||
) | ||
|
||
if script_args.merge_and_save: | ||
model = peft_model.merge_and_unload() | ||
model.save_pretrained(os.path.join(script_args.output_dir, "bone_merged")) | ||
tokenizer.save_pretrained(os.path.join(script_args.output_dir, "bone_merged")) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
# Copyright 2024-present the HuggingFace Inc. team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from .config import BoneConfig | ||
from .layer import BoneLayer, BoneLinear | ||
from .model import BoneModel | ||
|
||
|
||
__all__ = ["BoneConfig", "BoneModel", "BoneLinear", "BoneLayer"] |
Oops, something went wrong.