-
Notifications
You must be signed in to change notification settings - Fork 2.1k
[WIP] Update LoraConfig
for KaSA implementation
#2698
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 10 commits
387ad2a
ae00a34
6588f1a
a824ac9
05e4e07
d1e7e43
9e53b9f
aa37111
9cfe65c
f9d7cc7
84813a3
39abcad
06f76d8
0043ae3
ea59432
574c1b8
73fa58f
b9e3190
2649fdf
d06cafd
4ee5da1
0377170
a536bbf
f8d8057
cd57c7b
2431a2c
7276b3b
5a67b1f
3ec6b18
461a89c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -33,7 +33,7 @@ | |
from peft.utils.other import transpose | ||
|
||
from .config import LoraConfig | ||
|
||
from .dora import DoraConv2dLayer, DoraConv3dLayer, DoraEmbeddingLayer, DoraLinearLayer, _DoraConvNdLayer | ||
|
||
class LoraVariant: | ||
""" | ||
|
@@ -107,6 +107,9 @@ def __init__(self, base_layer: nn.Module, ephemeral_gpu_offload: bool = False, * | |
self.cast_input_dtype_enabled: bool = True | ||
self.lora_variant: dict[str, LoraVariant] = {} | ||
self.kwargs = kwargs | ||
|
||
# Diag value | ||
self.lora_diag = nn.ParameterDict({}) | ||
|
||
base_layer = self.get_base_layer() | ||
if isinstance(base_layer, nn.Linear): | ||
|
@@ -161,7 +164,7 @@ def __init__(self, base_layer: nn.Module, ephemeral_gpu_offload: bool = False, * | |
self.in_features = in_features | ||
self.out_features = out_features | ||
|
||
def resolve_lora_variant(self, *, use_dora: bool, **kwargs) -> Optional[LoraVariant]: | ||
def resolve_lora_variant(self, *, use_dora: bool, use_kasa: bool, **kwargs) -> Optional[LoraVariant]: | ||
"""Return a matching LoRA variant for this layer type. | ||
|
||
Given the init arguments of this layer, return the correct LoRA variant, if any. E.g., if `use_dora=True`, this | ||
|
@@ -173,7 +176,18 @@ def resolve_lora_variant(self, *, use_dora: bool, **kwargs) -> Optional[LoraVari | |
convention, and not here. | ||
|
||
""" | ||
return None | ||
if use_dora and use_kasa: | ||
|
||
raise ValueError("Cannot use DoRA and KaSA at the same time, please choose only one.") | ||
|
||
variant = None | ||
if use_dora: | ||
from .variants import DoraLinearVariant | ||
variant = DoraLinearVariant() | ||
elif use_kasa: | ||
from .variants import KasaLinearVariant | ||
variant = KasaLinearVariant() | ||
|
||
return variant | ||
|
||
def update_layer( | ||
self, | ||
|
@@ -184,6 +198,7 @@ def update_layer( | |
init_lora_weights, | ||
use_rslora, | ||
use_dora: bool = False, | ||
use_kasa: bool = False, | ||
lora_bias: bool = False, | ||
): | ||
# collect the kwargs | ||
|
@@ -194,7 +209,7 @@ def update_layer( | |
if r <= 0: | ||
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}") | ||
|
||
lora_variant = self.resolve_lora_variant(use_dora=use_dora) | ||
lora_variant = self.resolve_lora_variant(use_dora=use_dora, use_kasa=use_kasa) | ||
if lora_variant is not None: | ||
self.lora_variant[adapter_name] = lora_variant | ||
|
||
|
@@ -217,7 +232,20 @@ def update_layer( | |
self.scaling[adapter_name] = lora_alpha / r | ||
|
||
self.use_dora[adapter_name] = use_dora | ||
|
||
############ kasa ############# | ||
self.lora_diag[adapter_name] = nn.Parameter(torch.randn(r), requires_grad=True) | ||
|
||
weight = self.get_base_layer().weight | ||
dtype = weight.dtype | ||
svd_rank = self.in_features - r | ||
weight = weight.to(torch.float32) | ||
U, S, Vh = torch.linalg.svd(weight.data, full_matrices=False) | ||
U_principle, S_principle, Vh_principle = U[:, :svd_rank], S[:svd_rank], Vh[:svd_rank, :] | ||
self.get_base_layer().weight.data = (U_principle @ torch.diag(S_principle) @ Vh_principle).to(dtype) | ||
|
||
######################### | ||
|
||
|
||
# for inits that require access to the base weight, use gather_param_ctx so that the weight is gathered when using DeepSpeed | ||
if isinstance(init_lora_weights, str) and init_lora_weights.startswith("pissa"): | ||
with gather_params_ctx(self.get_base_layer().weight): | ||
|
@@ -571,6 +599,7 @@ def __init__( | |
init_lora_weights: Union[bool, str] = True, | ||
use_rslora: bool = False, | ||
use_dora: bool = False, | ||
use_kasa: bool = False, | ||
lora_bias: bool = False, | ||
**kwargs, | ||
) -> None: | ||
|
@@ -587,6 +616,7 @@ def __init__( | |
init_lora_weights=init_lora_weights, | ||
use_rslora=use_rslora, | ||
use_dora=use_dora, | ||
use_kasa=use_kasa, | ||
lora_bias=lora_bias, | ||
) | ||
self.is_target_conv_1d_layer = is_target_conv_1d_layer | ||
|
@@ -813,7 +843,7 @@ def update_layer( | |
if r <= 0: | ||
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}") | ||
|
||
lora_variant = self.resolve_lora_variant(use_dora=use_dora) | ||
lora_variant = self.resolve_lora_variant(use_dora=use_dora, use_kasa=use_kasa) | ||
if lora_variant is not None: | ||
self.lora_variant[adapter_name] = lora_variant | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,6 +22,7 @@ | |
from peft.utils.other import transpose | ||
|
||
from .dora import DoraConv2dLayer, DoraConv3dLayer, DoraEmbeddingLayer, DoraLinearLayer | ||
from .kasa import KasaLinearLayer | ||
from .layer import Conv2d, Conv3d, Embedding, Linear, LoraVariant, _ConvNd | ||
|
||
|
||
|
@@ -308,3 +309,56 @@ class DoraConv3dVariant(_DoraConvNdVariant): | |
def init(module: Conv3d, adapter_name: str, **kwargs: Any) -> None: | ||
dora_layer = DoraConv3dLayer(fan_in_fan_out=False) | ||
_DoraConvNdVariant.init_convd_variant(module, adapter_name, dora_layer=dora_layer) | ||
|
||
class KasaLinearVariant(LoraVariant): | ||
@staticmethod | ||
def init(module: Linear, adapter_name: str, **kwargs: Any) -> None: | ||
if not module.lora_diag: | ||
nsbg marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
# first kasa layer being added, add lora_diag to the list of learnable parameters | ||
module.adapter_layer_names = module.adapter_layer_names[:] + ("lora_diag",) | ||
|
||
# initialize lora_diag | ||
module.lora_diag[adapter_name] = nn.Parameter(torch.randn(module.r[adapter_name]), requires_grad=True) | ||
|
||
# SVD | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's add a reference here, so that we know the origin: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I put it in here, how is it? |
||
weight = module.get_base_layer().weight # original weight | ||
dtype = weight.dtype | ||
svd_rank = module.in_features - module.r[adapter_name] | ||
weight = weight.to(torch.float32) | ||
U, S, Vh = torch.linalg.svd(weight, full_matrices=False) | ||
U_principle = U[:, :svd_rank] | ||
S_principle = S[:svd_rank] | ||
Vh_principle = Vh[:svd_rank, :] | ||
|
||
module.get_base_layer().weight.data = (U_principle @ torch.diag(S_principle) @ Vh_principle).to(dtype) | ||
|
||
@staticmethod | ||
def merge_safe(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor: | ||
delta_weight = module.get_delta_weight(active_adapter) | ||
return orig_weight + delta_weight | ||
|
||
@staticmethod | ||
def merge_unsafe(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> None: | ||
delta_weight = module.get_delta_weight(active_adapter) | ||
orig_weight.data += delta_weight | ||
|
||
@staticmethod | ||
def unmerge(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor: | ||
delta_weight = module.get_delta_weight(active_adapter) | ||
|
||
return orig_weight - delta_weight | ||
Comment on lines
479
to
510
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. KaSA should have an influence on the merged weights, should it not? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Although this PR is closed, it seems I've incorporated everything else except for this comment (of course, you'd have to look at the code). Could you explain this question in more detail? |
||
|
||
@staticmethod | ||
def forward(module: Linear, active_adapter: str, x: torch.Tensor, result: torch.Tensor) -> torch.Tensor: | ||
nsbg marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
lora_A = module.lora_A[active_adapter] | ||
lora_B = module.lora_B[active_adapter] | ||
dropout = module.lora_dropout[active_adapter] | ||
scaling = module.scaling[active_adapter] | ||
|
||
diag = torch.diag(module.lora_diag[active_adapter]) | ||
x = module._cast_input_dtype(x, lora_A.weight.dtype) | ||
if isinstance(dropout, nn.Identity) or not module.training: | ||
x = dropout(x) | ||
|
||
# KaSA calculation | ||
lora_output = lora_B(torch.einsum('ijk,kl->ijl', lora_A(x), diag)) * scaling | ||
|
||
return result + lora_output |
Uh oh!
There was an error while loading. Please reload this page.