Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
387ad2a
Add KaSA implementation to layer.py
nsbg May 17, 2025
ae00a34
Add `use_kasa` argument to LoraConfig
nsbg Aug 2, 2025
6588f1a
Add use_kasa parameter to Linear class
nsbg Sep 2, 2025
a824ac9
Add KasaLinearVariant class (just copy of DoraLinearVariant class) in…
nsbg Sep 2, 2025
05e4e07
Add kasa description
nsbg Sep 2, 2025
d1e7e43
Remove unnecessary self.kasa
nsbg Sep 4, 2025
9e53b9f
[WIP] update KasaLinearVariant class with SVD implementation
nsbg Sep 8, 2025
aa37111
Modify merge/unmerge method in KasaLinearVariant class
nsbg Sep 8, 2025
9cfe65c
update KasaLinearVariant class with SVD implementation
nsbg Sep 8, 2025
f9d7cc7
fix type in init method
nsbg Sep 8, 2025
84813a3
delete unnecessary part in layer.py
nsbg Sep 16, 2025
39abcad
add original reference in layer.py
nsbg Sep 16, 2025
06f76d8
merge main to peft-kasa
nsbg Sep 16, 2025
0043ae3
re-add KaSA implementation to variants.py
nsbg Sep 20, 2025
ea59432
add use_kasa param to resolve_lora_variant in other layers
nsbg Sep 27, 2025
574c1b8
delete unnecessary part in layer.py
nsbg Sep 29, 2025
73fa58f
delete unnecessary part in variants.py
nsbg Sep 29, 2025
b9e3190
add _get_delta_weight static method in KasaLinearVariants class
nsbg Oct 4, 2025
2649fdf
update module.get_delta_weight to KasaLinearVariant._get_delta_weight…
nsbg Oct 7, 2025
d06cafd
add kasa test
nsbg Oct 7, 2025
4ee5da1
add dropout
nsbg Oct 8, 2025
0377170
Update tests/test_custom_models.py
nsbg Oct 11, 2025
a536bbf
Update src/peft/tuners/lora/variants.py
nsbg Oct 11, 2025
f8d8057
Update src/peft/tuners/lora/variants.py
nsbg Oct 11, 2025
cd57c7b
Update src/peft/tuners/lora/variants.py
nsbg Oct 11, 2025
2431a2c
add use_kasa param in LoraModel class
nsbg Oct 11, 2025
7276b3b
restore output_tensor variable in Linear class get_delta_weight method
nsbg Oct 11, 2025
5a67b1f
add use_kasa handling condition in resolve_lora_variant method
nsbg Oct 11, 2025
3ec6b18
fix KaSA self, mat1, mat2 dtype error
nsbg Oct 11, 2025
461a89c
fix make style error
nsbg Oct 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/peft/tuners/lora/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,16 @@ class LoraConfig(PeftConfig):
default=None, metadata={"help": "The necessary config to apply arrow routing on the model."}
)

use_kasa: bool = field(
default=False,
metadata={
"help": (
"Enable <a href='https://arxiv.org/abs/2412.06071'>'Knowledge-Aware Singular-Value Adaptation of Large Language Models' (KaSA)</a>. This technique leverages "
"singular value decomposition (SVD) with knowledge-aware singular values to dynamically "
"activate parametric knowledge according to its relevance to downstream tasks."
)
}
)
def to_dict(self):
"""
Returns the configuration for your adapter model as a dictionary. Removes runtime configurations.
Expand Down
34 changes: 23 additions & 11 deletions src/peft/tuners/lora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def __init__(self, base_layer: nn.Module, ephemeral_gpu_offload: bool = False, *
self.in_features = in_features
self.out_features = out_features

def resolve_lora_variant(self, *, use_dora: bool, **kwargs) -> Optional[LoraVariant]:
def resolve_lora_variant(self, *, use_dora: bool, use_kasa: bool, **kwargs) -> Optional[LoraVariant]:
"""Return a matching LoRA variant for this layer type.

Given the init arguments of this layer, return the correct LoRA variant, if any. E.g., if `use_dora=True`, this
Expand All @@ -204,6 +204,7 @@ def update_layer(
init_lora_weights,
use_rslora,
use_dora: bool = False,
use_kasa: bool = False,
use_alora: bool = False,
use_qalora: bool = False,
lora_bias: bool = False,
Expand All @@ -229,11 +230,13 @@ def update_layer(

lora_variant = self.resolve_lora_variant(
use_dora=use_dora,
use_kasa=use_kasa,
use_alora=use_alora,
use_qalora=use_qalora,
qalora_group_size=qalora_group_size,
arrow_config=arrow_config,
)

if lora_variant is not None:
self.lora_variant[adapter_name] = lora_variant

Expand Down Expand Up @@ -664,6 +667,7 @@ def __init__(
init_lora_weights: Union[bool, str] = True,
use_rslora: bool = False,
use_dora: bool = False,
use_kasa: bool = False,
use_alora: bool = False,
arrow_config: ArrowConfig = None,
lora_bias: bool = False,
Expand All @@ -682,27 +686,30 @@ def __init__(
init_lora_weights=init_lora_weights,
use_rslora=use_rslora,
use_dora=use_dora,
use_kasa=use_kasa,
use_alora=use_alora,
lora_bias=lora_bias,
arrow_config=arrow_config,
)
self.is_target_conv_1d_layer = is_target_conv_1d_layer

def resolve_lora_variant(
self, *, arrow_config: ArrowConfig, use_dora: bool, use_alora: bool, **kwargs
self, *, arrow_config: ArrowConfig, use_dora: bool, use_alora: bool, use_kasa: bool, **kwargs
) -> Optional[LoraVariant]:
if arrow_config is not None:
from .variants import ArrowLinearVariant

return ArrowLinearVariant()

if not use_dora and not use_alora:
if not use_dora and not use_alora and not use_kasa:
return None

from .variants import ALoraLinearVariant, DoraLinearVariant
from .variants import ALoraLinearVariant, DoraLinearVariant, KasaLinearVariant

if use_alora:
return ALoraLinearVariant()
elif use_kasa:
return KasaLinearVariant()
else:
return DoraLinearVariant()

Expand Down Expand Up @@ -916,7 +923,7 @@ def __init__(
arrow_config=arrow_config,
)

def resolve_lora_variant(self, *, use_dora: bool, **kwargs) -> Optional[LoraVariant]:
def resolve_lora_variant(self, *, use_dora: bool, use_kasa: bool, **kwargs) -> Optional[LoraVariant]:
if not use_dora:
return None

Expand All @@ -933,6 +940,7 @@ def update_layer(
init_lora_weights,
use_rslora,
use_dora,
use_kasa,
lora_bias,
arrow_config: ArrowConfig = None,
inference_mode: bool = False,
Expand All @@ -945,7 +953,8 @@ def update_layer(
if r <= 0:
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}")

lora_variant = self.resolve_lora_variant(use_dora=use_dora, arrow_config=arrow_config)
lora_variant = self.resolve_lora_variant(use_dora=use_dora, use_kasa=use_kasa, arrow_config=arrow_config)

if lora_variant is not None:
self.lora_variant[adapter_name] = lora_variant

Expand Down Expand Up @@ -1180,6 +1189,7 @@ def __init__(
init_lora_weights: Union[bool, str] = True,
use_rslora: bool = False,
use_dora: bool = False,
use_kasa: bool = False,
arrow_config: ArrowConfig = None,
lora_bias: bool = False,
**kwargs,
Expand Down Expand Up @@ -1222,6 +1232,7 @@ def update_layer(
init_lora_weights,
use_rslora,
use_dora,
use_kasa,
lora_bias,
arrow_config: ArrowConfig = None,
inference_mode: bool = False,
Expand All @@ -1241,7 +1252,7 @@ def update_layer(
PeftWarning,
)

lora_variant = self.resolve_lora_variant(use_dora=use_dora, arrow_config=arrow_config)
lora_variant = self.resolve_lora_variant(use_dora=use_dora, arrow_config=arrow_config, use_kasa=use_kasa)
if lora_variant is not None:
self.lora_variant[adapter_name] = lora_variant

Expand Down Expand Up @@ -1485,7 +1496,7 @@ def __init__(self, *args, **kwargs):
raise ValueError(f"Conv2d layer kernel must have 4 dimensions, not {self._kernel_dim}")
self.conv_fn = F.conv2d

def resolve_lora_variant(self, *, use_dora: bool, **kwargs) -> Optional[LoraVariant]:
def resolve_lora_variant(self, *, use_dora: bool, use_kasa: bool, **kwargs) -> Optional[LoraVariant]:
if not use_dora:
return None

Expand All @@ -1502,7 +1513,7 @@ def __init__(self, *args, **kwargs):
raise ValueError(f"Conv1d layer kernel must have 3 dimensions, not {self._kernel_dim}")
self.conv_fn = F.conv1d

def resolve_lora_variant(self, *, use_dora: bool, **kwargs) -> Optional[LoraVariant]:
def resolve_lora_variant(self, *, use_dora: bool, use_kasa: bool, **kwargs) -> Optional[LoraVariant]:
if not use_dora:
return None

Expand All @@ -1519,7 +1530,7 @@ def __init__(self, *args, **kwargs):
raise ValueError(f"Conv3d layer kernel must have 5 dimensions, not {self._kernel_dim}")
self.conv_fn = F.conv3d

def resolve_lora_variant(self, *, use_dora: bool, **kwargs) -> Optional[LoraVariant]:
def resolve_lora_variant(self, *, use_dora: bool, use_kasa: bool, **kwargs) -> Optional[LoraVariant]:
if not use_dora:
return None

Expand Down Expand Up @@ -2003,6 +2014,7 @@ def update_layer(
init_lora_weights,
use_rslora,
use_dora: bool = False,
use_kasa: bool = False,
use_qalora: bool = False,
lora_bias: bool = False,
qalora_group_size: int = 32,
Expand All @@ -2019,7 +2031,7 @@ def update_layer(
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}")

lora_variant = self.resolve_lora_variant(
use_dora=use_dora, use_qalora=use_qalora, qalora_group_size=qalora_group_size
use_dora=use_dora, use_qalora=use_qalora, qalora_group_size=qalora_group_size, use_kasa=use_kasa
)
if lora_variant is not None:
raise ValueError(f"lora.{self.__class__.__name__} does not work with LoRA variants like DoRA.")
Expand Down
2 changes: 2 additions & 0 deletions src/peft/tuners/lora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ def _create_and_replace(
"use_dora": lora_config.use_dora,
"use_alora": lora_config.alora_invocation_tokens is not None,
"use_qalora": lora_config.use_qalora,
"use_kasa": lora_config.use_kasa,
"qalora_group_size": lora_config.qalora_group_size,
"ephemeral_gpu_offload": lora_config.runtime_config.ephemeral_gpu_offload,
"lora_bias": lora_config.lora_bias,
Expand Down Expand Up @@ -255,6 +256,7 @@ def _create_and_replace(
init_lora_weights=lora_config.init_lora_weights,
use_rslora=lora_config.use_rslora,
use_dora=lora_config.use_dora,
use_kasa=lora_config.use_kasa,
lora_bias=lora_config.lora_bias,
arrow_config=lora_config.arrow_config,
inference_mode=lora_config.inference_mode,
Expand Down
108 changes: 108 additions & 0 deletions src/peft/tuners/lora/variants.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,114 @@ def init(module: Conv3d, adapter_name: str, **kwargs: Any) -> None:
_DoraConvNdVariant.init_convd_variant(module, adapter_name, dora_layer=dora_layer)


class KasaLinearVariant(LoraVariant):
@staticmethod
def init(module: Linear, adapter_name: str, **kwargs: Any) -> None:
if not hasattr(module, "lora_diag"):
module.lora_diag = nn.ParameterDict()
module.adapter_layer_names = module.adapter_layer_names[:] + ("lora_diag",)

# Initialize lora_diag with the same dtype as the base layer
base_dtype = module.get_base_layer().weight.dtype
module.lora_diag[adapter_name] = nn.Parameter(torch.randn(module.r[adapter_name], dtype=base_dtype), requires_grad=True)

# see https://github.com/juyongjiang/KaSA/blob/f85e88c22d0fa4cb8ab2923d7c2bf1bbec152da3/peft/src/peft/tuners/lora/layer.py#L132

# SVD
weight = module.get_base_layer().weight
dtype = weight.dtype
svd_rank = module.in_features - module.r[adapter_name]
weight = weight.to(torch.float32)
U, S, Vh = torch.linalg.svd(weight.data, full_matrices=False)
U_principle, S_principle, Vh_principle = U[:, :svd_rank], S[:svd_rank], Vh[:svd_rank, :]
reconstructed_weight = U_principle @ torch.diag(S_principle) @ Vh_principle
module.get_base_layer().weight.data = reconstructed_weight.to(dtype)

@staticmethod
def _get_delta_weight(weight_A, weight_B, lora_diag, scaling, fan_in_fan_out):
# Ensure all tensors have the same dtype
target_dtype = weight_A.dtype
weight_B = weight_B.to(target_dtype)
lora_diag = lora_diag.to(target_dtype)

diag = torch.diag(lora_diag)
delta = weight_B @ diag @ weight_A
if fan_in_fan_out:
delta = delta.transpose(0, 1)
delta = delta * scaling
return delta

@staticmethod
def merge_safe(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor:
delta_weight = KasaLinearVariant._get_delta_weight(
module.lora_A[active_adapter].weight,
module.lora_B[active_adapter].weight,
module.lora_diag[active_adapter],
module.scaling[active_adapter],
module.fan_in_fan_out
)
return orig_weight + delta_weight

@staticmethod
def merge_unsafe(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> None:
delta_weight = KasaLinearVariant._get_delta_weight(
module.lora_A[active_adapter].weight,
module.lora_B[active_adapter].weight,
module.lora_diag[active_adapter],
module.scaling[active_adapter],
module.fan_in_fan_out,
)
orig_weight.data += delta_weight

@staticmethod
def unmerge(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor:
delta_weight = KasaLinearVariant._get_delta_weight(
module.lora_A[active_adapter].weight,
module.lora_B[active_adapter].weight,
module.lora_diag[active_adapter],
module.scaling[active_adapter],
module.fan_in_fan_out,
)
return orig_weight - delta_weight

@staticmethod
def forward(module: Linear, active_adapter: str, x: torch.Tensor, result: torch.Tensor, **kwargs) -> torch.Tensor:
# Check if adapters are disabled
if module.disable_adapters:
return result

lora_A = module.lora_A[active_adapter]
lora_B = module.lora_B[active_adapter]
dropout = module.lora_dropout[active_adapter]
scaling = module.scaling[active_adapter]
diag = torch.diag(module.lora_diag[active_adapter])

# KaSA calculation
# see https://github.com/juyongjiang/KaSA/blob/f85e88c22d0fa4cb8ab2923d7c2bf1bbec152da3/peft/src/peft/tuners/lora/layer.py#L602C21-L602C110

# Ensure all tensors have the same dtype as the result
target_dtype = result.dtype
x = x.to(target_dtype)
diag = diag.to(target_dtype)

# Convert LoRA weights to target dtype
lora_A.weight.data = lora_A.weight.data.to(target_dtype)
lora_B.weight.data = lora_B.weight.data.to(target_dtype)

lora_A_output = lora_A(dropout(x))

if x.ndim == 3:
einsum_output = torch.einsum("ijk,kl->ijl", lora_A_output, diag)
lora_output = lora_B(einsum_output) * scaling
elif x.ndim == 2:
matmul_output = lora_A_output @ diag
lora_output = lora_B(matmul_output) * scaling
else:
raise ValueError(f"Using KaSA with inputs of shape {x.ndim} is not supported, only 2 or 3 dims.")

return result + lora_output


class QALoraLinearVariant(LoraVariant):
@staticmethod
def init(module: Linear, adapter_name: str, **kwargs: Any) -> None:
Expand Down
2 changes: 2 additions & 0 deletions tests/test_custom_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,8 @@
LoraConfig,
{"target_modules": ["lin0"], "target_parameters": ["lin1.weight"]},
),
("Vanilla MLP 7 LoRA with KaSA", "MLP", LoraConfig, {"target_modules": ["lin0"], "use_kasa": True}),
("Vanilla MLP 8 LoRA with KaSA", "MLP", LoraConfig, {"target_modules": ["lin0", "lin1"], "use_kasa": True}),
#######
# IA³ #
#######
Expand Down