-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Add conv2d support for fourierft and other improvements #2794
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add conv2d support for fourierft and other improvements #2794
Conversation
daa328d
to
017be24
Compare
Added an alpha parameter to dynamically set n_frequency based on the number of parameters of the patched layer. |
Add dynamic_scaling to learn the optimal scaling parameter for each layer as it seemed to be very empirical. EDIT - This is not working, the gradients stay too small for these to change! |
Thanks @frutiemax92 for contributing! And thanks @BenjaminBossan for pinging me! The code seems all right to me. |
5b81b6b
to
41d9012
Compare
Forget the dynamic_scaling, the gradients always stay too small for it to change! |
0f36880
to
017be24
Compare
I've scrapped the dynamic_scaling thing as in practice it didn't work because the gradients were so small. This should be good to merge now. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @frutiemax92 for adding support for nn.Conv2d
to FourierFT. Overall, the PR looks good, I found a few areas for improvement though, please check my comments.
Moreover, let's add tests for FourierFT with conv2d and for setting alpha
. For this, just add a few settings here:
peft/tests/test_custom_models.py
Lines 621 to 665 in 50329a7
############# | |
# FourierFT # | |
############# | |
# FourierFT is not initialized as an identity transform by default, hence set init_weights=True | |
( | |
"Vanilla MLP 1 FourierFT", | |
"MLP", | |
FourierFTConfig, | |
{"n_frequency": 10, "target_modules": "lin0", "init_weights": True}, | |
), | |
( | |
"Vanilla MLP 2 FourierFT", | |
"MLP", | |
FourierFTConfig, | |
{"n_frequency": 10, "target_modules": ["lin0"], "init_weights": True}, | |
), | |
( | |
"Vanilla MLP 3 FourierFT", | |
"MLP", | |
FourierFTConfig, | |
{"n_frequency": 10, "target_modules": ["lin1"], "init_weights": True}, | |
), | |
( | |
"Vanilla MLP 5 FourierFT", | |
"MLP", | |
FourierFTConfig, | |
{"n_frequency": 10, "target_modules": ["lin0"], "modules_to_save": ["lin1"], "init_weights": True}, | |
), | |
( | |
"Vanilla MLP 6 FourierFT", | |
"MLP", | |
FourierFTConfig, | |
{"n_frequency": 10, "target_modules": ["lin0", "lin1"], "modules_to_save": ["lin1"], "init_weights": True}, | |
), | |
( | |
"Vanilla MLP 7 FourierFT", | |
"MLP", | |
FourierFTConfig, | |
{ | |
"n_frequency_pattern": {"lin0": 5, "lin1": 10}, | |
"target_modules": ["lin0", "lin1"], | |
"modules_to_save": ["lin1"], | |
"init_weights": True, | |
}, | |
), |
Here are some example settings for LoRA with conv2d:
peft/tests/test_custom_models.py
Lines 120 to 121 in 50329a7
("Conv2d 1 LoRA", "Conv2d", LoraConfig, {"target_modules": ["conv2d"]}), | |
("Conv2d 2 LoRA", "Conv2d", LoraConfig, {"target_modules": ["conv2d", "lin0"]}), |
@BenjaminBossan Do you think the ifft2 on the dense_spectrum of shape (out_features, in_features, kW, kH) a good choice? I've reread what ifft2 does and it take the last two dimensions and apply operations in batches for the other dimensions. So for a Conv2D layer of (320, 320, 3, 3), it does 320320 operations of ifft2 on a 3x3 2d patch. Is this good or we want something else? i.e. Should we convert that to one huge ifft2 operation on a patch of size (in_featureskW, out_features*kH)? I thought this made less sense because out_features shouldn't be related to in_features... What do you think about it? It seems to work very well this way though from a SDXL finetune I just did. I've checked the delta_w from the Conv2D layers and there are some non-zero entries with mostly zeroes as expected. EDIT2 - Actually, I've sticked more closer to the paper and now it does a ifft2 operation on the whole block i.e. in_featureskW, out_featureskH. This should result in less dead weights. |
0c8dc75
to
017be24
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the updates, this looks quite good, I just saw a few small issues, please check my comments.
Regarding your thoughts on ifft2 efficiency: Great observation. I think with your recent formula, the overhead shouldn't be too large, but we can do some measurements if you think it could still be an issue.
We had a problem with the CI lately, could you please merge with/rebase on the latest main branch? That should resolve it more or less.
Finally, before committing, please don't forget to run make style
.
daf8707
to
9cbc72d
Compare
Ran the make style at the end, not ideal... |
@BenjaminBossan I'm testing this with SANA 1.6B, and there is a 1x1 convolution layer so big it goes out of VRAM on the ifft2 operation. I believe there's not much to be done to fix this though. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the latest updates, implementation-wise we're pretty much finished. I only added some comments regarding tests and docs.
Also, let's update this line to mention Conv2d
support:
- Only `nn.Linear` layers are supported. |
I'm testing this with SANA 1.6B, and there is a 1x1 convolution layer so big it goes out of VRAM on the ifft2 operation. I believe there's not much to be done to fix this though.
I agree, at the end of the day it's the user's responsibility to avoid applying PEFT methods to layers where it doesn't make sense. If this is an easy mistake to make, we can help with documentation or examples, but we cannot cover all possible edge cases.
raise ValueError("Don't set both alpha and n_frequency, as alpha overrides ...") | ||
|
||
if (self.alpha is not None) and (self.n_frequency_pattern != {}): | ||
raise ValueError("Don't set both alpha and n_frequency_pattern, as alpha overrides ...") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't mean to literally put ...
in the error message, I was just too lazy to type it out :-)
Let's replace it with as alpha overrides the latter's value.
.
}, | ||
) | ||
|
||
ifft2_norm: Optional[Literal["backward", "forward", "ortho"]] = field( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ifft2_norm
and alpha
also have to be added to the docstring (you can just copy the same help text).
"help": ( | ||
"The normalization applied for the ifft2 operation." | ||
"It has to be either `backward`, `forward` or `ortho`. See the pytorch documentation for the ifft2 function for more details" | ||
"The default value is `backward`." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"The default value is `backward`." | |
"(https://docs.pytorch.org/docs/stable/generated/torch.fft.ifft2.html). The default value is `backward`." |
"The normalization applied for the ifft2 operation." | ||
"It has to be either `backward`, `forward` or `ortho`. See the pytorch documentation for the ifft2 function for more details" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"The normalization applied for the ifft2 operation." | |
"It has to be either `backward`, `forward` or `ortho`. See the pytorch documentation for the ifft2 function for more details" | |
"The normalization applied for the ifft2 operation. " | |
"It has to be either `backward`, `forward` or `ortho`. See the pytorch documentation for the ifft2 function for more details " |
alpha: float = field( | ||
default=None, | ||
metadata={ | ||
"help": ("The alpha value dynamically sets the n_frequency = int(alpha * out_features * in_features)") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's mention that if this value is passed, users shouldn't set n_frequency
and n_frequency_pattern
.
def set_indices(self, adapter_name: str, n_frequency: int): | ||
super().set_indices(adapter_name, n_frequency) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can be safely removed.
# check for layers_to_transform and layers_pattern | ||
if self.layers_pattern and not self.layers_to_transform: | ||
raise ValueError("When `layers_pattern` is specified, `layers_to_transform` must also be specified. ") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding these checks. Let's ensure that they work correctly by adding a new test class TestFourierFTInitialization
to tests/test_initialization.py
.
Next, I think it's easiest to copy this method from the LoRA tests:
peft/tests/test_initialization.py
Lines 91 to 98 in 190f987
def get_model(self, bias=True): | |
class MyModule(nn.Module): | |
def __init__(self): | |
super().__init__() | |
# choose a large weight so that averages are close to expected values | |
self.linear = nn.Linear(1000, 1000, bias=bias) | |
self.embed = nn.Embedding(1000, 1000) | |
self.conv2d = nn.Conv2d(100, 100, 3, bias=bias) |
Then add two tests, test_fourierft_set_alpha_and_n_frequency_raises
and test_fourierft_set_alpha_and_n_frequency_pattern_raises
. The tests would work analogous to this one:
peft/tests/test_initialization.py
Lines 319 to 326 in 190f987
def test_lora_init_orthogonal_odd_rank_raises(self): | |
torch.manual_seed(0) | |
model = self.get_model() | |
config = LoraConfig(target_modules=["linear"], init_lora_weights="orthogonal", r=7) | |
msg = "Orthogonal initialization requires the LoRA rank to be even, got 7 instead." | |
with pytest.raises(ValueError, match=msg): | |
get_peft_model(model, config) |
I'm currently exploring peft training methods for efficiently training SDXL, and I found the FourierFT method quite interesting. It didn't have support for Conv2D layers, so this commit does this.
To do this, I've added a
FourierFTConv2D
class, so it gets called instead ofFourierFTLinear
when parsing aConv2d
layer. The main key differences are the following:fourierft_spectrum
is now a parameter with shape(n_frequency, kernel_w, kernel_h)
get_delta_weight
returns the same shape as the convolution layer weights.F.conv2d
with the same stride and padding parameters to support the spatial upscaling/downscaling from the unet architecture.The merging is the same as before; we add the weights from get_delta_weight with the original weights.
I've tested the code with my custom training code for SDXL and it is quite dependent on the good
scaling_factor
. The default value of150
is really too high for this model,10
seems to work fine with n_frequency=8000
. I also need to put the initialize_weights option to true, as it would corrupt the image generation process. So when you start training, the convolution results are zeros, but the gradients still flow back to the spectrum parameter of the layer.For information, with those parameters, the .safetensors takes 16.4MB on the disk.