Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions src/peft/tuners/lora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,12 +183,18 @@ def update_layer(
use_rslora,
use_dora: bool = False,
use_sinelora: bool = False,
sinelora_frequency: float = 200.0,
sinelora_scaling: Optional[float] = None,
lora_bias: bool = False,
):
# collect the kwargs
kwargs = locals().copy()
del kwargs["self"]

if use_sinelora:
self.sinelora_frequency = sinelora_frequency
self.sinelora_scaling = sinelora_scaling

# This code works for linear layers, override for other layer types
if r <= 0:
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}")
Expand Down Expand Up @@ -572,6 +578,8 @@ def __init__(
init_lora_weights: Union[bool, str] = True,
use_rslora: bool = False,
use_dora: bool = False,
sinelora_frequency: float = 200.0,
sinelora_scaling: Optional[float] = None,
lora_bias: bool = False,
use_sinelora: bool = False,
**kwargs,
Expand All @@ -591,6 +599,8 @@ def __init__(
use_dora=use_dora,
lora_bias=lora_bias,
use_sinelora=use_sinelora,
sinelora_frequency=sinelora_frequency,
sinelora_scaling=sinelora_scaling,
)
self.is_target_conv_1d_layer = is_target_conv_1d_layer

Expand Down Expand Up @@ -780,6 +790,8 @@ def __init__(
use_dora: bool = False,
lora_bias: bool = False,
use_sinelora: bool = False,
sinelora_frequency=200.0,
sinelora_scaling: Optional[float] = None,
**kwargs,
) -> None:
if lora_bias:
Expand All @@ -801,6 +813,8 @@ def __init__(
use_dora=use_dora,
use_sinelora=use_sinelora,
lora_bias=lora_bias,
sinelora_frequency=sinelora_frequency,
sinelora_scaling=sinelora_scaling,
)

def resolve_lora_variant(self, *, use_dora: bool, use_sinelora: bool, **kwargs) -> Optional[LoraVariant]:
Expand All @@ -826,11 +840,17 @@ def update_layer(
use_dora,
use_sinelora,
lora_bias,
sinelora_frequency,
sinelora_scaling,
):
# collect the kwargs
kwargs = locals().copy()
del kwargs["self"]

if use_sinelora:
self.sinelora_frequency = sinelora_frequency
self.sinelora_scaling = sinelora_scaling

if r <= 0:
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}")

Expand Down
17 changes: 9 additions & 8 deletions src/peft/tuners/lora/variants.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,9 +295,9 @@ class SineLoraLinearVariant(LoraVariant):
def init(module: Linear, adapter_name: str, **kwargs) -> None:
module.sinelora_frequency = kwargs['sinelora_frequency']

sinelora_scaling = kwargs.get('sinelora_scaling')
if sinelora_scaling is not None:
module.sinelora_scaling = sinelora_scaling
module.sinelora_scaling = kwargs['sinelora_scaling']
if module.sinelora_scaling is None:
module.sinelora_scaling = math.sqrt(module.in_features)

@staticmethod
def forward(module: Linear, active_adapter: str, x: torch.Tensor, result: torch.Tensor) -> torch.Tensor:
Expand All @@ -307,7 +307,7 @@ def forward(module: Linear, active_adapter: str, x: torch.Tensor, result: torch.
lora_scaling = module.scaling[active_adapter]
sine_output = (
x
@ torch.sin(module.sinelora_frequency * lora_B.weight.T @ lora_A.weight)
@ torch.sin(module.sinelora_frequency * lora_A.weight.T @ lora_B.weight.T)
/ module.sinelora_scaling
* lora_scaling
)
Expand All @@ -319,17 +319,18 @@ class SineLoraEmbeddingVariant(SineLoraLinearVariant):
def init(module: Embedding, adapter_name: str, **kwargs) -> None:
module.sinelora_frequency = kwargs['sinelora_frequency']

sinelora_scaling = kwargs.get('sinelora_scaling')
if sinelora_scaling is not None:
module.sinelora_scaling = sinelora_scaling
sinelora_scaling = kwargs['sinelora_scaling']
if sinelora_scaling is None:
module.sinelora_scaling = math.sqrt(module.in_features)

@staticmethod
def forward(module: Embedding, active_adapter: str, x: torch.Tensor, result: torch.Tensor) -> torch.Tensor:
lora_embedding_A = module.lora_embedding_A[active_adapter]
lora_embedding_B = module.lora_embedding_B[active_adapter]
lora_scaling = module.scaling[active_adapter]
sine_output = (
module._embed(x)
@ torch.sin(module.sinelora_frequency * lora_embedding_B.weight.T @ lora_embedding_A.weight)
@ torch.sin(module.sinelora_frequency * lora_embedding_A.weight.T @ lora_embedding_B.weight.T)
/ module.sinelora_scaling
* lora_scaling
)
Expand Down