Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion src/peft/tuners/fourierft/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from __future__ import annotations

from dataclasses import dataclass, field
from typing import Optional, Union
from typing import Literal, Optional, Union

from peft.config import PeftConfig
from peft.utils import PeftType
Expand Down Expand Up @@ -174,6 +174,17 @@ class FourierFTConfig(PeftConfig):
)
},
)

ifft2_norm: Optional[Literal["backward", "forward", "ortho"]] = field(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ifft2_norm and alpha also have to be added to the docstring (you can just copy the same help text).

default_factory="backward",
metadata={
"help": (
"The normalization applied for the ifft2 operation."
"It has to be either `backward`, `forward` or `ortho`. See the pytorch documentation for the ifft2 function for more details"
Comment on lines +182 to +183
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"The normalization applied for the ifft2 operation."
"It has to be either `backward`, `forward` or `ortho`. See the pytorch documentation for the ifft2 function for more details"
"The normalization applied for the ifft2 operation. "
"It has to be either `backward`, `forward` or `ortho`. See the pytorch documentation for the ifft2 function for more details "

"The default value is `backward`."
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"The default value is `backward`."
"(https://docs.pytorch.org/docs/stable/generated/torch.fft.ifft2.html). The default value is `backward`."

)
},
)
init_weights: bool = field(
default=False,
metadata={
Expand All @@ -185,6 +196,13 @@ class FourierFTConfig(PeftConfig):
},
)

alpha: float = field(
default=None,
metadata={
"help": ("The alpha value dynamically sets the n_frequency = int(alpha * out_features * in_features)")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's mention that if this value is passed, users shouldn't set n_frequency and n_frequency_pattern.

},
)

def __post_init__(self):
super().__post_init__()
self.peft_type = PeftType.FOURIERFT
Expand All @@ -204,3 +222,9 @@ def __post_init__(self):
# check for layers_to_transform and layers_pattern
if self.layers_pattern and not self.layers_to_transform:
raise ValueError("When `layers_pattern` is specified, `layers_to_transform` must also be specified. ")

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding these checks. Let's ensure that they work correctly by adding a new test class TestFourierFTInitialization to tests/test_initialization.py.

Next, I think it's easiest to copy this method from the LoRA tests:

def get_model(self, bias=True):
class MyModule(nn.Module):
def __init__(self):
super().__init__()
# choose a large weight so that averages are close to expected values
self.linear = nn.Linear(1000, 1000, bias=bias)
self.embed = nn.Embedding(1000, 1000)
self.conv2d = nn.Conv2d(100, 100, 3, bias=bias)

Then add two tests, test_fourierft_set_alpha_and_n_frequency_raises and test_fourierft_set_alpha_and_n_frequency_pattern_raises. The tests would work analogous to this one:

def test_lora_init_orthogonal_odd_rank_raises(self):
torch.manual_seed(0)
model = self.get_model()
config = LoraConfig(target_modules=["linear"], init_lora_weights="orthogonal", r=7)
msg = "Orthogonal initialization requires the LoRA rank to be even, got 7 instead."
with pytest.raises(ValueError, match=msg):
get_peft_model(model, config)

if (self.alpha is not None) and (self.n_frequency != 1000):
raise ValueError("Don't set both alpha and n_frequency, as alpha overrides ...")

if (self.alpha is not None) and (self.n_frequency_pattern != {}):
raise ValueError("Don't set both alpha and n_frequency_pattern, as alpha overrides ...")
Comment on lines +227 to +230
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't mean to literally put ... in the error message, I was just too lazy to type it out :-)

Let's replace it with as alpha overrides the latter's value..

175 changes: 144 additions & 31 deletions src/peft/tuners/fourierft/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,12 @@ class FourierFTLayer(BaseTunerLayer):
# All names of layers that may contain (trainable) adapter weights
adapter_layer_names = ("fourierft_spectrum",)
# All names of other parameters that may contain adapter-related parameters
other_param_names = ("fourierft_n_frequency", "fourierft_scaling", "fourierft_random_loc_seed")
other_param_names = (
"fourierft_n_frequency",
"fourierft_scaling",
"fourierft_random_loc_seed",
"fourierft_ifft2_norm",
)

def __init__(self, base_layer: nn.Module, **kwargs) -> None:
self.base_layer = base_layer
Expand All @@ -39,6 +44,7 @@ def __init__(self, base_layer: nn.Module, **kwargs) -> None:
# Mark the weight as unmerged
self._disable_adapters = False
self.merged_adapters = []
self.fourierft_ifft2_norm = kwargs["ifft2_norm"]
self.kwargs = kwargs

base_layer = self.get_base_layer()
Expand All @@ -48,6 +54,9 @@ def __init__(self, base_layer: nn.Module, **kwargs) -> None:
self.in_features, self.out_features = (
base_layer.weight.ds_shape if hasattr(base_layer.weight, "ds_shape") else base_layer.weight.shape
)
elif isinstance(base_layer, nn.Conv2d):
self.in_features = base_layer.in_channels
self.out_features = base_layer.out_channels
else:
raise ValueError(f"Unsupported layer type {type(base_layer)}")

Expand All @@ -56,20 +65,22 @@ def update_layer(
):
if n_frequency <= 0:
raise ValueError(f"`n_frequency` should be a positive integer value but the value passed is {n_frequency}")
if n_frequency > self.in_features * self.out_features:

if isinstance(self, FourierFTLinear):
max_freqs = self.in_features * self.out_features
else:
kW = self.base_layer.kernel_size[0]
kH = self.base_layer.kernel_size[1]
max_freqs = self.in_features * self.out_features * kW * kH

if n_frequency >= max_freqs:
raise ValueError(
f"`n_frequency` should be less than or equal to the product of the input and output dimensions "
f"but the value passed is {n_frequency} and the product is {self.in_features * self.out_features}"
f"but the value passed is {n_frequency} and the product is {max_freqs}"
)
self.fourierft_n_frequency[adapter_name] = n_frequency
self.fourierft_random_loc_seed[adapter_name] = random_loc_seed
self.indices[adapter_name] = torch.randperm(
self.out_features * self.in_features,
generator=torch.Generator().manual_seed(self.fourierft_random_loc_seed[adapter_name]),
)[:n_frequency]
self.indices[adapter_name] = torch.stack(
[self.indices[adapter_name] // self.in_features, self.indices[adapter_name] % self.in_features], dim=0
)
self.set_indices(adapter_name, n_frequency)
self.fourierft_scaling[adapter_name] = scaling
# Actual trainable parameters
self.fourierft_spectrum[adapter_name] = nn.Parameter(torch.randn(n_frequency), requires_grad=True)
Expand All @@ -91,29 +102,11 @@ def get_delta_weight(self, adapter) -> torch.Tensor:
indices = self.indices[adapter].to(spectrum.device)
dense_spectrum = torch.zeros(self.out_features, self.in_features, device=spectrum.device)
dense_spectrum[indices[0, :], indices[1, :]] = spectrum.float()
delta_weight = torch.fft.ifft2(dense_spectrum).real * self.fourierft_scaling[adapter]
delta_weight = (
torch.fft.ifft2(dense_spectrum, norm=self.fourierft_ifft2_norm).real * self.fourierft_scaling[adapter]
)
return delta_weight.to(spectrum.dtype)


class FourierFTLinear(nn.Module, FourierFTLayer):
# FourierFT implemented in a dense layer
def __init__(
self,
base_layer,
adapter_name: str,
n_frequency: int = 1000,
scaling: float = 150.0,
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
init_weights: Union[bool, str] = False,
random_loc_seed: int = 777,
**kwargs,
) -> None:
super().__init__()
FourierFTLayer.__init__(self, base_layer, **kwargs)
self.fan_in_fan_out = fan_in_fan_out
self._active_adapter = adapter_name
self.update_layer(adapter_name, n_frequency, scaling, init_weights, random_loc_seed)

def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
"""
Merge the active adapter weights into the base weights
Expand Down Expand Up @@ -163,6 +156,41 @@ def unmerge(self) -> None:
if active_adapter in self.fourierft_spectrum.keys():
self.get_base_layer().weight.data -= self.get_delta_weight(active_adapter)

def set_indices(self, adapter_name: str, n_frequency: int):
self.indices[adapter_name] = torch.randperm(
self.out_features * self.in_features,
generator=torch.Generator().manual_seed(self.fourierft_random_loc_seed[adapter_name]),
)[:n_frequency]
self.indices[adapter_name] = torch.stack(
[self.indices[adapter_name] // self.in_features, self.indices[adapter_name] % self.in_features], dim=0
)


class FourierFTLinear(nn.Module, FourierFTLayer):
# FourierFT implemented in a dense layer
def __init__(
self,
base_layer,
adapter_name: str,
n_frequency: int = 1000,
alpha: Optional[float] = None,
scaling: float = 150.0,
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
init_weights: Union[bool, str] = False,
random_loc_seed: int = 777,
**kwargs,
) -> None:
super().__init__()
FourierFTLayer.__init__(self, base_layer, **kwargs)

# apply alpha patch
if alpha:
n_frequency = int(alpha * self.in_features * self.out_features)

self.fan_in_fan_out = fan_in_fan_out
self._active_adapter = adapter_name
self.update_layer(adapter_name, n_frequency, scaling, init_weights, random_loc_seed)

def get_delta_weight(self, adapter) -> torch.Tensor:
return super().get_delta_weight(adapter)

Expand Down Expand Up @@ -191,3 +219,88 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
def __repr__(self) -> str:
rep = super().__repr__()
return "fourierft." + rep

def set_indices(self, adapter_name: str, n_frequency: int):
super().set_indices(adapter_name, n_frequency)
Comment on lines +223 to +224
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can be safely removed.



class FourierFTConv2D(nn.Module, FourierFTLayer):
# FourierFT implemented in a dense layer
def __init__(
self,
base_layer,
adapter_name: str,
n_frequency: int = 1000,
alpha: Optional[float] = None,
scaling: float = 150.0,
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
init_weights: Union[bool, str] = False,
random_loc_seed: int = 777,
**kwargs,
) -> None:
super().__init__()
FourierFTLayer.__init__(self, base_layer, **kwargs)

self.fan_in_fan_out = fan_in_fan_out
self._active_adapter = adapter_name
kW = base_layer.kernel_size[0]
kH = base_layer.kernel_size[1]

# apply alpha patch
if alpha:
n_frequency = int(alpha * self.in_features * self.out_features * kW * kH)
self.update_layer(adapter_name, n_frequency, scaling, init_weights, random_loc_seed)

def set_indices(self, adapter_name: str, n_frequency: int):
kW = self.base_layer.kernel_size[0]
kH = self.base_layer.kernel_size[1]
self.indices[adapter_name] = torch.randperm(
self.out_features * self.in_features * kW * kH,
generator=torch.Generator().manual_seed(self.fourierft_random_loc_seed[adapter_name]),
)[:n_frequency]
self.indices[adapter_name] = torch.stack(
[
self.indices[adapter_name] // (self.in_features * kW),
self.indices[adapter_name] % (self.in_features * kW),
],
dim=0,
)

def get_delta_weight(self, adapter) -> torch.Tensor:
kW = self.base_layer.kernel_size[0]
kH = self.base_layer.kernel_size[1]
spectrum = self.fourierft_spectrum[adapter]
indices = self.indices[adapter].to(spectrum.device)
dense_spectrum = torch.zeros(self.out_features * kH, self.in_features * kW, device=spectrum.device)
dense_spectrum[indices[0, :], indices[1, :]] = spectrum.float()
delta_weight = (
torch.fft.ifft2(dense_spectrum, norm=self.fourierft_ifft2_norm).real * self.fourierft_scaling[adapter]
)
return torch.reshape(delta_weight, (self.out_features, self.in_features, kW, kH))

def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
previous_dtype = x.dtype

if self.disable_adapters:
if self.merged:
self.unmerge()
result = self.base_layer(x, *args, **kwargs)
elif self.merged:
result = self.base_layer(x, *args, **kwargs)
else:
result = self.base_layer(x, *args, **kwargs)
for active_adapter in self.active_adapters:
if active_adapter not in self.fourierft_spectrum.keys():
continue

delta_w = self.get_delta_weight(active_adapter)
x = x.to(delta_w.dtype)
y = F.conv2d(x, delta_w, stride=self.base_layer.stride, padding=self.base_layer.padding)
result += y

result = result.to(previous_dtype)
return result

def __repr__(self) -> str:
rep = super().__repr__()
return "fourierft." + rep
16 changes: 11 additions & 5 deletions src/peft/tuners/fourierft/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,15 @@
from itertools import chain

import torch
from torch.nn import Conv2d
from transformers.pytorch_utils import Conv1D

from peft.tuners.tuners_utils import BaseTuner, BaseTunerLayer
from peft.utils import (
TRANSFORMERS_MODELS_TO_FOURIERFT_TARGET_MODULES_MAPPING,
)

from .layer import FourierFTLayer, FourierFTLinear
from .layer import FourierFTConv2D, FourierFTLayer, FourierFTLinear


class FourierFTModel(BaseTuner):
Expand Down Expand Up @@ -71,11 +72,15 @@ def _create_and_replace(

n_frequency = fourierft_config.n_frequency_pattern.get(target_name_key, fourierft_config.n_frequency)
scaling = fourierft_config.scaling
alpha = fourierft_config.alpha
ifft2_norm = fourierft_config.ifft2_norm
random_loc_seed = fourierft_config.random_loc_seed
bias = hasattr(target, "bias") and target.bias is not None
kwargs = {
"n_frequency": n_frequency,
"alpha": alpha,
"scaling": scaling,
"ifft2_norm": ifft2_norm,
"fan_in_fan_out": fourierft_config.fan_in_fan_out,
"init_weights": fourierft_config.init_weights,
"random_loc_seed": fourierft_config.random_loc_seed,
Expand Down Expand Up @@ -110,19 +115,20 @@ def _create_new_module(fourierft_config, adapter_name, target, **kwargs):
"Setting fan_in_fan_out to False."
)
kwargs["fan_in_fan_out"] = fourierft_config.fan_in_fan_out = False
new_module = FourierFTLinear(target, adapter_name, **kwargs)
elif isinstance(target_base_layer, Conv1D):
kwargs["is_target_conv_1d_layer"] = True
if not kwargs["fan_in_fan_out"]:
warnings.warn(
"fan_in_fan_out is set to False but the target module is `Conv1D`. Setting fan_in_fan_out to True."
)
kwargs["fan_in_fan_out"] = fourierft_config.fan_in_fan_out = True
new_module = FourierFTLinear(target, adapter_name, **kwargs)
elif isinstance(target_base_layer, Conv2d):
new_module = FourierFTConv2D(target, adapter_name, **kwargs)
else:
raise ValueError(
f"Target module {target} is not supported. Currently, only the following modules are supported: "
"`torch.nn.Linear`."
"`torch.nn.Linear`, `torch.nn.Conv2d`"
)

new_module = FourierFTLinear(target, adapter_name, **kwargs)

return new_module
20 changes: 20 additions & 0 deletions tests/test_custom_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,6 +663,26 @@
"init_weights": True,
},
),
(
"Conv2d 1 FourierFT",
"Conv2d",
FourierFTConfig,
{
"target_modules": ["conv2d"],
"n_frequency": 1000,
},
),
(
"Conv2d 2 FourierFT",
"Conv2d",
FourierFTConfig,
{
"target_modules": ["conv2d", "lin0"],
"alpha": 0.01,
"init_weights": True,
"ifft2_norm": "ortho",
},
),
##########
# VBLoRA #
##########
Expand Down