@@ -232,19 +232,7 @@ def update_layer(
232232 self .scaling [adapter_name ] = lora_alpha / r
233233
234234 self .use_dora [adapter_name ] = use_dora
235-
236- ############ kasa #############
237- self .lora_diag [adapter_name ] = nn .Parameter (torch .randn (r ), requires_grad = True )
238-
239- weight = self .get_base_layer ().weight
240- dtype = weight .dtype
241- svd_rank = self .in_features - r
242- weight = weight .to (torch .float32 )
243- U , S , Vh = torch .linalg .svd (weight .data , full_matrices = False )
244- U_principle , S_principle , Vh_principle = U [:, :svd_rank ], S [:svd_rank ], Vh [:svd_rank , :]
245- self .get_base_layer ().weight .data = (U_principle @ torch .diag (S_principle ) @ Vh_principle ).to (dtype )
246-
247- #########################
235+ self .use_kasa [adapter_name ] = use_kasa
248236
249237 # for inits that require access to the base weight, use gather_param_ctx so that the weight is gathered when using DeepSpeed
250238 if isinstance (init_lora_weights , str ) and init_lora_weights .startswith ("pissa" ):
@@ -733,7 +721,12 @@ def get_delta_weight(self, adapter) -> torch.Tensor:
733721 weight_A = weight_A .float ()
734722 weight_B = weight_B .float ()
735723
736- output_tensor = transpose (weight_B @ weight_A , self .fan_in_fan_out ) * self .scaling [adapter ]
724+ # KaSA handling
725+ if self .use_kasa .get (adapter , False ):
726+ diag = torch .diag (self .lora_diag [adapter ])
727+ output_tensor = transpose (weight_B @ diag @ weight_A , self .fan_in_fan_out ) * self .scaling [adapter ]
728+ else :
729+ output_tensor = transpose (weight_B @ weight_A , self .fan_in_fan_out ) * self .scaling [adapter ]
737730
738731 if cast_to_fp32 :
739732 output_tensor = output_tensor .to (dtype = dtype )
0 commit comments