Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] New LoRA Initialization Method: Explained Variance Adaptation #2142

Merged
merged 65 commits into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from 48 commits
Commits
Show all changes
65 commits
Select commit Hold shift + click to select a range
ec63093
initial commit
lukash-jku Oct 3, 2024
dc26e7e
add target modules based on eva_state_dict
lukash-jku Oct 5, 2024
70171c8
remove default args
lukash-jku Oct 5, 2024
f537b8e
cleanup
lukash-jku Oct 5, 2024
a3ae4ba
set correct prefix for state dict loading
lukash-jku Oct 6, 2024
dd5a38b
revert
lukash-jku Oct 7, 2024
2cfdf0c
update docstrings, minor changes
sirluk Oct 8, 2024
f3f89c6
update docstrings, integrate peft changes
sirluk Oct 8, 2024
b71b422
fix issues related to config having either type PeftConfig or dict
sirluk Oct 8, 2024
571eca8
move state_dict modification to function
sirluk Oct 8, 2024
bec1f80
update paper link
sirluk Oct 10, 2024
fc34367
remove comments
sirluk Oct 10, 2024
0e8d29c
add documentation
sirluk Oct 10, 2024
c41e820
add docstrings
sirluk Oct 10, 2024
eff0c9d
update
sirluk Oct 15, 2024
1c91272
fix docs and add default evaconfig
sirluk Oct 15, 2024
5779971
add eva init to peft namespace
sirluk Oct 19, 2024
b5968d2
add test for lowrank argument
sirluk Oct 19, 2024
a397418
add check
sirluk Oct 20, 2024
57b72b4
simplify arguments
sirluk Oct 20, 2024
073f0a1
update docstrings
sirluk Oct 21, 2024
6d0e0d7
optimize indices calc
sirluk Oct 21, 2024
3574109
Update src/peft/tuners/lora/eva.py
sirluk Oct 21, 2024
91e87e6
add warning
sirluk Oct 21, 2024
a133e8a
update
sirluk Oct 21, 2024
523f1cd
add check if all hooks have been removed
sirluk Oct 21, 2024
261c81e
extend documentation
sirluk Oct 21, 2024
566ebf0
make style
sirluk Oct 21, 2024
e940efa
add tests for eva
sirluk Oct 22, 2024
ba82bd5
add licence notice
sirluk Oct 22, 2024
202f933
add tests for lora_config with eva
sirluk Oct 22, 2024
c6a5fc5
update
sirluk Oct 25, 2024
413be29
fix tau range
sirluk Oct 25, 2024
5ad5f00
Merge branch 'main' into main
sirluk Oct 25, 2024
b1bcf02
update tau tests
sirluk Oct 25, 2024
e890839
add validity checks to initialize_lora_eva_weights
sirluk Oct 26, 2024
ebb0ac6
style
sirluk Oct 26, 2024
81fbc28
extend documentation and small fixes
sirluk Oct 27, 2024
62ae35c
improve customization options
sirluk Oct 28, 2024
0b316d3
Merge pull request #3 from sirluk/simplify_entrypoints
sirluk Oct 28, 2024
7c753b6
update documentation
sirluk Oct 29, 2024
404022f
update docs
sirluk Oct 29, 2024
724fd1c
error to warning
sirluk Oct 29, 2024
d79672b
make style
sirluk Oct 29, 2024
ad646da
fix type
sirluk Oct 29, 2024
5187466
fix potential issues
sirluk Oct 30, 2024
b6b7b7b
add option to adjust alpha after redistribution
sirluk Oct 31, 2024
13ffc56
Merge pull request #4 from sirluk/alpha_pattern
sirluk Oct 31, 2024
17f5bf1
Update src/peft/tuners/lora/eva.py
sirluk Oct 31, 2024
1f03a95
fix edge cases
sirluk Nov 1, 2024
f74b044
Merge pull request #5 from sirluk/alpha_pattern
sirluk Nov 1, 2024
12d497f
Merge branch 'main' into main
sirluk Nov 1, 2024
e609969
account for layer pattern
sirluk Nov 3, 2024
7e4505f
split up print
sirluk Nov 3, 2024
bec670f
fix rank_budget in case of rank_pattern
sirluk Nov 4, 2024
b13d014
Merge branch 'huggingface:main' into main
sirluk Nov 4, 2024
b24500c
small fixes
sirluk Nov 4, 2024
3119dc4
missing return statement
sirluk Nov 4, 2024
ec40897
Merge branch 'main' into main
sirluk Nov 5, 2024
05f99ba
move dataloader none check
sirluk Nov 5, 2024
c628afe
adjust default value
sirluk Nov 5, 2024
6989662
update test threshold
sirluk Nov 5, 2024
5982139
update docs
sirluk Nov 6, 2024
b5c6b8f
remove speed test and update docs
sirluk Nov 9, 2024
2c6fb37
fix typo
sirluk Nov 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions docs/source/developer_guides/lora.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,37 @@ from peft import LoraConfig
config = LoraConfig(init_lora_weights="olora", ...)
```
For more advanced usage, please refer to our [documentation](https://github.com/huggingface/peft/tree/main/examples/olora_finetuning).

### EVA
[EVA](https://arxiv.org/pdf/2410.07170) performs SVD on the input activations of each layer and uses the right-singular vectors to initialize LoRA weights. It therefore is a data-driven initialization scheme. Furthermore EVA adaptively allocates ranks across layers based on their "explained variance ratio" - a metric derived from the SVD analysis.

You can use EVA by setting `init_lora_weights="eva"` and defining [`EvaConfig`] in [`LoraConfig`]:
```python
from peft import LoraConfig, EvaConfig
peft_config = LoraConfig(
init_lora_weights = "eva",
eva_config = EvaConfig(rho = 2.0),
...
)
```
`rho` controls the degree of redistribution possible (>= 1.0). For `r=16` and `rho=1.0`, it means at most 16 ranks can be used, meaning no redistribution is possible.

It is recommended to run the SVD computation on a GPU as it is much faster. To optimize the amount of available memory for EVA, you can use the `low_cpu_mem_usage` flag in [`get_peft_model`]:
```python
peft_model = get_peft_model(model, peft_config, low_cpu_mem_usage=True)
sirluk marked this conversation as resolved.
Show resolved Hide resolved
```
Then, call [`initialize_lora_eva_weights`] to initialize the EVA weights (in most cases the dataloader used for eva initialization can be the same as the one used for finetuning):
```python
initialize_lora_eva_weights(peft_model, dataloader)
```
EVA works out of the box with bitsandbytes. Simply initialize the model with `quantization_config` and call [`initialize_lora_eva_weights`] as usual.

<Tip>

For further instructions on using EVA, please refer to our [documentation](https://github.com/huggingface/peft/tree/main/examples/eva_finetuning).

</Tip>

### LoftQ

#### Standard approach
Expand Down
20 changes: 20 additions & 0 deletions docs/source/package_reference/lora.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,24 @@ The abstract from the paper is:

## Utility

### LoftQ

[[autodoc]] utils.loftq_utils.replace_lora_weights_loftq

### Eva

#### EvaConfig

[[autodoc]] tuners.lora.config.EvaConfig

#### initialize_lora_eva_weights

[[autodoc]] tuners.lora.eva.initialize_lora_eva_weights

#### get_eva_state_dict

[[autodoc]] tuners.lora.eva.get_eva_state_dict

#### load_eva_state_dict

[[autodoc]] tuners.lora.eva.load_eva_state_dict
155 changes: 155 additions & 0 deletions examples/eva_finetuning/README.md
sirluk marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
# EVA: Explained Variance Adaptation
## Introduction ([Paper](https://arxiv.org/abs/2410.07170), [code](https://github.com/ml-jku/EVA))
Explained Variance Adaptation (EVA) is a novel initialization method for LoRA style adapters which initializes adapter weights in a data driven manner and adaptively allocates ranks according to the variance they explain. EVA improves average performance on a multitude of tasks across various domains, such as Language generation and understanding, Image classification, and Decision Making.

The abstract from the paper is:

*Foundation models (FMs) are pre-trained on large-scale datasets and then fine-tuned on a downstream task for a specific application. The most successful and most commonly used fine-tuning method is to update the pre-trained weights via a low-rank adaptation (LoRA). LoRA introduces new weight matrices that are usually initialized at random with a uniform rank distribution across model weights. Recent works focus on weight-driven initialization or learning of adaptive ranks during training. Both approaches have only been investigated in isolation, resulting in slow convergence or a uniform rank distribution, in turn leading to sub-optimal performance. We propose to enhance LoRA by initializing the new weights in a data-driven manner by computing singular value decomposition on minibatches of activation vectors. Then, we initialize the LoRA matrices with the obtained right-singular vectors and re-distribute ranks among all weight matrices to explain the maximal amount of variance and continue the standard LoRA fine-tuning procedure. This results in our new method **E**xplained **V**ariance **A**daptation (EVA). We apply EVA to a variety of fine-tuning tasks ranging from language generation and understanding to image classification and reinforcement learning. EVA exhibits faster convergence than competitors and attains the highest average score across a multitude of tasks per domain.*

## Quick Start
Below is an example of how to use EVA with a causal language model. For a more detailed example see [eva_finetuning.py](https://github.com/huggingface/peft/blob/main/examples/eva_finetuning/eva_finetuning.py).
```python
import torch
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer

from peft import EvaConfig, LoraConfig, get_peft_model, initialize_lora_eva_weights


# config
model_name = "meta-llama/Llama-3.1-8B"
max_seq_len = 512
rank = 16
alpha = 1
rho = 2.0
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"]
svd_batch_size = 4 # can be different from the batch size used in finetuning

# load model and tokenizer
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

# load dataset
dataset = load_dataset("Rowan/hellaswag")
dataset = dataset.map(
lambda x: tokenizer(x["ctx"], padding="max_length", truncation=True, max_length=max_seq_len),
batched=True,
remove_columns=dataset["train"].column_names,
)
dataset.set_format(type="torch")

# create dataloader for SVD
# typically this is the same as the dataloader used for finetuning
dataloader = DataLoader(
dataset["train"],
batch_size=svd_batch_size,
collate_fn=lambda examples: {k: torch.stack([v[k] for v in examples], dim=0) for k in examples[0].keys()},
)

# setup peft config
eva_config = EvaConfig(
rho=rho
)
peft_config = LoraConfig(
r=rank,
lora_alpha=alpha,
target_modules=target_modules,
init_lora_weights="eva",
eva_config=eva_config
)

# move model to GPU
model = model.cuda()

# to optimize memory usage during EVA initialization, set low_cpu_mem_usage=True
peft_model = get_peft_model(model, peft_config, low_cpu_mem_usage=True)

initialize_lora_eva_weights(peft_model, dataloader)
```
`initialize_lora_eva_weights` will compute the SVD and load the components into the model. After this continue with standard LoRA finetuning.

## Using EVA with Bitsandbytes
EVA is fully compatible with bitsandbytes. Simply initialize the pretrained model with a BitsAndBytesConfig and then use the peft model with EVA.
```python
from transformers import BitsAndBytesConfig
from peft import prepare_model_for_kbit_training

model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.1-8B",
quantization_config=BitsAndBytesConfig(load_in_4bit=True)
)
model = prepare_model_for_kbit_training(model)
peft_model = get_peft_model(model, peft_config)
initialize_lora_eva_weights(peft_model, dataloader)
```

## Getting the EVA state_dict without loading the adapter weights
In some cases you might just want to get the state_dict after EVA initialization without loading the adapter weights. This can be useful for example if:
- you want to precompute and store the state_dict for different downstream tasks.
- you need to quantize the model for finetuning but want to perform EVA initialization with model weights in full/half precision.
- you do not intend to use a peft model for LoRA finetuning.

You can do this by calling `get_eva_state_dict` directly:
```python
from peft import get_eva_state_dict

eva_state_dict = get_eva_state_dict(model, peft_config, dataloader)
```
Later you can load the state_dict into a model without adapter weights by calling `load_eva_state_dict`:
```python
from peft import load_eva_state_dict

load_eva_state_dict(model, eva_state_dict)
```

## Customizing EVA

By default, EVA is designed to work with standard transformer language models. However we integrated three different paramters which can be used to customize EVA for other types of models.
1. `forward_fn`: Defines how the forward pass during EVA initialization should be computed.
2. `prepare_model_inputs_fn`: Can be used if it is necessary to use information contained in the original model_input to prepare the input for SVD in individual layers.
3. `prepare_layer_inputs_fn`: Defines how layer inputs should be prepared for SVD.

All three parameters can be passed to `initialize_lora_eva_weights` and `get_eva_state_dict`.

### forward_fn

`forward_fn` defines how the forward pass during EVA initialization should be computed. `forward_fn` receives two arguments: `model` and `inputs`. By default this is set to `forward_fn_dict` which simply returns `model(**inputs)`.

### prepare_model_inputs_fn

`prepare_model_inputs_fn` can be used if it is necessary to use information contained in the original model_input to prepare the input for SVD in individual layers. `prepare_model_inputs_fn` receives two arguments: `model_input` and `peft_config`. This component is separate from `prepare_layer_inputs_fn` as the output only needs to be computed once per batch. By default this parameter is set to `prepare_model_inputs_fn_language_modeling` which is used get a subset of indices based on attention and label mask to avoid including padding tokens in the SVD computation. If you would like to not use this component set `prepare_model_inputs_fn` to None. The default logic is:
```python
def prepare_model_inputs_fn_language_modeling(model_input, peft_config: LoraConfig):
mask = model_input.get("attention_mask", torch.ones_like(model_input["input_ids"])).bool()
if peft_config.eva_config.use_label_mask and hasattr(model_input, "labels"):
mask = torch.logical_and(mask, model_input["labels"] != peft_config.eva_config.label_mask_value)
return mask.nonzero()
```

### prepare_layer_inputs_fn

`prepare_layer_inputs_fn` can be used to preprocess the layer inputs before passing them to the SVD algorithm. `prepare_layer_inputs_fn` receives three arguments: `layer_input`, `model_input` and `layer_name`. It can either be a callable or a dictionary where the keys are the layer names and the values are callables. If it is a dictionary, functions are assigned to adapter layers based on the layer names. By default a language modeling setting is assumed where model_inputs are the outputs of `prepare_model_inputs_fn_language_modeling` which is a mask of indices. If this parameter is set to None, only two modifications are made to the layer inputs
- take the first element incase of a tuple or list.
- if the input has more than 2 dimensions, we flatten all but the last dimension.

Must always return a tensor. The default logic is:
```python
def prepare_layer_inputs_fn_default(layer_input, model_input, layer_name) -> torch.Tensor:
if isinstance(layer_input, (tuple, list)):
layer_input = layer_input[0]
return layer_input[model_input.T.unbind()]
```

## Citation
In case you find our work useful, please consider citing it.

```
@article{paischer2024eva,
sirluk marked this conversation as resolved.
Show resolved Hide resolved
title={One Initialization to Rule them All: Fine-tuning via Explained Variance Adaptation},
author={Fabian Paischer, Lukas Hauzenberger, Thomas Schmied, Benedikt Alkin, Marc Peter Deisenroth, Sepp Hochreiter},
journal={arXiv preprint arXiv:2410.07170},
year={2024}
}
```
87 changes: 87 additions & 0 deletions examples/eva_finetuning/eva_finetuning.py
sirluk marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Copyright 2024-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
from utils import DataCollator, TokenizerMetaMath

from peft import EvaConfig, LoraConfig, get_peft_model, initialize_lora_eva_weights


# config
model_name = "meta-llama/Llama-3.1-8B"
max_seq_len = 512
rank = 16
alpha = 1
rho = 2.0
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"]
svd_batch_size = 4 # can be different from the batch size used in finetuning
batch_size = 4
num_epochs = 1
output_dir = "outputs"
device = "cuda:0"

# load model and tokenizer
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# load dataset
dataset = load_dataset("meta-math/MetaMathQA")
dataset = dataset.map(
TokenizerMetaMath(model_name),
batched=True,
remove_columns=dataset["train"].column_names,
)
dataset.set_format(type="torch")

# data collator
data_collator = DataCollator(tokenizer.eos_token_id, max_length=max_seq_len)

# dataloader
dataloader = DataLoader(
dataset["train"],
batch_size=svd_batch_size,
collate_fn=data_collator,
)

# setup peft config
eva_config = EvaConfig(rho=rho)
peft_config = LoraConfig(
r=rank, lora_alpha=alpha, target_modules=target_modules, init_lora_weights="eva", eva_config=eva_config
)

# move model to GPU
model = model.to(device)

# to optimize memory usage during eva initialization, set low_cpu_mem_usage=True
peft_model = get_peft_model(model, peft_config, low_cpu_mem_usage=True)
initialize_lora_eva_weights(peft_model, dataloader)

# setup training arguments
training_args = TrainingArguments(
per_device_train_batch_size=batch_size,
num_train_epochs=num_epochs,
output_dir=output_dir,
remove_unused_columns=False,
)

# continue with standard finetuning
trainer = Trainer(
model=peft_model,
args=training_args,
train_dataset=dataset["train"],
data_collator=data_collator,
)
trainer.train()
76 changes: 76 additions & 0 deletions examples/eva_finetuning/utils.py
sirluk marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Copyright 2024-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from transformers import AutoTokenizer


class TokenizerMetaMath:
PROMPT_NO_INPUT = (
"Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{query}\n\n### Response: "
)
PROMPT = (
"Below is an instruction that describes a task, paired with an input that provides further context. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{query}\n\n### Input:\n{input}\n\n### Response: "
)

def format_prompt(self, query):
query = query.split("\n", 1)
if len(query) == 1 or query[1].strip("\n") == "":
return self.PROMPT_NO_INPUT.format(query=query[0])
else:
return self.PROMPT.format(query=query[0], input=query[1])

def __init__(self, tokenizer_path):
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)

def __call__(self, examples):
prompts = [self.format_prompt(text) for text in examples["query"]]
completions = examples["response"]
return self._tokenize_fn(prompts, completions)

def _tokenize_fn(self, prompts, completions):
prompt_tokens = self.tokenizer(prompts, add_special_tokens=False)["input_ids"]
input_tokens = self.tokenizer([x + y for x, y in zip(prompts, completions)], add_special_tokens=False)[
"input_ids"
]
input_tokens = [[self.tokenizer.bos_token_id] + x + [self.tokenizer.eos_token_id] for x in input_tokens]
prompt_length = [len(x) + 1 for x in prompt_tokens] # +1 for the bos token
input_length = [len(x) for x in input_tokens]
return {"input_ids": input_tokens, "prompt_length": prompt_length, "input_length": input_length}


class DataCollator:
def __init__(self, eos_token_id, max_length=None):
self.eos_token_id = eos_token_id
self.max_length = max_length

def __call__(self, batch):
batch = {k: [item[k] for item in batch] for k in batch[0]}
input_lengths = torch.stack(batch["input_length"])
prompt_lengths = torch.stack(batch["prompt_length"])
input_ids = torch.nn.utils.rnn.pad_sequence(
batch["input_ids"], batch_first=True, padding_value=self.eos_token_id
)
col_indices = torch.arange(input_ids.size(1)).unsqueeze(0)
attention_mask = col_indices < input_lengths.unsqueeze(1)
label_mask = torch.logical_or(col_indices < prompt_lengths.unsqueeze(1), ~attention_mask)
labels = input_ids.masked_fill(label_mask, -100)
if self.max_length is not None:
input_ids = input_ids[:, : self.max_length]
attention_mask = attention_mask[:, : self.max_length]
labels = labels[:, : self.max_length]
return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
Loading