Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/fp4_finetuning/finetune_fp4_opt_bnb_peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
max_memory = f"{free_in_GB - 2}GB"

n_gpus = torch.cuda.device_count()
max_memory = {i: max_memory for i in range(n_gpus)}
max_memory = dict.fromkeys(range(n_gpus), max_memory)

model = AutoModelForCausalLM.from_pretrained(
"facebook/opt-350m",
Expand Down
6 changes: 3 additions & 3 deletions src/peft/peft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,9 +721,9 @@ def get_prompt(self, batch_size: int, task_ids: Optional[torch.Tensor] = None) -
# If we don't apply this, prefix-tuning fails to update cross-attn cache
past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values)
past_key_values.cross_attention_cache = DynamicCache()
past_key_values.is_updated = {
layer_idx: False for layer_idx in range(len(past_key_values.cross_attention_cache.key_cache))
}
past_key_values.is_updated = dict.fromkeys(
range(len(past_key_values.cross_attention_cache.key_cache)), False
)
map_cache_to_layer_device_map(self.get_base_model(), past_key_values) # no-op if not a Cache instance
return past_key_values
else:
Expand Down
34 changes: 34 additions & 0 deletions src/peft/tuners/lora/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,14 @@ class LoraConfig(PeftConfig):
ranks. Right now, DoRA only supports linear and Conv2D layers. DoRA introduces a bigger overhead than pure
LoRA, so it is recommended to merge weights for inference. For more information, see
https://arxiv.org/abs/2402.09353.
use_sinelora (`bool`):
Enable 'Sine Activated Low-Rank Adaptation' (Sine-LoRA). This technique introduce to apply sine activation
on the low-rank adaptor. This can be beneficial for rank boosting for low-rank matrices and enhancing its
capacity. For more information, see https://arxiv.org/pdf/2403.19243.
sinelora_frequency (`float`):
The frequency factor for the sine activation. If not specified, it will be set to the default value of 200.
sinelora_scaling (`float`):
The scaling factor for the sine activation. If not specified, it will be set to the default value of sqrt(in_features).
Comment on lines +307 to +308
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this value is optional, it should be marked as type Optional[float]

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change type of sinelora_scaling here to Optional[float] as it is defined in code.

layer_replication (`List[Tuple[int, int]]`):
Build a new stack of layers by stacking the original model layers according to the ranges specified. This
allows expanding (or shrinking) the model without duplicating the base model weights. The new layers will
Expand Down Expand Up @@ -493,6 +501,32 @@ class LoraConfig(PeftConfig):
)
},
)
use_sinelora: bool = field(
default=False,
metadata={
"help": (
"Enable 'Sine Activated Low-Rank Adaptation' (Sine-LoRA). This technique introduce to apply sine activation "
"on the low-rank adaptor. This can be beneficial for rank boosting for low-rank matrices and enhancing its "
"capacity. For more information, see https://arxiv.org/pdf/2403.19243. "
)
},
)
sinelora_frequency: float = field(
default=200.0,
metadata={
"help": (
"The frequency factor for the sine activation. If not specified, it will be set to the default value of 200."
)
},
)
sinelora_scaling: Optional[float] = field(
default=None,
metadata={
"help": (
"The scaling factor for the sine activation. If not specified, it will be set to the default value of sqrt(in_features)."
)
},
)
# Enables replicating layers in a model to expand it to a larger model.
layer_replication: Optional[list[tuple[int, int]]] = field(
default=None,
Expand Down
2 changes: 1 addition & 1 deletion src/peft/tuners/lora/eva.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ def _get_rank_distribution(hooks, layer_hook_map, equal_inputs_map, rank_budget,
else:
pbar = iter(cycle(dataloader))
use_tqdm = False
convergence_dict = {k: False for k in hooks.keys()}
convergence_dict = dict.fromkeys(hooks.keys(), False)
rank_dist = max_components.copy()
for inputs in pbar:
if device is not None:
Expand Down
85 changes: 69 additions & 16 deletions src/peft/tuners/lora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def __init__(self, base_layer: nn.Module, ephemeral_gpu_offload: bool = False, *
self.in_features = in_features
self.out_features = out_features

def resolve_lora_variant(self, *, use_dora: bool, **kwargs) -> Optional[LoraVariant]:
def resolve_lora_variant(self, *, use_dora: bool, use_sinelora: bool, **kwargs) -> Optional[LoraVariant]:
"""Return a matching LoRA variant for this layer type.

Given the init arguments of this layer, return the correct LoRA variant, if any. E.g., if `use_dora=True`, this
Expand All @@ -170,6 +170,7 @@ def resolve_lora_variant(self, *, use_dora: bool, **kwargs) -> Optional[LoraVari
convention, and not here.

"""

return None

def update_layer(
Expand All @@ -181,17 +182,24 @@ def update_layer(
init_lora_weights,
use_rslora,
use_dora: bool = False,
use_sinelora: bool = False,
sinelora_frequency: float = 200.0,
sinelora_scaling: Optional[float] = None,
lora_bias: bool = False,
):
# collect the kwargs
kwargs = locals().copy()
del kwargs["self"]

if use_sinelora:
self.sinelora_frequency = sinelora_frequency
self.sinelora_scaling = sinelora_scaling

# This code works for linear layers, override for other layer types
if r <= 0:
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}")

lora_variant = self.resolve_lora_variant(use_dora=use_dora)
lora_variant = self.resolve_lora_variant(use_dora=use_dora, use_sinelora=use_sinelora)
if lora_variant is not None:
self.lora_variant[adapter_name] = lora_variant

Expand Down Expand Up @@ -570,7 +578,10 @@ def __init__(
init_lora_weights: Union[bool, str] = True,
use_rslora: bool = False,
use_dora: bool = False,
sinelora_frequency: float = 200.0,
sinelora_scaling: Optional[float] = None,
lora_bias: bool = False,
use_sinelora: bool = False,
**kwargs,
) -> None:
super().__init__()
Expand All @@ -587,16 +598,24 @@ def __init__(
use_rslora=use_rslora,
use_dora=use_dora,
lora_bias=lora_bias,
use_sinelora=use_sinelora,
sinelora_frequency=sinelora_frequency,
sinelora_scaling=sinelora_scaling,
)
self.is_target_conv_1d_layer = is_target_conv_1d_layer

def resolve_lora_variant(self, *, use_dora: bool, **kwargs) -> Optional[LoraVariant]:
if not use_dora:
return None
def resolve_lora_variant(self, *, use_dora: bool, use_sinelora: bool, **kwargs) -> Optional[LoraVariant]:
if use_dora:
from .variants import DoraLinearVariant

return DoraLinearVariant()

from .variants import DoraLinearVariant
elif use_sinelora:
from .variants import SineLoraLinearVariant

return DoraLinearVariant()
return SineLoraLinearVariant()

return None

def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
"""
Expand Down Expand Up @@ -770,6 +789,9 @@ def __init__(
use_rslora: bool = False,
use_dora: bool = False,
lora_bias: bool = False,
use_sinelora: bool = False,
sinelora_frequency=200.0,
sinelora_scaling: Optional[float] = None,
**kwargs,
) -> None:
if lora_bias:
Expand All @@ -789,28 +811,50 @@ def __init__(
init_lora_weights=init_lora_weights,
use_rslora=use_rslora,
use_dora=use_dora,
use_sinelora=use_sinelora,
lora_bias=lora_bias,
sinelora_frequency=sinelora_frequency,
sinelora_scaling=sinelora_scaling,
)

def resolve_lora_variant(self, *, use_dora: bool, **kwargs) -> Optional[LoraVariant]:
if not use_dora:
return None
def resolve_lora_variant(self, *, use_dora: bool, use_sinelora: bool, **kwargs) -> Optional[LoraVariant]:
if use_dora:
from .variants import DoraEmbeddingVariant

from .variants import DoraEmbeddingVariant
return DoraEmbeddingVariant()
elif use_sinelora:
from .variants import SineLoraEmbeddingVariant

return DoraEmbeddingVariant()
return SineLoraEmbeddingVariant()
else:
return None

def update_layer(
self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora, use_dora, lora_bias
self,
adapter_name,
r,
lora_alpha,
lora_dropout,
init_lora_weights,
use_rslora,
use_dora,
use_sinelora,
lora_bias,
sinelora_frequency,
sinelora_scaling,
):
# collect the kwargs
kwargs = locals().copy()
del kwargs["self"]

if use_sinelora:
self.sinelora_frequency = sinelora_frequency
self.sinelora_scaling = sinelora_scaling

if r <= 0:
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}")

lora_variant = self.resolve_lora_variant(use_dora=use_dora)
lora_variant = self.resolve_lora_variant(use_dora=use_dora, use_sinelora=use_sinelora)
if lora_variant is not None:
self.lora_variant[adapter_name] = lora_variant

Expand Down Expand Up @@ -1063,7 +1107,16 @@ def __init__(
)

def update_layer(
self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora, use_dora, lora_bias
self,
adapter_name,
r,
lora_alpha,
lora_dropout,
init_lora_weights,
use_rslora,
use_dora,
use_sinelora,
Copy link
Collaborator

@githubnemo githubnemo Jun 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sure to change the update_layer call in ConvNd.__init__ as well (currently misses all sinelora arguments). But since we're skipping convolutions for now I suggest to remove it entirely.

lora_bias,
):
# collect the kwargs
kwargs = locals().copy()
Expand All @@ -1072,7 +1125,7 @@ def update_layer(
if r <= 0:
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}")

lora_variant = self.resolve_lora_variant(use_dora=use_dora)
lora_variant = self.resolve_lora_variant(use_dora=use_dora, use_sinelora=use_sinelora)
if lora_variant is not None:
self.lora_variant[adapter_name] = lora_variant

Expand Down
3 changes: 3 additions & 0 deletions src/peft/tuners/lora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,9 @@ def _create_and_replace(
"init_lora_weights": lora_config.init_lora_weights,
"use_rslora": lora_config.use_rslora,
"use_dora": lora_config.use_dora,
"use_sinelora": lora_config.use_sinelora,
"sinelora_frequency": lora_config.sinelora_frequency,
"sinelora_scaling": lora_config.sinelora_scaling,
"ephemeral_gpu_offload": lora_config.runtime_config.ephemeral_gpu_offload,
"lora_bias": lora_config.lora_bias,
"loaded_in_8bit": getattr(self.model, "is_loaded_in_8bit", False),
Expand Down
49 changes: 49 additions & 0 deletions src/peft/tuners/lora/variants.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
from __future__ import annotations

import math
from typing import Any

import torch
Expand Down Expand Up @@ -287,3 +288,51 @@ class DoraConv3dVariant(_DoraConvNdVariant):
def init(module: Conv3d, adapter_name: str, **kwargs: Any) -> None:
dora_layer = DoraConv3dLayer(fan_in_fan_out=False)
_DoraConvNdVariant.init_convd_variant(module, adapter_name, dora_layer=dora_layer)


class SineLoraLinearVariant(LoraVariant):
@staticmethod
def init(module: Linear, adapter_name: str, **kwargs) -> None:
module.sinelora_frequency = kwargs['sinelora_frequency']

module.sinelora_scaling = kwargs['sinelora_scaling']
if module.sinelora_scaling is None:
module.sinelora_scaling = math.sqrt(module.in_features)

@staticmethod
def forward(module: Linear, active_adapter: str, x: torch.Tensor, result: torch.Tensor) -> torch.Tensor:

lora_A = module.lora_A[active_adapter]
lora_B = module.lora_B[active_adapter]
lora_scaling = module.scaling[active_adapter]
sine_output = (
x
@ torch.sin(module.sinelora_frequency * lora_A.weight.T @ lora_B.weight.T)
/ module.sinelora_scaling
* lora_scaling
)
result = result + sine_output

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're missing a return here.


class SineLoraEmbeddingVariant(SineLoraLinearVariant):
@staticmethod
def init(module: Embedding, adapter_name: str, **kwargs) -> None:
module.sinelora_frequency = kwargs['sinelora_frequency']

sinelora_scaling = kwargs['sinelora_scaling']
if sinelora_scaling is None:
module.sinelora_scaling = math.sqrt(module.in_features)

@staticmethod
def forward(module: Embedding, active_adapter: str, x: torch.Tensor, result: torch.Tensor) -> torch.Tensor:
lora_embedding_A = module.lora_embedding_A[active_adapter]
lora_embedding_B = module.lora_embedding_B[active_adapter]
lora_scaling = module.scaling[active_adapter]
sine_output = (
module._embed(x)
@ torch.sin(module.sinelora_frequency * lora_embedding_A.weight.T @ lora_embedding_B.weight.T)
/ module.sinelora_scaling
* lora_scaling
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder, is this correct? The way I see it, _embed takes to parameters, x and weight but we're only supplying x here. weight should probably be lora_embedding_A.T?

result = result + sine_output
return result
2 changes: 1 addition & 1 deletion src/peft/utils/integrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def get_layer_device_map(model):
return None

if len(execution_device_map) == 1 and "" in execution_device_map:
return {idx: execution_device_map[""] for idx in range(model.config.num_hidden_layers)}
return dict.fromkeys(range(model.config.num_hidden_layers), execution_device_map[""])

layer_device_map = {}
for layer in execution_device_map:
Expand Down
4 changes: 4 additions & 0 deletions tests/test_custom_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,10 @@
TrainableTokensConfig,
{"target_modules": ["emb"], "token_indices": [0, 1, 3], "init_weights": False},
),
###################
# LoRA + SineLoRA #
###################
("Vanilla MLP LoRA + SineLoRA", "MLP", LoraConfig, {"target_modules": ["lin0", "lin1"], "use_sinelora": True}),
]

# For this test matrix, each tuple consists of:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2653,7 +2653,7 @@ def fn(x, *args):
if prepare_layer_inputs_keys is None:
prepare_layer_inputs_fn = fn
else:
prepare_layer_inputs_fn = {k: fn for k in prepare_layer_inputs_keys}
prepare_layer_inputs_fn = dict.fromkeys(prepare_layer_inputs_keys, fn)

shuffled_dataset = dataset.shuffle(seed=0)
dataloader = self.get_dataloader(shuffled_dataset)
Expand Down