Skip to content

Conversation

@romitjain
Copy link
Contributor

@romitjain romitjain commented Oct 29, 2025

Solves #2864 for target_modules

Enables ensure_weight_tying flag in LoraConfig for target_modules.

For LoRA, if any of the tied layers are added to target_modules and ensure_weight_tying == True, the adapters added to the layer are shared with all the tied layers.

For example, if a model has tied weights and target_modules=['embed_tokens'] then, LoRA adapters are added to both embed_tokens and lm_head. The adapters in lm_head share the weights with the adapters added to embed_tokens

@romitjain
Copy link
Contributor Author

romitjain commented Oct 29, 2025

@BenjaminBossan
I have added the relevant test cases and implemented the ensure_weight_tying flag for target_modules. The current implementation works only if embed_tokens is added and not if lm_head is added. I will implement that fix and update the PR, but meanwhile would appreciate your views on the logic and implementation.

At a high level

  1. I have updated BaseTuner._check_tied_modules to check for tied modules in target_modules
  2. I have added a private method BaseTuner._add_targets_to_tie that needs to be implemented by the inheriting classes
  3. I have added a loop in BaseTuner.inject_adapter to tie the adapters. I have implemented this extra loop to ensure that the order in which adapters are added to the target modules do not matter.

Thank you

@romitjain romitjain changed the title Tie weights for target_modules in Lora Tie weights for target_modules in Lora (#2864) Oct 29, 2025
Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this draft PR to extend the feature to target_modules. I haven't done a full review yet, as some implementation details have yet to be figured out, but I gave some early feedback. This feature could be a bit more difficult to implement than for modules_to_save, I added some comments on why, please check.

peft_config.modules_to_tie = missing_keys

def _add_targets_to_tie(self, peft_config, tied_weight_keys):
target_modules = set(getattr(peft_config, "target_modules", []) or [])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to consider the case that target_modules is a string and not a list of strings. If it's a string, we perform a regex match. Honestly, I'm not sure if there is a good solution. So far, I have 3 ideas:

  1. We could try to use the model.targeted_module_names attribute, which lists all targeted modules after the targets have been resolved. But that would mean that we need to first apply all LoRA layers and only then can we check for tied layers, which is the opposite order of how things are implemented right now.
  2. We could try using the string directly and then for example do something like: config.target_modules += f"|{missing_key}" but this is very brittle and won't work with all regexes, so I would like to avoid this.
  3. We could forbid using ensure_weight_tying=True and target_modules = <str>. Then we'd raise an error and tell users they have to pass a list of str if they want ensure_weight_tying.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, after going through the code a few more times, I realized this would not work for all the cases. I would go with the 1st approach and move the call to this function after model.targeted_module_names is updated

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. This has yet to be updated, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I will do this in the next commit

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@BenjaminBossan

Moving this after model.targeted_modules_name is populated is tough, as the loop which populates this (https://github.com/huggingface/peft/blob/main/src/peft/tuners/tuners_utils.py#L773-L819) also needs to check and skip if the layers are tied.

Reversing the order would mean that we may end up adding adapters where they're not required. The subsequent code would become more involved, but essentially, we would have to remove adapters from all tied layers, re-add in embed_tokens, and proceed to tie remaining adapters to this. This is an opinionated solve which has the least complexity, according to me.

We can go with (1) in your original comment and redo a few things, or keep the current flow and go with (3).

I think the above might have become tough to follow 😅, so let me know and I can share some schematics. Will wait for your input.

@romitjain romitjain marked this pull request as ready for review October 31, 2025 12:11
@romitjain
Copy link
Contributor Author

@BenjaminBossan This is now ready for review. I have also updated the logic for tied layers in modules_to_save so that lm_head and [embed_tokens, lm_head] cases are supported. Earlier, they would not have worked. The high level implementation remains the same but according to me it's much better placed then my earlier commits.

I have also added a few tests for the above case, and all of the tests pass.

The only thing remaining is how to check for target_modules in case it's a string. I will come back to it, but you can go ahead and review the core logic.

@romitjain romitjain changed the title Tie weights for target_modules in Lora (#2864) ENH: Tie weights for target_modules in Lora (#2864) Oct 31, 2025
Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the latest updates. This looks much cleaner now, I think we're approaching the finish line. There were still some issues I had though, so please check my comments.

return CausalLM().eval().to(self.torch_device)

def test_weight_tying_tied_model_lora(self):
@pytest.mark.parametrize("layer", ["lm_head", "embed_tokens", ["lm_head", "embed_tokens"]])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@pytest.mark.parametrize("layer", ["lm_head", "embed_tokens", ["lm_head", "embed_tokens"]])
@pytest.mark.parametrize("modules_to_save", [["lm_head"], ["embed_tokens"], ["lm_head", "embed_tokens"]])

Let's call this modules_to_save to make it immediately clear what is meant and also put everything into a list, so that we don't need to check if isinstance(layer, list) below. Same for the other tests.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Part 1 of the comment is not resolved. I think all these tests would benefit from s/layer/modules_to_save/ here.

@romitjain
Copy link
Contributor Author

@BenjaminBossan I have addressed your comments. PTAL

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates. I think there are some yet unaddressed comments from before and I also added a few more, please check.

There is, however, a bit of a blocker right now. Currently, a huge PR in transformers is on the way: huggingface/transformers#41580. It is intended to be released soon with transformers v5. A change that might affect us is that _tied_weights_keys will be converted from a list to a dict (with keys being targets and values sources). It could also affect _get_tied_weight_keys. We're still discussing how this will affect PEFT. Possibly it's going to be fine, but we're not sure yet, the PR is still changing.

peft_config.modules_to_tie = missing_keys

def _add_targets_to_tie(self, peft_config, tied_weight_keys):
target_modules = set(getattr(peft_config, "target_modules", []) or [])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. This has yet to be updated, right?

tied_weight_keys = set(tied_weight_keys)
peft_config.target_modules_to_tie = tied_weight_keys

raw_target_modules = getattr(peft_config, "target_modules", None)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@BenjaminBossan Please review this logic. I know this is a bit hacky! I am open to suggestions

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm yeah, this is rough. We can't really operate on the string like this, as there are too many possible ways that the regex could be formed. I wonder if we should just leave it be and deal with the tied module edge case in inject_adapter directly. I haven't fully thought this through, perhaps you already tried that and there is a caveat that I'm missing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#2879 (comment)

It should be possible, it would just make the flow very convoluted.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I redid this a bit. We just need to make sure that embed_tokens is present in the target_modules

@romitjain
Copy link
Contributor Author

@BenjaminBossan please review the latest changes now. I believe I have addressed all your comments, but let me know if I missed something.

I have added test cases where we are passing target_modules as str and added a (slightly) hacky solve for that. It's opinionated to keep the flow simple.

Regarding the transformers v5 update, since we would be having a version locked in peft, I believe if this PR advances faster than that, we can merge this. I can take up changes too whenever they're needed. However, you are much closer to this, so you can decide and let me know.

@romitjain
Copy link
Contributor Author

@BenjaminBossan Done

@romitjain
Copy link
Contributor Author

@BenjaminBossan Let me know if any steps are remaining from my side for final push?

@BenjaminBossan
Copy link
Member

@romitjain No, thank you, let's wait for @githubnemo's review.

@romitjain
Copy link
Contributor Author

Hi @githubnemo, it would be very helpful if you could review the PR. One of our internal features depends on this :)

@romitjain
Copy link
Contributor Author

Hi @githubnemo, gentle reminder. Would really appreciate an update. Thanks

Copy link
Collaborator

@githubnemo githubnemo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey, sorry for the delayed review.

Thanks for working on this, it's quite the gnarly topic :)

I'm not sure if I understood everything correctly so some of my comments may just be me misunderstanding something but I think that there are some places where the match is rather probabilistic.

I wonder if we need to normalize layer names at some point so that we only work with fully-qualified names after that point. For example in _add_modules_to_tie we will look at the modules to save set:

modules_to_save = getattr(peft_config, "modules_to_save", []) or []

I don't think that we have guaranteed fully-qualified names here as they are still user-supplied. IMO it would be worthwhile to first collect the full names of all values in modules_to_save and then check if they are tied to save us from having various places where we do prefix/suffix/infix/whatever comparisons.


modules_to_save = getattr(peft_config, "modules_to_save", []) or []

embed_layer_name = find_parameter_name_by_tensor(self.model, self.model.get_input_embeddings().weight)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is no guarantee that this will return the name of the embedding layer. It could also return the name of a layer tied to the embedding layer. It is probably safer to compare module identity instead (even though for transformers <5 this will also be flaky for models like T5).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated to check for module identity instead

Comment on lines 888 to 889
# find_parameter_name_by_tensor returns the parameter name, so we need to strip the weight from the name
embed_layer_name = embed_layer_name.replace(".weight", "").replace("model.", "")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if replacing these strings is a good idea. encoder_model.embed_tokens would be turned into encoder_embed_tokens. Maybe using a more restricted approach (only one replacement, only if the key is found) would be better? .weight for example could be dropped by using .removesuffix.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated the logic

Comment on lines 895 to 896
if m in modules_to_save:
modules_to_save.remove(m)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure how often this will generate a match. If I understand correctly, tied_weight_keys are fully-qualified keys. So this check will only match if the keys in modules_to_save are also fully-qualified. I don't think this happens often. cc @BenjaminBossan

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@githubnemo In the flow, I am unable to find any place where we are converting keys in modules_to_save to fully qualified keys.

The only two relevant checks I see are:

  1. https://github.com/huggingface/peft/blob/main/src/peft/tuners/tuners_utils.py#L1655
  2. https://github.com/huggingface/peft/blob/main/src/peft/utils/other.py#L1016

Here, we want to make sure that fully qualified keys from tied_weight_keys match the ones in modules_to_save. I propose that I do the following:

For every key in model.named_parameters, I perform a check similar to what is given in (1) and match it with tied_weight_keys. If both of them give a match, I remove the key from modules_to_save

Comment on lines 928 to 930
for m in tied_weight_keys:
if m in target_modules:
target_modules.remove(m)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will also only occasionally match, right? Only if users supply the fully-qualified module names.

@romitjain
Copy link
Contributor Author

@githubnemo Thanks for the review, I'll redo the PR a bit based on your suggestions

@romitjain
Copy link
Contributor Author

romitjain commented Dec 16, 2025

@githubnemo I redid a few parts. Let me explain:

The major concern, according to your comments was that we are doing a check for tied keys (which are fully qualified names (fqn) of the layers) with modules_to_save and target_modules which might not be fqn during the comparison.

For modules_to_save, I have added a comment here: #2879 (comment)

For target_modules, I think the alternate flow would be too complex to follow. I explain my thought process in point 3 here: #2879 (comment)

But essentially, we are converting target_modules to fqns while also adding adapters to it.

For your suggestion on collecting the fqns first, we will have to break this loop - https://github.com/huggingface/peft/blob/main/src/peft/tuners/tuners_utils.py#L774-L820 where we first generate fqns and then in a second loop add adapters which are aware of the weight tying. I can implement this as well, if you feel this is the best solve.

My latest changes (67a71d6) incorporated all of your suggestions and an opinionated solution on target_modules matching with tied keys (different from what you proposed).

Let me know your thoughts.

@romitjain romitjain requested a review from githubnemo December 16, 2025 09:38
@romitjain romitjain force-pushed the enh/tie-target-modules branch from cd2e82c to 1cd9137 Compare December 16, 2025 10:07
@romitjain
Copy link
Contributor Author

Hi @githubnemo, gentle reminder for your review :)

@BenjaminBossan
Copy link
Member

@romitjain Sorry for the delay, we will hopefully get this finishes soon. As the corresponding PR for trainable tokens was merged, we now have some merge conflicts, but that should be easily resolved. Could you please take care?

@romitjain
Copy link
Contributor Author

@BenjaminBossan Resolved the merge conflict

@romitjain
Copy link
Contributor Author

Hi @BenjaminBossan/ @githubnemo
Gentle reminder for the review. Let me know if the explanation I added to #2879 makes sense.

Thank you

@BenjaminBossan
Copy link
Member

Thanks for updating. From my POV, it's good to go, but since it's a complex topic, let's wait for the 2nd review.

@romitjain
Copy link
Contributor Author

@githubnemo Request you to please review this.
I have addressed your comments in the comment: #2879 (comment)

Thank you

adapter_name: str,
r: int,
lora_alpha: int,
lora_dropout,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this was added by merge commit, let me take a look and remove this if not required

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this looks like a wrong merge. #2960 refactored this code to only include lora_alpha and config instead of all those params + kwargs. lora_dropout to use_bdlora should be removed.

Copy link
Collaborator

@githubnemo githubnemo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @romitjain,

thanks for the updates and keeping up-to-date with main!

I have a few medium sized comments but generally this looks good.

One note regarding the review process: I think it's best to let the reviewer mark the issues as resolved to minimize the chances of missing a marked but still incomplete comment. It also helps me keep track of what I need to revise next time. It would help me a lot if we did it that way. Thanks :)

adapter_name: str,
r: int,
lora_alpha: int,
lora_dropout,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this looks like a wrong merge. #2960 refactored this code to only include lora_alpha and config instead of all those params + kwargs. lora_dropout to use_bdlora should be removed.

return CausalLM().eval().to(self.torch_device)

def test_weight_tying_tied_model_lora(self):
@pytest.mark.parametrize("layer", ["lm_head", "embed_tokens", ["lm_head", "embed_tokens"]])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Part 1 of the comment is not resolved. I think all these tests would benefit from s/layer/modules_to_save/ here.

assert embed_lora_A.data_ptr() == lm_lora_B.data_ptr()
assert embed_lora_B.data_ptr() == lm_lora_A.data_ptr()

@pytest.mark.parametrize("layer", [".*embed_tokens$", ".*lm_head$", ".*(embed_tokens|lm_head)$"])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's rename layer to target_modules in accordance with https://github.com/huggingface/peft/pull/2879/files#r2486161404

Comment on lines +770 to +779
lora_dropout=lora_dropout,
init_lora_weights=init_lora_weights,
use_rslora=use_rslora,
use_dora=use_dora,
use_alora=use_alora,
lora_bias=lora_bias,
arrow_config=arrow_config,
tied_adapter=kwargs.pop("tied_adapter", None),
lora_ga_config=lora_ga_config,
use_bdlora=use_bdlora,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove, merge artifact

Comment on lines +1980 to +1995
def find_parameter_name_by_tensor(model: nn.Module, reference_tensor: torch.Tensor) -> str:
"""
Find layer name from the model by matching the reference tensor to the model parameters

Args:
model (nn.Module): The model with named modules
reference_tensor (torch.Tensor): The reference tensor to find

Returns:
str: Name of the layer
"""
for n, m in model.named_modules():
if m is reference_tensor:
return n

return ""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a misnomer now. reference_tensor should actually be reference_module. Let's change the function name as well to match the new behavior.

Comment on lines +924 to +925
if embed_layer_name.endswith(".weight"):
embed_layer_name = embed_layer_name.removesuffix(".weight")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as above

Comment on lines +926 to +928
prefix, sep, suffix = embed_layer_name.partition(".")
if sep and "model" in prefix:
embed_layer_name = suffix
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we split the fully-qualified name and make it a less precise name only to add it to the raw_target_modules where it could have been a fully-qualified name? If there's a good reason then it should be documented in a comment. If it is just escaping, use re.escape(embed_layer_name) and use that result in the new raw_target_modules string.

if sep and "model" in prefix:
embed_layer_name = suffix

if isinstance(raw_target_modules, str):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that target_modules="all-linear" qualifies for this case (technically) but it probably doesn't reach this code path since it doesn't match the embedding layer. Still, I think we should have a test that makes sure that LoraConfig(target_modules="all-linear", ensure_weight_tying=True) doesn't break (if it doesn't exist already, I may have missed it!)

Comment on lines +941 to +944
# Iterate over `tied_weight_keys` which are
# fully qualified keys and remove matching keys from
# `target_modules`. It will only remove first encounter
# in `target_modules`, which should be safe, because `tied_weight_keys`
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above, misses the 'why'

Comment on lines +5072 to +5073
@pytest.mark.parametrize("layer", ["lm_head", "embed_tokens", ["lm_head", "embed_tokens"]])
def test_weight_tying_tied_model_target_modules_lora(self, layer):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be worthwhile to have a similar test for encoder-decoder models that attempts to target embed_tokens in T5 or something like that.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants