Skip to content

Commit

Permalink
FEAT Add Bone method (#2172)
Browse files Browse the repository at this point in the history
Implements the method: "Block Affine Transformation as Parameter
Efficient Fine-tuning Methods for Large Language Models" described in
https://arxiv.org/abs/2409.15371.
  • Loading branch information
JL-er authored Nov 5, 2024
1 parent 7295b33 commit 13fb29f
Show file tree
Hide file tree
Showing 21 changed files with 1,127 additions and 8 deletions.
2 changes: 2 additions & 0 deletions docs/source/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@
title: VB-LoRA
- local: package_reference/hra
title: HRA
- local: package_reference/bone
title: Bone

title: Adapters
- sections:
Expand Down
11 changes: 10 additions & 1 deletion docs/source/conceptual_guides/adapter.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,4 +117,13 @@ To avoid adding noise to the tokens, the adapter uses zero-initialized attention

HRA constructs a chain of `r` trainable Householder reflections (HRs). Because the Householder reflection matrix is an orthogonal matrix and the product of orthogonal matrices is also an orthogonal matrix, HRA satisfies the theoretical guarantee of Orthogonal Finetuning (OFT). Meanwhile, HRA can also be viewed as an low-rank fine-tuning adapter by rewriting formula.

The higher `r`, the more trainable parameters, resulting in a larger model capacity and better performance. Besides, due to the chain structure, the orthogonality of HR planes impacts the capacity and regularity of HRA. To achieve a trade-off between the model capacity and regularity, an orthogonality regularizer of the HR planes is added to the loss function. The weight \\(\lambda\\) can control the strength of the regularizer.
The higher `r`, the more trainable parameters, resulting in a larger model capacity and better performance. Besides, due to the chain structure, the orthogonality of HR planes impacts the capacity and regularity of HRA. To achieve a trade-off between the model capacity and regularity, an orthogonality regularizer of the HR planes is added to the loss function. The weight \\(\lambda\\) can control the strength of the regularizer.

## Bone
[Bone](https://huggingface.co/papers/2409.15371) is a new PEFT technology different from the LoRA series, reducing training resources and requiring no complex initialization. Bone not only enhances the utilization of original weight information but also emphasizes the internal connections between weights, leading to faster convergence and better data fitting.

<small><a href="https://huggingface.co/papers/2409.15371">Bone: Block Affine Transformation as Parameter Efficient Fine-tuning Methods for Large Language Models</a></small>

Bone reduces the number of trainable parameters by using a block grouping method, and the block computation of weight W effectively promotes information exchange in the original weights, enhancing data fitting capability during fine-tuning. The experiment mentions controlling the size of trainable parameters through b (block size), similar to r (rank) in LoRA. For consistency within PEFT, we also name b as r. Note: Bone's r (b) is special and requires that weight W satisfies the conditions `in_features % r == 0` and `out_features % r == 0`. Additionally, when `in_features == out_features` and Bone-r equals LoRA-r, Bone's number of trainable parameters is only half that of LoRA.

From the experiments in the paper, it is evident that Bone, with only half the parameters of LoRA, can outperform LoRA variants across various metrics. However, Bone currently has some issues: it is slower than LoRA and requires checkpointing to address excessive memory usage due to intermediate values, which further reduces training speed. We plan to address this in the future. Contributions are welcome.
31 changes: 31 additions & 0 deletions docs/source/package_reference/bone.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->

# Bone

Block Affine ([Bone](https://huggingface.co/papers/2409.15371)) is a new PEFT technology different from the LoRA series, reducing training resources and requiring no complex initialization. Bone not only enhances the utilization of original weight information but also emphasizes the internal connections between weights, leading to faster convergence and better data fitting.

The abstract from the paper is:

Low-Rank Adaptation (LoRA) has achieved remarkable training results by freezing the original weights and training only low-rank matrices, establishing itself as the predominant fine-tuning method for LLMs. In pursuit of performance closer to full-parameter training, a series of LoRA variants have emerged, such as LoRA+, PISSA, Olora, and LoRA-GA. However, these improvements complicate the initial setup of model training and increase initialization time. More importantly, they overlook the internal interactions of the original weight information. To address these issues, we introduce a novel theory, ``Weight Guide'' aimed at continuously guiding trainable matrices through the original weights during training to enhance the utilization of weight information. Based on this theory, we designed a new PEFT technique called Bone (Block Affine), which not only enhances the utilization of original weight information but also emphasizes the internal connections between weights, leading to faster convergence and better data fitting. Experimental comparisons across two different LLM architectures (LLaMA2, RWKV6) and various parameter scales demonstrate that the Bone structure can achieve rapid convergence and superior data fitting without the need for complex initialization. For example, when fine-tuning LLaMA2-7B on the MetaMathQA dataset and validating on GSM8k and math benchmarks, Bone achieved fine-tuning scores of 49.36 and 8.8, respectively, outperforming PISSA by 5.84\% and 1.96\%.

## BoneConfig

[[autodoc]] tuners.bone.config.BoneConfig

## BoneModel

[[autodoc]] tuners.bone.model.BoneModel
87 changes: 87 additions & 0 deletions examples/bone_finetuning/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# BONE: BLOCK AFFINE TRANSFORMATION AS PARAMETER EFFICIENT FINE-TUNING METHODS FOR LARGE LANGUAGE MODELS
## Introduction ([Paper](https://arxiv.org/pdf/2409.15371), [code](https://github.com/JL-er/Bone))
Low-Rank Adaptation (LoRA) has achieved remarkable training results by freezing the original weights and training only low-rank matrices, establishing itself as the predominant fine-tuning method for LLMs. In pursuit of performance closer to full-parameter training, a series of LoRA variants have emerged, such as LoRA+, PISSA, Olora, and LoRA-GA. However, these improvements complicate the initial setup of model training and increase initialization time. More importantly, they overlook the internal interactions of the original weight information. To address these issues, we introduce a novel theory, ``Weight Guide'' aimed at continuously guiding trainable matrices through the original weights during training to enhance the utilization of weight information. Based on this theory, we designed a new PEFT technique called Bone Block Affine, which not only enhances the utilization of original weight information but also emphasizes the internal connections between weights, leading to faster convergence and better data fitting. Experimental comparisons across two different LLM architectures (LLaMA2, RWKV6) and various parameter scales demonstrate that the Bone structure can achieve rapid convergence and superior data fitting without the need for complex initialization. For example, when fine-tuning LLaMA2-7B on the MetaMathQA dataset and validating on GSM8k and math benchmarks, Bone achieved fine-tuning scores of 49.36 and 8.8, respectively, outperforming PISSA by 5.84% and 1.96%.

## Quick Start
```python
import torch
from peft import LoraConfig, get_peft_model
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import SFTConfig, SFTTrainer
from datasets import load_dataset

model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", torch_dtype=torch.bfloat16, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
tokenizer.pad_token_id = tokenizer.eos_token_id
bone_config = BoneConfig(
r = 64
)
peft_model = get_peft_model(model, bone_config)

peft_model.print_trainable_parameters()

dataset = load_dataset("imdb", split="train[:1%]")

training_args = SFTConfig(dataset_text_field="text", max_seq_length=128)
trainer = SFTTrainer(
model=peft_model,
args=training_args,
train_dataset=dataset,
tokenizer=tokenizer,
)
trainer.train()
peft_model.save_pretrained("bone-llama-2-7b")
```


To utilize the fine-tuned Bone modules, simply run the following command:
```python
import torch
from peft import PeftModel
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf", torch_dtype=torch.bfloat16, device_map="auto"
)
peft_model = PeftModel.from_pretrained(model, "bone-llama-2-7b")
```

## Advanced Usage

### Fine-tune
```shell
python bone_finetuning.py \
--base_model_name_or_path meta-llama/Llama-2-7b-hf \
--output_dir output/bone-llama-2-7b-metamath-10k \
--bits bf16 \
--data_path meta-math/MetaMathQA \
--dataset_split train[:100000] \
--dataset_field query response \
--bf16 True \
--num_train_epochs 1 \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 8 \
--save_strategy "steps" \
--save_steps 1000 \
--save_total_limit 1 \
--logging_steps 1 \
--learning_rate 2e-5 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--tf32 True \
--report_to none
```



# Citation
```bib
@misc{kang2024boneblockaffinetransformation,
title={Bone: Block Affine Transformation as Parameter Efficient Fine-tuning Methods for Large Language Models},
author={Jiale Kang},
year={2024},
eprint={2409.15371},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2409.15371},
}
99 changes: 99 additions & 0 deletions examples/bone_finetuning/bone_finetuning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from dataclasses import dataclass, field
from typing import Optional

import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser
from trl import SFTConfig, SFTTrainer

from peft import BoneConfig, get_peft_model


@dataclass
class ScriptArguments(SFTConfig):
# model configs
base_model_name_or_path: Optional[str] = field(
default=None, metadata={"help": "The name or path of the fp32/16 base model."}
)
bits: str = field(default="bf16", metadata={"help": "(`['bf16', 'fp16', fp32]`)"})
init_bone_weights: str = field(default="False")
bone_r: int = field(default=16)
merge_and_save: bool = field(default=False)
# dataset configs
data_path: str = field(default="imdb", metadata={"help": "Path to the training data."})
dataset_split: str = field(default="train[:1%]", metadata={"help": "(`['train', 'test', 'eval']`):"})
dataset_field: list[str] = field(default=None, metadata={"help": "Fields of dataset input and output."})


parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses()[0]
print(script_args)

print(f"Load pre-processed residual model in {script_args.bits} bits.")
if script_args.bits in ["nf4", "fp4", "int8"]:
print("Bone currently does not support quantization.")

elif script_args.base_model_name_or_path is not None:
print(f"No available pre-processed model, manually initialize a Bone using {script_args.base_model_name_or_path}.")
model = AutoModelForCausalLM.from_pretrained(
script_args.base_model_name_or_path,
torch_dtype=(
torch.float16
if script_args.bits == "fp16"
else (torch.bfloat16 if script_args.bits == "bf16" else torch.float32)
),
device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained(script_args.base_model_name_or_path)
tokenizer.pad_token_id = tokenizer.eos_token_id
lora_config = BoneConfig(
r=script_args.bone_r,
target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
bias="none",
task_type="CAUSAL_LM",
)
peft_model = get_peft_model(model, lora_config)

print(peft_model)
peft_model.print_trainable_parameters()

print(f"Training Bone with trl on the {script_args.data_path}[{script_args.dataset_split}] dataset.")
dataset = load_dataset(script_args.data_path, split=script_args.dataset_split)
dataset = dataset.map(
lambda example: {
"text": f"### USER: {example[script_args.dataset_field[0]]}\n### ASSISTANT: {example[script_args.dataset_field[1]]}"
}
)

trainer = SFTTrainer(
model=peft_model,
args=script_args,
train_dataset=dataset,
tokenizer=tokenizer,
)
trainer.train()
trainer.save_state()

peft_model.save_pretrained(
os.path.join(script_args.output_dir, "bone_ft"),
)

if script_args.merge_and_save:
model = peft_model.merge_and_unload()
model.save_pretrained(os.path.join(script_args.output_dir, "bone_merged"))
tokenizer.save_pretrained(os.path.join(script_args.output_dir, "bone_merged"))
2 changes: 2 additions & 0 deletions src/peft/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@
HRAConfig,
HRAModel,
VBLoRAConfig,
BoneConfig,
BoneModel,
)
from .utils import (
TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING,
Expand Down
4 changes: 4 additions & 0 deletions src/peft/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
AdaptionPromptConfig,
BOFTConfig,
BOFTModel,
BoneConfig,
BoneModel,
FourierFTConfig,
FourierFTModel,
HRAConfig,
Expand Down Expand Up @@ -104,6 +106,7 @@
"XLORA": XLoraConfig,
"HRA": HRAConfig,
"VBLORA": VBLoRAConfig,
"BONE": BoneConfig,
}

PEFT_TYPE_TO_TUNER_MAPPING: dict[str, type[BaseTuner]] = {
Expand All @@ -121,6 +124,7 @@
"XLORA": XLoraModel,
"HRA": HRAModel,
"VBLORA": VBLoRAModel,
"BONE": BoneModel,
}


Expand Down
2 changes: 2 additions & 0 deletions src/peft/peft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
AdaLoraModel,
AdaptionPromptModel,
BOFTModel,
BoneModel,
FourierFTModel,
HRAModel,
IA3Model,
Expand Down Expand Up @@ -104,6 +105,7 @@
PeftType.XLORA: XLoraModel,
PeftType.HRA: HRAModel,
PeftType.VBLORA: VBLoRAModel,
PeftType.BONE: BoneModel,
}


Expand Down
1 change: 1 addition & 0 deletions src/peft/tuners/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,4 @@
from .xlora import XLoraConfig, XLoraModel
from .hra import HRAConfig, HRAModel
from .vblora import VBLoRAConfig, VBLoRAModel
from .bone import BoneConfig, BoneModel
20 changes: 20 additions & 0 deletions src/peft/tuners/bone/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright 2024-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .config import BoneConfig
from .layer import BoneLayer, BoneLinear
from .model import BoneModel


__all__ = ["BoneConfig", "BoneModel", "BoneLinear", "BoneLayer"]
Loading

0 comments on commit 13fb29f

Please sign in to comment.