Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions =3
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Collecting oauthlib
Using cached oauthlib-3.2.2-py3-none-any.whl.metadata (7.5 kB)
Using cached oauthlib-3.2.2-py3-none-any.whl (151 kB)
Installing collected packages: oauthlib
Successfully installed oauthlib-3.2.2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove :)

76 changes: 76 additions & 0 deletions posters/.$Untitled Diagram.drawio.dtmp

Large diffs are not rendered by default.

76 changes: 76 additions & 0 deletions posters/Untitled Diagram.drawio

Large diffs are not rendered by default.

Binary file added posters/Untitled Diagram.drawio.pdf
Binary file not shown.
35 changes: 35 additions & 0 deletions src/peft/tuners/lora/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,14 @@ class LoraConfig(PeftConfig):
ranks. Right now, DoRA only supports linear and Conv2D layers. DoRA introduces a bigger overhead than pure
LoRA, so it is recommended to merge weights for inference. For more information, see
https://arxiv.org/abs/2402.09353.
use_sinelora (`bool`):
Enable 'Sine Activated Low-Rank Adaptation' (Sine-LoRA). This technique introduce to apply sine activation
on the low-rank adaptor. This can be beneficial for rank boosting for low-rank matrices and enhancing its
capacity. For more information, see https://arxiv.org/pdf/2403.19243.
sinelora_frequency (`float`):
The frequency factor for the sine activation. If not specified, it will be set to the default value of 200.
sinelora_scaling (`float`):
The scaling factor for the sine activation. If not specified, it will be set to the default value of sqrt(in_features).
Comment on lines +307 to +308
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this value is optional, it should be marked as type Optional[float]

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change type of sinelora_scaling here to Optional[float] as it is defined in code.

layer_replication (`List[Tuple[int, int]]`):
Build a new stack of layers by stacking the original model layers according to the ranges specified. This
allows expanding (or shrinking) the model without duplicating the base model weights. The new layers will
Expand Down Expand Up @@ -493,6 +501,32 @@ class LoraConfig(PeftConfig):
)
},
)
use_sinelora: bool = field(
default=False,
metadata={
"help": (
"Enable 'Sine Activated Low-Rank Adaptation' (Sine-LoRA). This technique introduce to apply sine activation "
"on the low-rank adaptor. This can be beneficial for rank boosting for low-rank matrices and enhancing its "
"capacity. For more information, see https://arxiv.org/pdf/2403.19243. "
)
},
)
sinelora_frequency: float = field(
default=200.0,
metadata={
"help": (
"The frequency factor for the sine activation. If not specified, it will be set to the default value of 200."
)
},
)
sinelora_scaling: float = field(
default=None,
metadata={
"help": (
"The scaling factor for the sine activation. If not specified, it will be set to the default value of sqrt(in_features)."
)
},
)
# Enables replicating layers in a model to expand it to a larger model.
layer_replication: Optional[list[tuple[int, int]]] = field(
default=None,
Expand Down Expand Up @@ -597,6 +631,7 @@ def __post_init__(self):
)
if self.use_dora:
raise ValueError("The argument lora_bias=True is not supported for DoRA, please pass use_dora=False")


# Using post training conversion of modified base weights to restore their initial values PiSSA/CorDA/OLoRA cannot
# be correctly done when using rslora + rank_pattern/alpha_pattern. We can't really know if the user intends
Expand Down
45 changes: 28 additions & 17 deletions src/peft/tuners/lora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def __init__(self, base_layer: nn.Module, ephemeral_gpu_offload: bool = False, *
self.in_features = in_features
self.out_features = out_features

def resolve_lora_variant(self, *, use_dora: bool, **kwargs) -> Optional[LoraVariant]:
def resolve_lora_variant(self, *, use_dora: bool, use_sinelora: bool, **kwargs) -> Optional[LoraVariant]:
"""Return a matching LoRA variant for this layer type.

Given the init arguments of this layer, return the correct LoRA variant, if any. E.g., if `use_dora=True`, this
Expand All @@ -170,6 +170,9 @@ def resolve_lora_variant(self, *, use_dora: bool, **kwargs) -> Optional[LoraVari
convention, and not here.

"""
if use_sinelora:
from .variants import SineLoraLinearVariant
return SineLoraLinearVariant()
return None

def update_layer(
Expand All @@ -181,6 +184,7 @@ def update_layer(
init_lora_weights,
use_rslora,
use_dora: bool = False,
use_sinelora: bool = False,
lora_bias: bool = False,
):
# collect the kwargs
Expand All @@ -191,7 +195,7 @@ def update_layer(
if r <= 0:
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}")

lora_variant = self.resolve_lora_variant(use_dora=use_dora)
lora_variant = self.resolve_lora_variant(use_dora=use_dora,use_sinelora=use_sinelora)
if lora_variant is not None:
self.lora_variant[adapter_name] = lora_variant

Expand Down Expand Up @@ -590,13 +594,18 @@ def __init__(
)
self.is_target_conv_1d_layer = is_target_conv_1d_layer

def resolve_lora_variant(self, *, use_dora: bool, **kwargs) -> Optional[LoraVariant]:
if not use_dora:
return None
def resolve_lora_variant(self, *, use_dora: bool, use_sinelora: bool, **kwargs) -> Optional[LoraVariant]:
if use_dora:
from .variants import DoraLinearVariant
return DoraLinearVariant()

elif use_sinelora:
from .variants import SineLoraLinearVariant
return SineLoraLinearVariant()

return None

from .variants import DoraLinearVariant

return DoraLinearVariant()

def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
"""
Expand Down Expand Up @@ -792,16 +801,18 @@ def __init__(
lora_bias=lora_bias,
)

def resolve_lora_variant(self, *, use_dora: bool, **kwargs) -> Optional[LoraVariant]:
if not use_dora:
def resolve_lora_variant(self, *, use_dora: bool, use_sinelora:bool,**kwargs) -> Optional[LoraVariant]:
if use_dora:
from .variants import DoraEmbeddingVariant
return DoraEmbeddingVariant()
elif use_sinelora:
from .variants import SineLoraEmbeddingVariant
return SineLoraEmbeddingVariant()
else:
return None

from .variants import DoraEmbeddingVariant

return DoraEmbeddingVariant()

def update_layer(
self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora, use_dora, lora_bias
self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora, use_dora, use_sinelora,lora_bias
):
# collect the kwargs
kwargs = locals().copy()
Expand All @@ -810,7 +821,7 @@ def update_layer(
if r <= 0:
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}")

lora_variant = self.resolve_lora_variant(use_dora=use_dora)
lora_variant = self.resolve_lora_variant(use_dora=use_dora,use_sinelora=use_sinelora)
if lora_variant is not None:
self.lora_variant[adapter_name] = lora_variant

Expand Down Expand Up @@ -1063,7 +1074,7 @@ def __init__(
)

def update_layer(
self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora, use_dora, lora_bias
self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora, use_dora, use_sinelora, lora_bias
):
# collect the kwargs
kwargs = locals().copy()
Expand All @@ -1072,7 +1083,7 @@ def update_layer(
if r <= 0:
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}")

lora_variant = self.resolve_lora_variant(use_dora=use_dora)
lora_variant = self.resolve_lora_variant(use_dora=use_dora,use_sinelora=use_sinelora)
if lora_variant is not None:
self.lora_variant[adapter_name] = lora_variant

Expand Down
3 changes: 3 additions & 0 deletions src/peft/tuners/lora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,9 @@ def _create_and_replace(
"init_lora_weights": lora_config.init_lora_weights,
"use_rslora": lora_config.use_rslora,
"use_dora": lora_config.use_dora,
"use_sinelora": lora_config.use_sinelora,
"sinelora_frequency": lora_config.sinelora_frequency,
"sinelora_scaling": lora_config.sinelora_scaling,
"ephemeral_gpu_offload": lora_config.runtime_config.ephemeral_gpu_offload,
"lora_bias": lora_config.lora_bias,
"loaded_in_8bit": getattr(self.model, "is_loaded_in_8bit", False),
Expand Down
33 changes: 32 additions & 1 deletion src/peft/tuners/lora/variants.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from __future__ import annotations

from typing import Any

import math
import torch
from accelerate.utils.imports import is_xpu_available
from torch import nn
Expand Down Expand Up @@ -287,3 +287,34 @@ class DoraConv3dVariant(_DoraConvNdVariant):
def init(module: Conv3d, adapter_name: str, **kwargs: Any) -> None:
dora_layer = DoraConv3dLayer(fan_in_fan_out=False)
_DoraConvNdVariant.init_convd_variant(module, adapter_name, dora_layer=dora_layer)


class SineLoraLinearVariant(LoraVariant):
@staticmethod
def init(module: Linear, adapter_name:str) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def init(module: Linear, adapter_name:str) -> None:
def init(module: Linear, adapter_name:str, **kwargs) -> None:

With PR #2455 now merged, init() receives all the parameters that update_layer receives.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmmm I did not use that and do you think that is ok?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so, no.

Currently the tests do not work because of the changes necessary in Linear.__init__ and Embedding.__init__. Once the changes are in place you'll see that calls to init will complain about unexpected arguments passed to init(). That's because all the config args are passed to init and without the wildcard **kwargs you have to define them all (which we don't want, of course).

Also you need a place to set module.sinelora_scaling and module.sinelora_frequency. This is here, from the kwargs, e.g.

module.sinelora_frequency = kwargs['sinelora_frequency']

For sinelora_scaling you need to check if kwargs['sinelora_scaling'] is None.

if module.sinelora_scaling is None:
module.sinelora_scaling = math.sqrt(module.in_features)

@staticmethod
def forward(module: Linear, active_adapter: str, x: torch.Tensor, result: torch.Tensor) -> torch.Tensor:

lora_A = module.lora_A[active_adapter]
lora_B = module.lora_B[active_adapter]
lora_scaling = module.scaling[active_adapter]
sine_output = x @ torch.sin(module.sinelora_frequency * lora_B.weight.T @ lora_A.weight) / module.sinelora_scaling * lora_scaling
result = result + sine_output

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're missing a return here.

class SineLoraEmbeddingVariant(SineLoraLinearVariant):
@staticmethod
def init(module: Linear, adapter_name:str) -> None:
if module.sinelora_scaling is None:
module.sinelora_scaling = math.sqrt(module.in_features)

@staticmethod
def forward(module: Embedding, active_adapter: str, x: torch.Tensor, result: torch.Tensor) -> torch.Tensor:
lora_embedding_A = module.lora_embedding_A[active_adapter]
lora_embedding_B = module.lora_embedding_B[active_adapter]
lora_scaling = module.scaling[active_adapter]
sine_output = module._embed(x) @ torch.sin(module.sinelora_frequency * lora_embedding_B.weight.T @ lora_embedding_A.weight) / module.sinelora_scaling * lora_scaling
result = result + sine_output
return result
4 changes: 4 additions & 0 deletions tests/test_custom_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,10 @@
TrainableTokensConfig,
{"target_modules": ["emb"], "token_indices": [0, 1, 3], "init_weights": False},
),
###################
# LoRA + SineLoRA #
###################
("Vanilla MLP LoRA + SineLoRA", "MLP", LoraConfig, {"target_modules": ["lin0", "lin1"],"use_sinelora": True}),
]

# For this test matrix, each tuple consists of:
Expand Down