-
Notifications
You must be signed in to change notification settings - Fork 2.2k
ENH: Tie weights for target_modules in Lora (#2864) #2879
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
@BenjaminBossan At a high level
Thank you |
BenjaminBossan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this draft PR to extend the feature to target_modules. I haven't done a full review yet, as some implementation details have yet to be figured out, but I gave some early feedback. This feature could be a bit more difficult to implement than for modules_to_save, I added some comments on why, please check.
src/peft/tuners/lora/model.py
Outdated
| peft_config.modules_to_tie = missing_keys | ||
|
|
||
| def _add_targets_to_tie(self, peft_config, tied_weight_keys): | ||
| target_modules = set(getattr(peft_config, "target_modules", []) or []) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to consider the case that target_modules is a string and not a list of strings. If it's a string, we perform a regex match. Honestly, I'm not sure if there is a good solution. So far, I have 3 ideas:
- We could try to use the
model.targeted_module_namesattribute, which lists all targeted modules after the targets have been resolved. But that would mean that we need to first apply all LoRA layers and only then can we check for tied layers, which is the opposite order of how things are implemented right now. - We could try using the string directly and then for example do something like:
config.target_modules += f"|{missing_key}"but this is very brittle and won't work with all regexes, so I would like to avoid this. - We could forbid using
ensure_weight_tying=Trueandtarget_modules = <str>. Then we'd raise an error and tell users they have to pass a list of str if they wantensure_weight_tying.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, after going through the code a few more times, I realized this would not work for all the cases. I would go with the 1st approach and move the call to this function after model.targeted_module_names is updated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good. This has yet to be updated, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I will do this in the next commit
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moving this after model.targeted_modules_name is populated is tough, as the loop which populates this (https://github.com/huggingface/peft/blob/main/src/peft/tuners/tuners_utils.py#L773-L819) also needs to check and skip if the layers are tied.
Reversing the order would mean that we may end up adding adapters where they're not required. The subsequent code would become more involved, but essentially, we would have to remove adapters from all tied layers, re-add in embed_tokens, and proceed to tie remaining adapters to this. This is an opinionated solve which has the least complexity, according to me.
We can go with (1) in your original comment and redo a few things, or keep the current flow and go with (3).
I think the above might have become tough to follow 😅, so let me know and I can share some schematics. Will wait for your input.
Signed-off-by: romit <[email protected]>
|
@BenjaminBossan This is now ready for review. I have also updated the logic for tied layers in I have also added a few tests for the above case, and all of the tests pass. The only thing remaining is how to check for |
BenjaminBossan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the latest updates. This looks much cleaner now, I think we're approaching the finish line. There were still some issues I had though, so please check my comments.
| return CausalLM().eval().to(self.torch_device) | ||
|
|
||
| def test_weight_tying_tied_model_lora(self): | ||
| @pytest.mark.parametrize("layer", ["lm_head", "embed_tokens", ["lm_head", "embed_tokens"]]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| @pytest.mark.parametrize("layer", ["lm_head", "embed_tokens", ["lm_head", "embed_tokens"]]) | |
| @pytest.mark.parametrize("modules_to_save", [["lm_head"], ["embed_tokens"], ["lm_head", "embed_tokens"]]) |
Let's call this modules_to_save to make it immediately clear what is meant and also put everything into a list, so that we don't need to check if isinstance(layer, list) below. Same for the other tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Part 1 of the comment is not resolved. I think all these tests would benefit from s/layer/modules_to_save/ here.
Signed-off-by: romit <[email protected]>
Co-authored-by: Benjamin Bossan <[email protected]>
Signed-off-by: romit <[email protected]>
|
@BenjaminBossan I have addressed your comments. PTAL |
BenjaminBossan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the updates. I think there are some yet unaddressed comments from before and I also added a few more, please check.
There is, however, a bit of a blocker right now. Currently, a huge PR in transformers is on the way: huggingface/transformers#41580. It is intended to be released soon with transformers v5. A change that might affect us is that _tied_weights_keys will be converted from a list to a dict (with keys being targets and values sources). It could also affect _get_tied_weight_keys. We're still discussing how this will affect PEFT. Possibly it's going to be fine, but we're not sure yet, the PR is still changing.
src/peft/tuners/lora/model.py
Outdated
| peft_config.modules_to_tie = missing_keys | ||
|
|
||
| def _add_targets_to_tie(self, peft_config, tied_weight_keys): | ||
| target_modules = set(getattr(peft_config, "target_modules", []) or []) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good. This has yet to be updated, right?
Co-authored-by: Benjamin Bossan <[email protected]>
Co-authored-by: Benjamin Bossan <[email protected]>
Co-authored-by: Benjamin Bossan <[email protected]>
…to enh/tie-target-modules
Signed-off-by: romit <[email protected]>
| tied_weight_keys = set(tied_weight_keys) | ||
| peft_config.target_modules_to_tie = tied_weight_keys | ||
|
|
||
| raw_target_modules = getattr(peft_config, "target_modules", None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@BenjaminBossan Please review this logic. I know this is a bit hacky! I am open to suggestions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm yeah, this is rough. We can't really operate on the string like this, as there are too many possible ways that the regex could be formed. I wonder if we should just leave it be and deal with the tied module edge case in inject_adapter directly. I haven't fully thought this through, perhaps you already tried that and there is a caveat that I'm missing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should be possible, it would just make the flow very convoluted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I redid this a bit. We just need to make sure that embed_tokens is present in the target_modules
|
@BenjaminBossan please review the latest changes now. I believe I have addressed all your comments, but let me know if I missed something. I have added test cases where we are passing Regarding the transformers v5 update, since we would be having a version locked in peft, I believe if this PR advances faster than that, we can merge this. I can take up changes too whenever they're needed. However, you are much closer to this, so you can decide and let me know. |
|
@BenjaminBossan Done |
|
@BenjaminBossan Let me know if any steps are remaining from my side for final push? |
|
@romitjain No, thank you, let's wait for @githubnemo's review. |
|
Hi @githubnemo, it would be very helpful if you could review the PR. One of our internal features depends on this :) |
|
Hi @githubnemo, gentle reminder. Would really appreciate an update. Thanks |
githubnemo
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey, sorry for the delayed review.
Thanks for working on this, it's quite the gnarly topic :)
I'm not sure if I understood everything correctly so some of my comments may just be me misunderstanding something but I think that there are some places where the match is rather probabilistic.
I wonder if we need to normalize layer names at some point so that we only work with fully-qualified names after that point. For example in _add_modules_to_tie we will look at the modules to save set:
modules_to_save = getattr(peft_config, "modules_to_save", []) or []
I don't think that we have guaranteed fully-qualified names here as they are still user-supplied. IMO it would be worthwhile to first collect the full names of all values in modules_to_save and then check if they are tied to save us from having various places where we do prefix/suffix/infix/whatever comparisons.
src/peft/tuners/lora/model.py
Outdated
|
|
||
| modules_to_save = getattr(peft_config, "modules_to_save", []) or [] | ||
|
|
||
| embed_layer_name = find_parameter_name_by_tensor(self.model, self.model.get_input_embeddings().weight) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think there is no guarantee that this will return the name of the embedding layer. It could also return the name of a layer tied to the embedding layer. It is probably safer to compare module identity instead (even though for transformers <5 this will also be flaky for models like T5).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated to check for module identity instead
src/peft/tuners/lora/model.py
Outdated
| # find_parameter_name_by_tensor returns the parameter name, so we need to strip the weight from the name | ||
| embed_layer_name = embed_layer_name.replace(".weight", "").replace("model.", "") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if replacing these strings is a good idea. encoder_model.embed_tokens would be turned into encoder_embed_tokens. Maybe using a more restricted approach (only one replacement, only if the key is found) would be better? .weight for example could be dropped by using .removesuffix.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated the logic
src/peft/tuners/lora/model.py
Outdated
| if m in modules_to_save: | ||
| modules_to_save.remove(m) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure how often this will generate a match. If I understand correctly, tied_weight_keys are fully-qualified keys. So this check will only match if the keys in modules_to_save are also fully-qualified. I don't think this happens often. cc @BenjaminBossan
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@githubnemo In the flow, I am unable to find any place where we are converting keys in modules_to_save to fully qualified keys.
The only two relevant checks I see are:
- https://github.com/huggingface/peft/blob/main/src/peft/tuners/tuners_utils.py#L1655
- https://github.com/huggingface/peft/blob/main/src/peft/utils/other.py#L1016
Here, we want to make sure that fully qualified keys from tied_weight_keys match the ones in modules_to_save. I propose that I do the following:
For every key in model.named_parameters, I perform a check similar to what is given in (1) and match it with tied_weight_keys. If both of them give a match, I remove the key from modules_to_save
src/peft/tuners/lora/model.py
Outdated
| for m in tied_weight_keys: | ||
| if m in target_modules: | ||
| target_modules.remove(m) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will also only occasionally match, right? Only if users supply the fully-qualified module names.
Co-authored-by: githubnemo <[email protected]>
…to enh/tie-target-modules
|
@githubnemo Thanks for the review, I'll redo the PR a bit based on your suggestions |
|
@githubnemo I redid a few parts. Let me explain: The major concern, according to your comments was that we are doing a check for tied keys (which are fully qualified names (fqn) of the layers) with For For But essentially, we are converting For your suggestion on collecting the fqns first, we will have to break this loop - https://github.com/huggingface/peft/blob/main/src/peft/tuners/tuners_utils.py#L774-L820 where we first generate fqns and then in a second loop add adapters which are aware of the weight tying. I can implement this as well, if you feel this is the best solve. My latest changes (67a71d6) incorporated all of your suggestions and an opinionated solution on Let me know your thoughts. |
cd2e82c to
1cd9137
Compare
Signed-off-by: romit <[email protected]>
1cd9137 to
67a71d6
Compare
|
Hi @githubnemo, gentle reminder for your review :) |
|
@romitjain Sorry for the delay, we will hopefully get this finishes soon. As the corresponding PR for trainable tokens was merged, we now have some merge conflicts, but that should be easily resolved. Could you please take care? |
|
@BenjaminBossan Resolved the merge conflict |
|
Hi @BenjaminBossan/ @githubnemo Thank you |
|
Thanks for updating. From my POV, it's good to go, but since it's a complex topic, let's wait for the 2nd review. |
|
@githubnemo Request you to please review this. Thank you |
| adapter_name: str, | ||
| r: int, | ||
| lora_alpha: int, | ||
| lora_dropout, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this was added by merge commit, let me take a look and remove this if not required
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah this looks like a wrong merge. #2960 refactored this code to only include lora_alpha and config instead of all those params + kwargs. lora_dropout to use_bdlora should be removed.
githubnemo
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @romitjain,
thanks for the updates and keeping up-to-date with main!
I have a few medium sized comments but generally this looks good.
One note regarding the review process: I think it's best to let the reviewer mark the issues as resolved to minimize the chances of missing a marked but still incomplete comment. It also helps me keep track of what I need to revise next time. It would help me a lot if we did it that way. Thanks :)
| adapter_name: str, | ||
| r: int, | ||
| lora_alpha: int, | ||
| lora_dropout, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah this looks like a wrong merge. #2960 refactored this code to only include lora_alpha and config instead of all those params + kwargs. lora_dropout to use_bdlora should be removed.
| return CausalLM().eval().to(self.torch_device) | ||
|
|
||
| def test_weight_tying_tied_model_lora(self): | ||
| @pytest.mark.parametrize("layer", ["lm_head", "embed_tokens", ["lm_head", "embed_tokens"]]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Part 1 of the comment is not resolved. I think all these tests would benefit from s/layer/modules_to_save/ here.
| assert embed_lora_A.data_ptr() == lm_lora_B.data_ptr() | ||
| assert embed_lora_B.data_ptr() == lm_lora_A.data_ptr() | ||
|
|
||
| @pytest.mark.parametrize("layer", [".*embed_tokens$", ".*lm_head$", ".*(embed_tokens|lm_head)$"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's rename layer to target_modules in accordance with https://github.com/huggingface/peft/pull/2879/files#r2486161404
| lora_dropout=lora_dropout, | ||
| init_lora_weights=init_lora_weights, | ||
| use_rslora=use_rslora, | ||
| use_dora=use_dora, | ||
| use_alora=use_alora, | ||
| lora_bias=lora_bias, | ||
| arrow_config=arrow_config, | ||
| tied_adapter=kwargs.pop("tied_adapter", None), | ||
| lora_ga_config=lora_ga_config, | ||
| use_bdlora=use_bdlora, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove, merge artifact
| def find_parameter_name_by_tensor(model: nn.Module, reference_tensor: torch.Tensor) -> str: | ||
| """ | ||
| Find layer name from the model by matching the reference tensor to the model parameters | ||
|
|
||
| Args: | ||
| model (nn.Module): The model with named modules | ||
| reference_tensor (torch.Tensor): The reference tensor to find | ||
|
|
||
| Returns: | ||
| str: Name of the layer | ||
| """ | ||
| for n, m in model.named_modules(): | ||
| if m is reference_tensor: | ||
| return n | ||
|
|
||
| return "" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is a misnomer now. reference_tensor should actually be reference_module. Let's change the function name as well to match the new behavior.
| if embed_layer_name.endswith(".weight"): | ||
| embed_layer_name = embed_layer_name.removesuffix(".weight") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment as above
| prefix, sep, suffix = embed_layer_name.partition(".") | ||
| if sep and "model" in prefix: | ||
| embed_layer_name = suffix |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we split the fully-qualified name and make it a less precise name only to add it to the raw_target_modules where it could have been a fully-qualified name? If there's a good reason then it should be documented in a comment. If it is just escaping, use re.escape(embed_layer_name) and use that result in the new raw_target_modules string.
| if sep and "model" in prefix: | ||
| embed_layer_name = suffix | ||
|
|
||
| if isinstance(raw_target_modules, str): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that target_modules="all-linear" qualifies for this case (technically) but it probably doesn't reach this code path since it doesn't match the embedding layer. Still, I think we should have a test that makes sure that LoraConfig(target_modules="all-linear", ensure_weight_tying=True) doesn't break (if it doesn't exist already, I may have missed it!)
| # Iterate over `tied_weight_keys` which are | ||
| # fully qualified keys and remove matching keys from | ||
| # `target_modules`. It will only remove first encounter | ||
| # in `target_modules`, which should be safe, because `tied_weight_keys` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above, misses the 'why'
| @pytest.mark.parametrize("layer", ["lm_head", "embed_tokens", ["lm_head", "embed_tokens"]]) | ||
| def test_weight_tying_tied_model_target_modules_lora(self, layer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would be worthwhile to have a similar test for encoder-decoder models that attempts to target embed_tokens in T5 or something like that.
Solves #2864 for
target_modulesEnables
ensure_weight_tyingflag inLoraConfigfortarget_modules.For LoRA, if any of the tied layers are added to
target_modulesandensure_weight_tying == True, the adapters added to the layer are shared with all the tied layers.For example, if a model has tied weights and
target_modules=['embed_tokens']then, LoRA adapters are added to bothembed_tokensandlm_head. The adapters inlm_headshare the weights with the adapters added toembed_tokens