@@ -244,15 +244,15 @@ def update_layer(
244244 self .fourierft_n_frequency [adapter_name ] = n_frequency
245245 self .fourierft_random_loc_seed [adapter_name ] = random_loc_seed
246246 self .indices [adapter_name ] = torch .randperm (
247- self .out_features * self .in_features ,
247+ self .out_features * self .in_features * self . kW * self . kH ,
248248 generator = torch .Generator ().manual_seed (self .fourierft_random_loc_seed [adapter_name ]),
249249 )[:n_frequency ]
250250 self .indices [adapter_name ] = torch .stack (
251- [self .indices [adapter_name ] // self .in_features , self .indices [adapter_name ] % self .in_features ], dim = 0
251+ [self .indices [adapter_name ] // ( self .in_features * self . kW ) , self .indices [adapter_name ] % ( self .in_features * self . kW ) ], dim = 0
252252 )
253253 self .fourierft_scaling [adapter_name ] = scaling
254254 # Actual trainable parameters
255- self .fourierft_spectrum [adapter_name ] = nn .Parameter (torch .randn (n_frequency , self . kW , self . kH ), requires_grad = True )
255+ self .fourierft_spectrum [adapter_name ] = nn .Parameter (torch .randn (n_frequency ), requires_grad = True )
256256
257257 if init_weights :
258258 self .reset_fourier_parameters (adapter_name )
@@ -310,13 +310,12 @@ def unmerge(self) -> None:
310310 self .get_base_layer ().weight .data -= self .get_delta_weight (active_adapter )
311311
312312 def get_delta_weight (self , adapter ) -> torch .Tensor :
313- # careful: ifft2 does not work with float16 or bfloat16
314313 spectrum = self .fourierft_spectrum [adapter ]
315314 indices = self .indices [adapter ].to (spectrum .device )
316- dense_spectrum = torch .zeros (self .out_features , self .in_features , self .kW , self .kH , device = spectrum .device )
315+ dense_spectrum = torch .zeros (self .out_features * self .kH , self .in_features * self .kW , device = spectrum .device )
317316 dense_spectrum [indices [0 , :], indices [1 , :]] = spectrum .float ()
318317 delta_weight = torch .fft .ifft2 (dense_spectrum ).real * self .fourierft_scaling [adapter ]
319- return delta_weight
318+ return delta_weight . reshape (( self . out_features , self . in_features , self . kW , self . kH ))
320319
321320 def forward (self , x : torch .Tensor , * args : Any , ** kwargs : Any ) -> torch .Tensor :
322321 previous_dtype = x .dtype
0 commit comments