-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fine-tuning OpenClip with Hugingface's PEFT (such as LoRA) #761
Comments
Sorry, could you please provide more details? Are you looking for help how to achieve that or are you suggesting that it doesn't work right now? |
Now, Hugingface's PEFT (such as LoRA) can not finetune the linear layer of torch.nn.MultiHeadAttention based transformer model (such as OpenCLIP). If I must use the LoRA, I should replace the torch.nn.MultiHeadAttention layer with a self-implemented naive MultiHeadAttention layer. Can you help to integrate it to the official PEFT lib? |
I see, thanks for explaining. Indeed, right now, it is impossible as a user to change what type of LoRA layer is being used. We have ideas about exposing a "low level" API that would allow users more fine-grained control, including the possibility to allow using custom layers, as you suggest. I cannot say yet if it will really work out and when it's ready, but I'll let you know. |
Thanks for your efforts! |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. |
I'd like to bump this, being unable to put LoRA weights on anything that uses nn.MultiheadAttention is a real pain and using a naive implementation is clunky and cumbersome. Seems strange that LoRA-Torch can do it but not peft. |
Hey, I created a PR to add MHA: #1324. The implementation was a bit tricky because this layer is not very "friendly" for LoRA-adaptation, but I think I got it working. For now, this is just a rough draft, so it would be great if you could test it and tell me if it works your use case. To install from this branch, run:
So far, I did the following testing: import torch
from torch import nn
import open_clip
from peft import LoraConfig, get_peft_model
from PIL import Image
import requests
model, preprocess = open_clip.create_model_from_pretrained('hf-hub:laion/CLIP-ViT-g-14-laion2B-s12B-b42K')
tokenizer = open_clip.get_tokenizer('hf-hub:laion/CLIP-ViT-g-14-laion2B-s12B-b42K')
peft_model = get_peft_model(model, config)
opt = torch.optim.SGD(peft_model.parameters(), 0.1)
# text encoder
text = tokenizer(["a diagram", "a dog", "a cat"])
text_features = peft_model.encode_text(text)
loss = text_features.sum()
loss.backward()
opt.step()
# image encoder
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)
image = preprocess(image).unsqueeze(0)
image_features = model.encode_image(image)
image_features.sum().backward()
opt.step() |
@ambroser53 I think the linked LoRA-torch library has some bugs. For instance: import torch, loratorch
import torch.nn as nn
model_torch = loratorch.Linear(5, 6, r=4, lora_alpha=1)
loratorch.mark_only_lora_as_trainable(model_torch)
print(model_torch.state_dict().keys())
# prints odict_keys(['weight', 'bias', 'w_lora_A', 'w_lora_B'])
optimizer_torch = torch.optim.SGD(model_torch.parameters(), lr=0.1)
for _ in range(3):
model_torch.train()
x = torch.rand(2, 5)
loss2 = model_torch(x).sum()
optimizer_torch.zero_grad()
loss2.backward()
optimizer_torch.step()
print(model_torch.state_dict().keys())
# odict_keys(['bias', 'w_lora_A', 'w_lora_B'])
# note the missing 'weight' key! As you can see, the |
Hey @BenjaminBossan cheers for the fork I'll run some tests on Tuesday. I realised that LoRATorch was a bit buggy after I started trying to combine it with peft's LoraLayer but if there's a way to do it without it that'd be much better. |
@ambroser53 Did you have time to give it a try? |
Hi sorry I meant to get back to you sooner. It appears the layers are placed on the nn.MultiheadAttention blocks just fine on my model. My use case is very complicated though as its a custom architecture so I will need to get back to you on how effective it is and whether the openclip finetuning is bottlenecked or non-performative in some way. Once I have these answers I'll report back. |
Great, thanks for testing. Do you have an ETA for when these tests finish? Regarding performance, I would expect a larger overhead than for simple LoRA layers like |
Should get initial results early next week if theres no disasters. Out of curiousity is said overheard computational or memory? |
Thanks!
It should be computational only. However, since we take the same approach here as LoRA-torch, it shouldn't be better or worse than using that. |
I've dug deeper in my testing. Mine is a very specific case where LoRA weights are only placed on specific layers and the model is mixed quantisation so the placement needed further tinkering. However, now that I've specifically made sure which layers are getting where they need to there's a logic error that seems to only occur some of the time. Essentially, say you have It doesn't look like the Could there be a simple fix to just do the same as there is on |
Thanks a lot @ambroser53, your analysis is 100% correct. I pushed a new commit to the PR that now takes into account As is, we now apply LoRA to both in_proj and out_proj. There is currently no way to specify only I'll be out of office starting next week, so that PR may stall for a while unless one of the other maintainers has time to take over. Still, please try out this new PR and give us feedback if it works for you. |
No that sounds perfect I don't think having one or the other would make sense. I should be able to give it a go now and give results next week. |
Nice. If you can give some early feedback today, I may still have time to react to it :) |
This may be a problem with my own complex set up so could be out of scope here but does peft automatically cast parameters to int8 if the underlying model is loaded in int8? Asking since part of the model is in int8 but the rest is skipped via |
Hmm, normally the weights should not be automatically cast to int8. If you have some way to reproduce this error, I could investigate. Looking at this issue in general, I think, however, that this implementation will not work correctly with quantized weights. As is, we merge the LoRA weights into the base weights. When the latter are quantized, this requires special treatment, similar to the bnb layers we have for LoRA, a normal merge would surely fail. So I think we would need a completely separate MHA class for quantized layers. I'm not exactly sure what it is that you're doing with quantization, but as you've remarked earlier, the |
I understand that but the point is that the MHA aren't quantised at all. The confusing part is that the MHA and |
Ah I see, that is indeed very strange and should not happen.
Can you point me to a reference for |
Here's the code for I'll try and get together a code sample that reproduces (this code I'm referring to right now is a proprietary for a company) |
One more potential bug. It seems that when using get_peft_model on a large model with an MHA inside, it puts the internal parameters (i.e. in_proj_weight and out_proj.weight) in the MHA as requires_grad=True. Its actually really hard to force it it to not be true and I don't quite know why. I wonder whether its because of the nested |
It is very bizarre. The following code is from my script. model.base_model.model.model.vision_model.attn_pool.attn.base_layer.in_proj_weight.requires_grad = False
model.base_model.model.model.vision_model.attn_pool.attn.base_layer.out_proj.base_layer.weight.requires_grad = False
trainable_params = [name for name, param in model.named_parameters() if param.requires_grad]
print(model.base_model.model.model.vision_model.attn_pool.attn.base_layer.in_proj_weight.requires_grad) This outputs true and both the |
This repo is a self contained case that reproduces the error when using the MHA peft branch This takes priority over the int8 stuff. |
Hi @ambroser53 I'm back in office. Thanks a lot for figuring out this bug and providing a reproducer. I could identify the issue and it should now be fixed. When running your example locally, I now get the correct gradients. Please take a look.
This was indeed the case! The reason for this is explained here:
Sorry, did you mean to include a link here? |
@BenjaminBossan I tested this by applying the code in the link below. A runtime error occurs: (has no attribute 'weight') This code works fine for similar clip models like the one below. If you have working sample code(based on official one), I'd be happy to test it.
|
@sailfish009 Sorry, I can't read that blogpost, but it seems to be using some custom code based on (some rather old) PEFT code. In general, if you want to apply LoRA to OpenCLIP, you have to use PEFT based on the PR #1324. This is because OpenCLIP uses |
@BenjaminBossan Thank you. I checked with the branch you provided, and it's working fine. |
I use the LAION huggingface CLIP checkpoints instead of the ones from open clip, then you will be able to use the PEFT package without any effort https://huggingface.co/collections/laion/openclip-laion-2b-64fcade42d20ced4e9389b30 |
Thanks for testing. |
I am not sure why this issue is marked as |
Are there any plans to merge #1324 in the nearest future? |
Thanks for the reminder @mm-tpx. Just to explain: The PR is not merged yet as the solution is kind of hacky (due to how MHA is implemented in torch) and people who tested it have reported a few issues, although I tried my best to address them over time, as witnessed in the discussion. The more confirmation I get that people used this implementation successfully, the higher the confidence that the "hack" works and the PR can be merged. So if you gave that PR a try and it worked for you, please let me know about it. If not, it would be great if you could test it out and give me feedback. |
For me it worked great :) |
Thanks for the feedback. I'll do my best to keep the branch up-to-date so that it remains usable. Hopefully I'll be able to merge it soon, perhaps as an experimental feature. |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. |
Replying to let our robot overlords know this issue is not stale yet and people want this feature :) |
I did not forget about this, it's still on my todo list. |
@BenjaminBossan lora_config = LoraConfig(
r=16,
target_modules=["in_proj_weight"],
lora_alpha=32,
lora_dropout=0.05
) An error occurs as By the way, I download
|
This is a consequence of how multihead attention is implemented and one of the reason it is so complicated to apply LoRA to it.
Note that the PR you mentioned will target the whole multihead attention layer, not just one of |
@BenjaminBossan Previously, I tried with Could you please provide a demo code for how to set the |
I can investigate this issue, but I need the code for this. Could you provide reproducer for this please? I only need the model initialization and merging, no need for the data and training part. |
@BenjaminBossan import open_clip
from peft import LoraConfig, get_peft_model
from peft.tuners.lora.layer import MultiheadAttention as PeftMha
lora_config = LoraConfig(
r=16,
target_modules=["attn"],
lora_alpha=32,
lora_dropout=0.05
)
model, preprocess = open_clip.create_model_from_pretrained(model_name='ViT-L-14-quickgelu', pretrained="PATH-TO-YOUR-MODEL")
tokenizer = open_clip.get_tokenizer('ViT-L-14-quickgelu')
peft_model = get_peft_model(model, lora_config)
print(len([m for m in peft_model.modules() if isinstance(m, PeftMha)])) # 36
peft_model.print_trainable_parameters() # trainable params: 3,244,032 || all params: 430,860,545 || trainable%: 0.7529
peft_model.merge_and_unload()
#peft_model.merge_adapter()
print(peft_model.state_dict().keys()) In my code, I use MetaCLIP via |
Thanks a lot for providing the reproducer. There was indeed a bug in the code, it should now be fixed. Could you try again based on the latest commit? Btw., this line |
@BenjaminBossan import open_clip
import requests
import torch
from torch import nn
from peft import LoraConfig, get_peft_model
from PIL import Image
from peft.tuners.lora.layer import MultiheadAttention as PeftMha
lora_config = LoraConfig(
r=16,
target_modules=["attn"],
lora_alpha=32,
lora_dropout=0.05
)
model, preprocess = open_clip.create_model_from_pretrained(model_name='ViT-L-14-quickgelu', pretrained="CLIP-PATH")
# original model
print(len(model.state_dict().keys())) # 446
print(len(model.visual.state_dict().keys())) # 296
# add LoRA
peft_model = get_peft_model(model, lora_config)
print(len(peft_model.state_dict().keys())) # 590
print(len(peft_model.visual.state_dict().keys())) # 392
print(peft_model.visual.state_dict().keys())
# merge LoRA
merged_model = peft_model.merge_and_unload()
print(len(merged_model.state_dict().keys())) # 374
print(len(merged_model.visual.state_dict().keys())) # 248 As the results show, there are 248 keys in ViT of CLIP after merging, but the original number is 296. When printing the keys, I found that Could you please fix this? |
Thanks for the report. I pushed a new change to the branch that should fix it. Testing your snippet locally, I get the same values now after unloading. |
Thanks for your contribution. |
Thanks for testing @mashijie1028.
Yes, it's planned. Right now there is a blocker that prevents MHA to work with |
Cheers! Thanks again for your contribution to the community! Hope everything goes well. |
@BenjaminBossan @ambroser53 Hi, I have fixed this problem (Baijiong-Lin/LoRA-Torch@e3e20a0). Could you test it again? |
Feature request
fine-tuning OpenClip with Hugingface's PEFT (such as LoRA)
Motivation
fine-tuning OpenClip with Hugingface's PEFT (such as LoRA)
Your contribution
refer to https://github.com/KyanChen/MakeMultiHeadNaive/tree/master for help!
The text was updated successfully, but these errors were encountered: