-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Add conv2d support for fourierft and other improvements #2794
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
28b7fd4
bf4dce5
3172bf0
7b408d8
8926877
9cbc72d
2b205f3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -15,7 +15,7 @@ | |||||||||||||||||||||||||||||||||
from __future__ import annotations | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
from dataclasses import dataclass, field | ||||||||||||||||||||||||||||||||||
from typing import Optional, Union | ||||||||||||||||||||||||||||||||||
from typing import Literal, Optional, Union | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
from peft.config import PeftConfig | ||||||||||||||||||||||||||||||||||
from peft.utils import PeftType | ||||||||||||||||||||||||||||||||||
|
@@ -174,6 +174,17 @@ class FourierFTConfig(PeftConfig): | |||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||
}, | ||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
ifft2_norm: Optional[Literal["backward", "forward", "ortho"]] = field( | ||||||||||||||||||||||||||||||||||
default_factory="backward", | ||||||||||||||||||||||||||||||||||
metadata={ | ||||||||||||||||||||||||||||||||||
"help": ( | ||||||||||||||||||||||||||||||||||
"The normalization applied for the ifft2 operation." | ||||||||||||||||||||||||||||||||||
"It has to be either `backward`, `forward` or `ortho`. See the pytorch documentation for the ifft2 function for more details" | ||||||||||||||||||||||||||||||||||
Comment on lines
+182
to
+183
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||||
"The default value is `backward`." | ||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||
}, | ||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||
init_weights: bool = field( | ||||||||||||||||||||||||||||||||||
default=False, | ||||||||||||||||||||||||||||||||||
metadata={ | ||||||||||||||||||||||||||||||||||
|
@@ -185,6 +196,13 @@ class FourierFTConfig(PeftConfig): | |||||||||||||||||||||||||||||||||
}, | ||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
alpha: float = field( | ||||||||||||||||||||||||||||||||||
frutiemax92 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||
default=None, | ||||||||||||||||||||||||||||||||||
metadata={ | ||||||||||||||||||||||||||||||||||
"help": ("The alpha value dynamically sets the n_frequency = int(alpha * out_features * in_features)") | ||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's mention that if this value is passed, users shouldn't set |
||||||||||||||||||||||||||||||||||
}, | ||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
def __post_init__(self): | ||||||||||||||||||||||||||||||||||
super().__post_init__() | ||||||||||||||||||||||||||||||||||
self.peft_type = PeftType.FOURIERFT | ||||||||||||||||||||||||||||||||||
|
@@ -204,3 +222,9 @@ def __post_init__(self): | |||||||||||||||||||||||||||||||||
# check for layers_to_transform and layers_pattern | ||||||||||||||||||||||||||||||||||
if self.layers_pattern and not self.layers_to_transform: | ||||||||||||||||||||||||||||||||||
raise ValueError("When `layers_pattern` is specified, `layers_to_transform` must also be specified. ") | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for adding these checks. Let's ensure that they work correctly by adding a new test class Next, I think it's easiest to copy this method from the LoRA tests: peft/tests/test_initialization.py Lines 91 to 98 in 190f987
Then add two tests, peft/tests/test_initialization.py Lines 319 to 326 in 190f987
|
||||||||||||||||||||||||||||||||||
if (self.alpha is not None) and (self.n_frequency != 1000): | ||||||||||||||||||||||||||||||||||
raise ValueError("Don't set both alpha and n_frequency, as alpha overrides ...") | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
if (self.alpha is not None) and (self.n_frequency_pattern != {}): | ||||||||||||||||||||||||||||||||||
raise ValueError("Don't set both alpha and n_frequency_pattern, as alpha overrides ...") | ||||||||||||||||||||||||||||||||||
Comment on lines
+227
to
+230
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I didn't mean to literally put Let's replace it with |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,7 +27,12 @@ class FourierFTLayer(BaseTunerLayer): | |
# All names of layers that may contain (trainable) adapter weights | ||
adapter_layer_names = ("fourierft_spectrum",) | ||
# All names of other parameters that may contain adapter-related parameters | ||
other_param_names = ("fourierft_n_frequency", "fourierft_scaling", "fourierft_random_loc_seed") | ||
other_param_names = ( | ||
"fourierft_n_frequency", | ||
"fourierft_scaling", | ||
"fourierft_random_loc_seed", | ||
"fourierft_ifft2_norm", | ||
) | ||
|
||
def __init__(self, base_layer: nn.Module, **kwargs) -> None: | ||
self.base_layer = base_layer | ||
|
@@ -39,6 +44,7 @@ def __init__(self, base_layer: nn.Module, **kwargs) -> None: | |
# Mark the weight as unmerged | ||
self._disable_adapters = False | ||
self.merged_adapters = [] | ||
self.fourierft_ifft2_norm = kwargs["ifft2_norm"] | ||
self.kwargs = kwargs | ||
|
||
base_layer = self.get_base_layer() | ||
|
@@ -48,6 +54,9 @@ def __init__(self, base_layer: nn.Module, **kwargs) -> None: | |
self.in_features, self.out_features = ( | ||
base_layer.weight.ds_shape if hasattr(base_layer.weight, "ds_shape") else base_layer.weight.shape | ||
) | ||
elif isinstance(base_layer, nn.Conv2d): | ||
self.in_features = base_layer.in_channels | ||
self.out_features = base_layer.out_channels | ||
else: | ||
raise ValueError(f"Unsupported layer type {type(base_layer)}") | ||
|
||
|
@@ -56,20 +65,22 @@ def update_layer( | |
): | ||
if n_frequency <= 0: | ||
raise ValueError(f"`n_frequency` should be a positive integer value but the value passed is {n_frequency}") | ||
if n_frequency > self.in_features * self.out_features: | ||
|
||
if isinstance(self, FourierFTLinear): | ||
max_freqs = self.in_features * self.out_features | ||
else: | ||
kW = self.base_layer.kernel_size[0] | ||
kH = self.base_layer.kernel_size[1] | ||
max_freqs = self.in_features * self.out_features * kW * kH | ||
|
||
if n_frequency >= max_freqs: | ||
raise ValueError( | ||
f"`n_frequency` should be less than or equal to the product of the input and output dimensions " | ||
f"but the value passed is {n_frequency} and the product is {self.in_features * self.out_features}" | ||
f"but the value passed is {n_frequency} and the product is {max_freqs}" | ||
) | ||
self.fourierft_n_frequency[adapter_name] = n_frequency | ||
self.fourierft_random_loc_seed[adapter_name] = random_loc_seed | ||
self.indices[adapter_name] = torch.randperm( | ||
self.out_features * self.in_features, | ||
generator=torch.Generator().manual_seed(self.fourierft_random_loc_seed[adapter_name]), | ||
)[:n_frequency] | ||
self.indices[adapter_name] = torch.stack( | ||
[self.indices[adapter_name] // self.in_features, self.indices[adapter_name] % self.in_features], dim=0 | ||
) | ||
self.set_indices(adapter_name, n_frequency) | ||
self.fourierft_scaling[adapter_name] = scaling | ||
# Actual trainable parameters | ||
self.fourierft_spectrum[adapter_name] = nn.Parameter(torch.randn(n_frequency), requires_grad=True) | ||
|
@@ -91,29 +102,11 @@ def get_delta_weight(self, adapter) -> torch.Tensor: | |
indices = self.indices[adapter].to(spectrum.device) | ||
dense_spectrum = torch.zeros(self.out_features, self.in_features, device=spectrum.device) | ||
dense_spectrum[indices[0, :], indices[1, :]] = spectrum.float() | ||
delta_weight = torch.fft.ifft2(dense_spectrum).real * self.fourierft_scaling[adapter] | ||
delta_weight = ( | ||
torch.fft.ifft2(dense_spectrum, norm=self.fourierft_ifft2_norm).real * self.fourierft_scaling[adapter] | ||
) | ||
return delta_weight.to(spectrum.dtype) | ||
|
||
|
||
class FourierFTLinear(nn.Module, FourierFTLayer): | ||
# FourierFT implemented in a dense layer | ||
def __init__( | ||
self, | ||
base_layer, | ||
adapter_name: str, | ||
n_frequency: int = 1000, | ||
scaling: float = 150.0, | ||
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) | ||
init_weights: Union[bool, str] = False, | ||
random_loc_seed: int = 777, | ||
**kwargs, | ||
) -> None: | ||
super().__init__() | ||
FourierFTLayer.__init__(self, base_layer, **kwargs) | ||
self.fan_in_fan_out = fan_in_fan_out | ||
self._active_adapter = adapter_name | ||
self.update_layer(adapter_name, n_frequency, scaling, init_weights, random_loc_seed) | ||
|
||
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: | ||
""" | ||
Merge the active adapter weights into the base weights | ||
|
@@ -163,6 +156,41 @@ def unmerge(self) -> None: | |
if active_adapter in self.fourierft_spectrum.keys(): | ||
self.get_base_layer().weight.data -= self.get_delta_weight(active_adapter) | ||
|
||
def set_indices(self, adapter_name: str, n_frequency: int): | ||
self.indices[adapter_name] = torch.randperm( | ||
self.out_features * self.in_features, | ||
generator=torch.Generator().manual_seed(self.fourierft_random_loc_seed[adapter_name]), | ||
)[:n_frequency] | ||
self.indices[adapter_name] = torch.stack( | ||
[self.indices[adapter_name] // self.in_features, self.indices[adapter_name] % self.in_features], dim=0 | ||
) | ||
|
||
|
||
class FourierFTLinear(nn.Module, FourierFTLayer): | ||
# FourierFT implemented in a dense layer | ||
def __init__( | ||
self, | ||
base_layer, | ||
adapter_name: str, | ||
n_frequency: int = 1000, | ||
alpha: Optional[float] = None, | ||
scaling: float = 150.0, | ||
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) | ||
init_weights: Union[bool, str] = False, | ||
random_loc_seed: int = 777, | ||
**kwargs, | ||
) -> None: | ||
super().__init__() | ||
FourierFTLayer.__init__(self, base_layer, **kwargs) | ||
|
||
# apply alpha patch | ||
if alpha: | ||
n_frequency = int(alpha * self.in_features * self.out_features) | ||
|
||
self.fan_in_fan_out = fan_in_fan_out | ||
self._active_adapter = adapter_name | ||
self.update_layer(adapter_name, n_frequency, scaling, init_weights, random_loc_seed) | ||
|
||
def get_delta_weight(self, adapter) -> torch.Tensor: | ||
return super().get_delta_weight(adapter) | ||
|
||
|
@@ -191,3 +219,88 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: | |
def __repr__(self) -> str: | ||
rep = super().__repr__() | ||
return "fourierft." + rep | ||
|
||
def set_indices(self, adapter_name: str, n_frequency: int): | ||
super().set_indices(adapter_name, n_frequency) | ||
Comment on lines
+223
to
+224
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can be safely removed. |
||
|
||
|
||
class FourierFTConv2D(nn.Module, FourierFTLayer): | ||
# FourierFT implemented in a dense layer | ||
def __init__( | ||
self, | ||
base_layer, | ||
adapter_name: str, | ||
n_frequency: int = 1000, | ||
alpha: Optional[float] = None, | ||
scaling: float = 150.0, | ||
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) | ||
init_weights: Union[bool, str] = False, | ||
random_loc_seed: int = 777, | ||
**kwargs, | ||
) -> None: | ||
super().__init__() | ||
FourierFTLayer.__init__(self, base_layer, **kwargs) | ||
|
||
self.fan_in_fan_out = fan_in_fan_out | ||
self._active_adapter = adapter_name | ||
kW = base_layer.kernel_size[0] | ||
kH = base_layer.kernel_size[1] | ||
|
||
# apply alpha patch | ||
if alpha: | ||
n_frequency = int(alpha * self.in_features * self.out_features * kW * kH) | ||
self.update_layer(adapter_name, n_frequency, scaling, init_weights, random_loc_seed) | ||
|
||
def set_indices(self, adapter_name: str, n_frequency: int): | ||
kW = self.base_layer.kernel_size[0] | ||
kH = self.base_layer.kernel_size[1] | ||
self.indices[adapter_name] = torch.randperm( | ||
self.out_features * self.in_features * kW * kH, | ||
generator=torch.Generator().manual_seed(self.fourierft_random_loc_seed[adapter_name]), | ||
)[:n_frequency] | ||
self.indices[adapter_name] = torch.stack( | ||
[ | ||
self.indices[adapter_name] // (self.in_features * kW), | ||
self.indices[adapter_name] % (self.in_features * kW), | ||
], | ||
dim=0, | ||
) | ||
|
||
def get_delta_weight(self, adapter) -> torch.Tensor: | ||
kW = self.base_layer.kernel_size[0] | ||
kH = self.base_layer.kernel_size[1] | ||
spectrum = self.fourierft_spectrum[adapter] | ||
indices = self.indices[adapter].to(spectrum.device) | ||
dense_spectrum = torch.zeros(self.out_features * kH, self.in_features * kW, device=spectrum.device) | ||
dense_spectrum[indices[0, :], indices[1, :]] = spectrum.float() | ||
delta_weight = ( | ||
torch.fft.ifft2(dense_spectrum, norm=self.fourierft_ifft2_norm).real * self.fourierft_scaling[adapter] | ||
) | ||
return torch.reshape(delta_weight, (self.out_features, self.in_features, kW, kH)) | ||
|
||
def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: | ||
previous_dtype = x.dtype | ||
|
||
if self.disable_adapters: | ||
if self.merged: | ||
self.unmerge() | ||
result = self.base_layer(x, *args, **kwargs) | ||
elif self.merged: | ||
result = self.base_layer(x, *args, **kwargs) | ||
else: | ||
result = self.base_layer(x, *args, **kwargs) | ||
for active_adapter in self.active_adapters: | ||
if active_adapter not in self.fourierft_spectrum.keys(): | ||
continue | ||
|
||
delta_w = self.get_delta_weight(active_adapter) | ||
x = x.to(delta_w.dtype) | ||
y = F.conv2d(x, delta_w, stride=self.base_layer.stride, padding=self.base_layer.padding) | ||
result += y | ||
|
||
result = result.to(previous_dtype) | ||
return result | ||
|
||
def __repr__(self) -> str: | ||
rep = super().__repr__() | ||
return "fourierft." + rep |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ifft2_norm
andalpha
also have to be added to the docstring (you can just copy the same help text).