Skip to content
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
387ad2a
Add KaSA implementation to layer.py
nsbg May 17, 2025
ae00a34
Add `use_kasa` argument to LoraConfig
nsbg Aug 2, 2025
6588f1a
Add use_kasa parameter to Linear class
nsbg Sep 2, 2025
a824ac9
Add KasaLinearVariant class (just copy of DoraLinearVariant class) in…
nsbg Sep 2, 2025
05e4e07
Add kasa description
nsbg Sep 2, 2025
d1e7e43
Remove unnecessary self.kasa
nsbg Sep 4, 2025
9e53b9f
[WIP] update KasaLinearVariant class with SVD implementation
nsbg Sep 8, 2025
aa37111
Modify merge/unmerge method in KasaLinearVariant class
nsbg Sep 8, 2025
9cfe65c
update KasaLinearVariant class with SVD implementation
nsbg Sep 8, 2025
f9d7cc7
fix type in init method
nsbg Sep 8, 2025
84813a3
delete unnecessary part in layer.py
nsbg Sep 16, 2025
39abcad
add original reference in layer.py
nsbg Sep 16, 2025
06f76d8
merge main to peft-kasa
nsbg Sep 16, 2025
0043ae3
re-add KaSA implementation to variants.py
nsbg Sep 20, 2025
ea59432
add use_kasa param to resolve_lora_variant in other layers
nsbg Sep 27, 2025
574c1b8
delete unnecessary part in layer.py
nsbg Sep 29, 2025
73fa58f
delete unnecessary part in variants.py
nsbg Sep 29, 2025
b9e3190
add _get_delta_weight static method in KasaLinearVariants class
nsbg Oct 4, 2025
2649fdf
update module.get_delta_weight to KasaLinearVariant._get_delta_weight…
nsbg Oct 7, 2025
d06cafd
add kasa test
nsbg Oct 7, 2025
4ee5da1
add dropout
nsbg Oct 8, 2025
0377170
Update tests/test_custom_models.py
nsbg Oct 11, 2025
a536bbf
Update src/peft/tuners/lora/variants.py
nsbg Oct 11, 2025
f8d8057
Update src/peft/tuners/lora/variants.py
nsbg Oct 11, 2025
cd57c7b
Update src/peft/tuners/lora/variants.py
nsbg Oct 11, 2025
2431a2c
add use_kasa param in LoraModel class
nsbg Oct 11, 2025
7276b3b
restore output_tensor variable in Linear class get_delta_weight method
nsbg Oct 11, 2025
5a67b1f
add use_kasa handling condition in resolve_lora_variant method
nsbg Oct 11, 2025
3ec6b18
fix KaSA self, mat1, mat2 dtype error
nsbg Oct 11, 2025
461a89c
fix make style error
nsbg Oct 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/peft/tuners/lora/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,16 @@ class LoraConfig(PeftConfig):
},
)

use_kasa: bool = field(
default=False,
metadata={
"help": (
"Enable <a href='https://arxiv.org/abs/2412.06071'>'Knowledge-Aware Singular-Value Adaptation of Large Language Models' (KaSA)</a>. This technique leverages "
"singular value decomposition (SVD) with knowledge-aware singular values to dynamically "
"activate parametric knowledge according to its relevance to downstream tasks."
)
}
)
def to_dict(self):
"""
Returns the configuration for your adapter model as a dictionary. Removes runtime configurations.
Expand Down
40 changes: 35 additions & 5 deletions src/peft/tuners/lora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from peft.utils.other import transpose

from .config import LoraConfig

from .dora import DoraConv2dLayer, DoraConv3dLayer, DoraEmbeddingLayer, DoraLinearLayer, _DoraConvNdLayer

class LoraVariant:
"""
Expand Down Expand Up @@ -107,6 +107,9 @@ def __init__(self, base_layer: nn.Module, ephemeral_gpu_offload: bool = False, *
self.cast_input_dtype_enabled: bool = True
self.lora_variant: dict[str, LoraVariant] = {}
self.kwargs = kwargs

# Diag value
self.lora_diag = nn.ParameterDict({})

base_layer = self.get_base_layer()
if isinstance(base_layer, nn.Linear):
Expand Down Expand Up @@ -161,7 +164,7 @@ def __init__(self, base_layer: nn.Module, ephemeral_gpu_offload: bool = False, *
self.in_features = in_features
self.out_features = out_features

def resolve_lora_variant(self, *, use_dora: bool, **kwargs) -> Optional[LoraVariant]:
def resolve_lora_variant(self, *, use_dora: bool, use_kasa: bool, **kwargs) -> Optional[LoraVariant]:
"""Return a matching LoRA variant for this layer type.

Given the init arguments of this layer, return the correct LoRA variant, if any. E.g., if `use_dora=True`, this
Expand All @@ -173,7 +176,18 @@ def resolve_lora_variant(self, *, use_dora: bool, **kwargs) -> Optional[LoraVari
convention, and not here.

"""
return None
if use_dora and use_kasa:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's undo the changes in this method body and return None. Instead, since this KaSA layer is implemented for Linear only, add the logic to lora.Linear.resolve_lora_variant instead.

Also, we should update the resolve_lora_variant methods of the other layer types like lora.Embedding.resolve_lora_variant to accept the use_kasa argument but raise an error if it's True. Otherwise, users may add it to non-supported layers and not notice that it doesn't actually do anything there.

raise ValueError("Cannot use DoRA and KaSA at the same time, please choose only one.")

variant = None
if use_dora:
from .variants import DoraLinearVariant
variant = DoraLinearVariant()
elif use_kasa:
from .variants import KasaLinearVariant
variant = KasaLinearVariant()

return variant

def update_layer(
self,
Expand All @@ -184,6 +198,7 @@ def update_layer(
init_lora_weights,
use_rslora,
use_dora: bool = False,
use_kasa: bool = False,
lora_bias: bool = False,
):
# collect the kwargs
Expand All @@ -194,7 +209,7 @@ def update_layer(
if r <= 0:
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}")

lora_variant = self.resolve_lora_variant(use_dora=use_dora)
lora_variant = self.resolve_lora_variant(use_dora=use_dora, use_kasa=use_kasa)
if lora_variant is not None:
self.lora_variant[adapter_name] = lora_variant

Expand All @@ -217,7 +232,20 @@ def update_layer(
self.scaling[adapter_name] = lora_alpha / r

self.use_dora[adapter_name] = use_dora

############ kasa #############
self.lora_diag[adapter_name] = nn.Parameter(torch.randn(r), requires_grad=True)

weight = self.get_base_layer().weight
dtype = weight.dtype
svd_rank = self.in_features - r
weight = weight.to(torch.float32)
U, S, Vh = torch.linalg.svd(weight.data, full_matrices=False)
U_principle, S_principle, Vh_principle = U[:, :svd_rank], S[:svd_rank], Vh[:svd_rank, :]
self.get_base_layer().weight.data = (U_principle @ torch.diag(S_principle) @ Vh_principle).to(dtype)

#########################
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All of this can be removed, since it's part of KasaLinearVariant.init, right?


# for inits that require access to the base weight, use gather_param_ctx so that the weight is gathered when using DeepSpeed
if isinstance(init_lora_weights, str) and init_lora_weights.startswith("pissa"):
with gather_params_ctx(self.get_base_layer().weight):
Expand Down Expand Up @@ -571,6 +599,7 @@ def __init__(
init_lora_weights: Union[bool, str] = True,
use_rslora: bool = False,
use_dora: bool = False,
use_kasa: bool = False,
lora_bias: bool = False,
**kwargs,
) -> None:
Expand All @@ -587,6 +616,7 @@ def __init__(
init_lora_weights=init_lora_weights,
use_rslora=use_rslora,
use_dora=use_dora,
use_kasa=use_kasa,
lora_bias=lora_bias,
)
self.is_target_conv_1d_layer = is_target_conv_1d_layer
Expand Down Expand Up @@ -813,7 +843,7 @@ def update_layer(
if r <= 0:
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}")

lora_variant = self.resolve_lora_variant(use_dora=use_dora)
lora_variant = self.resolve_lora_variant(use_dora=use_dora, use_kasa=use_kasa)
if lora_variant is not None:
self.lora_variant[adapter_name] = lora_variant

Expand Down
54 changes: 54 additions & 0 deletions src/peft/tuners/lora/variants.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from peft.utils.other import transpose

from .dora import DoraConv2dLayer, DoraConv3dLayer, DoraEmbeddingLayer, DoraLinearLayer
from .kasa import KasaLinearLayer
from .layer import Conv2d, Conv3d, Embedding, Linear, LoraVariant, _ConvNd


Expand Down Expand Up @@ -308,3 +309,56 @@ class DoraConv3dVariant(_DoraConvNdVariant):
def init(module: Conv3d, adapter_name: str, **kwargs: Any) -> None:
dora_layer = DoraConv3dLayer(fan_in_fan_out=False)
_DoraConvNdVariant.init_convd_variant(module, adapter_name, dora_layer=dora_layer)

class KasaLinearVariant(LoraVariant):
@staticmethod
def init(module: Linear, adapter_name: str, **kwargs: Any) -> None:
if not module.lora_diag:
# first kasa layer being added, add lora_diag to the list of learnable parameters
module.adapter_layer_names = module.adapter_layer_names[:] + ("lora_diag",)

# initialize lora_diag
module.lora_diag[adapter_name] = nn.Parameter(torch.randn(module.r[adapter_name]), requires_grad=True)

# SVD
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add a reference here, so that we know the origin:
# see https://github.com/juyongjiang/KaSA/blob/f85e88c22d0fa4cb8ab2923d7c2bf1bbec152da3/peft/src/peft/tuners/lora/layer.py#L132

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

# initialize lora_diag
module.lora_diag[adapter_name] = nn.Parameter(torch.randn(module.r[adapter_name]), requires_grad=True)

# see https://github.com/juyongjiang/KaSA/blob/f85e88c22d0fa4cb8ab2923d7c2bf1bbec152da3/peft/src/peft/tuners/lora/layer.py#L132
        
# SVD

I put it in here, how is it?

weight = module.get_base_layer().weight # original weight
dtype = weight.dtype
svd_rank = module.in_features - module.r[adapter_name]
weight = weight.to(torch.float32)
U, S, Vh = torch.linalg.svd(weight, full_matrices=False)
U_principle = U[:, :svd_rank]
S_principle = S[:svd_rank]
Vh_principle = Vh[:svd_rank, :]

module.get_base_layer().weight.data = (U_principle @ torch.diag(S_principle) @ Vh_principle).to(dtype)

@staticmethod
def merge_safe(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor:
delta_weight = module.get_delta_weight(active_adapter)
return orig_weight + delta_weight

@staticmethod
def merge_unsafe(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> None:
delta_weight = module.get_delta_weight(active_adapter)
orig_weight.data += delta_weight

@staticmethod
def unmerge(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor:
delta_weight = module.get_delta_weight(active_adapter)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of module.get_delta_weight, these methods should call the newly added KasaLinearVariant._get_delta_weight, right?

return orig_weight - delta_weight
Comment on lines 479 to 510
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

KaSA should have an influence on the merged weights, should it not?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although this PR is closed, it seems I've incorporated everything else except for this comment (of course, you'd have to look at the code). Could you explain this question in more detail?


@staticmethod
def forward(module: Linear, active_adapter: str, x: torch.Tensor, result: torch.Tensor) -> torch.Tensor:
lora_A = module.lora_A[active_adapter]
lora_B = module.lora_B[active_adapter]
dropout = module.lora_dropout[active_adapter]
scaling = module.scaling[active_adapter]

diag = torch.diag(module.lora_diag[active_adapter])
x = module._cast_input_dtype(x, lora_A.weight.dtype)
if isinstance(dropout, nn.Identity) or not module.training:
x = dropout(x)

# KaSA calculation
lora_output = lora_B(torch.einsum('ijk,kl->ijl', lora_A(x), diag)) * scaling
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, let's add a reference:
# see https://github.com/juyongjiang/KaSA/blob/f85e88c22d0fa4cb8ab2923d7c2bf1bbec152da3/peft/src/peft/tuners/lora/layer.py#L602C21-L602C110

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

# KaSA calculation
# see https://github.com/juyongjiang/KaSA/blob/f85e88c22d0fa4cb8ab2923d7c2bf1bbec152da3/peft/src/peft/tuners/lora/layer.py#L602C21-L602C110
lora_output = lora_B(torch.einsum('ijk,kl->ijl', lora_A(x), diag)) * scaling
return result + lora_output

I inserted this near where the actual calculation logic begins, rather than just in an empty space. I think this is a bit better.

return result + lora_output