Skip to content

Commit 84813a3

Browse files
committed
delete unnecessary part in layer.py
1 parent f9d7cc7 commit 84813a3

File tree

1 file changed

+7
-14
lines changed

1 file changed

+7
-14
lines changed

src/peft/tuners/lora/layer.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -232,19 +232,7 @@ def update_layer(
232232
self.scaling[adapter_name] = lora_alpha / r
233233

234234
self.use_dora[adapter_name] = use_dora
235-
236-
############ kasa #############
237-
self.lora_diag[adapter_name] = nn.Parameter(torch.randn(r), requires_grad=True)
238-
239-
weight = self.get_base_layer().weight
240-
dtype = weight.dtype
241-
svd_rank = self.in_features - r
242-
weight = weight.to(torch.float32)
243-
U, S, Vh = torch.linalg.svd(weight.data, full_matrices=False)
244-
U_principle, S_principle, Vh_principle = U[:, :svd_rank], S[:svd_rank], Vh[:svd_rank, :]
245-
self.get_base_layer().weight.data = (U_principle @ torch.diag(S_principle) @ Vh_principle).to(dtype)
246-
247-
#########################
235+
self.use_kasa[adapter_name] = use_kasa
248236

249237
# for inits that require access to the base weight, use gather_param_ctx so that the weight is gathered when using DeepSpeed
250238
if isinstance(init_lora_weights, str) and init_lora_weights.startswith("pissa"):
@@ -733,7 +721,12 @@ def get_delta_weight(self, adapter) -> torch.Tensor:
733721
weight_A = weight_A.float()
734722
weight_B = weight_B.float()
735723

736-
output_tensor = transpose(weight_B @ weight_A, self.fan_in_fan_out) * self.scaling[adapter]
724+
# KaSA handling
725+
if self.use_kasa.get(adapter, False):
726+
diag = torch.diag(self.lora_diag[adapter])
727+
output_tensor = transpose(weight_B @ diag @ weight_A, self.fan_in_fan_out) * self.scaling[adapter]
728+
else:
729+
output_tensor = transpose(weight_B @ weight_A, self.fan_in_fan_out) * self.scaling[adapter]
737730

738731
if cast_to_fp32:
739732
output_tensor = output_tensor.to(dtype=dtype)

0 commit comments

Comments
 (0)