Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions =3
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Collecting oauthlib
Using cached oauthlib-3.2.2-py3-none-any.whl.metadata (7.5 kB)
Using cached oauthlib-3.2.2-py3-none-any.whl (151 kB)
Installing collected packages: oauthlib
Successfully installed oauthlib-3.2.2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove :)

30 changes: 30 additions & 0 deletions src/peft/tuners/lora/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,10 @@ class LoraConfig(PeftConfig):
ranks. Right now, DoRA only supports linear and Conv2D layers. DoRA introduces a bigger overhead than pure
LoRA, so it is recommended to merge weights for inference. For more information, see
https://arxiv.org/abs/2402.09353.
use_sine_lora (`bool`):
Enable 'Sine Activated Low-Rank Adaptation' (Sine-LoRA). This technique introduce to apply sine activation
on the low-rank adaptor. This can be beneficial for rank boosting for low-rank matrices and enhancing its
capacity. For more information, see https://arxiv.org/pdf/2403.19243.
layer_replication (`List[Tuple[int, int]]`):
Build a new stack of layers by stacking the original model layers according to the ranges specified. This
allows expanding (or shrinking) the model without duplicating the base model weights. The new layers will
Expand Down Expand Up @@ -493,6 +497,31 @@ class LoraConfig(PeftConfig):
)
},
)
use_sinelora: bool = field(
default=False,
metadata={
"help": (
"Enable <a href='https://arxiv.org/pdf/2403.19243'> This technique introduce to apply sine activation on the low-rank adaptor."
"This can be beneficial for rank boosting for low-rank matrices and enhancing its capacity, especially at low ranks."
)
},
)
sinelora_frequency: float = field(
default=200.0,
metadata={
"help": (
"The frequency factor for the sine activation. If not specified, it will be set to the default value of 200."
)
},
)
sinelora_scaling: float = field(
default=None,
metadata={
"help": (
"The scaling factor for the sine activation. If not specified, it will be set to the default value of sqrt(in_features)."
)
},
)
# Enables replicating layers in a model to expand it to a larger model.
layer_replication: Optional[list[tuple[int, int]]] = field(
default=None,
Expand Down Expand Up @@ -597,6 +626,7 @@ def __post_init__(self):
)
if self.use_dora:
raise ValueError("The argument lora_bias=True is not supported for DoRA, please pass use_dora=False")


# Using post training conversion of modified base weights to restore their initial values PiSSA/CorDA/OLoRA cannot
# be correctly done when using rslora + rank_pattern/alpha_pattern. We can't really know if the user intends
Expand Down
14 changes: 8 additions & 6 deletions src/peft/tuners/lora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,14 +792,16 @@ def __init__(
lora_bias=lora_bias,
)

def resolve_lora_variant(self, *, use_dora: bool, **kwargs) -> Optional[LoraVariant]:
if not use_dora:
def resolve_lora_variant(self, *, use_dora: bool, use_sine_lora:bool,**kwargs) -> Optional[LoraVariant]:
if use_dora:
from .variants import DoraEmbeddingVariant
return DoraEmbeddingVariant()
elif use_sine_lora:
from .variants import SineLoraLinearVariant
return SineLoraLinearVariant()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should reference the (not yet existing) SineLoraEmbeddingVariant since we're in the Embedding class.

But this code is good for Linear.resolve_lora_variant :) You can use it there!

Effectively every class that overrides modules in the model (Linear, Embedding, Conv2d, ...) needs its own variant implementation and resolve_lora_variant implementation but we can keep it at Linear and Embedding for now if you want.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I corrected. Please check my new PR.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm can you check again? I don't see changes in this regard :/

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I git push again. I only implemented Linear and Embedding for now and is that ok?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep that's totally fine. We can add support for convolutions later once Linear and Embedding work as expected.

else:
return None

from .variants import DoraEmbeddingVariant

return DoraEmbeddingVariant()

def update_layer(
self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora, use_dora, lora_bias
):
Expand Down
3 changes: 3 additions & 0 deletions src/peft/tuners/lora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,9 @@ def _create_and_replace(
"init_lora_weights": lora_config.init_lora_weights,
"use_rslora": lora_config.use_rslora,
"use_dora": lora_config.use_dora,
"use_sinelora": lora_config.use_sinelora,
"sinelora_frequency": lora_config.sinelora_frequency,
"sinelora_scaling": lora_config.sinelora_scaling,
"ephemeral_gpu_offload": lora_config.runtime_config.ephemeral_gpu_offload,
"lora_bias": lora_config.lora_bias,
"loaded_in_8bit": getattr(self.model, "is_loaded_in_8bit", False),
Expand Down
18 changes: 18 additions & 0 deletions src/peft/tuners/lora/variants.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,3 +287,21 @@ class DoraConv3dVariant(_DoraConvNdVariant):
def init(module: Conv3d, adapter_name: str, **kwargs: Any) -> None:
dora_layer = DoraConv3dLayer(fan_in_fan_out=False)
_DoraConvNdVariant.init_convd_variant(module, adapter_name, dora_layer=dora_layer)


class SineLoraLinearVariant(LoraVariant):
@staticmethod
def init(module: Linear, adapter_name:str) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def init(module: Linear, adapter_name:str) -> None:
def init(module: Linear, adapter_name:str, **kwargs) -> None:

With PR #2455 now merged, init() receives all the parameters that update_layer receives.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmmm I did not use that and do you think that is ok?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so, no.

Currently the tests do not work because of the changes necessary in Linear.__init__ and Embedding.__init__. Once the changes are in place you'll see that calls to init will complain about unexpected arguments passed to init(). That's because all the config args are passed to init and without the wildcard **kwargs you have to define them all (which we don't want, of course).

Also you need a place to set module.sinelora_scaling and module.sinelora_frequency. This is here, from the kwargs, e.g.

module.sinelora_frequency = kwargs['sinelora_frequency']

For sinelora_scaling you need to check if kwargs['sinelora_scaling'] is None.


if module.sinelora_scaling is None:
import math
module.sinelora_scaling = math.sqrt(module.in_features)

@staticmethod
def forward(module: Linear, active_adapter: str, x: torch.Tensor, result: torch.Tensor) -> torch.Tensor:

lora_A = module.lora_A[active_adapter]
lora_B = module.lora_B[active_adapter]
lora_scaling = module.scaling[active_adapter]
sine_output = x @ torch.sin(module.sinelora_frequency * lora_B.weight.T @ lora_A.weight) / module.sinelora_scaling * lora_scaling
result = result + sine_output
3 changes: 2 additions & 1 deletion tests/test_custom_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@
("Vanilla MLP 1 LoRA", "MLP", LoraConfig, {"target_modules": "lin0"}),
("Vanilla MLP 2 LoRA", "MLP", LoraConfig, {"target_modules": ["lin0"]}),
("Vanilla MLP 3 LoRA", "MLP", LoraConfig, {"target_modules": ["lin1"]}),
("Vanilla MLP 4 LoRA", "MLP", LoraConfig, {"target_modules": ["lin0", "lin1"]}),
("Vanilla MLP 4 LoRA", "MLP", LoraConfig, {"target_modules": ["lin0", "lin1"],"use_sine_lora": True}),
("Vanilla MLP LoRA + SineLoRA", "MLP", LoraConfig, {"target_modules": ["lin0", "lin1"]}),
("Vanilla MLP 5 LoRA", "MLP", LoraConfig, {"target_modules": ["lin0"], "modules_to_save": ["lin1"]}),
(
"Vanilla MLP 6 LoRA",
Expand Down