Skip to content

Commit 0c8dc75

Browse files
committed
Put the spectrum into one big 2D window, not batches of small windows
1 parent 017be24 commit 0c8dc75

File tree

1 file changed

+5
-6
lines changed

1 file changed

+5
-6
lines changed

src/peft/tuners/fourierft/layer.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -244,15 +244,15 @@ def update_layer(
244244
self.fourierft_n_frequency[adapter_name] = n_frequency
245245
self.fourierft_random_loc_seed[adapter_name] = random_loc_seed
246246
self.indices[adapter_name] = torch.randperm(
247-
self.out_features * self.in_features,
247+
self.out_features * self.in_features * self.kW * self.kH,
248248
generator=torch.Generator().manual_seed(self.fourierft_random_loc_seed[adapter_name]),
249249
)[:n_frequency]
250250
self.indices[adapter_name] = torch.stack(
251-
[self.indices[adapter_name] // self.in_features, self.indices[adapter_name] % self.in_features], dim=0
251+
[self.indices[adapter_name] // (self.in_features * self.kW), self.indices[adapter_name] % (self.in_features * self.kW)], dim=0
252252
)
253253
self.fourierft_scaling[adapter_name] = scaling
254254
# Actual trainable parameters
255-
self.fourierft_spectrum[adapter_name] = nn.Parameter(torch.randn(n_frequency, self.kW, self.kH), requires_grad=True)
255+
self.fourierft_spectrum[adapter_name] = nn.Parameter(torch.randn(n_frequency), requires_grad=True)
256256

257257
if init_weights:
258258
self.reset_fourier_parameters(adapter_name)
@@ -310,13 +310,12 @@ def unmerge(self) -> None:
310310
self.get_base_layer().weight.data -= self.get_delta_weight(active_adapter)
311311

312312
def get_delta_weight(self, adapter) -> torch.Tensor:
313-
# careful: ifft2 does not work with float16 or bfloat16
314313
spectrum = self.fourierft_spectrum[adapter]
315314
indices = self.indices[adapter].to(spectrum.device)
316-
dense_spectrum = torch.zeros(self.out_features, self.in_features, self.kW, self.kH, device=spectrum.device)
315+
dense_spectrum = torch.zeros(self.out_features * self.kH, self.in_features * self.kW, device=spectrum.device)
317316
dense_spectrum[indices[0, :], indices[1, :]] = spectrum.float()
318317
delta_weight = torch.fft.ifft2(dense_spectrum).real * self.fourierft_scaling[adapter]
319-
return delta_weight
318+
return delta_weight.reshape((self.out_features, self.in_features, self.kW, self.kH))
320319

321320
def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
322321
previous_dtype = x.dtype

0 commit comments

Comments
 (0)