Skip to content
Open
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
387ad2a
Add KaSA implementation to layer.py
nsbg May 17, 2025
ae00a34
Add `use_kasa` argument to LoraConfig
nsbg Aug 2, 2025
6588f1a
Add use_kasa parameter to Linear class
nsbg Sep 2, 2025
a824ac9
Add KasaLinearVariant class (just copy of DoraLinearVariant class) in…
nsbg Sep 2, 2025
05e4e07
Add kasa description
nsbg Sep 2, 2025
d1e7e43
Remove unnecessary self.kasa
nsbg Sep 4, 2025
9e53b9f
[WIP] update KasaLinearVariant class with SVD implementation
nsbg Sep 8, 2025
aa37111
Modify merge/unmerge method in KasaLinearVariant class
nsbg Sep 8, 2025
9cfe65c
update KasaLinearVariant class with SVD implementation
nsbg Sep 8, 2025
f9d7cc7
fix type in init method
nsbg Sep 8, 2025
84813a3
delete unnecessary part in layer.py
nsbg Sep 16, 2025
39abcad
add original reference in layer.py
nsbg Sep 16, 2025
06f76d8
merge main to peft-kasa
nsbg Sep 16, 2025
0043ae3
re-add KaSA implementation to variants.py
nsbg Sep 20, 2025
ea59432
add use_kasa param to resolve_lora_variant in other layers
nsbg Sep 27, 2025
574c1b8
delete unnecessary part in layer.py
nsbg Sep 29, 2025
73fa58f
delete unnecessary part in variants.py
nsbg Sep 29, 2025
b9e3190
add _get_delta_weight static method in KasaLinearVariants class
nsbg Oct 4, 2025
2649fdf
update module.get_delta_weight to KasaLinearVariant._get_delta_weight…
nsbg Oct 7, 2025
d06cafd
add kasa test
nsbg Oct 7, 2025
4ee5da1
add dropout
nsbg Oct 8, 2025
0377170
Update tests/test_custom_models.py
nsbg Oct 11, 2025
a536bbf
Update src/peft/tuners/lora/variants.py
nsbg Oct 11, 2025
f8d8057
Update src/peft/tuners/lora/variants.py
nsbg Oct 11, 2025
cd57c7b
Update src/peft/tuners/lora/variants.py
nsbg Oct 11, 2025
2431a2c
add use_kasa param in LoraModel class
nsbg Oct 11, 2025
7276b3b
restore output_tensor variable in Linear class get_delta_weight method
nsbg Oct 11, 2025
5a67b1f
add use_kasa handling condition in resolve_lora_variant method
nsbg Oct 11, 2025
3ec6b18
fix KaSA self, mat1, mat2 dtype error
nsbg Oct 11, 2025
461a89c
fix make style error
nsbg Oct 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/peft/tuners/lora/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,16 @@ class LoraConfig(PeftConfig):
default=None, metadata={"help": "The necessary config to apply arrow routing on the model."}
)

use_kasa: bool = field(
default=False,
metadata={
"help": (
"Enable <a href='https://arxiv.org/abs/2412.06071'>'Knowledge-Aware Singular-Value Adaptation of Large Language Models' (KaSA)</a>. This technique leverages "
"singular value decomposition (SVD) with knowledge-aware singular values to dynamically "
"activate parametric knowledge according to its relevance to downstream tasks."
)
}
)
def to_dict(self):
"""
Returns the configuration for your adapter model as a dictionary. Removes runtime configurations.
Expand Down
30 changes: 19 additions & 11 deletions src/peft/tuners/lora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def __init__(self, base_layer: nn.Module, ephemeral_gpu_offload: bool = False, *
self._disable_adapters = False
self.merged_adapters = []
self.use_dora: dict[str, bool] = {} # not actively used anymore after #2443, keep it for BC
self.use_kasa: dict[str, bool] = {}
self.use_rslora: dict[str, bool] = {}
self.lora_bias: dict[str, bool] = {}
self.lora_magnitude_vector = torch.nn.ModuleDict() # for DoRA
Expand Down Expand Up @@ -181,7 +182,7 @@ def __init__(self, base_layer: nn.Module, ephemeral_gpu_offload: bool = False, *
self.in_features = in_features
self.out_features = out_features

def resolve_lora_variant(self, *, use_dora: bool, **kwargs) -> Optional[LoraVariant]:
def resolve_lora_variant(self, *, use_dora: bool, use_kasa: bool, **kwargs) -> Optional[LoraVariant]:
"""Return a matching LoRA variant for this layer type.

Given the init arguments of this layer, return the correct LoRA variant, if any. E.g., if `use_dora=True`, this
Expand All @@ -204,6 +205,7 @@ def update_layer(
init_lora_weights,
use_rslora,
use_dora: bool = False,
use_kasa: bool = False,
use_alora: bool = False,
use_qalora: bool = False,
lora_bias: bool = False,
Expand All @@ -229,11 +231,13 @@ def update_layer(

lora_variant = self.resolve_lora_variant(
use_dora=use_dora,
use_kasa=use_kasa,
use_alora=use_alora,
use_qalora=use_qalora,
qalora_group_size=qalora_group_size,
arrow_config=arrow_config,
)

if lora_variant is not None:
self.lora_variant[adapter_name] = lora_variant

Expand Down Expand Up @@ -664,6 +668,7 @@ def __init__(
init_lora_weights: Union[bool, str] = True,
use_rslora: bool = False,
use_dora: bool = False,
use_kasa: bool = False,
use_alora: bool = False,
arrow_config: ArrowConfig = None,
lora_bias: bool = False,
Expand All @@ -682,14 +687,15 @@ def __init__(
init_lora_weights=init_lora_weights,
use_rslora=use_rslora,
use_dora=use_dora,
use_kasa=use_kasa,
use_alora=use_alora,
lora_bias=lora_bias,
arrow_config=arrow_config,
)
self.is_target_conv_1d_layer = is_target_conv_1d_layer

def resolve_lora_variant(
self, *, arrow_config: ArrowConfig, use_dora: bool, use_alora: bool, **kwargs
self, *, arrow_config: ArrowConfig, use_dora: bool, use_alora: bool, use_kasa: bool, **kwargs
) -> Optional[LoraVariant]:
if arrow_config is not None:
from .variants import ArrowLinearVariant
Expand Down Expand Up @@ -819,8 +825,6 @@ def get_delta_weight(self, adapter) -> torch.Tensor:
weight_A = weight_A.float()
weight_B = weight_B.float()

output_tensor = transpose(weight_B @ weight_A, self.fan_in_fan_out) * self.scaling[adapter]

if cast_to_fp32:
output_tensor = output_tensor.to(dtype=dtype)

Expand Down Expand Up @@ -916,7 +920,7 @@ def __init__(
arrow_config=arrow_config,
)

def resolve_lora_variant(self, *, use_dora: bool, **kwargs) -> Optional[LoraVariant]:
def resolve_lora_variant(self, *, use_dora: bool, use_kasa: bool, **kwargs) -> Optional[LoraVariant]:
if not use_dora:
return None

Expand All @@ -933,6 +937,7 @@ def update_layer(
init_lora_weights,
use_rslora,
use_dora,
use_kasa,
lora_bias,
arrow_config: ArrowConfig = None,
inference_mode: bool = False,
Expand All @@ -945,7 +950,8 @@ def update_layer(
if r <= 0:
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}")

lora_variant = self.resolve_lora_variant(use_dora=use_dora, arrow_config=arrow_config)
lora_variant = self.resolve_lora_variant(use_dora=use_dora, use_kasa=use_kasa, arrow_config=arrow_config)

if lora_variant is not None:
self.lora_variant[adapter_name] = lora_variant

Expand Down Expand Up @@ -1180,6 +1186,7 @@ def __init__(
init_lora_weights: Union[bool, str] = True,
use_rslora: bool = False,
use_dora: bool = False,
use_kasa: bool = False,
arrow_config: ArrowConfig = None,
lora_bias: bool = False,
**kwargs,
Expand Down Expand Up @@ -1222,6 +1229,7 @@ def update_layer(
init_lora_weights,
use_rslora,
use_dora,
use_kasa,
lora_bias,
arrow_config: ArrowConfig = None,
inference_mode: bool = False,
Expand All @@ -1241,7 +1249,7 @@ def update_layer(
PeftWarning,
)

lora_variant = self.resolve_lora_variant(use_dora=use_dora, arrow_config=arrow_config)
lora_variant = self.resolve_lora_variant(use_dora=use_dora, arrow_config=arrow_config, use_kasa=use_kasa)
if lora_variant is not None:
self.lora_variant[adapter_name] = lora_variant

Expand Down Expand Up @@ -1485,7 +1493,7 @@ def __init__(self, *args, **kwargs):
raise ValueError(f"Conv2d layer kernel must have 4 dimensions, not {self._kernel_dim}")
self.conv_fn = F.conv2d

def resolve_lora_variant(self, *, use_dora: bool, **kwargs) -> Optional[LoraVariant]:
def resolve_lora_variant(self, *, use_dora: bool, use_kasa: bool, **kwargs) -> Optional[LoraVariant]:
if not use_dora:
return None

Expand All @@ -1502,7 +1510,7 @@ def __init__(self, *args, **kwargs):
raise ValueError(f"Conv1d layer kernel must have 3 dimensions, not {self._kernel_dim}")
self.conv_fn = F.conv1d

def resolve_lora_variant(self, *, use_dora: bool, **kwargs) -> Optional[LoraVariant]:
def resolve_lora_variant(self, *, use_dora: bool, use_kasa: bool, **kwargs) -> Optional[LoraVariant]:
if not use_dora:
return None

Expand All @@ -1519,7 +1527,7 @@ def __init__(self, *args, **kwargs):
raise ValueError(f"Conv3d layer kernel must have 5 dimensions, not {self._kernel_dim}")
self.conv_fn = F.conv3d

def resolve_lora_variant(self, *, use_dora: bool, **kwargs) -> Optional[LoraVariant]:
def resolve_lora_variant(self, *, use_dora: bool, use_kasa: bool, **kwargs) -> Optional[LoraVariant]:
if not use_dora:
return None

Expand Down Expand Up @@ -2019,7 +2027,7 @@ def update_layer(
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}")

lora_variant = self.resolve_lora_variant(
use_dora=use_dora, use_qalora=use_qalora, qalora_group_size=qalora_group_size
use_dora=use_dora, use_qalora=use_qalora, qalora_group_size=qalora_group_size, use_kasa=use_kasa
)
if lora_variant is not None:
raise ValueError(f"lora.{self.__class__.__name__} does not work with LoRA variants like DoRA.")
Expand Down
76 changes: 76 additions & 0 deletions src/peft/tuners/lora/variants.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,82 @@ def init(module: Conv3d, adapter_name: str, **kwargs: Any) -> None:
_DoraConvNdVariant.init_convd_variant(module, adapter_name, dora_layer=dora_layer)


class KasaLinearVariant(LoraVariant):
@staticmethod
def init(module: Linear, adapter_name: str, **kwargs: Any) -> None:
if not module.lora_diag:
module.adapter_layer_names = module.adapter_layer_names[:] + ("lora_diag",)

# Initialize lora_diag
module.lora_diag[adapter_name] = nn.Parameter(torch.randn(module.r[adapter_name]), requires_grad=True)

# see https://github.com/juyongjiang/KaSA/blob/f85e88c22d0fa4cb8ab2923d7c2bf1bbec152da3/peft/src/peft/tuners/lora/layer.py#L132

# SVD
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add a reference here, so that we know the origin:
# see https://github.com/juyongjiang/KaSA/blob/f85e88c22d0fa4cb8ab2923d7c2bf1bbec152da3/peft/src/peft/tuners/lora/layer.py#L132

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

# initialize lora_diag
module.lora_diag[adapter_name] = nn.Parameter(torch.randn(module.r[adapter_name]), requires_grad=True)

# see https://github.com/juyongjiang/KaSA/blob/f85e88c22d0fa4cb8ab2923d7c2bf1bbec152da3/peft/src/peft/tuners/lora/layer.py#L132
        
# SVD

I put it in here, how is it?

weight = module.get_base_layer().weight
dtype = weight.dtype
svd_rank = module.in_features - module.r[adapter_name]
weight = weight.to(torch.float32)
U, S, Vh = torch.linalg.svd(weight.data, full_matrices=False)
U_principle, S_principle, Vh_principle = U[:, :svd_rank], S[:svd_rank], Vh[:svd_rank, :]
module.get_base_layer().weight.data = (U_principle @ torch.diag(S_principle) @ Vh_principle).to(dtype)

@staticmethod
def _get_delta_weight(weight_A, weight_B, lora_diag, scaling, fan_in_fan_out):
diag = torch.diag(lora_diag)
delta = weight_B @ diag @ weight_A
if fan_in_fan_out:
delta = delta.transpose(0, 1)
delta = delta * scaling
return delta

@staticmethod
def merge_safe(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor:
delta_weight = KasaLinearVariant._get_delta_weight(
module.lora_A[active_adapter].weight,
module.lora_B[active_adapter].weight,
module.lora_diag[active_adapter],
module.scaling[active_adapter],
module.fan_in_fan_out
)
return orig_weight + delta_weight

@staticmethod
def merge_unsafe(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> None:
delta_weight = KasaLinearVariant._get_delta_weight(
module.lora_A[active_adapter].weight,
module.lora_B[active_adapter].weight,
module.lora_diag[active_adapter],
module.scaling[active_adapter],
module.fan_in_fan_out,
)
orig_weight.data += delta_weight

@staticmethod
def unmerge(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor:
delta_weight = KasaLinearVariant._get_delta_weight(
module.lora_A[active_adapter].weight,
module.lora_B[active_adapter].weight,
module.lora_diag[active_adapter],
module.scaling[active_adapter],
module.fan_in_fan_out,
)
return orig_weight - delta_weight

@staticmethod
def forward(module: Linear, active_adapter: str, x: torch.Tensor, result: torch.Tensor) -> torch.Tensor:
lora_A = module.lora_A[active_adapter]
lora_B = module.lora_B[active_adapter]
dropout = module.lora_dropout[active_adapter]
scaling = module.scaling[active_adapter]
diag = torch.diag(module.lora_diag[active_adapter])

# KaSA calculation
# see https://github.com/juyongjiang/KaSA/blob/f85e88c22d0fa4cb8ab2923d7c2bf1bbec152da3/peft/src/peft/tuners/lora/layer.py#L602C21-L602C110
lora_output = lora_B(torch.einsum("ijk,kl->ijl", lora_A(dropout(x)), diag)) * scaling
return result + lora_output


class QALoraLinearVariant(LoraVariant):
@staticmethod
def init(module: Linear, adapter_name: str, **kwargs: Any) -> None:
Expand Down
2 changes: 2 additions & 0 deletions tests/test_custom_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,8 @@
LoraConfig,
{"target_modules": ["lin0"], "target_parameters": ["lin1.weight"]},
),
("Vanilla MLP 7 LoRA with KaSA", "MLP", LoraConfig, {"target_modules": ["lin0"], "use_kasa": True}),
("Vanilla MLP 8 LoRA with KaSA", "MLP", LoraConfig, {"target_modules": ["lin0", "lin1"], "use_kasa": True})
#######
# IA³ #
#######
Expand Down