Skip to content

Commit 39abcad

Browse files
committed
add original reference in layer.py
1 parent 84813a3 commit 39abcad

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

src/peft/tuners/lora/variants.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,8 @@ def init(module: Linear, adapter_name: str, **kwargs: Any) -> None:
320320
# initialize lora_diag
321321
module.lora_diag[adapter_name] = nn.Parameter(torch.randn(module.r[adapter_name]), requires_grad=True)
322322

323+
# see https://github.com/juyongjiang/KaSA/blob/f85e88c22d0fa4cb8ab2923d7c2bf1bbec152da3/peft/src/peft/tuners/lora/layer.py#L132
324+
323325
# SVD
324326
weight = module.get_base_layer().weight # original weight
325327
dtype = weight.dtype
@@ -358,7 +360,8 @@ def forward(module: Linear, active_adapter: str, x: torch.Tensor, result: torch.
358360
x = module._cast_input_dtype(x, lora_A.weight.dtype)
359361
if isinstance(dropout, nn.Identity) or not module.training:
360362
x = dropout(x)
361-
363+
362364
# KaSA calculation
365+
# see https://github.com/juyongjiang/KaSA/blob/f85e88c22d0fa4cb8ab2923d7c2bf1bbec152da3/peft/src/peft/tuners/lora/layer.py#L602C21-L602C110
363366
lora_output = lora_B(torch.einsum('ijk,kl->ijl', lora_A(x), diag)) * scaling
364367
return result + lora_output

0 commit comments

Comments
 (0)