diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index fe66a3d6c4..d76e523f17 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -118,6 +118,8 @@ title: VB-LoRA - local: package_reference/hra title: HRA + - local: package_reference/bone + title: Bone title: Adapters - sections: diff --git a/docs/source/conceptual_guides/adapter.md b/docs/source/conceptual_guides/adapter.md index d72248c47b..978f1ada7e 100644 --- a/docs/source/conceptual_guides/adapter.md +++ b/docs/source/conceptual_guides/adapter.md @@ -117,4 +117,13 @@ To avoid adding noise to the tokens, the adapter uses zero-initialized attention HRA constructs a chain of `r` trainable Householder reflections (HRs). Because the Householder reflection matrix is an orthogonal matrix and the product of orthogonal matrices is also an orthogonal matrix, HRA satisfies the theoretical guarantee of Orthogonal Finetuning (OFT). Meanwhile, HRA can also be viewed as an low-rank fine-tuning adapter by rewriting formula. -The higher `r`, the more trainable parameters, resulting in a larger model capacity and better performance. Besides, due to the chain structure, the orthogonality of HR planes impacts the capacity and regularity of HRA. To achieve a trade-off between the model capacity and regularity, an orthogonality regularizer of the HR planes is added to the loss function. The weight \\(\lambda\\) can control the strength of the regularizer. \ No newline at end of file +The higher `r`, the more trainable parameters, resulting in a larger model capacity and better performance. Besides, due to the chain structure, the orthogonality of HR planes impacts the capacity and regularity of HRA. To achieve a trade-off between the model capacity and regularity, an orthogonality regularizer of the HR planes is added to the loss function. The weight \\(\lambda\\) can control the strength of the regularizer. + +## Bone +[Bone](https://huggingface.co/papers/2409.15371) is a new PEFT technology different from the LoRA series, reducing training resources and requiring no complex initialization. Bone not only enhances the utilization of original weight information but also emphasizes the internal connections between weights, leading to faster convergence and better data fitting. + +Bone: Block Affine Transformation as Parameter Efficient Fine-tuning Methods for Large Language Models + +Bone reduces the number of trainable parameters by using a block grouping method, and the block computation of weight W effectively promotes information exchange in the original weights, enhancing data fitting capability during fine-tuning. The experiment mentions controlling the size of trainable parameters through b (block size), similar to r (rank) in LoRA. For consistency within PEFT, we also name b as r. Note: Bone's r (b) is special and requires that weight W satisfies the conditions `in_features % r == 0` and `out_features % r == 0`. Additionally, when `in_features == out_features` and Bone-r equals LoRA-r, Bone's number of trainable parameters is only half that of LoRA. + +From the experiments in the paper, it is evident that Bone, with only half the parameters of LoRA, can outperform LoRA variants across various metrics. However, Bone currently has some issues: it is slower than LoRA and requires checkpointing to address excessive memory usage due to intermediate values, which further reduces training speed. We plan to address this in the future. Contributions are welcome. \ No newline at end of file diff --git a/docs/source/package_reference/bone.md b/docs/source/package_reference/bone.md new file mode 100644 index 0000000000..32144c375d --- /dev/null +++ b/docs/source/package_reference/bone.md @@ -0,0 +1,31 @@ + + +# Bone + +Block Affine ([Bone](https://huggingface.co/papers/2409.15371)) is a new PEFT technology different from the LoRA series, reducing training resources and requiring no complex initialization. Bone not only enhances the utilization of original weight information but also emphasizes the internal connections between weights, leading to faster convergence and better data fitting. + +The abstract from the paper is: + +Low-Rank Adaptation (LoRA) has achieved remarkable training results by freezing the original weights and training only low-rank matrices, establishing itself as the predominant fine-tuning method for LLMs. In pursuit of performance closer to full-parameter training, a series of LoRA variants have emerged, such as LoRA+, PISSA, Olora, and LoRA-GA. However, these improvements complicate the initial setup of model training and increase initialization time. More importantly, they overlook the internal interactions of the original weight information. To address these issues, we introduce a novel theory, ``Weight Guide'' aimed at continuously guiding trainable matrices through the original weights during training to enhance the utilization of weight information. Based on this theory, we designed a new PEFT technique called Bone (Block Affine), which not only enhances the utilization of original weight information but also emphasizes the internal connections between weights, leading to faster convergence and better data fitting. Experimental comparisons across two different LLM architectures (LLaMA2, RWKV6) and various parameter scales demonstrate that the Bone structure can achieve rapid convergence and superior data fitting without the need for complex initialization. For example, when fine-tuning LLaMA2-7B on the MetaMathQA dataset and validating on GSM8k and math benchmarks, Bone achieved fine-tuning scores of 49.36 and 8.8, respectively, outperforming PISSA by 5.84\% and 1.96\%. + +## BoneConfig + +[[autodoc]] tuners.bone.config.BoneConfig + +## BoneModel + +[[autodoc]] tuners.bone.model.BoneModel \ No newline at end of file diff --git a/examples/bone_finetuning/README.md b/examples/bone_finetuning/README.md new file mode 100644 index 0000000000..e787089177 --- /dev/null +++ b/examples/bone_finetuning/README.md @@ -0,0 +1,87 @@ +# BONE: BLOCK AFFINE TRANSFORMATION AS PARAMETER EFFICIENT FINE-TUNING METHODS FOR LARGE LANGUAGE MODELS +## Introduction ([Paper](https://arxiv.org/pdf/2409.15371), [code](https://github.com/JL-er/Bone)) +Low-Rank Adaptation (LoRA) has achieved remarkable training results by freezing the original weights and training only low-rank matrices, establishing itself as the predominant fine-tuning method for LLMs. In pursuit of performance closer to full-parameter training, a series of LoRA variants have emerged, such as LoRA+, PISSA, Olora, and LoRA-GA. However, these improvements complicate the initial setup of model training and increase initialization time. More importantly, they overlook the internal interactions of the original weight information. To address these issues, we introduce a novel theory, ``Weight Guide'' aimed at continuously guiding trainable matrices through the original weights during training to enhance the utilization of weight information. Based on this theory, we designed a new PEFT technique called Bone Block Affine, which not only enhances the utilization of original weight information but also emphasizes the internal connections between weights, leading to faster convergence and better data fitting. Experimental comparisons across two different LLM architectures (LLaMA2, RWKV6) and various parameter scales demonstrate that the Bone structure can achieve rapid convergence and superior data fitting without the need for complex initialization. For example, when fine-tuning LLaMA2-7B on the MetaMathQA dataset and validating on GSM8k and math benchmarks, Bone achieved fine-tuning scores of 49.36 and 8.8, respectively, outperforming PISSA by 5.84% and 1.96%. + +## Quick Start +```python +import torch +from peft import LoraConfig, get_peft_model +from transformers import AutoTokenizer, AutoModelForCausalLM +from trl import SFTConfig, SFTTrainer +from datasets import load_dataset + +model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", torch_dtype=torch.bfloat16, device_map="auto") +tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") +tokenizer.pad_token_id = tokenizer.eos_token_id +bone_config = BoneConfig( + r = 64 +) +peft_model = get_peft_model(model, bone_config) + +peft_model.print_trainable_parameters() + +dataset = load_dataset("imdb", split="train[:1%]") + +training_args = SFTConfig(dataset_text_field="text", max_seq_length=128) +trainer = SFTTrainer( + model=peft_model, + args=training_args, + train_dataset=dataset, + tokenizer=tokenizer, +) +trainer.train() +peft_model.save_pretrained("bone-llama-2-7b") +``` + + +To utilize the fine-tuned Bone modules, simply run the following command: +```python +import torch +from peft import PeftModel +from transformers import AutoModelForCausalLM + +model = AutoModelForCausalLM.from_pretrained( + "meta-llama/Llama-2-7b-hf", torch_dtype=torch.bfloat16, device_map="auto" +) +peft_model = PeftModel.from_pretrained(model, "bone-llama-2-7b") +``` + +## Advanced Usage + +### Fine-tune +```shell +python bone_finetuning.py \ + --base_model_name_or_path meta-llama/Llama-2-7b-hf \ + --output_dir output/bone-llama-2-7b-metamath-10k \ + --bits bf16 \ + --data_path meta-math/MetaMathQA \ + --dataset_split train[:100000] \ + --dataset_field query response \ + --bf16 True \ + --num_train_epochs 1 \ + --per_device_train_batch_size 2 \ + --gradient_accumulation_steps 8 \ + --save_strategy "steps" \ + --save_steps 1000 \ + --save_total_limit 1 \ + --logging_steps 1 \ + --learning_rate 2e-5 \ + --weight_decay 0. \ + --warmup_ratio 0.03 \ + --tf32 True \ + --report_to none +``` + + + +# Citation +```bib +@misc{kang2024boneblockaffinetransformation, + title={Bone: Block Affine Transformation as Parameter Efficient Fine-tuning Methods for Large Language Models}, + author={Jiale Kang}, + year={2024}, + eprint={2409.15371}, + archivePrefix={arXiv}, + primaryClass={cs.CL}, + url={https://arxiv.org/abs/2409.15371}, +} \ No newline at end of file diff --git a/examples/bone_finetuning/bone_finetuning.py b/examples/bone_finetuning/bone_finetuning.py new file mode 100644 index 0000000000..e546a89454 --- /dev/null +++ b/examples/bone_finetuning/bone_finetuning.py @@ -0,0 +1,99 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from dataclasses import dataclass, field +from typing import Optional + +import torch +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser +from trl import SFTConfig, SFTTrainer + +from peft import BoneConfig, get_peft_model + + +@dataclass +class ScriptArguments(SFTConfig): + # model configs + base_model_name_or_path: Optional[str] = field( + default=None, metadata={"help": "The name or path of the fp32/16 base model."} + ) + bits: str = field(default="bf16", metadata={"help": "(`['bf16', 'fp16', fp32]`)"}) + init_bone_weights: str = field(default="False") + bone_r: int = field(default=16) + merge_and_save: bool = field(default=False) + # dataset configs + data_path: str = field(default="imdb", metadata={"help": "Path to the training data."}) + dataset_split: str = field(default="train[:1%]", metadata={"help": "(`['train', 'test', 'eval']`):"}) + dataset_field: list[str] = field(default=None, metadata={"help": "Fields of dataset input and output."}) + + +parser = HfArgumentParser(ScriptArguments) +script_args = parser.parse_args_into_dataclasses()[0] +print(script_args) + +print(f"Load pre-processed residual model in {script_args.bits} bits.") +if script_args.bits in ["nf4", "fp4", "int8"]: + print("Bone currently does not support quantization.") + +elif script_args.base_model_name_or_path is not None: + print(f"No available pre-processed model, manually initialize a Bone using {script_args.base_model_name_or_path}.") + model = AutoModelForCausalLM.from_pretrained( + script_args.base_model_name_or_path, + torch_dtype=( + torch.float16 + if script_args.bits == "fp16" + else (torch.bfloat16 if script_args.bits == "bf16" else torch.float32) + ), + device_map="auto", + ) + tokenizer = AutoTokenizer.from_pretrained(script_args.base_model_name_or_path) + tokenizer.pad_token_id = tokenizer.eos_token_id + lora_config = BoneConfig( + r=script_args.bone_r, + target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"], + bias="none", + task_type="CAUSAL_LM", + ) + peft_model = get_peft_model(model, lora_config) + +print(peft_model) +peft_model.print_trainable_parameters() + +print(f"Training Bone with trl on the {script_args.data_path}[{script_args.dataset_split}] dataset.") +dataset = load_dataset(script_args.data_path, split=script_args.dataset_split) +dataset = dataset.map( + lambda example: { + "text": f"### USER: {example[script_args.dataset_field[0]]}\n### ASSISTANT: {example[script_args.dataset_field[1]]}" + } +) + +trainer = SFTTrainer( + model=peft_model, + args=script_args, + train_dataset=dataset, + tokenizer=tokenizer, +) +trainer.train() +trainer.save_state() + +peft_model.save_pretrained( + os.path.join(script_args.output_dir, "bone_ft"), +) + +if script_args.merge_and_save: + model = peft_model.merge_and_unload() + model.save_pretrained(os.path.join(script_args.output_dir, "bone_merged")) + tokenizer.save_pretrained(os.path.join(script_args.output_dir, "bone_merged")) diff --git a/src/peft/__init__.py b/src/peft/__init__.py index 6b2908b35d..8c51c3976e 100644 --- a/src/peft/__init__.py +++ b/src/peft/__init__.py @@ -91,6 +91,8 @@ HRAConfig, HRAModel, VBLoRAConfig, + BoneConfig, + BoneModel, ) from .utils import ( TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING, diff --git a/src/peft/mapping.py b/src/peft/mapping.py index 9c0ab95986..0b82e18c64 100644 --- a/src/peft/mapping.py +++ b/src/peft/mapping.py @@ -38,6 +38,8 @@ AdaptionPromptConfig, BOFTConfig, BOFTModel, + BoneConfig, + BoneModel, FourierFTConfig, FourierFTModel, HRAConfig, @@ -104,6 +106,7 @@ "XLORA": XLoraConfig, "HRA": HRAConfig, "VBLORA": VBLoRAConfig, + "BONE": BoneConfig, } PEFT_TYPE_TO_TUNER_MAPPING: dict[str, type[BaseTuner]] = { @@ -121,6 +124,7 @@ "XLORA": XLoraModel, "HRA": HRAModel, "VBLORA": VBLoRAModel, + "BONE": BoneModel, } diff --git a/src/peft/peft_model.py b/src/peft/peft_model.py index 66eba17b3b..84f0e16bb5 100644 --- a/src/peft/peft_model.py +++ b/src/peft/peft_model.py @@ -46,6 +46,7 @@ AdaLoraModel, AdaptionPromptModel, BOFTModel, + BoneModel, FourierFTModel, HRAModel, IA3Model, @@ -104,6 +105,7 @@ PeftType.XLORA: XLoraModel, PeftType.HRA: HRAModel, PeftType.VBLORA: VBLoRAModel, + PeftType.BONE: BoneModel, } diff --git a/src/peft/tuners/__init__.py b/src/peft/tuners/__init__.py index d58ff9e3e6..3db7376fa5 100644 --- a/src/peft/tuners/__init__.py +++ b/src/peft/tuners/__init__.py @@ -37,3 +37,4 @@ from .xlora import XLoraConfig, XLoraModel from .hra import HRAConfig, HRAModel from .vblora import VBLoRAConfig, VBLoRAModel +from .bone import BoneConfig, BoneModel diff --git a/src/peft/tuners/bone/__init__.py b/src/peft/tuners/bone/__init__.py new file mode 100644 index 0000000000..7d76fe6cb5 --- /dev/null +++ b/src/peft/tuners/bone/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2024-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .config import BoneConfig +from .layer import BoneLayer, BoneLinear +from .model import BoneModel + + +__all__ = ["BoneConfig", "BoneModel", "BoneLinear", "BoneLayer"] diff --git a/src/peft/tuners/bone/config.py b/src/peft/tuners/bone/config.py new file mode 100644 index 0000000000..880cff94f7 --- /dev/null +++ b/src/peft/tuners/bone/config.py @@ -0,0 +1,124 @@ +# Copyright 2024-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Optional, Union + +from peft.config import PeftConfig +from peft.utils import PeftType + + +@dataclass +class BoneConfig(PeftConfig): + """ + This is the configuration class to store the configuration of a [`BoneModel`]. + + Args: + r (`int`): + The rank of Bone across different layers. It is best to set 'r' to an even number; otherwise, the default + initialization method will not work. + target_modules (`Optional[Union[List[str], str]]`): + The names of the modules to apply the adapter to. If this is specified, only the modules with the specified + names will be replaced. When passing a string, a regex match will be performed. When passing a list of + strings, either an exact match will be performed or it is checked if the name of the module ends with any + of the passed strings. If this is specified as 'all-linear', then all linear modules are chosen, excluding + the output layer. If this is not specified, modules will be chosen according to the model architecture. If + the architecture is not known, an error will be raised -- in this case, you should specify the target + modules manually. + exclude_modules (`Optional[Union[List[str], str]]`): + The names of the modules to not apply the adapter. When passing a string, a regex match will be performed. + When passing a list of strings, either an exact match will be performed or it is checked if the name of the + module ends with any of the passed strings. + init_weights (`bool`): + Whether to perform initialization of Bone weights. + layers_to_transform (`Union[List[int], int]`): + The layer indices to transform. If a list of ints is passed, it will apply the adapter to the layer indices + that are specified in this list. If a single integer is passed, it will apply the transformations on the + layer at this index. + layers_pattern (`str`): + The layer pattern name, used only if `layers_to_transform` is different from `None`. + rank_pattern (`dict`): + The mapping from layer names or regexp expression to ranks which are different from the default rank + specified by `r`. + modules_to_save (`List[str]`): + List of modules apart from adapter layers to be set as trainable and saved in the final checkpoint. + """ + + r: int = field( + default=64, + metadata={ + "help": "The rank of Bone across different layers.", + "note": "It is best to set 'r' to an even number; otherwise, the default initialization method will not work.", + }, + ) + + target_modules: Optional[Union[list[str], str]] = field( + default=None, + metadata={ + "help": "List of module names or regex expression of the module names to replace with Bone.", + "example": "For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$' ", + }, + ) + exclude_modules: Optional[Union[list[str], str]] = field( + default=None, + metadata={"help": "List of module names or regex expression of the module names to exclude from Bone."}, + ) + init_weights: bool = field( + default=True, + metadata={ + "help": ( + "Whether to initialize the weights of the Bone layers with their default initialization. Don't change " + "this setting, except if you know exactly what you're doing." + ), + }, + ) + layers_to_transform: Optional[Union[list[int], int]] = field( + default=None, + metadata={ + "help": "The layer indexes to transform, is this argument is specified, PEFT will transform only the layers indexes that are specified inside this list. If a single integer is passed, PEFT will transform only the layer at this index." + }, + ) + layers_pattern: Optional[str] = field( + default=None, + metadata={ + "help": "The layer pattern name, used only if `layers_to_transform` is different to None and if the layer pattern is not in the common layers pattern." + }, + ) + bias: str = field(default="none", metadata={"help": "Bias type for Bone. Can be 'none', 'all' or 'Bone_only'"}) + modules_to_save: Optional[list[str]] = field( + default=None, + metadata={ + "help": "List of modules apart from Bone layers to be set as trainable and saved in the final checkpoint. " + "For example, in Sequence Classification or Token Classification tasks, " + "the final layer `classifier/score` are randomly initialized and as such need to be trainable and saved." + }, + ) + + def __post_init__(self): + self.peft_type = PeftType.BONE + self.target_modules = ( + set(self.target_modules) if isinstance(self.target_modules, list) else self.target_modules + ) + self.exclude_modules = ( + set(self.exclude_modules) if isinstance(self.exclude_modules, list) else self.exclude_modules + ) + # if target_modules is a regex expression, then layers_to_transform should be None + if isinstance(self.target_modules, str) and self.layers_to_transform is not None: + raise ValueError("`layers_to_transform` cannot be used when `target_modules` is a str.") + + # if target_modules is a regex expression, then layers_pattern should be None + if isinstance(self.target_modules, str) and self.layers_pattern is not None: + raise ValueError("`layers_pattern` cannot be used when `target_modules` is a str.") diff --git a/src/peft/tuners/bone/layer.py b/src/peft/tuners/bone/layer.py new file mode 100644 index 0000000000..a718fbd0c3 --- /dev/null +++ b/src/peft/tuners/bone/layer.py @@ -0,0 +1,255 @@ +# Copyright 2024-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +import warnings +from typing import Any, List, Optional, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge + + +class BoneLayer(BaseTunerLayer): + # All names of layers that may contain (trainable) adapter weights + adapter_layer_names = ("bone_block",) + # All names of other parameters that may contain adapter-related parameters + other_param_names = ("bone_r",) + + def __init__(self, base_layer: nn.Module, **kwargs) -> None: + self.base_layer = base_layer + self.bone_r = {} + self.bone_block = nn.ParameterDict({}) + # Mark the weight as unmerged + self._disable_adapters = False + self.merged_adapters = [] + self.kwargs = kwargs + + base_layer = self.get_base_layer() + if isinstance(base_layer, nn.Linear): + self.in_features, self.out_features = base_layer.in_features, base_layer.out_features + else: + raise ValueError(f"Unsupported layer type {type(base_layer)}") + + def update_layer( + self, + adapter_name: str, + r: int, + init_weights: bool, + **kwargs, + ) -> None: + """Internal function to create bone adapter + + Args: + adapter_name (`str`): Name for the adapter to add. + r (`int`): Rank for the added adapter. + init_weights (`bool`): Whether to initialize weights. + """ + if r <= 0: + raise ValueError(f"`r` should be a positive integer value but the value passed is {r}") + if self.in_features % r != 0 or self.out_features % r != 0: + raise ValueError("The weight matrix must be fully divisible into [r, r] blocks.") + + self.bone_r[adapter_name] = r + + # Determine shape of Bone weights + base_layer = self.get_base_layer() + if isinstance(base_layer, nn.Linear): + self.bone_block[adapter_name] = nn.Parameter(torch.zeros(self.out_features // r, r, r), requires_grad=True) + + else: + raise TypeError(f"Bone is not implemented for base layers of type {type(base_layer).__name__}") + + # Initialize weights + + if init_weights: + self.reset_bone_parameters(adapter_name, r) + else: + self.reset_bone_parameters_random(adapter_name) + # Move new weights to device + self._move_adapter_to_device_of_base_layer(adapter_name) + self.set_adapter(self.active_adapters) + + def reset_bone_parameters(self, adapter_name: str, r): + self.bone_block[adapter_name] = nn.Parameter(torch.zeros(self.out_features // r, r, r), requires_grad=True) + + def reset_bone_parameters_random(self, adapter_name: str): + nn.init.kaiming_uniform_(self.bone_block[adapter_name], a=math.sqrt(5)) + + def scale_layer(self, scale: float) -> None: + if scale == 1: + return + + for active_adapter in self.active_adapters: + if active_adapter not in self.bone_block.keys(): + continue + + warnings.warn("Scaling operation for Bone not supported! Automatically set scale to 1.") + + def unscale_layer(self, scale=None) -> None: + for active_adapter in self.active_adapters: + if active_adapter not in self.bone_block.keys(): + continue + + warnings.warn("Unscaling operation for Bone not supported! Keeping scale at 1.") + + +class BoneLinear(nn.Module, BoneLayer): + """ + Bone implemented in a dense layer. + """ + + def __init__( + self, + base_layer, + adapter_name: str, + r: int = 0, + init_weights: Union[bool, str] = True, + **kwargs, + ) -> None: + super().__init__() + BoneLayer.__init__(self, base_layer, **kwargs) + self._active_adapter = adapter_name + self.update_layer(adapter_name, r, init_weights, **kwargs) + + def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None: + """ + Merge the active adapter weights into the base weights + + Args: + safe_merge (`bool`, *optional*): + If `True`, the merge operation will be performed in a copy of the original weights and check for NaNs + before merging the weights. This is useful if you want to check if the merge operation will produce + NaNs. Defaults to `False`. + adapter_names (`List[str]`, *optional*): + The list of adapter names that should be merged. If `None`, all active adapters will be merged. + Defaults to `None`. + """ + adapter_names = check_adapters_to_merge(self, adapter_names) + if not adapter_names: + # no adapter to merge + return + + for active_adapter in adapter_names: + if active_adapter in self.bone_block.keys(): + base_layer = self.get_base_layer() + if safe_merge: + # Note that safe_merge will be slower than the normal merge + # because of the copy operation. + orig_weight = base_layer.weight.data.clone() + delta_weight = self.get_delta_weight(active_adapter, orig_weight) + orig_weight += delta_weight + + if not torch.isfinite(orig_weight).all(): + raise ValueError( + f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" + ) + + self.base_layer.weight.data = orig_weight + else: + delta_weight = self.get_delta_weight(active_adapter, self.base_layer.weight.data) + self.base_layer.weight.data += delta_weight + self.merged_adapters.append(active_adapter) + + def unmerge(self) -> None: + """ + This method unmerges all merged adapter layers from the base weights. + """ + if not self.merged: + warnings.warn("Already unmerged. Nothing to do.") + return + while len(self.merged_adapters) > 0: + active_adapter = self.merged_adapters.pop() + if active_adapter in self.bone_block.keys(): + orig_weight = self.get_base_layer().weight.data.clone() + delta_weight = self.get_delta_weight(active_adapter, orig_weight, re=True) + self.get_base_layer().weight.data = delta_weight + + def get_delta_weight(self, adapter, orig_weight, re: bool = False) -> torch.Tensor: + """ + Compute the delta weight for the given adapter. + + Args: + adapter (str): + The name of the adapter for which the delta weight should be computed. + """ + device = self.bone_block[adapter].device + dtype = self.bone_block[adapter].dtype + # In case users wants to merge the adapter weights that are in + # (b)float16 while being on CPU, we need to cast the weights to float32, perform the merge and then cast back to + # (b)float16 because some CPUs have slow bf16/fp16 matmuls. + cast_to_fp32 = device.type == "cpu" and (dtype == torch.float16 or dtype == torch.bfloat16) + + weight_bone = self.bone_block[adapter] + + if cast_to_fp32: + weight_bone = weight_bone.float() + + r = weight_bone.size(-1) + if re: + o = orig_weight.reshape(orig_weight.size(0) // r, r, orig_weight.size(1) // r, r).permute(2, 0, 1, 3) + one = torch.eye(weight_bone.size(-1)).to(weight_bone.device) + inv_I_plus_b = torch.inverse(one + weight_bone) + w = (o - weight_bone) @ inv_I_plus_b + output_tensor = w.permute(1, 2, 0, 3).reshape(*orig_weight.shape) + else: + w = ( + orig_weight.reshape(orig_weight.size(0) // r, r, orig_weight.size(1) // r, r).permute(2, 0, 1, 3) + @ weight_bone + + weight_bone + ) + output_tensor = w.permute(1, 2, 0, 3).reshape(*orig_weight.shape) + + if cast_to_fp32: + output_tensor = output_tensor.to(dtype=dtype) + + # cast back the weights + self.bone_block[adapter].data = weight_bone.to(dtype) + + return output_tensor + + def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: + previous_dtype = x.dtype + + if self.disable_adapters: + if self.merged: + self.unmerge() + result = self.base_layer(x, *args, **kwargs) + elif self.merged: + result = self.base_layer(x, *args, **kwargs) + else: + orig_weight = self.base_layer.weight.data.clone() + for active_adapter in self.active_adapters: + if active_adapter not in self.bone_block.keys(): + continue + delta_weight = self.get_delta_weight(active_adapter, orig_weight) + orig_weight = orig_weight + delta_weight + + result = F.linear(input=x, weight=orig_weight, bias=self.base_layer.bias) + # result = self.base_layer(x, *args, **kwargs) + # delta_weight = self.base_layer.weight.data.clone() + # for active_adapter in self.active_adapters: + # if active_adapter not in self.bone_block.keys(): + # continue + # delta_weight = self.get_delta_weight(active_adapter, delta_weight) + + # result = result + F.linear(input=x, weight=delta_weight, bias=None) + + result = result.to(previous_dtype) + return result + + def __repr__(self) -> str: + rep = super().__repr__() + return "bone." + rep diff --git a/src/peft/tuners/bone/model.py b/src/peft/tuners/bone/model.py new file mode 100644 index 0000000000..07809a5d38 --- /dev/null +++ b/src/peft/tuners/bone/model.py @@ -0,0 +1,336 @@ +# Copyright 2024-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from dataclasses import asdict +from enum import Enum +from typing import List, Optional + +import torch +from torch import nn +from tqdm import tqdm + +from peft.tuners.tuners_utils import BaseTuner, BaseTunerLayer, check_target_module_exists +from peft.utils import ( + TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING, + ModulesToSaveWrapper, + _get_submodules, +) + +from .config import BoneConfig +from .layer import BoneLayer, BoneLinear + + +class BoneModel(BaseTuner): + """ + Creates Householder reflection adaptation (Bone) model from a pretrained model. The method is described in + https://arxiv.org/abs/2409.15371 + + Args: + model (`torch.nn.Module`): The model to which the adapter tuner layers will be attached. + config ([`BoneConfig`]): The configuration of the Bone model. + adapter_name (`str`): The name of the adapter, defaults to `"default"`. + low_cpu_mem_usage (`bool`, `optional`, defaults to `False`): + Create empty adapter weights on meta device. Useful to speed up the loading process. + + Returns: + `torch.nn.Module`: The Bone model. + + Example: + ```py + >>> from diffusers import StableDiffusionPipeline + >>> from peft import BoneModel, BoneConfig + + >>> config_te = BoneConfig( + ... r=8, + ... target_modules=["k_proj", "q_proj", "v_proj", "out_proj", "fc1", "fc2"], + ... init_weights=True, + ... ) + >>> config_unet = BoneConfig( + ... r=8, + ... target_modules=[ + ... "proj_in", + ... "proj_out", + ... "to_k", + ... "to_q", + ... "to_v", + ... "to_out.0", + ... "ff.net.0.proj", + ... "ff.net.2", + ... ], + ... init_weights=True, + ... ) + + >>> model = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") + >>> model.text_encoder = BoneModel(model.text_encoder, config_te, "default") + >>> model.unet = BoneModel(model.unet, config_unet, "default") + ``` + + **Attributes**: + - **model** ([`~torch.nn.Module`]) -- The model to be adapted. + - **peft_config** ([`BoneConfig`]): The configuration of the Bone model. + """ + + prefix: str = "bone_" + + def _check_new_adapter_config(self, config: BoneConfig) -> None: + """ + A helper method to check the config when a new adapter is being added. + + Raise a ValueError if there is something wrong with the config or if it conflicts with existing adapters. + + """ + # TODO: there should be a check if any of the existing adapters actually has bias != "none", or else the check + # does not fully correspond to the error message. + if (len(self.peft_config) > 1) and (config.bias != "none"): + raise ValueError( + f"{self.__class__.__name__} supports only 1 adapter with bias. When using multiple adapters, " + "set bias to 'none' for all adapters." + ) + + @staticmethod + def _check_target_module_exists(bone_config, key): + return check_target_module_exists(bone_config, key) + + def _create_and_replace( + self, + bone_config, + adapter_name, + target, + target_name, + parent, + current_key, + **optional_kwargs, + ): + if current_key is None: + raise ValueError("Current Key shouldn't be `None`") + + bias = hasattr(target, "bias") and target.bias is not None + kwargs = { + "r": bone_config.r, + "init_weights": bone_config.init_weights, + } + kwargs["bias"] = bias + + # If it is not a BoneLayer, create a new module, else update it with new adapters + if not isinstance(target, BoneLayer): + new_module = self._create_new_module(bone_config, adapter_name, target, **kwargs) + if adapter_name not in self.active_adapters: + # adding an additional adapter: it is not automatically trainable + new_module.requires_grad_(False) + self._replace_module(parent, target_name, new_module, target) + else: + target.update_layer( + adapter_name, + r=bone_config.r, + init_weights=bone_config.init_weights, + ) + + def _replace_module(self, parent, child_name, new_module, child): + setattr(parent, child_name, new_module) + # It's not necessary to set requires_grad here, as that is handled by + # _mark_only_adapters_as_trainable + + # child layer wraps the original module, unpack it + if hasattr(child, "base_layer"): + child = child.base_layer + + if not hasattr(new_module, "base_layer"): + new_module.weight = child.weight + if hasattr(child, "bias"): + new_module.bias = child.bias + + if getattr(child, "state", None) is not None: + if hasattr(new_module, "base_layer"): + new_module.base_layer.state = child.state + else: + new_module.state = child.state + new_module.to(child.weight.device) + + meta = torch.device("meta") + # dispatch to correct device + for name, module in new_module.named_modules(): + if self.prefix in name: + if not any(p.device == meta for p in module.parameters()): + module.to(child.weight.device) + + def _mark_only_adapters_as_trainable(self, model: nn.Module) -> None: + for n, p in model.named_parameters(): + if self.prefix not in n: + p.requires_grad = False + + for active_adapter in self.active_adapters: + bias = self.peft_config[active_adapter].bias + if bias == "none": + continue + + if bias == "all": + for n, p in model.named_parameters(): + if "bias" in n: + p.requires_grad = True + elif bias == "bone_only": + for name, m in model.named_modules(): + if isinstance(m, BoneLayer) and hasattr(m, "bias") and m.bias is not None: + m.bias.requires_grad = True + else: + raise NotImplementedError(f"Requested bias: {bias}, is not implemented.") + + @staticmethod + def _create_new_module(bone_config, adapter_name, target, **kwargs): + if isinstance(target, BaseTunerLayer): + target_base_layer = target.get_base_layer() + else: + target_base_layer = target + + if isinstance(target_base_layer, torch.nn.Linear): + new_module = BoneLinear(target, adapter_name, **kwargs) + else: + raise ValueError( + f"Target module {target} is not supported. " "Currently, only `torch.nn.Linear` is supported." + ) + + return new_module + + def __getattr__(self, name: str): + """Forward missing attributes to the wrapped module.""" + try: + return super().__getattr__(name) # defer to nn.Module's logic + except AttributeError: + if name == "base_model": + raise + return getattr(self.model, name) + + def get_peft_config_as_dict(self, inference: bool = False): + config_dict = {} + for key, value in self.peft_config.items(): + config = {k: v.value if isinstance(v, Enum) else v for k, v in asdict(value).items()} + if inference: + config["inference_mode"] = True + config_dict[key] = config + return config + + def _set_adapter_layers(self, enabled=True): + for module in self.model.modules(): + if isinstance(module, (BaseTunerLayer, ModulesToSaveWrapper)): + module.enable_adapters(enabled) + + def enable_adapter_layers(self): + self._set_adapter_layers(enabled=True) + + def disable_adapter_layers(self): + for active_adapter in self.active_adapters: + val = self.peft_config[active_adapter].bias + if val != "none": + msg = ( + f"Careful, disabling adapter layers with bias configured to be '{val}' does not produce the same " + "output as the the base model would without adaption." + ) + warnings.warn(msg) + self._set_adapter_layers(enabled=False) + + def set_adapter(self, adapter_name): + for module in self.model.modules(): + if isinstance(module, BoneLayer): + if module.merged: + warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.") + module.unmerge() + module.set_adapter(adapter_name) + self.active_adapter = adapter_name + + @staticmethod + def _prepare_adapter_config(peft_config, model_config): + if peft_config.target_modules is None: + if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING: + raise ValueError("Please specify `target_modules` in `peft_config`") + peft_config.target_modules = set( + TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING[model_config["model_type"]] + ) + return peft_config + + def _unload_and_optionally_merge( + self, + merge=True, + progressbar: bool = False, + safe_merge: bool = False, + adapter_names: Optional[List[str]] = None, + ): + self._unloading_checks(adapter_names) + key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key] + desc = "Unloading " + ("and merging " if merge else "") + "model" + for key in tqdm(key_list, disable=not progressbar, desc=desc): + try: + parent, target, target_name = _get_submodules(self.model, key) + except AttributeError: + continue + + if hasattr(target, "base_layer"): + if merge: + target.merge(safe_merge=safe_merge, adapter_names=adapter_names) + self._replace_module(parent, target_name, target.get_base_layer(), target) + elif isinstance(target, ModulesToSaveWrapper): + # save any additional trainable modules part of `modules_to_save` + setattr(parent, target_name, target.modules_to_save[target.active_adapter]) + + return self.model + + def delete_adapter(self, adapter_name: str) -> None: + """ + Deletes an existing adapter. + + Args: + adapter_name (str): Name of the adapter to be deleted. + """ + if adapter_name not in list(self.peft_config.keys()): + raise ValueError(f"Adapter {adapter_name} does not exist") + del self.peft_config[adapter_name] + + key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key] + new_adapter = None + for key in key_list: + _, target, _ = _get_submodules(self.model, key) + if isinstance(target, BoneLayer): + target.delete_adapter(adapter_name) + if new_adapter is None: + new_adapter = target.active_adapters[:] + + self.active_adapter = new_adapter or [] + + def merge_and_unload( + self, progressbar: bool = False, safe_merge: bool = False, adapter_names: Optional[List[str]] = None + ) -> torch.nn.Module: + r""" + This method merges the Bone layers into the base model. This is needed if someone wants to use the base model + as a standalone model. + + Args: + progressbar (`bool`): + whether to show a progressbar indicating the unload and merge process + safe_merge (`bool`): + whether to activate the safe merging check to check if there is any potential Nan in the adapter + weights + adapter_names (`List[str]`, *optional*): + The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults + to `None`. + + """ + return self._unload_and_optionally_merge( + progressbar=progressbar, safe_merge=safe_merge, adapter_names=adapter_names + ) + + def unload(self) -> torch.nn.Module: + """ + Gets back the base model by removing all the bone modules without merging. This gives back the original base + model. + """ + return self._unload_and_optionally_merge(merge=False) diff --git a/src/peft/utils/constants.py b/src/peft/utils/constants.py index 4365c878fe..986be87bd6 100644 --- a/src/peft/utils/constants.py +++ b/src/peft/utils/constants.py @@ -300,6 +300,7 @@ def starcoder_model_postprocess_past_key_value(past_key_values): PeftType.FOURIERFT: "fourierft_", PeftType.HRA: "hra_", PeftType.VBLORA: "vblora_", + PeftType.BONE: "bone_", } WEIGHTS_NAME = "adapter_model.bin" diff --git a/src/peft/utils/peft_types.py b/src/peft/utils/peft_types.py index 4072878700..b0fa8de9ba 100644 --- a/src/peft/utils/peft_types.py +++ b/src/peft/utils/peft_types.py @@ -42,6 +42,7 @@ class PeftType(str, enum.Enum): - VERA - FOURIERFT - HRA + - BONE """ PROMPT_TUNING = "PROMPT_TUNING" @@ -63,6 +64,7 @@ class PeftType(str, enum.Enum): XLORA = "XLORA" HRA = "HRA" VBLORA = "VBLORA" + BONE = "BONE" class TaskType(str, enum.Enum): diff --git a/src/peft/utils/save_and_load.py b/src/peft/utils/save_and_load.py index c4f8ccb810..44a1cad5ff 100644 --- a/src/peft/utils/save_and_load.py +++ b/src/peft/utils/save_and_load.py @@ -208,6 +208,8 @@ def renamed_dora_weights(k): to_return["base_model.vblora_vector_bank." + adapter_name] = state_dict[ "base_model.vblora_vector_bank." + adapter_name ] + elif config.peft_type == PeftType.BONE: + to_return = {k: state_dict[k] for k in state_dict if "bone_" in k} else: raise ValueError(f"Unknown PEFT type passed: {config.peft_type}") diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py index 611b07bf97..ede819d9e6 100644 --- a/tests/test_custom_models.py +++ b/tests/test_custom_models.py @@ -35,6 +35,7 @@ from peft import ( AdaLoraConfig, BOFTConfig, + BoneConfig, FourierFTConfig, HRAConfig, IA3Config, @@ -294,6 +295,13 @@ ("Vanilla MLP 3 HRA", "MLP", HRAConfig, {"target_modules": ["lin0", "lin1"]}), ("Vanilla MLP 5 HRA", "MLP", HRAConfig, {"target_modules": ["lin0"], "modules_to_save": ["lin1"]}), ("Conv2d 1 HRA", "Conv2d", HRAConfig, {"target_modules": ["conv2d"]}), + ######## + # Bone # + ######## + ("Vanilla MLP 1 Bone", "MLP", BoneConfig, {"target_modules": "lin0", "r": 2}), + ("Vanilla MLP 2 Bone", "MLP", BoneConfig, {"target_modules": ["lin0"], "r": 2}), + ("Vanilla MLP 3 Bone", "MLP", BoneConfig, {"target_modules": ["lin0", "lin1"], "r": 2}), + ("Vanilla MLP 5 Bone", "MLP", BoneConfig, {"target_modules": ["lin0"], "modules_to_save": ["lin1"], "r": 2}), ############# # LN Tuning # ############# @@ -558,6 +566,20 @@ {"target_modules": ["lin0"], "init_weights": False}, {"target_modules": ["lin1"], "init_weights": False}, ), + ( + "Bone Same", + "bone", + BoneConfig, + {"target_modules": ["lin0"], "init_weights": False, "r": 2}, + {"target_modules": ["lin0"], "init_weights": False, "r": 2}, + ), + ( + "Bone Different", + "bone", + BoneConfig, + {"target_modules": ["lin0"], "init_weights": False, "r": 2}, + {"target_modules": ["lin1"], "init_weights": False, "r": 2}, + ), ( "VBLoRA Same", "vblora", @@ -600,6 +622,7 @@ FourierFTConfig: "fourierft_", HRAConfig: "hra_", VBLoRAConfig: "vblora_", + BoneConfig: "bone_", } @@ -1420,7 +1443,7 @@ def test_multiple_adapters_automatic_modules_to_save(self): assert "default" in model.base_model.classifier.modules_to_save assert "other" in model.base_model.classifier.modules_to_save - @parameterized.expand([IA3Config, LoHaConfig, LoKrConfig, LoraConfig, HRAConfig]) + @parameterized.expand([IA3Config, LoHaConfig, LoKrConfig, LoraConfig, HRAConfig, BoneConfig]) def test_multiple_adapters_mixed_modules_to_save(self, config_cls): # See issue 1574 # Check that we can have a model where one adapter has modules_to_save and the other doesn't. It should be @@ -1428,6 +1451,9 @@ def test_multiple_adapters_mixed_modules_to_save(self, config_cls): if hasattr(config_cls, "feedforward_modules"): # IA³ config_cls = partial(config_cls, feedforward_modules=["lin0"]) + if config_cls == BoneConfig: + config_cls = partial(config_cls, r=2) + config0 = config_cls(target_modules=["lin0"], modules_to_save=["lin1"]) config1 = config_cls(target_modules=["lin0"]) model = MLP() @@ -1445,13 +1471,16 @@ def test_multiple_adapters_mixed_modules_to_save(self, config_cls): model.set_adapter("other") model(**inputs) - @parameterized.expand([IA3Config, LoHaConfig, LoKrConfig, LoraConfig, HRAConfig]) + @parameterized.expand([IA3Config, LoHaConfig, LoKrConfig, LoraConfig, HRAConfig, BoneConfig]) def test_multiple_adapters_mixed_modules_to_save_order_switched(self, config_cls): # See issue 1574 # Same test as test_multiple_adapters_mixed_modules_to_save, but this time the 2nd adapter has modules_to_save. if hasattr(config_cls, "feedforward_modules"): # IA³ config_cls = partial(config_cls, feedforward_modules=["lin0"]) + if config_cls == BoneConfig: + config_cls = partial(config_cls, r=2) + config0 = config_cls(target_modules=["lin0"]) config1 = config_cls(target_modules=["lin0"], modules_to_save=["lin1"]) model = MLP() @@ -1651,6 +1680,7 @@ def test_load_resized_embedding_ignore_mismatched_sizes(self): OFTConfig(target_modules=["lin0"], init_weights=False, r=2), BOFTConfig(target_modules=["lin0"], init_weights=False, boft_block_size=2), HRAConfig(target_modules=["lin0"], init_weights=False), + BoneConfig(target_modules=["lin0"], init_weights=False, r=2), ] ) def test_adapter_name_makes_no_difference(self, config0): @@ -2887,6 +2917,83 @@ def test_requires_grad_hra_same_targets(self): "base_model.model.lin0.hra_u.adapter1", ) + def test_requires_grad_bone_different_targets(self): + # test two different HRA adapters that target different modules + config0 = BoneConfig(target_modules=["lin0"], r=2) + peft_model = get_peft_model(MLP(), config0) + + config1 = BoneConfig(target_modules=["lin1"], r=2, inference_mode=True) + peft_model.add_adapter("adapter1", config1) + + # active adapter is still "default" + self.check_requires_grad( + peft_model, + "base_model.model.lin0.bone_block.default", + ) + + # set config0 as active, should not change anything + peft_model.set_adapter("default") + self.check_requires_grad( + peft_model, + "base_model.model.lin0.bone_block.default", + ) + + # change activate pter to pter1 + peft_model.set_adapter("adapter1") + self.check_requires_grad( + peft_model, + "base_model.model.lin1.bone_block.adapter1", + ) + + # disable all pters + with peft_model.disable_adapter(): + self.check_requires_grad(peft_model) + + # after context is exited, return to the previous state + self.check_requires_grad( + peft_model, + "base_model.model.lin1.bone_block.adapter1", + ) + + def test_requires_grad_bone_same_targets(self): + # same as previous test, except that HRA adapters target the same layer + config0 = BoneConfig(target_modules=["lin0"], r=2) + peft_model = get_peft_model(MLP(), config0) + + config1 = BoneConfig(target_modules=["lin0"], r=2, inference_mode=True) + peft_model.add_adapter("adapter1", config1) + + # active adapter is still "default" + self.check_requires_grad( + peft_model, + "base_model.model.lin0.bone_block.default", + ) + + # set config0 as active, should not change anything + peft_model.set_adapter("default") + self.check_requires_grad( + peft_model, + "base_model.model.lin0.bone_block.default", + ) + + # change activate adapter to adapter1 + peft_model.set_adapter("adapter1") + self.check_requires_grad( + peft_model, + "base_model.model.lin0.bone_block.adapter1", + ) + + # disable all adapters + with peft_model.disable_adapter(): + self.check_requires_grad(peft_model) + + # after context is exited, return to the previous state + peft_model.set_adapter("adapter1") + self.check_requires_grad( + peft_model, + "base_model.model.lin0.bone_block.adapter1", + ) + def test_requires_grad_boft_different_targets(self): # test two different OFT adapters that target different modules config0 = BOFTConfig(target_modules=["lin0"], boft_block_size=2) diff --git a/tests/test_decoder_models.py b/tests/test_decoder_models.py index f8c9d2c65a..5be158e3a7 100644 --- a/tests/test_decoder_models.py +++ b/tests/test_decoder_models.py @@ -30,6 +30,7 @@ from peft import ( AdaLoraConfig, BOFTConfig, + BoneConfig, HRAConfig, LoraConfig, OFTConfig, @@ -83,6 +84,7 @@ def skip_oft_or_hra_and_gpt2(test_list): if not ( ("GPT2LMHeadModel" in test[1]) and ((test[2] == BOFTConfig) or (test[2] == HRAConfig) or (test[2] == OFTConfig)) + or (test[2] == BoneConfig) ) ] @@ -98,6 +100,7 @@ def skip_adalora_or_oft_or_hra_and_gpt2(test_list): or (test[2] == BOFTConfig) or (test[2] == HRAConfig) or (test[2] == OFTConfig) + or (test[2] == BoneConfig) ) ) ] @@ -224,9 +227,6 @@ def test_save_pretrained_selected_adapters(self, test_name, model_id, config_cls def test_save_pretrained_selected_adapters_pickle(self, test_name, model_id, config_cls, config_kwargs): self._test_save_pretrained_selected_adapters(model_id, config_cls, config_kwargs, safe_serialization=False) - def test_load_model_low_cpu_mem_usage(self): - self._test_load_model_low_cpu_mem_usage(PEFT_DECODER_MODELS_TO_TEST[0], LoraConfig, {}) - @parameterized.expand( PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_oft_or_hra_and_gpt2) ) @@ -245,6 +245,7 @@ def test_from_pretrained_config_construction(self, test_name, model_id, config_c "vera_kwargs": {"init_weights": [False]}, "fourierft_kwargs": {"init_weights": [False]}, "hra_kwargs": {"init_weights": [False]}, + "bone_kwargs": {"init_weights": [False]}, "task_type": "CAUSAL_LM", }, ) @@ -263,6 +264,7 @@ def test_merge_layers(self, test_name, model_id, config_cls, config_kwargs): "vera_kwargs": {"init_weights": [False]}, "fourierft_kwargs": {"init_weights": [False]}, "hra_kwargs": {"init_weights": [False]}, + "bone_kwargs": {"init_weights": [False]}, "task_type": "CAUSAL_LM", }, filter_params_func=skip_oft_or_hra_and_gpt2, @@ -279,6 +281,7 @@ def test_merge_layers_multi(self, test_name, model_id, config_cls, config_kwargs "ia3_kwargs": {"init_ia3_weights": [False]}, "boft_kwargs": {"init_weights": [False]}, "oft_kwargs": {"init_weights": [False]}, + "bone_kwargs": {"init_weights": [False]}, "task_type": "CAUSAL_LM", }, ) @@ -291,6 +294,7 @@ def test_merge_layers_nan(self, test_name, model_id, config_cls, config_kwargs): { "model_ids": PEFT_DECODER_MODELS_TO_TEST, "lora_kwargs": {"init_lora_weights": [False]}, + "bone_kwargs": {"init_weights": [False]}, "task_type": "CAUSAL_LM", }, ) @@ -379,6 +383,7 @@ def test_adding_multiple_adapters_with_bias_raises(self, test_name, model_id, co "vera_kwargs": {"init_weights": [False]}, "fourierft_kwargs": {"init_weights": [False]}, "hra_kwargs": {"init_weights": [False]}, + "bone_kwargs": {"init_weights": [False]}, "task_type": "CAUSAL_LM", }, filter_params_func=skip_adalora_or_oft_or_hra_and_gpt2, @@ -395,6 +400,7 @@ def test_unload_adapter(self, test_name, model_id, config_cls, config_kwargs): "ia3_kwargs": {"init_ia3_weights": [False]}, "boft_kwargs": {"init_weights": [False]}, "oft_kwargs": {"init_weights": [False]}, + "bone_kwargs": {"init_weights": [False]}, "task_type": "CAUSAL_LM", }, ) @@ -418,6 +424,7 @@ def test_training_prompt_learning_tasks(self, test_name, model_id, config_cls, c "vera_kwargs": {"init_weights": [False]}, "fourierft_kwargs": {"init_weights": [False]}, "hra_kwargs": {"init_weights": [False]}, + "bone_kwargs": {"init_weights": [False]}, "task_type": "CAUSAL_LM", }, filter_params_func=skip_oft_or_hra_and_gpt2, diff --git a/tests/test_encoder_decoder_models.py b/tests/test_encoder_decoder_models.py index d013bbad99..e22f010089 100644 --- a/tests/test_encoder_decoder_models.py +++ b/tests/test_encoder_decoder_models.py @@ -98,6 +98,7 @@ def test_from_pretrained_config_construction(self, test_name, model_id, config_c "ia3_kwargs": {"init_ia3_weights": [False]}, "vera_kwargs": {"init_weights": [False]}, "hra_kwargs": {"init_weights": [False]}, + "bone_kwargs": {"init_weights": [False]}, "task_type": "SEQ_2_SEQ_LM", }, ) @@ -179,6 +180,7 @@ def test_adding_multiple_adapters_with_bias_raises(self, test_name, model_id, co "oft_kwargs": {"init_weights": [False]}, "vera_kwargs": {"init_weights": [False]}, "hra_kwargs": {"init_weights": [False]}, + "bone_kwargs": {"init_weights": [False]}, "task_type": "SEQ_2_SEQ_LM", }, ) @@ -192,6 +194,7 @@ def test_unload_adapter(self, test_name, model_id, config_cls, config_kwargs): "model_ids": PEFT_ENCODER_DECODER_MODELS_TO_TEST, "lora_kwargs": {"init_lora_weights": [False]}, "ia3_kwargs": {"init_ia3_weights": [False]}, + "bone_kwargs": {"init_weights": [False]}, "task_type": "SEQ_2_SEQ_LM", }, ) @@ -214,6 +217,7 @@ def test_training_prompt_learning_tasks(self, test_name, model_id, config_cls, c "oft_kwargs": {"init_weights": [False]}, "vera_kwargs": {"init_weights": [False]}, "hra_kwargs": {"init_weights": [False]}, + "bone_kwargs": {"init_weights": [False]}, "task_type": "SEQ_2_SEQ_LM", }, ) diff --git a/tests/test_feature_extraction_models.py b/tests/test_feature_extraction_models.py index 884c8cf7e0..2bffb935ef 100644 --- a/tests/test_feature_extraction_models.py +++ b/tests/test_feature_extraction_models.py @@ -47,7 +47,7 @@ def skip_deberta_lora_tests(test_list): Skip tests that are checkpointing with lora/ia3/boft/vera/fourierft for Deberta models (couldn't find much info on the error) """ - to_skip = ["lora", "ia3", "boft", "vera", "fourierft", "hra"] + to_skip = ["lora", "ia3", "boft", "vera", "fourierft", "hra", "bone"] return [test for test in test_list if not (any(k in test[0] for k in to_skip) and "Deberta" in test[0])] @@ -117,6 +117,7 @@ def test_from_pretrained_config_construction(self, test_name, model_id, config_c "oft_kwargs": {"init_weights": [False]}, "vera_kwargs": {"init_weights": [False]}, "hra_kwargs": {"init_weights": [False]}, + "bone_kwargs": {"init_weights": [False]}, "task_type": "FEATURE_EXTRACTION", }, ) @@ -171,6 +172,7 @@ def test_delete_inactive_adapter(self, test_name, model_id, config_cls, config_k "oft_kwargs": {"init_weights": [False]}, "vera_kwargs": {"init_weights": [False]}, "hra_kwargs": {"init_weights": [False]}, + "bone_kwargs": {"init_weights": [False]}, "task_type": "FEATURE_EXTRACTION", }, ) @@ -187,6 +189,7 @@ def test_unload_adapter(self, test_name, model_id, config_cls, config_kwargs): "boft_kwargs": {"init_weights": [False]}, "oft_kwargs": {"init_weights": [False]}, "hra_kwargs": {"init_weights": [False]}, + "bone_kwargs": {"init_weights": [False]}, "task_type": "FEATURE_EXTRACTION", }, ) diff --git a/tests/testing_common.py b/tests/testing_common.py index 77dd529862..1c74b2746e 100644 --- a/tests/testing_common.py +++ b/tests/testing_common.py @@ -32,6 +32,7 @@ from peft import ( AdaLoraConfig, BOFTConfig, + BoneConfig, FourierFTConfig, HRAConfig, IA3Config, @@ -119,6 +120,11 @@ { "target_modules": None, }, + # Bone + { + "target_modules": None, + "r": 2, + }, ) CLASSES_MAPPING = { @@ -134,6 +140,7 @@ "hra": (HRAConfig, CONFIG_TESTING_KWARGS[9]), "vblora": (VBLoRAConfig, CONFIG_TESTING_KWARGS[10]), "oft": (OFTConfig, CONFIG_TESTING_KWARGS[11]), + "bone": (BoneConfig, CONFIG_TESTING_KWARGS[12]), } @@ -732,6 +739,7 @@ def _test_merge_layers_multi(self, model_id, config_cls, config_kwargs): PeftType.OFT, PeftType.BOFT, PeftType.HRA, + PeftType.BONE, ] if ("gpt2" in model_id.lower()) and (config_cls == IA3Config): @@ -1207,6 +1215,7 @@ def _test_delete_adapter(self, model_id, config_cls, config_kwargs): PeftType.FOURIERFT, PeftType.HRA, PeftType.VBLORA, + PeftType.BONE, ] # IA3 does not support deleting adapters yet, but it just needs to be added # AdaLora does not support multiple adapters @@ -1255,6 +1264,7 @@ def _test_delete_inactive_adapter(self, model_id, config_cls, config_kwargs): PeftType.FOURIERFT, PeftType.HRA, PeftType.VBLORA, + PeftType.BONE, ] # IA3 does not support deleting adapters yet, but it just needs to be added # AdaLora does not support multiple adapters @@ -1300,7 +1310,18 @@ def _test_unload_adapter(self, model_id, config_cls, config_kwargs): model = get_peft_model(model, config) model = model.to(self.torch_device) - if config.peft_type not in ("LORA", "ADALORA", "IA3", "BOFT", "OFT", "VERA", "FOURIERFT", "HRA", "VBLORA"): + if config.peft_type not in ( + "LORA", + "ADALORA", + "IA3", + "BOFT", + "OFT", + "VERA", + "FOURIERFT", + "HRA", + "VBLORA", + "BONE", + ): with pytest.raises(AttributeError): model = model.unload() else: