diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index f46a4a67b8..0fd81f127c 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -136,6 +136,8 @@ title: RoAd - local: package_reference/waveft title: WaveFT + - local: package_reference/delora + title: DeLoRA title: Adapters - sections: diff --git a/docs/source/package_reference/delora.md b/docs/source/package_reference/delora.md new file mode 100644 index 0000000000..6dd5bf9129 --- /dev/null +++ b/docs/source/package_reference/delora.md @@ -0,0 +1,35 @@ + + +# DeLoRA: Decoupled Low-rank Adaptation + +[DeLoRA](https://huggingface.co/papers/2503.18225) is a parameter-efficient fine-tuning technique that leverages effectively decouples the learning of angles and magnitudes. + +Note: +- use 10-100x larger learning rate than standard LoRA variants +- the boundary parameter lambda sets an upper bound to the Frobenius norm of the weight change. Using different lambdas for different layers is possible + +The abstract from the paper is: + +> Parameter-Efficient FineTuning (PEFT) methods have recently gained significant popularity thanks to the widespread availability of large-scale pretrained models. These methods allow for quick adaptation to downstream tasks with minimal computational cost. However, popular finetuning methods such as LoRA exhibit limited robustness when it comes to hyperparameter choices or extended training regimes, preventing optimal out-of-the-box performance. In contrast, bounded approaches, such as ETHER, provide greater robustness but are limited to extremely low-rank adaptations and fixed-strength transformations, reducing their adaptation expressive power. In this work, we propose Decoupled Low-rank Adaptation (DeLoRA), a novel finetuning method that normalizes and scales learnable low-rank matrices. By bounding the distance of the transformation, DeLoRA effectively decouples the angular learning from the adaptation strength, enhancing robustness without compromising performance. Through evaluations on subject-driven image generation, natural language understanding, and instruction tuning, we show that DeLoRA matches or surpasses performance of competing PEFT methods, while exhibiting stronger robustness. + +## DeloraConfig + +[[autodoc]] tuners.delora.config.DeloraConfig + +## DeloraModel + +[[autodoc]] tuners.delora.model.DeloraModel diff --git a/examples/boft_controlnet/test_controlnet.py b/examples/boft_controlnet/test_controlnet.py index 2080deb0a7..9624b7c341 100644 --- a/examples/boft_controlnet/test_controlnet.py +++ b/examples/boft_controlnet/test_controlnet.py @@ -22,7 +22,6 @@ import numpy as np import torch -import torch.utils.checkpoint from accelerate import Accelerator from diffusers import DDIMScheduler from diffusers.utils import check_min_version diff --git a/method_comparison/MetaMathQA/experiments/delora/llama-3.2-3B-rank32/adapter_config.json b/method_comparison/MetaMathQA/experiments/delora/llama-3.2-3B-rank32/adapter_config.json new file mode 100644 index 0000000000..6217e8f494 --- /dev/null +++ b/method_comparison/MetaMathQA/experiments/delora/llama-3.2-3B-rank32/adapter_config.json @@ -0,0 +1,20 @@ +{ + "lambda_pattern": {}, + "auto_mapping": null, + "base_model_name_or_path": null, + "bias": "none", + "exclude_modules": null, + "inference_mode": false, + "init_weights": true, + "layers_pattern": null, + "layers_to_transform": null, + "delora_lambda": 15, + "module_dropout": 0.0, + "modules_to_save": null, + "peft_type": "DELORA", + "r": 32, + "rank_pattern": {}, + "revision": null, + "target_modules": null, + "task_type": "CAUSAL_LM" +} diff --git a/method_comparison/MetaMathQA/experiments/delora/llama-3.2-3B-rank32/training_params.json b/method_comparison/MetaMathQA/experiments/delora/llama-3.2-3B-rank32/training_params.json new file mode 100644 index 0000000000..8a120ad9a8 --- /dev/null +++ b/method_comparison/MetaMathQA/experiments/delora/llama-3.2-3B-rank32/training_params.json @@ -0,0 +1,6 @@ +{ + "optimizer_kwargs": { + "lr": 1e-3 + } +} + diff --git a/src/peft/__init__.py b/src/peft/__init__.py index af26f8309b..af7254d88a 100644 --- a/src/peft/__init__.py +++ b/src/peft/__init__.py @@ -59,6 +59,8 @@ C3AModel, CPTConfig, CPTEmbedding, + DeloraConfig, + DeloraModel, EvaConfig, FourierFTConfig, FourierFTModel, @@ -154,6 +156,8 @@ "C3AModel", "CPTConfig", "CPTEmbedding", + "DeloraConfig", + "DeloraModel", "EvaConfig", "FourierFTConfig", "FourierFTModel", diff --git a/src/peft/tuners/__init__.py b/src/peft/tuners/__init__.py index c193f89d49..ab56ae1d85 100644 --- a/src/peft/tuners/__init__.py +++ b/src/peft/tuners/__init__.py @@ -18,6 +18,7 @@ from .bone import BoneConfig, BoneModel from .c3a import C3AConfig, C3AModel from .cpt import CPTConfig, CPTEmbedding +from .delora import DeloraConfig, DeloraModel from .fourierft import FourierFTConfig, FourierFTModel from .hra import HRAConfig, HRAModel from .ia3 import IA3Config, IA3Model @@ -67,6 +68,8 @@ "C3AModel", "CPTConfig", "CPTEmbedding", + "DeloraConfig", + "DeloraModel", "EvaConfig", "FourierFTConfig", "FourierFTModel", diff --git a/src/peft/tuners/delora/__init__.py b/src/peft/tuners/delora/__init__.py new file mode 100644 index 0000000000..982801ca56 --- /dev/null +++ b/src/peft/tuners/delora/__init__.py @@ -0,0 +1,23 @@ +# Copyright 2025-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from peft.utils import register_peft_method + +from .config import DeloraConfig +from .layer import DeloraLayer, DeloraLinear +from .model import DeloraModel + + +__all__ = ["DeloraConfig", "DeloraLayer", "DeloraLinear", "DeloraModel"] + +register_peft_method(name="delora", model_cls=DeloraModel, config_cls=DeloraConfig) diff --git a/src/peft/tuners/delora/config.py b/src/peft/tuners/delora/config.py new file mode 100644 index 0000000000..0a28cc94be --- /dev/null +++ b/src/peft/tuners/delora/config.py @@ -0,0 +1,154 @@ +# Copyright 2025-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Optional, Union + +from peft.config import PeftConfig +from peft.utils import PeftType + + +@dataclass +class DeloraConfig(PeftConfig): + """ + This is the configuration class to store the configuration of a [`DeloraModel`]. + + Args: + r (`int`): + The rank of the DeLoRA adapter. + delora_lambda (`int`): + The initial value of the boundary of the DeLoRA adapter. This variable sets an upper bound to the Frobenius + norm of the weight change, avoiding the finetuned model to deviate too much from the original model. + module_dropout (`float`): + The dropout probability for disabling DeLoRA modules during training. + target_modules (`Optional[Union[List[str], str]]`): + The names of the modules to apply the adapter to. If this is specified, only the modules with the specified + names will be replaced. When passing a string, a regex match will be performed. When passing a list of + strings, either an exact match will be performed or it is checked if the name of the module ends with any + of the passed strings. If this is specified as 'all-linear', then all linear/Conv1D modules are chosen, + excluding the output layer. If this is not specified, modules will be chosen according to the model + architecture. If the architecture is not known, an error will be raised -- in this case, you should specify + the target modules manually. + exclude_modules (`Optional[Union[List[str], str]]`): + The names of the modules to not apply the adapter. When passing a string, a regex match will be performed. + When passing a list of strings, either an exact match will be performed or it is checked if the name of the + module ends with any of the passed strings. + bias (`str`): + Bias type for DeLoRA. Can be 'none', 'all' or 'delora_only'. If 'all' or 'delora_only', the corresponding + biases will be updated during training. Be aware that this means that, even when disabling the adapters, + the model will not produce the same output as the base model would have without adaptation. + init_weights (`bool`): + Whether to perform initialization of adapter weights. If `True` (default): A is initialized with kaiming + uniform initialization, while B is initialized with zeros. If `False`: A and B are both initialized with + kaiming uniform, immediately contributing a non-zero delta. This is generally discouraged for normal use. + layers_to_transform (`Union[List[int], int]`): + The layer indices to transform. If a list of ints is passed, it will apply the adapter to the layer indices + that are specified in this list. If a single integer is passed, it will apply the transformations on the + layer at this index. + layers_pattern (`Optional[Union[List[str], str]]`): + The layer pattern name, used only if `layers_to_transform` is different from `None`. This should target the + `nn.ModuleList` of the model, which is often called `'layers'` or `'h'`. + rank_pattern (`dict`): + The mapping from layer names or regexp expression to ranks which are different from the default rank + specified by `r`. For example, `{'^model.decoder.layers.0.encoder_attn.k_proj': 16}`. + lambda_pattern (`dict`): + The mapping from layer names or regexp expression to lambdas which are different from the default lambda + specified by `delora_lambda`. For example, `{'^model.decoder.layers.0.encoder_attn.k_proj': 16}`. + modules_to_save (`Optional[List[str]]`): + List of modules apart from adapter layers to be set as trainable and saved in the final checkpoint. + """ + + r: int = field(default=8, metadata={"help": "DeLoRA rank"}) + delora_lambda: int = field( + default=15, + metadata={ + "help": "The initial value of the boundary of the DeLoRA adapter. This variable sets an upper bound to the " + "Frobenius norm of the weight change, avoiding the finetuned model to deviate too much from the original model." + }, + ) + module_dropout: float = field( + default=0.0, metadata={"help": "The dropout probability for disabling DeLoRA modules during training"} + ) + target_modules: Optional[Union[list[str], str]] = field( + default=None, + metadata={ + "help": "List of module names or regex expression of the module names to replace with DeLoRA." + "For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$' " + "This can also be a wildcard 'all-linear' which matches all linear layers except the output layer." + }, + ) + exclude_modules: Optional[Union[list[str], str]] = field( + default=None, + metadata={"help": "List of module names or regex expression of the module names to exclude from DeLoRA."}, + ) + bias: str = field(default="none", metadata={"help": "Bias type for DeLoRA. Can be 'none' or 'all'"}) + init_weights: bool = field( + default=True, + metadata={ + "help": "Whether to perform initialization of adapter weights. If `True` (default): A is initialized with kaiming uniform " + "initialization, while B is initialized with zeros. If `False`: A and B are both initialized with kaiming uniform, " + "immediately contributing a non-zero delta. This is generally discouraged for normal use." + }, + ) + layers_to_transform: Optional[Union[list[int], int]] = field( + default=None, + metadata={ + "help": "The layer indexes to transform, is this argument is specified, PEFT will transform only the layers indexes that " + "are specified inside this list. If a single integer is passed, PEFT will transform only the layer at this index." + }, + ) + layers_pattern: Optional[Union[list[str], str]] = field( + default=None, + metadata={ + "help": "The layer pattern name, used only if `layers_to_transform` is different to None and if the layer pattern is not in the " + "common layers pattern. This should target the `nn.ModuleList` of the model, which is often called `'layers'` or `'h'`." + }, + ) + rank_pattern: Optional[dict] = field( + default_factory=dict, + metadata={ + "help": "The mapping from layer names or regexp expression to ranks which are different from the default rank specified " + "by `r`. For example, `{'^model.decoder.layers.0.encoder_attn.k_proj': 16}`." + }, + ) + lambda_pattern: Optional[dict] = field( + default_factory=dict, + metadata={ + "help": "The mapping from layer names or regexp expression to lambdas which are different from the default lambda specified by `delora_lambda`." + }, + ) + modules_to_save: Optional[list[str]] = field( + default=None, + metadata={ + "help": "List of modules apart from DeLoRA layers to be set as trainable and saved in the final checkpoint. " + "For example, in Sequence Classification or Token Classification tasks, the final layer `classifier/score` " + "are randomly initialized and as such need to be trainable and saved." + }, + ) + + def __post_init__(self): + super().__post_init__() + # PeftType enum members are uppercase; use DELORA + self.peft_type = PeftType.DELORA + self.target_modules = ( + set(self.target_modules) if isinstance(self.target_modules, list) else self.target_modules + ) + # if target_modules is a regex expression, then layers_to_transform should be None + if isinstance(self.target_modules, str) and self.layers_to_transform is not None: + raise ValueError("`layers_to_transform` cannot be used when `target_modules` is a str.") + + # check for layers_to_transform and layers_pattern + if self.layers_pattern and not self.layers_to_transform: + raise ValueError("When `layers_pattern` is specified, `layers_to_transform` must also be specified. ") diff --git a/src/peft/tuners/delora/layer.py b/src/peft/tuners/delora/layer.py new file mode 100644 index 0000000000..6a225d4e1a --- /dev/null +++ b/src/peft/tuners/delora/layer.py @@ -0,0 +1,269 @@ +# Copyright 2025-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import math +import warnings +from typing import Any, Optional + +import torch +import torch.nn as nn + +from peft.tuners._buffer_dict import BufferDict +from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge + + +class DeloraLayer(BaseTunerLayer): + # All names of layers that may contain (trainable) adapter weights + adapter_layer_names = ( + "delora_A", + "delora_B", + "delora_lambda", + ) + # All names of other parameters that may contain adapter-related parameters + other_param_names = ( + "r", + "module_dropout", + "delora_w_norm", + ) + + def __init__(self, base_layer: nn.Module, **kwargs) -> None: + self.base_layer = base_layer + self.r = {} + self.module_dropout = nn.ModuleDict({}) + self.delora_A = nn.ParameterDict({}) + self.delora_B = nn.ParameterDict({}) + self.delora_lambda = nn.ParameterDict({}) + # Use persistent buffers so they are included in state_dict and saved. + self.delora_w_norm = BufferDict({}, persistent=True) + # Mark the weight as unmerged + self._disable_adapters = False + self.merged_adapters = [] + self.kwargs = kwargs + + base_layer_mod = self.get_base_layer() + if isinstance(base_layer_mod, nn.Linear): + self.in_features, self.out_features = base_layer_mod.in_features, base_layer_mod.out_features + else: + raise ValueError(f"Unsupported layer type {type(base_layer_mod)}") + + @staticmethod + def _compute_delta( + A: torch.Tensor, B: torch.Tensor, delora_lambda: torch.Tensor, r: int, w_norm: torch.Tensor + ) -> torch.Tensor: + """Compute delta = B @ diag(delora_lambda/r / (||A_i||*||B^j||)) @ A, scaled by provided w_norm (per-input channel)""" + An = torch.clamp(A.norm(dim=1), min=1e-4) + Bn = torch.clamp(B.norm(dim=0), min=1e-4) + diag = torch.diag_embed(delora_lambda / r / (An * Bn)) + delta = B @ diag @ A + delta = delta * w_norm.unsqueeze(0) + return delta + + def get_delta_weight(self, adapter: str) -> torch.Tensor: + if adapter not in self.delora_A or adapter not in self.delora_B: + raise ValueError(f"Adapter {adapter} not found.") + + delta = self._compute_delta( + self.delora_A[adapter], + self.delora_B[adapter], + self.delora_lambda[adapter], + self.r[adapter], + self.delora_w_norm[adapter], + ) + return delta + + def update_layer( + self, + adapter_name: str, + r: int, + delora_lambda: float, + module_dropout: float, + init_weights: bool = True, + inference_mode: bool = False, + **kwargs: Any, + ) -> None: + """Internal function to create delora adapter + + Args: + adapter_name (`str`): Name for the adapter to add. + r (`int`): Rank for the added adapter. + delora_lambda (`float`): Boundary for the adapter's norm. + module_dropout (`float`): The dropout probability for disabling adapter during training. + init_weights (`bool`): Whether to initialize weights. + """ + if r <= 0: + raise ValueError(f"`r` should be a positive integer value but the value passed is {r}") + + self.r[adapter_name] = r + self.delora_A[adapter_name] = nn.Parameter(torch.empty(r, self.in_features)) + self.delora_B[adapter_name] = nn.Parameter(torch.empty(self.out_features, r)) + self.delora_lambda[adapter_name] = nn.Parameter(torch.empty(1)) + if module_dropout > 0.0: + module_dropout_layer = nn.Dropout(p=module_dropout) + else: + module_dropout_layer = nn.Identity() + self.module_dropout.update(nn.ModuleDict({adapter_name: module_dropout_layer})) + + # Initialize weights + self.reset_delora_parameters(adapter_name, init_weights, delora_lambda) + + # Move new weights to device + self._move_adapter_to_device_of_base_layer(adapter_name) + self.set_adapter(self.active_adapters, inference_mode=inference_mode) + + def reset_delora_parameters( + self, + adapter_name: str, + init_weights: bool = True, + delora_lambda: float = 15.0, + ) -> None: + if adapter_name not in self.delora_A.keys(): + return + + if init_weights is True: + nn.init.kaiming_uniform_(self.delora_A[adapter_name], a=math.sqrt(5)) + nn.init.zeros_(self.delora_B[adapter_name]) + else: + nn.init.kaiming_uniform_(self.delora_A[adapter_name], a=math.sqrt(5)) + nn.init.kaiming_uniform_(self.delora_B[adapter_name], a=math.sqrt(5)) + + self.delora_lambda[adapter_name].data.fill_(float(delora_lambda)) + + # capture a fixed norm for this adapter to use for future delta computations + with torch.no_grad(): + w = self.get_base_layer().weight + if w.device.type != "meta": + w_norm = torch.norm(w.data, dim=0).detach() + else: + # For meta tensors, we can't compute the norm, so use a default value + w_norm = torch.ones(w.shape[1], device=w.device) + self.delora_w_norm[adapter_name] = w_norm + + +class DeloraLinear(nn.Module, DeloraLayer): + # DeLoRA implemented in a dense layer + def __init__( + self, + base_layer, + adapter_name: str, + r: int, + delora_lambda: float, + module_dropout: float, + init_weights: bool = True, + **kwargs, + ) -> None: + super().__init__() + DeloraLayer.__init__(self, base_layer, **kwargs) + self._active_adapter = adapter_name + self.update_layer(adapter_name, r, delora_lambda, module_dropout, init_weights) + + def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: + """ + Merge the active adapter weights into the base weights + + Args: + safe_merge (`bool`, *optional*): + If True, the merge operation will be performed in a copy of the original weights and check for NaNs + before merging the weights. This is useful if you want to check if the merge operation will produce + NaNs. Defaults to `False`. + adapter_names (`list[str]`, *optional*): + The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults + to `None`. + """ + adapter_names = check_adapters_to_merge(self, adapter_names) + if not adapter_names: + return + + for active_adapter in adapter_names: + if active_adapter in self.delora_A.keys(): + base_layer = self.get_base_layer() + delta_weight = ( + self.get_delta_weight(active_adapter) + .detach() + .to(dtype=base_layer.weight.dtype, device=base_layer.weight.device) + ) + with torch.no_grad(): + if safe_merge: + orig_weights = base_layer.weight.data.clone() + orig_weights = orig_weights + delta_weight + + if not torch.isfinite(orig_weights).all(): + raise ValueError( + f"NaNs detected in merged weights for adapter {active_adapter}; aborting merge" + ) + + base_layer.weight.data = orig_weights + else: + base_layer.weight.data.add_(delta_weight) + + self.merged_adapters.append(active_adapter) + + def unmerge(self) -> None: + """ + Unmerge all merged adapter layers from the base weights. + """ + if not self.merged: + warnings.warn("Already unmerged. Nothing to do.") + return + while len(self.merged_adapters) > 0: + active_adapter = self.merged_adapters.pop() + if active_adapter in self.delora_A.keys(): + self.get_base_layer().weight.data -= self.get_delta_weight(active_adapter) + + def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: + previous_dtype = x.dtype + + if self.disable_adapters: + if self.merged: + self.unmerge() + result = self.base_layer(x, *args, **kwargs) + elif self.merged: + result = self.base_layer(x, *args, **kwargs) + else: + if not self.active_adapters: + return self.base_layer(x, *args, **kwargs).to(previous_dtype) + + base_out = self.base_layer(x, *args, **kwargs) + add_out = torch.zeros_like(base_out) + + for adapter in self.active_adapters: + if adapter not in self.delora_A: + continue + + x_d = self.module_dropout[adapter](x) + + # Decomposed delta calculation + # 1. (x * w_norm) @ A.T + h = nn.functional.linear(x_d * self.delora_w_norm[adapter], self.delora_A[adapter]) + + # 2. h @ diag + An = torch.clamp(self.delora_A[adapter].norm(dim=1), min=1e-4) + Bn = torch.clamp(self.delora_B[adapter].norm(dim=0), min=1e-4) + scaling = (self.delora_lambda[adapter] / self.r[adapter]) / (An * Bn) + + h = h * scaling + + # 3. h @ B.T + h = nn.functional.linear(h, self.delora_B[adapter]) + + add_out += h + + result = base_out + add_out.to(base_out.dtype) + + result = result.to(previous_dtype) + return result + + def __repr__(self) -> str: + rep = super().__repr__() + return "delora." + rep diff --git a/src/peft/tuners/delora/model.py b/src/peft/tuners/delora/model.py new file mode 100644 index 0000000000..04492c351c --- /dev/null +++ b/src/peft/tuners/delora/model.py @@ -0,0 +1,105 @@ +# Copyright 2025-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import torch + +from peft.tuners.tuners_utils import BaseTuner, BaseTunerLayer +from peft.utils import ( + TRANSFORMERS_MODELS_TO_DELORA_TARGET_MODULES_MAPPING, +) +from peft.utils.other import get_pattern_key + +from .config import DeloraConfig +from .layer import DeloraLayer, DeloraLinear + + +class DeloraModel(BaseTuner): + """ + Creates DeLoRA model from a pretrained transformers model. + + The method is described in detail in [TODO]. + + Args: + model ([`torch.nn.Module`]): The model to be adapted. + config ([`DeloraConfig`]): The configuration of the DeLoRA model. + adapter_name (`str`): The name of the adapter, defaults to `"default"`. + + Returns: + `torch.nn.Module`: The DeLoRA model. + + **Attributes**: + - **model** ([`~transformers.PreTrainedModel`]) -- The model to be adapted. + - **peft_config** ([`DeloraConfig`]): The configuration of the DeLoRA model. + """ + + prefix: str = "delora_" + tuner_layer_cls = DeloraLayer + target_module_mapping = TRANSFORMERS_MODELS_TO_DELORA_TARGET_MODULES_MAPPING + + def _check_new_adapter_config(self, config: DeloraConfig) -> None: + """ + A helper method to check the config when a new adapter is being added. + + Raise a ValueError if there is something wrong with the config or if it conflicts with existing adapters. + + """ + super()._check_new_adapter_config(config) + + def _create_and_replace( + self, + delora_config, + adapter_name, + target, + target_name, + parent, + current_key, + **optional_kwargs, + ): + if current_key is None: + raise ValueError("Current Key shouldn't be `None`") + + # Regexp matching - Find key which matches current target_name in patterns provided + r_key = get_pattern_key(delora_config.rank_pattern.keys(), current_key) + lambda_key = get_pattern_key(delora_config.lambda_pattern.keys(), current_key) + r = delora_config.rank_pattern.get(r_key, delora_config.r) + delora_lambda = delora_config.lambda_pattern.get(lambda_key, delora_config.delora_lambda) + + kwargs = { + "r": r, + "delora_lambda": delora_lambda, + "module_dropout": delora_config.module_dropout, + "init_weights": delora_config.init_weights, + } + + if isinstance(target, DeloraLinear): + target.update_layer(adapter_name, **kwargs) + else: + new_module = self._create_new_module(delora_config, adapter_name, target, **kwargs) + if adapter_name != self.active_adapter: + # adding an additional adapter: it is not automatically trainable + new_module.requires_grad_(False) + self._replace_module(parent, target_name, new_module, target) + + @staticmethod + def _create_new_module(delora_config, adapter_name, target, **kwargs): + if isinstance(target, BaseTunerLayer): + target_base_layer = target.get_base_layer() + else: + target_base_layer = target + + if isinstance(target_base_layer, torch.nn.Linear): + new_module = DeloraLinear(target, adapter_name, **kwargs) + + return new_module diff --git a/src/peft/utils/__init__.py b/src/peft/utils/__init__.py index 781495465e..271f0b21b1 100644 --- a/src/peft/utils/__init__.py +++ b/src/peft/utils/__init__.py @@ -22,6 +22,7 @@ TRANSFORMERS_MODELS_TO_BOFT_TARGET_MODULES_MAPPING, TRANSFORMERS_MODELS_TO_BONE_TARGET_MODULES_MAPPING, TRANSFORMERS_MODELS_TO_C3A_TARGET_MODULES_MAPPING, + TRANSFORMERS_MODELS_TO_DELORA_TARGET_MODULES_MAPPING, TRANSFORMERS_MODELS_TO_FOURIERFT_TARGET_MODULES_MAPPING, TRANSFORMERS_MODELS_TO_HRA_TARGET_MODULES_MAPPING, TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING, @@ -77,6 +78,7 @@ "TRANSFORMERS_MODELS_TO_BOFT_TARGET_MODULES_MAPPING", "TRANSFORMERS_MODELS_TO_BONE_TARGET_MODULES_MAPPING", "TRANSFORMERS_MODELS_TO_C3A_TARGET_MODULES_MAPPING", + "TRANSFORMERS_MODELS_TO_DELORA_TARGET_MODULES_MAPPING", "TRANSFORMERS_MODELS_TO_FOURIERFT_TARGET_MODULES_MAPPING", "TRANSFORMERS_MODELS_TO_HRA_TARGET_MODULES_MAPPING", "TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING", diff --git a/src/peft/utils/constants.py b/src/peft/utils/constants.py index 79c7d92b00..ff66556e0c 100644 --- a/src/peft/utils/constants.py +++ b/src/peft/utils/constants.py @@ -106,6 +106,7 @@ def starcoder_model_postprocess_past_key_value(past_key_values): TRANSFORMERS_MODELS_TO_BOFT_TARGET_MODULES_MAPPING = TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING.copy() TRANSFORMERS_MODELS_TO_BONE_TARGET_MODULES_MAPPING = TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING.copy() TRANSFORMERS_MODELS_TO_C3A_TARGET_MODULES_MAPPING = TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING.copy() +TRANSFORMERS_MODELS_TO_DELORA_TARGET_MODULES_MAPPING = TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING.copy() TRANSFORMERS_MODELS_TO_HRA_TARGET_MODULES_MAPPING = TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING.copy() TRANSFORMERS_MODELS_TO_LOHA_TARGET_MODULES_MAPPING = TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING.copy() TRANSFORMERS_MODELS_TO_LOKR_TARGET_MODULES_MAPPING = TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING.copy() diff --git a/src/peft/utils/other.py b/src/peft/utils/other.py index 6f8437152e..09923216b4 100644 --- a/src/peft/utils/other.py +++ b/src/peft/utils/other.py @@ -46,6 +46,7 @@ TRANSFORMERS_MODELS_TO_BOFT_TARGET_MODULES_MAPPING, TRANSFORMERS_MODELS_TO_BONE_TARGET_MODULES_MAPPING, TRANSFORMERS_MODELS_TO_C3A_TARGET_MODULES_MAPPING, + TRANSFORMERS_MODELS_TO_DELORA_TARGET_MODULES_MAPPING, TRANSFORMERS_MODELS_TO_FOURIERFT_TARGET_MODULES_MAPPING, TRANSFORMERS_MODELS_TO_HRA_TARGET_MODULES_MAPPING, TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING, @@ -86,6 +87,7 @@ "TRANSFORMERS_MODELS_TO_BOFT_TARGET_MODULES_MAPPING", "TRANSFORMERS_MODELS_TO_BONE_TARGET_MODULES_MAPPING", "TRANSFORMERS_MODELS_TO_C3A_TARGET_MODULES_MAPPING", + "TRANSFORMERS_MODELS_TO_DELORA_TARGET_MODULES_MAPPING", "TRANSFORMERS_MODELS_TO_FOURIERFT_TARGET_MODULES_MAPPING", "TRANSFORMERS_MODELS_TO_HRA_TARGET_MODULES_MAPPING", "TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING", diff --git a/src/peft/utils/peft_types.py b/src/peft/utils/peft_types.py index 8815aa4684..4700d7fb78 100644 --- a/src/peft/utils/peft_types.py +++ b/src/peft/utils/peft_types.py @@ -46,6 +46,7 @@ class PeftType(str, enum.Enum): - C3A - ROAD - WAVEFT + - DELORA """ PROMPT_TUNING = "PROMPT_TUNING" @@ -76,6 +77,7 @@ class PeftType(str, enum.Enum): SHIRA = "SHIRA" C3A = "C3A" WAVEFT = "WAVEFT" + DELORA = "DELORA" class TaskType(str, enum.Enum): diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py index 5116919978..e03b18f976 100644 --- a/tests/test_custom_models.py +++ b/tests/test_custom_models.py @@ -36,6 +36,7 @@ BOFTConfig, BoneConfig, C3AConfig, + DeloraConfig, FourierFTConfig, HRAConfig, IA3Config, @@ -848,6 +849,19 @@ WaveFTConfig, {"target_modules": "lin0", "n_frequency": 16, "wavelet_family": "db1", "proportional_parameters": True}, ), + ########## + # DeLoRA # + ########## + ("Vanilla MLP 1 DeLoRA", "MLP", DeloraConfig, {"target_modules": "lin0"}), + ("Vanilla MLP 2 DeLoRA", "MLP", DeloraConfig, {"target_modules": ["lin0"]}), + ("Vanilla MLP 3 DeLoRA", "MLP", DeloraConfig, {"target_modules": ["lin1"]}), + ("Vanilla MLP 4 DeLoRA", "MLP", DeloraConfig, {"target_modules": ["lin0", "lin1"]}), + ( + "Vanilla MLP 5 DeLoRA", + "MLP", + DeloraConfig, + {"target_modules": ["lin0"], "module_dropout": 0.1}, + ), ] ALL_PEFT_CONFIG_CLASSES = sorted({row[2] for row in TEST_CASES}, key=lambda cls: cls.__name__) @@ -1118,6 +1132,20 @@ {"target_modules": ["lin0"], "init_weights": False, "n_frequency": 8}, {"target_modules": ["lin1"], "init_weights": False, "n_frequency": 8}, ), + ( + "DeLoRA Same", + "delora", + DeloraConfig, + {"target_modules": ["lin0"], "init_weights": False}, + {"target_modules": ["lin0"], "init_weights": False}, + ), + ( + "DeLoRA Different", + "delora", + DeloraConfig, + {"target_modules": ["lin0"], "init_weights": False}, + {"target_modules": ["lin1"], "init_weights": False}, + ), ] PREFIXES = { @@ -1138,6 +1166,7 @@ BoneConfig: "bone_", RoadConfig: "road_", MissConfig: "miss_", + DeloraConfig: "delora_", TrainableTokensConfig: "trainable_tokens_", WaveFTConfig: "waveft_", } diff --git a/tests/test_decoder_models.py b/tests/test_decoder_models.py index c0e3710194..a6069af89e 100644 --- a/tests/test_decoder_models.py +++ b/tests/test_decoder_models.py @@ -32,6 +32,7 @@ BoneConfig, C3AConfig, CPTConfig, + DeloraConfig, FourierFTConfig, HRAConfig, IA3Config, @@ -119,6 +120,14 @@ "cpt_tokens_type_mask": [1, 2, 2, 2, 3, 3, 4, 4], }, ), + ( + DeloraConfig, + { + "task_type": "CAUSAL_LM", + "target_modules": None, + "r": 2, + }, + ), ( FourierFTConfig, { @@ -290,8 +299,9 @@ def _skip_if_not_conv1d_supported(model_id, config_cls): ShiraConfig, C3AConfig, MissConfig, + DeloraConfig, ]: - pytest.skip("Skipping BOFT/HRA/OFT/Bone/Road/SHiRA/C3A/MiSS for GPT2LMHeadModel") + pytest.skip("Skipping BOFT/HRA/OFT/Bone/Road/SHiRA/C3A/MiSS/DeLoRA for GPT2LMHeadModel") def _skip_adalora_oft_hra_bone_for_gpt2(model_id, config_cls): @@ -304,8 +314,9 @@ def _skip_adalora_oft_hra_bone_for_gpt2(model_id, config_cls): C3AConfig, RoadConfig, MissConfig, + DeloraConfig, ]: - pytest.skip("Skipping AdaLora/BOFT/HRA/OFT/Bone/MiSS for GPT2LMHeadModel") + pytest.skip("Skipping AdaLora/BOFT/HRA/OFT/Bone/MiSS/DeLoRA for GPT2LMHeadModel") def _skip_alora_no_activation(config_cls, config_kwargs): diff --git a/tests/test_encoder_decoder_models.py b/tests/test_encoder_decoder_models.py index d940d0f9a1..1ec0aa0668 100644 --- a/tests/test_encoder_decoder_models.py +++ b/tests/test_encoder_decoder_models.py @@ -22,6 +22,7 @@ BOFTConfig, BoneConfig, C3AConfig, + DeloraConfig, FourierFTConfig, HRAConfig, IA3Config, @@ -82,6 +83,14 @@ "task_type": "SEQ_2_SEQ_LM", }, ), + ( + DeloraConfig, + { + "task_type": "SEQ_2_SEQ_LM", + "target_modules": None, + "r": 2, + }, + ), ( FourierFTConfig, { diff --git a/tests/test_feature_extraction_models.py b/tests/test_feature_extraction_models.py index d7dd604c97..a5377827f4 100644 --- a/tests/test_feature_extraction_models.py +++ b/tests/test_feature_extraction_models.py @@ -20,6 +20,7 @@ BOFTConfig, BoneConfig, C3AConfig, + DeloraConfig, FourierFTConfig, HRAConfig, IA3Config, @@ -81,6 +82,14 @@ "r": 2, }, ), + ( + DeloraConfig, + { + "task_type": "FEATURE_EXTRACTION", + "target_modules": None, + "r": 2, + }, + ), ( FourierFTConfig, { diff --git a/tests/test_initialization.py b/tests/test_initialization.py index 8937f4b0c1..5029b56a27 100644 --- a/tests/test_initialization.py +++ b/tests/test_initialization.py @@ -36,6 +36,7 @@ from peft import ( AdaLoraConfig, C3AConfig, + DeloraConfig, EvaConfig, IA3Config, LoftQConfig, @@ -2093,6 +2094,68 @@ def test_road_with_conv2d_layer(self): get_peft_model(model, config) +class TestDeLoRAInitialization: + """Basic sanity tests for the DeLoRA tuner.""" + + torch_device = infer_device() + + def get_model(self, bias=True): + class MLP(nn.Module): + def __init__(self, bias=True): + super().__init__() + self.lin0 = nn.Linear(10, 30, bias=bias) + self.lin1 = nn.Linear(30, 2, bias=bias) + + def forward(self, X): + X = self.lin0(X) + X = self.lin1(X) + return X + + return MLP(bias=bias).to(self.torch_device).eval() + + @pytest.fixture + def data(self): + torch.manual_seed(0) + return torch.randn(4, 10, device=self.torch_device) + + def test_delora_injection_keeps_output_default(self, data): + # With init_weights=True (default), initial forward should match base model + torch.manual_seed(0) + base = self.get_model() + y_base = base(data) + + cfg = DeloraConfig(target_modules=["lin0"], r=8, delora_lambda=15, init_weights=True) + model = get_peft_model(base, cfg) + y_peft = model(data) + + assert torch.allclose(y_base, y_peft, atol=1e-6, rtol=1e-6) + + def test_delora_param_shapes(self): + base = self.get_model() + in_f, out_f = base.lin0.in_features, base.lin0.out_features + r = 4 + cfg = DeloraConfig(target_modules=["lin0"], r=r, delora_lambda=15, init_weights=True) + model = get_peft_model(base, cfg) + + layer = model.lin0 # DeloraLinear wrapper + assert hasattr(layer, "delora_A") and hasattr(layer, "delora_B") and hasattr(layer, "delora_lambda") + A = layer.delora_A["default"] + B = layer.delora_B["default"] + delora_lambda = layer.delora_lambda["default"] + assert tuple(A.shape) == (r, in_f) + assert tuple(B.shape) == (out_f, r) + assert tuple(delora_lambda.shape) == (1,) + + def test_init_weights_false_shifts_output(self, data): + # With init_weights=False, there should be an initial delta to the base model output + base = self.get_model() + y_base = base(data) + cfg = DeloraConfig(target_modules=["lin0"], r=8, delora_lambda=15, init_weights=False) + model = get_peft_model(base, cfg) + y_peft = model(data) + assert not torch.allclose(y_base, y_peft, atol=1e-6, rtol=1e-6) + + class TestNoInfiniteRecursionDeepspeed: # see #1892 for details classes = [ diff --git a/tests/test_seq_classifier.py b/tests/test_seq_classifier.py index eb0a3d38a4..03869c3a7a 100644 --- a/tests/test_seq_classifier.py +++ b/tests/test_seq_classifier.py @@ -20,6 +20,7 @@ BOFTConfig, BoneConfig, C3AConfig, + DeloraConfig, FourierFTConfig, HRAConfig, IA3Config, @@ -82,6 +83,14 @@ "r": 2, }, ), + ( + DeloraConfig, + { + "task_type": "SEQ_CLS", + "target_modules": None, + "r": 2, + }, + ), ( FourierFTConfig, { diff --git a/tests/testing_common.py b/tests/testing_common.py index 9c49119bf2..d520d8cb5c 100644 --- a/tests/testing_common.py +++ b/tests/testing_common.py @@ -35,6 +35,7 @@ BOFTConfig, BoneConfig, CPTConfig, + DeloraConfig, FourierFTConfig, HRAConfig, IA3Config, @@ -168,6 +169,12 @@ "cpt_mask": [1, 1, 1, 1, 1, 1, 1, 1], "cpt_tokens_type_mask": [1, 2, 2, 2, 3, 3, 4, 4], }, + # DeLoRA + { + "r": 8, + "target_modules": None, + "bias": "none", + }, ) CLASSES_MAPPING = { @@ -187,6 +194,7 @@ "miss": (MissConfig, CONFIG_TESTING_KWARGS[12]), "lora+trainable_tokens": (LoraConfig, CONFIG_TESTING_KWARGS[13]), "randlora": (RandLoraConfig, CONFIG_TESTING_KWARGS[14]), + "delora": (DeloraConfig, CONFIG_TESTING_KWARGS[17]), } DECODER_MODELS_EXTRA = {"cpt": (CPTConfig, CONFIG_TESTING_KWARGS[15])}