Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
387ad2a
Add KaSA implementation to layer.py
nsbg May 17, 2025
ae00a34
Add `use_kasa` argument to LoraConfig
nsbg Aug 2, 2025
6588f1a
Add use_kasa parameter to Linear class
nsbg Sep 2, 2025
a824ac9
Add KasaLinearVariant class (just copy of DoraLinearVariant class) in…
nsbg Sep 2, 2025
05e4e07
Add kasa description
nsbg Sep 2, 2025
d1e7e43
Remove unnecessary self.kasa
nsbg Sep 4, 2025
9e53b9f
[WIP] update KasaLinearVariant class with SVD implementation
nsbg Sep 8, 2025
aa37111
Modify merge/unmerge method in KasaLinearVariant class
nsbg Sep 8, 2025
9cfe65c
update KasaLinearVariant class with SVD implementation
nsbg Sep 8, 2025
f9d7cc7
fix type in init method
nsbg Sep 8, 2025
84813a3
delete unnecessary part in layer.py
nsbg Sep 16, 2025
39abcad
add original reference in layer.py
nsbg Sep 16, 2025
06f76d8
merge main to peft-kasa
nsbg Sep 16, 2025
0043ae3
re-add KaSA implementation to variants.py
nsbg Sep 20, 2025
ea59432
add use_kasa param to resolve_lora_variant in other layers
nsbg Sep 27, 2025
574c1b8
delete unnecessary part in layer.py
nsbg Sep 29, 2025
73fa58f
delete unnecessary part in variants.py
nsbg Sep 29, 2025
b9e3190
add _get_delta_weight static method in KasaLinearVariants class
nsbg Oct 4, 2025
2649fdf
update module.get_delta_weight to KasaLinearVariant._get_delta_weight…
nsbg Oct 7, 2025
d06cafd
add kasa test
nsbg Oct 7, 2025
4ee5da1
add dropout
nsbg Oct 8, 2025
0377170
Update tests/test_custom_models.py
nsbg Oct 11, 2025
a536bbf
Update src/peft/tuners/lora/variants.py
nsbg Oct 11, 2025
f8d8057
Update src/peft/tuners/lora/variants.py
nsbg Oct 11, 2025
cd57c7b
Update src/peft/tuners/lora/variants.py
nsbg Oct 11, 2025
2431a2c
add use_kasa param in LoraModel class
nsbg Oct 11, 2025
7276b3b
restore output_tensor variable in Linear class get_delta_weight method
nsbg Oct 11, 2025
5a67b1f
add use_kasa handling condition in resolve_lora_variant method
nsbg Oct 11, 2025
3ec6b18
fix KaSA self, mat1, mat2 dtype error
nsbg Oct 11, 2025
461a89c
fix make style error
nsbg Oct 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/peft/tuners/lora/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,12 @@ class LoraConfig(PeftConfig):
},
)

use_kasa: bool = field(
default=False,
metadata={
"help": ()
}
)
def to_dict(self):
"""
Returns the configuration for your adapter model as a dictionary. Removes runtime configurations.
Expand Down
18 changes: 17 additions & 1 deletion src/peft/tuners/lora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from peft.utils.other import transpose

from .config import LoraConfig

from .dora import DoraConv2dLayer, DoraConv3dLayer, DoraEmbeddingLayer, DoraLinearLayer, _DoraConvNdLayer

class LoraVariant:
"""
Expand Down Expand Up @@ -107,6 +107,9 @@ def __init__(self, base_layer: nn.Module, ephemeral_gpu_offload: bool = False, *
self.cast_input_dtype_enabled: bool = True
self.lora_variant: dict[str, LoraVariant] = {}
self.kwargs = kwargs

# Diag value
self.lora_diag = nn.ParameterDict({})

base_layer = self.get_base_layer()
if isinstance(base_layer, nn.Linear):
Expand Down Expand Up @@ -217,7 +220,20 @@ def update_layer(
self.scaling[adapter_name] = lora_alpha / r

self.use_dora[adapter_name] = use_dora

############ kasa #############
self.lora_diag[adapter_name] = nn.Parameter(torch.randn(r), requires_grad=True)

weight = self.get_base_layer().weight
dtype = weight.dtype
svd_rank = self.in_features - r
weight = weight.to(torch.float32)
U, S, Vh = torch.linalg.svd(weight.data, full_matrices=False)
U_principle, S_principle, Vh_principle = U[:, :svd_rank], S[:svd_rank], Vh[:svd_rank, :]
self.get_base_layer().weight.data = (U_principle @ torch.diag(S_principle) @ Vh_principle).to(dtype)

#########################
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All of this can be removed, since it's part of KasaLinearVariant.init, right?


# for inits that require access to the base weight, use gather_param_ctx so that the weight is gathered when using DeepSpeed
if isinstance(init_lora_weights, str) and init_lora_weights.startswith("pissa"):
with gather_params_ctx(self.get_base_layer().weight):
Expand Down