Skip to content

Commit 8111699

Browse files
yfeng95zqiu24YuliangXiu
authored
BOFT: Orthogonal Finetuning via Butterfly Factorization (#1326)
Implements https://hf.co/papers/2311.06243. --------- Co-authored-by: Zeju Qiu <[email protected]> Co-authored-by: Yuliang Xiu <[email protected]> Co-authored-by: Yao Feng <[email protected]>
1 parent b0f1bb4 commit 8111699

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

56 files changed

+6650
-49
lines changed

.gitignore

+1-1
Original file line numberDiff line numberDiff line change
@@ -138,4 +138,4 @@ dmypy.json
138138
.DS_Store
139139

140140
# More test things
141-
wandb
141+
wandb

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -155,4 +155,4 @@ To use 🤗 PEFT in your publication, please cite it by using the following BibT
155155
howpublished = {\url{https://github.com/huggingface/peft}},
156156
year = {2022}
157157
}
158-
```
158+
```

docs/source/_toctree.yml

+4
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@
5757
title: Soft prompts
5858
- local: conceptual_guides/ia3
5959
title: IA3
60+
- local: conceptual_guides/oft
61+
title: OFT/BOFT
6062

6163
- sections:
6264
- sections:
@@ -90,6 +92,8 @@
9092
title: Multitask Prompt Tuning
9193
- local: package_reference/oft
9294
title: OFT
95+
- local: package_reference/boft
96+
title: BOFT
9397
- local: package_reference/poly
9498
title: Polytropon
9599
- local: package_reference/p_tuning

docs/source/conceptual_guides/adapter.md

+6
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,12 @@ LoHa uses the [Hadamard product](https://en.wikipedia.org/wiki/Hadamard_product_
7171

7272
OFT preserves the hyperspherical energy by learning an orthogonal transformation for neurons to keep the cosine similarity between them unchanged. In practice, this means taking the matrix product of an orthogonal matrix with the pretrained weight matrix. However, to be parameter-efficient, the orthogonal matrix is represented as a block-diagonal matrix with rank `r` blocks. Whereas LoRA reduces the number of trainable parameters with low-rank structures, OFT reduces the number of trainable parameters with a sparse block-diagonal matrix structure.
7373

74+
## Orthogonal Butterfly (BOFT)
75+
76+
[BOFT](https://hf.co/papers/2311.06243) is a method that primarily focuses on preserving a pretrained model's generative performance in the finetuned model. It tries to maintain the same cosine similarity (hyperspherical energy) between all pairwise neurons in a layer because this better captures the semantic information among neurons. This means OFT is more capable at preserving the subject and it is better for controllable generation (similar to [ControlNet](https://huggingface.co/docs/diffusers/using-diffusers/controlnet)).
77+
78+
OFT preserves the hyperspherical energy by learning an orthogonal transformation for neurons to keep the cosine similarity between them unchanged. In practice, this means taking the matrix product of an orthogonal matrix with the pretrained weight matrix. However, to be parameter-efficient, the orthogonal matrix is represented as a block-diagonal matrix with rank `r` blocks. Whereas LoRA reduces the number of trainable parameters with low-rank structures, OFT reduces the number of trainable parameters with a sparse block-diagonal matrix structure.
79+
7480
## Adaptive Low-Rank Adaptation (AdaLoRA)
7581

7682
[AdaLoRA](https://hf.co/papers/2303.10512) manages the parameter budget introduced from LoRA by allocating more parameters - in other words, a higher rank `r` - for important weight matrices that are better adapted for a task and pruning less important ones. The rank is controlled by a method similar to singular value decomposition (SVD). The ∆W is parameterized with two orthogonal matrices and a diagonal matrix which contains singular values. This parametrization method avoids iteratively applying SVD which is computationally expensive. Based on this method, the rank of ∆W is adjusted according to an importance score. ∆W is divided into triplets and each triplet is scored according to its contribution to model performance. Triplets with low importance scores are pruned and triplets with high importance scores are kept for finetuning.

docs/source/conceptual_guides/oft.md

+107
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
12+
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
13+
rendered properly in your Markdown viewer.
14+
15+
-->
16+
17+
# Orthogonal Finetuning (OFT and BOFT)
18+
19+
This conceptual guide gives a brief overview of [OFT](https://arxiv.org/abs/2306.07280) and [BOFT](https://arxiv.org/abs/2311.06243), a parameter-efficient fine-tuning technique that utilizes orthogonal matrix to multiplicatively transform the pretrained weight matrices.
20+
21+
To achieve efficient fine-tuning, OFT represents the weight updates with an orthogonal transformation. The orthogonal transformation is parameterized by an orthogonal matrix multiplied to the pretrained weight matrix. These new matrices can be trained to adapt to the new data while keeping the overall number of changes low. The original weight matrix remains frozen and doesn’t receive any further adjustments. To produce the final results, both the original and the adapted weights are multiplied togethor.
22+
23+
Orthogonal Butterfly (BOFT) generalizes OFT with Butterfly factorization and further improves its parameter efficiency and finetuning flexibility. In short, OFT can be viewed as a special case of BOFT. Different from LoRA that uses additive low-rank weight updates, BOFT uses multiplicative orthogonal weight updates. The comparison is shown below.
24+
25+
<div class="flex justify-center">
26+
<img src="https://github.com/wy1iu/butterfly-oft/blob/main/assets/BOFT_comparison.png"/>
27+
</div>
28+
29+
30+
BOFT has some advantages compared to LoRA:
31+
32+
* BOFT proposes a simple yet generic way to finetune pretrained models to downstream tasks, yielding a better preservation of pretraining knowledge and a better parameter efficiency.
33+
* Through the orthogonality, BOFT introduces a structural constraint, i.e., keeping the [hyperspherical energy](https://arxiv.org/abs/1805.09298) unchanged during finetuning. This can effectively reduce the forgetting of pretraining knowledge.
34+
* BOFT uses the butterfly factorization to efficiently parameterize the orthogonal matrix, which yields a compact yet expressive learning space (i.e., hypothesis class).
35+
* The sparse matrix decomposition in BOFT brings in additional inductive biases that are beneficial to generalization.
36+
37+
In principle, BOFT can be applied to any subset of weight matrices in a neural network to reduce the number of trainable parameters. Given the target layers for injecting BOFT parameters, the number of trainable parameters can be determined based on the size of the weight matrices.
38+
39+
## Merge OFT/BOFT weights into the base model
40+
41+
Similar to LoRA, the weights learned by OFT/BOFT can be integrated into the pretrained weight matrices using the merge_and_unload() function. This function merges the adapter weights with the base model which allows you to effectively use the newly merged model as a standalone model.
42+
43+
<div class="flex justify-center">
44+
<img src="https://github.com/wy1iu/butterfly-oft/blob/main/assets/boft_merge.png"/>
45+
</div>
46+
47+
This works because during training, the orthogonal weight matrix (R in the diagram above) and the pretrained weight matrices are separate. But once training is complete, these weights can actually be merged (multiplied) into a new weight matrix that is equivalent.
48+
49+
## Utils for OFT / BOFT
50+
51+
### Common OFT / BOFT parameters in PEFT
52+
53+
As with other methods supported by PEFT, to fine-tune a model using OFT or BOFT, you need to:
54+
55+
1. Instantiate a base model.
56+
2. Create a configuration (`OFTConfig` or `BOFTConfig`) where you define OFT/BOFT-specific parameters.
57+
3. Wrap the base model with `get_peft_model()` to get a trainable `PeftModel`.
58+
4. Train the `PeftModel` as you normally would train the base model.
59+
60+
61+
### BOFT-specific paramters
62+
63+
`BOFTConfig` allows you to control how OFT/BOFT is applied to the base model through the following parameters:
64+
65+
- `boft_block_size`: the BOFT matrix block size across different layers, expressed in `int`. Smaller block size results in sparser update matrices with fewer trainable paramters. **Note**, please choose `boft_block_size` to be divisible by most layer's input dimension (`in_features`), e.g., 4, 8, 16. Also, please only
66+
specify either `boft_block_size` or `boft_block_num`, but not both simultaneously or leaving both to 0, because `boft_block_size` x `boft_block_num` must equal the layer's input dimension.
67+
- `boft_block_num`: the number of BOFT matrix blocks across different layers, expressed in `int`. Fewer blocks result in sparser update matrices with fewer trainable paramters. **Note**, please choose `boft_block_num` to be divisible by most layer's input dimension (`in_features`), e.g., 4, 8, 16. Also, please only
68+
specify either `boft_block_size` or `boft_block_num`, but not both simultaneously or leaving both to 0, because `boft_block_size` x `boft_block_num` must equal the layer's input dimension.
69+
- `boft_n_butterfly_factor`: the number of butterfly factors. **Note**, for `boft_n_butterfly_factor=1`, BOFT is the same as vanilla OFT, for `boft_n_butterfly_factor=2`, the effective block size of OFT becomes twice as big and the number of blocks become half.
70+
- `bias`: specify if the `bias` parameters should be trained. Can be `"none"`, `"all"` or `"boft_only"`.
71+
- `boft_dropout`: specify the probability of multiplicative dropout.
72+
- `target_modules`: The modules (for example, attention blocks) to inject the OFT/BOFT matrices.
73+
- `modules_to_save`: List of modules apart from OFT/BOFT matrices to be set as trainable and saved in the final checkpoint. These typically include model's custom head that is randomly initialized for the fine-tuning task.
74+
75+
76+
77+
## BOFT Example Usage
78+
79+
For an example of the BOFT method application to various downstream tasks, please refer to the following guides:
80+
81+
Take a look at the following step-by-step guides on how to finetune a model with BOFT:
82+
- [Dreambooth finetuning with BOFT](../task_guides/boft_dreambooth)
83+
- [Controllable generation finetuning with BOFT (ControlNet)](../task_guides/boft_controlnet)
84+
85+
For the task of image classification, one can initialize the BOFT config for a DinoV2 model as follows:
86+
87+
```py
88+
import transformers
89+
from transformers import AutoModelForSeq2SeqLM, BOFTConfig
90+
from peft import BOFTConfig, get_peft_model
91+
92+
config = BOFTConfig(
93+
boft_block_size=4,
94+
boft_n_butterfly_factor=2,
95+
target_modules=["query", "value", "key", "output.dense", "mlp.fc1", "mlp.fc2"],
96+
boft_dropout=0.1,
97+
bias="boft_only",
98+
modules_to_save=["classifier"],
99+
)
100+
101+
model = transformers.Dinov2ForImageClassification.from_pretrained(
102+
"facebook/dinov2-large",
103+
num_labels=100,
104+
)
105+
106+
boft_model = get_peft_model(model, config)
107+
```

docs/source/package_reference/boft.md

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
12+
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
13+
rendered properly in your Markdown viewer.
14+
15+
-->
16+
17+
# BOFT
18+
19+
[Orthogonal Butterfly (BOFT)](https://hf.co/papers/2311.06243) is a generic method designed for finetuning foundation models. It improves the paramter efficiency of the finetuning paradigm -- Orthogonal Finetuning (OFT), by taking inspiration from Cooley-Tukey fast Fourier transform, showing favorable results across finetuning different foundation models, including large vision transformers, large language models and text-to-image diffusion models.
20+
21+
The abstract from the paper is:
22+
23+
*Large foundation models are becoming ubiquitous, but training them from scratch is prohibitively expensive. Thus, efficiently adapting these powerful models to downstream tasks is increasingly important. In this paper, we study a principled finetuning paradigm -- Orthogonal Finetuning (OFT) -- for downstream task adaptation. Despite demonstrating good generalizability, OFT still uses a fairly large number of trainable parameters due to the high dimensionality of orthogonal matrices. To address this, we start by examining OFT from an information transmission perspective, and then identify a few key desiderata that enable better parameter-efficiency. Inspired by how the Cooley-Tukey fast Fourier transform algorithm enables efficient information transmission, we propose an efficient orthogonal parameterization using butterfly structures. We apply this parameterization to OFT, creating a novel parameter-efficient finetuning method, called Orthogonal Butterfly (BOFT). By subsuming OFT as a special case, BOFT introduces a generalized orthogonal finetuning framework. Finally, we conduct an extensive empirical study of adapting large vision transformers, large language models, and text-to-image diffusion models to various downstream tasks in vision and language*.
24+
25+
## BOFTConfig
26+
27+
[[autodoc]] tuners.boft.config.BOFTConfig
28+
29+
## BOFTModel
30+
31+
[[autodoc]] tuners.boft.model.BOFTModel

examples/boft_controlnet/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)