-
Notifications
You must be signed in to change notification settings - Fork 2.1k
FEAT Add sine-LoRA #2434 #2457
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
FEAT Add sine-LoRA #2434 #2457
Changes from all commits
cbd48a0
0b7e0ec
d98ab16
f27ef97
8ed09c4
f9ae3e9
e4e3608
8d4db0c
76b16ec
1723ba8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -161,7 +161,7 @@ def __init__(self, base_layer: nn.Module, ephemeral_gpu_offload: bool = False, * | |
self.in_features = in_features | ||
self.out_features = out_features | ||
|
||
def resolve_lora_variant(self, *, use_dora: bool, **kwargs) -> Optional[LoraVariant]: | ||
def resolve_lora_variant(self, *, use_dora: bool, use_sinelora: bool, **kwargs) -> Optional[LoraVariant]: | ||
"""Return a matching LoRA variant for this layer type. | ||
|
||
Given the init arguments of this layer, return the correct LoRA variant, if any. E.g., if `use_dora=True`, this | ||
|
@@ -173,6 +173,7 @@ def resolve_lora_variant(self, *, use_dora: bool, **kwargs) -> Optional[LoraVari | |
convention, and not here. | ||
|
||
""" | ||
|
||
return None | ||
|
||
def update_layer( | ||
|
@@ -184,17 +185,24 @@ def update_layer( | |
init_lora_weights, | ||
use_rslora, | ||
use_dora: bool = False, | ||
use_sinelora: bool = False, | ||
sinelora_frequency: float = 200.0, | ||
sinelora_scaling: Optional[float] = None, | ||
lora_bias: bool = False, | ||
): | ||
# collect the kwargs | ||
kwargs = locals().copy() | ||
del kwargs["self"] | ||
|
||
if use_sinelora: | ||
self.sinelora_frequency = sinelora_frequency | ||
self.sinelora_scaling = sinelora_scaling | ||
|
||
# This code works for linear layers, override for other layer types | ||
if r <= 0: | ||
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}") | ||
|
||
lora_variant = self.resolve_lora_variant(use_dora=use_dora) | ||
lora_variant = self.resolve_lora_variant(use_dora=use_dora, use_sinelora=use_sinelora) | ||
if lora_variant is not None: | ||
self.lora_variant[adapter_name] = lora_variant | ||
|
||
|
@@ -571,7 +579,10 @@ def __init__( | |
init_lora_weights: Union[bool, str] = True, | ||
use_rslora: bool = False, | ||
use_dora: bool = False, | ||
sinelora_frequency: float = 200.0, | ||
sinelora_scaling: Optional[float] = None, | ||
lora_bias: bool = False, | ||
use_sinelora: bool = False, | ||
**kwargs, | ||
) -> None: | ||
super().__init__() | ||
|
@@ -588,16 +599,24 @@ def __init__( | |
use_rslora=use_rslora, | ||
use_dora=use_dora, | ||
lora_bias=lora_bias, | ||
use_sinelora=use_sinelora, | ||
sinelora_frequency=sinelora_frequency, | ||
sinelora_scaling=sinelora_scaling, | ||
) | ||
self.is_target_conv_1d_layer = is_target_conv_1d_layer | ||
|
||
def resolve_lora_variant(self, *, use_dora: bool, **kwargs) -> Optional[LoraVariant]: | ||
if not use_dora: | ||
return None | ||
def resolve_lora_variant(self, *, use_dora: bool, use_sinelora: bool, **kwargs) -> Optional[LoraVariant]: | ||
if use_dora: | ||
from .variants import DoraLinearVariant | ||
|
||
return DoraLinearVariant() | ||
|
||
from .variants import DoraLinearVariant | ||
elif use_sinelora: | ||
from .variants import SineLoraLinearVariant | ||
|
||
return DoraLinearVariant() | ||
return SineLoraLinearVariant() | ||
|
||
return None | ||
|
||
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: | ||
""" | ||
|
@@ -773,6 +792,9 @@ def __init__( | |
use_rslora: bool = False, | ||
use_dora: bool = False, | ||
lora_bias: bool = False, | ||
use_sinelora: bool = False, | ||
sinelora_frequency=200.0, | ||
sinelora_scaling: Optional[float] = None, | ||
**kwargs, | ||
) -> None: | ||
if lora_bias: | ||
|
@@ -792,28 +814,50 @@ def __init__( | |
init_lora_weights=init_lora_weights, | ||
use_rslora=use_rslora, | ||
use_dora=use_dora, | ||
use_sinelora=use_sinelora, | ||
lora_bias=lora_bias, | ||
githubnemo marked this conversation as resolved.
Show resolved
Hide resolved
|
||
sinelora_frequency=sinelora_frequency, | ||
sinelora_scaling=sinelora_scaling, | ||
) | ||
|
||
def resolve_lora_variant(self, *, use_dora: bool, **kwargs) -> Optional[LoraVariant]: | ||
if not use_dora: | ||
return None | ||
def resolve_lora_variant(self, *, use_dora: bool, use_sinelora: bool, **kwargs) -> Optional[LoraVariant]: | ||
if use_dora: | ||
from .variants import DoraEmbeddingVariant | ||
|
||
from .variants import DoraEmbeddingVariant | ||
return DoraEmbeddingVariant() | ||
elif use_sinelora: | ||
from .variants import SineLoraEmbeddingVariant | ||
|
||
return DoraEmbeddingVariant() | ||
return SineLoraEmbeddingVariant() | ||
else: | ||
return None | ||
|
||
def update_layer( | ||
self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora, use_dora, lora_bias | ||
self, | ||
adapter_name, | ||
r, | ||
lora_alpha, | ||
lora_dropout, | ||
init_lora_weights, | ||
use_rslora, | ||
use_dora, | ||
lora_bias, | ||
use_sinelora: bool = False, | ||
sinelora_frequency: float = 200.0, | ||
sinelora_scaling: Optional[float] = None, | ||
): | ||
# collect the kwargs | ||
kwargs = locals().copy() | ||
del kwargs["self"] | ||
|
||
if use_sinelora: | ||
self.sinelora_frequency = sinelora_frequency | ||
self.sinelora_scaling = sinelora_scaling | ||
|
||
if r <= 0: | ||
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}") | ||
|
||
lora_variant = self.resolve_lora_variant(use_dora=use_dora) | ||
lora_variant = self.resolve_lora_variant(use_dora=use_dora, use_sinelora=use_sinelora) | ||
if lora_variant is not None: | ||
self.lora_variant[adapter_name] = lora_variant | ||
|
||
|
@@ -988,7 +1032,7 @@ def _embed(self, input: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: | |
norm_type=base_layer.norm_type, | ||
scale_grad_by_freq=base_layer.scale_grad_by_freq, | ||
sparse=base_layer.sparse, | ||
) | ||
) | ||
|
||
def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: | ||
# TODO: no dtype conversion here, unlike in Linear, is that correct? | ||
|
@@ -1068,7 +1112,16 @@ def __init__( | |
) | ||
|
||
def update_layer( | ||
self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora, use_dora, lora_bias | ||
self, | ||
adapter_name, | ||
r, | ||
lora_alpha, | ||
lora_dropout, | ||
init_lora_weights, | ||
use_rslora, | ||
use_dora, | ||
use_sinelora, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Make sure to change the |
||
lora_bias, | ||
): | ||
# collect the kwargs | ||
kwargs = locals().copy() | ||
|
@@ -1077,7 +1130,7 @@ def update_layer( | |
if r <= 0: | ||
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}") | ||
|
||
lora_variant = self.resolve_lora_variant(use_dora=use_dora) | ||
lora_variant = self.resolve_lora_variant(use_dora=use_dora, use_sinelora=use_sinelora) | ||
if lora_variant is not None: | ||
self.lora_variant[adapter_name] = lora_variant | ||
|
||
|
@@ -1326,13 +1379,12 @@ def __init__(self, *args, **kwargs): | |
raise ValueError(f"Conv1d layer kernel must have 3 dimensions, not {self._kernel_dim}") | ||
self.conv_fn = F.conv1d | ||
|
||
def resolve_lora_variant(self, *, use_dora: bool, **kwargs) -> Optional[LoraVariant]: | ||
def resolve_lora_variant(self, *, use_dora: bool, use_sinelora: bool, **kwargs) -> Optional[LoraVariant]: | ||
if not use_dora: | ||
return None | ||
|
||
from .variants import DoraConv1dVariant | ||
|
||
return DoraConv1dVariant() | ||
from .variants import DoraConv1dVariant | ||
elif use_sinelora: | ||
from .variants import SineLoraConv1dVariant | ||
return None | ||
Comment on lines
+1384
to
+1387
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
If you have an implementation for conv*d I'd suggest adding it. If you don't maybe it is worthwhile to skip it for now and undo the changes in the |
||
|
||
|
||
class Conv3d(_ConvNd): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this value is optional, it should be marked as type
Optional[float]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
change type of
sinelora_scaling
here toOptional[float]
as it is defined in code.