Skip to content

Conversation

romitjain
Copy link

@romitjain romitjain commented Sep 29, 2025

This PR solves: #2777

Essentially, the PR:

  1. Adds an argument in LoraConfig for users to maintain weight tying in the model
  2. Checks if the passed model has weight tying enabled, gets the layers that are tied tothe embedding layer
  3. Add a ModulesToSaveWrapper on modules returned in (2) and a reference to the original layer

This ensures that the weight tying is enabled if the user passes either embed_tokens or lm_head.

Limitations

  1. Only added for LoraConfig
  2. Only added for modules_to_save parameter. This can be extended to target_modules as well once the approach is approved

@romitjain romitjain marked this pull request as ready for review September 29, 2025 08:03
@romitjain
Copy link
Author

@BenjaminBossan Request your review on the approach. If this looks fine, I can handle the other case (target_modules) and add tests.

Thanks

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on the PR to deal with targeting tied weights.

I have a suspicion that parts of this PR can be implemented more easily by extending the peft_config.modules_to_save to include the tied module instead of calling _set_trainable explicitly.

However, before checking that, it makes most sense to add one or a few unit tests to establish if everything works as expected. I think the best place for those would be in this test class. LMK if you need help with implementing the tests.

Once we have the tests and confirm that everything works as expected, we can check if my proposal above works.

@romitjain
Copy link
Author

@BenjaminBossan I have addressed your comments. You mentioned that I can extend peft_config.modules_to_save to enable the functionality. Can you elaborate on that?

Essentially, we would need to mark the tied modules as modules to save and then update the weight pointer to point to the source adapter.

PTAL.
Thanks

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding the tests, they LGTM.

You mentioned that I can extend peft_config.modules_to_save to enable the functionality. Can you elaborate on that?

Yes, so my idea was the following: I would like to avoid spreading the _set_trainable calls further, as it makes it harder to reason what is set where. Instead, let's detect early if additional weights need to be added for modules_to_save and extend the peft_config if so. Then, when the PEFT model is created, we will automatically have a ModulesToSaveWrapper everywhere it's needed. Next, let's update _set_trainable such that it ties the weights of ModulesToSaveWrapper (it is a bit wasteful to do that post hoc, but IMO more robust, as we don't rely on the order of the tied modules being created).

For this to work, I had to make a small change to the signature of _set_trainable and move some of the logic you implemented around to different locations. Overall, this is the diff I came up with:

diff --git a/src/peft/mixed_model.py b/src/peft/mixed_model.py
index c3ad517e..3d608642 100644
--- a/src/peft/mixed_model.py
+++ b/src/peft/mixed_model.py
@@ -247,7 +247,7 @@ class PeftMixedModel(PushToHubMixin, torch.nn.Module):
             self,
             adapter_name,
             module_names=getattr(peft_config, "modules_to_save", None),
-            inference_mode=peft_config.inference_mode,
+            peft_config=peft_config,
         )
 
     def set_adapter(self, adapter_name: Union[str, list[str]], inference_mode: bool = False) -> None:
diff --git a/src/peft/peft_model.py b/src/peft/peft_model.py
index 36eccabb..73596487 100644
--- a/src/peft/peft_model.py
+++ b/src/peft/peft_model.py
@@ -1626,7 +1626,7 @@ class PeftModelForSequenceClassification(PeftModel):
             self,
             adapter_name,
             module_names=getattr(peft_config, "modules_to_save", None),
-            inference_mode=peft_config.inference_mode,
+            peft_config=peft_config,
         )
 
     def add_adapter(self, adapter_name: str, peft_config: PeftConfig, low_cpu_mem_usage: bool = False) -> None:
@@ -2482,7 +2482,7 @@ class PeftModelForTokenClassification(PeftModel):
             self,
             adapter_name,
             module_names=getattr(peft_config, "modules_to_save", None),
-            inference_mode=peft_config.inference_mode,
+            peft_config=peft_config,
         )
 
     def add_adapter(self, adapter_name: str, peft_config: PeftConfig, low_cpu_mem_usage: bool = False) -> None:
@@ -2703,7 +2703,7 @@ class PeftModelForQuestionAnswering(PeftModel):
             self,
             adapter_name,
             module_names=getattr(peft_config, "modules_to_save", None),
-            inference_mode=peft_config.inference_mode,
+            peft_config=peft_config,
         )
 
     def add_adapter(self, adapter_name: str, peft_config: PeftConfig, low_cpu_mem_usage: bool = False) -> None:
diff --git a/src/peft/tuners/lora/config.py b/src/peft/tuners/lora/config.py
index 36ad31ce..493a300b 100644
--- a/src/peft/tuners/lora/config.py
+++ b/src/peft/tuners/lora/config.py
@@ -663,6 +663,14 @@ class LoraConfig(PeftConfig):
     arrow_config: Optional[ArrowConfig] = field(
         default=None, metadata={"help": "The necessary config to apply arrow routing on the model."}
     )
+    ensure_weight_tieing: bool = field(
+        default=False,
+        metadata={
+            "help": (
+                "Whether to tie weights or not after peft initialization.Only supported for `task_type` == CAUSAL_LM"
+            )
+        },
+    )
 
     def to_dict(self):
         """
diff --git a/src/peft/tuners/tuners_utils.py b/src/peft/tuners/tuners_utils.py
index 4fd0d128..8ec6e0af 100644
--- a/src/peft/tuners/tuners_utils.py
+++ b/src/peft/tuners/tuners_utils.py
@@ -647,6 +647,30 @@ class BaseTuner(nn.Module, ABC):
     ) -> None:
         raise NotImplementedError(f"{self.__class__.__name__} does not support targeting nn.Parameter.")
 
+    def _check_tied_modules(self, peft_config: PeftConfig, model_config: dict, model: nn.Module) -> PeftConfig:
+        """TODO"""
+        if not getattr(peft_config, "ensure_weight_tieing", False):
+            return peft_config
+
+        from transformers.modeling_utils import _get_tied_weight_keys  # TODO hmm, maybe just copy that function
+
+        tied_weight_keys = _get_tied_weight_keys(model)
+
+        if not (
+            (peft_config.task_type == TaskType.CAUSAL_LM)
+            and tied_weight_keys
+            and (getattr(peft_config, "modules_to_save", None) is not None)
+        ):
+            warnings.warn("TODO")  # user wants to tie weights but we can't, there should be a warning
+            return peft_config
+
+        modules_to_save = set(peft_config.modules_to_save)
+        missing_keys = set(tied_weight_keys) - modules_to_save
+        for key in missing_keys:
+            module_name = key.rpartition(".")[0]  # remove the parameter name, e.g. "lm_head.weight" => "lm_head"
+            peft_config.modules_to_save.append(module_name)
+        return peft_config
+
     def inject_adapter(
         self,
         model: nn.Module,
@@ -693,6 +717,7 @@ class BaseTuner(nn.Module, ABC):
         model_config = self.get_model_config(model)
 
         peft_config = self._prepare_adapter_config(peft_config, model_config)
+        peft_config = self._check_tied_modules(peft_config, model_config, model)
 
         self._prepare_model(peft_config, model)
 
diff --git a/src/peft/utils/other.py b/src/peft/utils/other.py
index fe7fb83c..ed7b29a8 100644
--- a/src/peft/utils/other.py
+++ b/src/peft/utils/other.py
@@ -941,7 +941,7 @@ def _set_trainable(
     model,
     adapter_name,
     module_names,
-    inference_mode: bool,
+    peft_config,
     strict_module_check: bool = False,
     wrapper_cls: Optional[AuxiliaryTrainingWrapper] = None,
     activate_adapter: bool = True,
@@ -958,7 +958,7 @@ def _set_trainable(
     The default is to wrap the module in a `ModulesToSaveWrapper` wrapper.
 
     If `strict_module_check` is set, this method raises an ValueError, similar to BaseTuner.inject_adapter when none of
-    the requested modules in `module_names` is not found in the model.
+    the requested modules in `module_names` are found in the model.
 
     The `active_adapter` flag indicates if this new adapter should be activated.
     """
@@ -974,6 +974,7 @@ def _set_trainable(
 
     trainable_modules = []
     found_modules = set()
+    inference_mode = peft_config.inference_mode
     # disable removal of duplicates to support targeting tied weights
     key_list = [key for key, _ in model.named_modules(remove_duplicate=False)]
 
@@ -1008,6 +1009,20 @@ def _set_trainable(
                 trainable_modules.append(new_module)
             found_modules.add(target_name)
 
+    # deal with tied weights for modules_to_save
+    if getattr(peft_config, "ensure_weight_tieing", False) and hasattr(model, "get_input_embeddings") and issubclass(wrapper_cls, ModulesToSaveWrapper):
+        from transformers.modeling_utils import _get_tied_weight_keys  # TODO hmm, maybe just copy that function
+
+        tied_weight_keys = _get_tied_weight_keys(model)
+        # remove the parameter name, e.g. "lm_head.weight" => "lm_head"
+        tied_module_names = {key.rpartition(".")[0] for key in tied_weight_keys}
+        orig_module = getattr(model.get_input_embeddings().modules_to_save, adapter_name)
+        for module_name in found_modules:
+            if issubclass(wrapper_cls, ModulesToSaveWrapper) and (module_name in tied_module_names):
+                tied_module = model.get_submodule(module_name)
+                del tied_module.modules_to_save[adapter_name].weight
+                tied_module.modules_to_save[adapter_name].register_parameter("weight", orig_module.weight)
+
     not_found = set(module_names).difference(found_modules)
     if strict_module_check and not found_modules:
         raise ValueError(
@@ -1397,7 +1412,7 @@ def set_additional_trainable_modules(model, peft_config, model_config, adapter_n
         _set_trainable(
             model,
             adapter_name,
-            inference_mode=peft_config.inference_mode,
+            peft_config=peft_config,
             module_names=getattr(peft_config, "modules_to_save", None),
             activate_adapter=activate_adapter,
         )
@@ -1423,7 +1438,7 @@ def set_additional_trainable_modules(model, peft_config, model_config, adapter_n
             _set_trainable(
                 model,
                 adapter_name,
-                inference_mode=peft_config.inference_mode,
+                peft_config=peft_config,
                 module_names=[target_layer],
                 strict_module_check=True,
                 wrapper_cls=TrainableTokensWrapper,
@@ -1447,7 +1462,7 @@ def set_additional_trainable_modules(model, peft_config, model_config, adapter_n
             _set_trainable(
                 model,
                 adapter_name,
-                inference_mode=peft_config.inference_mode,
+                peft_config=peft_config,
                 module_names=module_keys,
                 strict_module_check=True,
                 wrapper_cls=TrainableTokensWrapper,

Testing locally, the tests you added pass. As you can see, I made some small changes to the implementation compared to what you did. Most notably, I changed model._tied_weights_keys to _get_tied_weight_keys(model), where _get_tied_weight_keys is from transformers. Of course, it's risky to use this private function, probably it's better to copy it over to PEFT.

Anyway, LMK what you think of this proposal.

Apart from this, I made some smaller comments and also please not that the correct spelling is "tying", not "tieing".

# in `modules_to_save` and we want to ensure the `lm_head`
# does not diverge from the `embed_tokens` layer
if (
peft_config.task_type == "CAUSAL_LM"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is it important?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updating lm_head makes sense only for CAUSAL_LM tasks. We can extend it to Seq2Seq.
I might be wrong here though.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The proposed code is not really requiring there to be an lm_head, right? In theory, it should also work when there are other tied weights. I'm not sure how relevant that is in practice, but if there are no strict reasons for this check, I'd say let's remove it.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I'll make the changes such that we only worry about weight tying in case one of the tied layers is added in the modules to save

@romitjain romitjain changed the title Ensure weight tieing is maintained for embed_tokens and lm_head Ensure weight tying is maintained for embed_tokens and lm_head Oct 8, 2025
@romitjain
Copy link
Author

romitjain commented Oct 8, 2025

@BenjaminBossan I think what you suggested is a cleaner approach. I am aligned with that.
The only concern I have is that we would first register the redundant weights (extra memory) for the tied modules and then replace the pointer post hoc.
However, the extra memory might not be too big, so it should be okay? What do you think?

I realized that the extra weights are being created in both situations, I will go ahead with the solution you proposed.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think what you suggested is a cleaner approach. I am aligned with that. The only concern I have is that we would first register the redundant weights (extra memory) for the tied modules and then replace the pointer post hoc. However, the extra memory might not be too big, so it should be okay? What do you think?

I realized that the extra weights are being created in both situations, I will go ahead with the solution you proposed.

We could prevent those weights from being created, e.g. by using init_empty_weights, but I'd be okay without having it.

# in `modules_to_save` and we want to ensure the `lm_head`
# does not diverge from the `embed_tokens` layer
if (
peft_config.task_type == "CAUSAL_LM"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The proposed code is not really requiring there to be an lm_head, right? In theory, it should also work when there are other tied weights. I'm not sure how relevant that is in practice, but if there are no strict reasons for this check, I'd say let's remove it.

@romitjain
Copy link
Author

@BenjaminBossan

I have made slight tweaks to your proposed approach. Instead of modifying _set_trainable, which has to check for tied modules and a bunch of other checks, I am doing that once in BaseTuners.inject_adapter and adding a new runtime field in LoraConfig for modules_to_tie that can be used later on to tie the modules.

On tests, I have covered the following cases

- tied weights
    - `embed_tokens` not added as modules_to_save
        - `ensure_weight_tying` == True : Warning test
        - `ensure_weight_tying` == False : No need to test
    - `embed_tokens` added as modules_to_save
        - `ensure_weight_tying` == True : Test
        - `ensure_weight_tying` == False : Warning test
- non tied weights
    - `embed_tokens` added/not added as modules_to_save
        - `ensure_weight_tying` == True : Warning test
        - `ensure_weight_tying` == False : No need to test

@BenjaminBossan
Copy link
Member

Thanks for your continued work on this PR.

The idea I proposed was to create the ModulesToSaveWrappers as normal in a first pass and then apply the weight tying in a second pass. If I understand your proposal correctly, you would apply the weight tying on the fly (while iterating through the modules). The reason why I wanted to avoid that is because it requires a specific order, right? If the module that should be tied is wrapped before the module that it's being tied to, your proposed approach would fail. Is my understanding correct here?

I guess one could argue that this would be a user error and the user needs to take care to define modules_to_save correctly. On the other hand, the very idea of this feature is to make it easier for the user to avoid errors. So if my understanding is correct, I would propose to go with an approach closer to my proposal with two passes, even if it results in a bit of extra code.

@romitjain
Copy link
Author

@BenjaminBossan

If I understand your proposal correctly, you would apply the weight tying on the fly (while iterating through the modules).

I would add them to a separate list called modules_to_tie. The modules will not be tied yet.

If the module that should be tied is wrapped before the module that it's being tied to, your proposed approach would fail. Is my understanding correct here?

In my proposed flow, the modules in modules_to_save will always be wrapped first, and only then, weight tying will be called. Please see the flow of set_additional_trainable_modules function.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for explaining further. To summarize how our proposals differ:

In my approach, I extend modules_to_save, which results in ModulesToSaveWrapper being applied to all necessary modules, and then do a second pass outside of the modules to tie weights.

In your approach, you define a new attribute at runtime, modules_to_tie. Then you call _set_trainable on those modules and pass an additional argument, the module to which the new ModulesToSaveWrapper should be tied. Tie tying happens inside of ModulesToSaveWrapper.

I can see some advantages and disadvantages to both approaches and can't confidently say one is better than the other. Thus I'm fine going with your approach for now, as it's already there.

I made a few suggestions on how to improve the PR, please check my comments.

"help": (
"Whether to tie weights or not after peft initialization. "
"This will ensure that the adapters added to the tied layers "
"are also tied."
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's mention that right now, this only applied to modules_to_save, not the LoRA weights.

# in a bad (half-initialized) state.
self._check_new_adapter_config(peft_config)

modules_to_save = (
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, so let's move this whole block into a sub-method on BaseTuner, e.g. def _check_tied_modules. On the BaseTuner, if the peft_config has modules_to_save, you can do a basic check if modules_to_save targets a tied module. This way, all PEFT methods with modules_to_save benefit from having the warnings.

Then, on LoraModel, override _check_tied_modules with the logic below. There, you can remove the PeftType.LORA check.

context_manager = deepspeed.zero.GatheredParameters(self.original_module.parameters(), modifier_rank=0)
break

tied_module = kwargs.get("tied_module", None)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add tied_module to the signature of init_modules with default None, instead of getting it from kwargs (but you can leave kwargs in the signature too, shouldn't hurt).

Copy link
Author

@romitjain romitjain Oct 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you mean update? Since this peice of code lies in the update method

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added in update

self._check_new_adapter_config(peft_config)

modules_to_save = (
set(getattr(peft_config, "modules_to_save", [])) if getattr(peft_config, "modules_to_save", []) else set()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is enough, right?

Suggested change
set(getattr(peft_config, "modules_to_save", [])) if getattr(peft_config, "modules_to_save", []) else set()
set(getattr(peft_config, "modules_to_save", []))

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

modules_to_save is set as None in src/peft/mixed_model.py which results in an error

# is passed as a `modules_to_save`, it needs to be ensured
# that lm_head is tied to the adapter added to `embed_tokens`

from peft.utils.other import ModulesToSaveWrapper
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can be imported at root level instead of locally on each test.

ensure_weight_tying=True,
)

model = get_peft_model(model, embed_token_config)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about checking that there is no warning related to weight tying here?

@romitjain
Copy link
Author

@BenjaminBossan I have resolved the comments. Let me know if I missed something?

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates, we're getting close. I added a few smaller comments, please check. Once you're done, don't forget to call make style.

Did you mean update? Since this peice of code lies in the update method

Yes, I looked at the wrong line.


return tied_weight_keys

def _check_tied_modules(self, model: nn.Module, peft_config):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO, it makes more sense to just do some basic checks here in this generic method of the BaseTuner, and add a more LoRA-specific implementation to LoraModel. That way, there is no need to check for peft_config.peft_type == PeftType.LORA.

Copy link
Author

@romitjain romitjain Oct 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't follow.

The _check_tied_modules function is called inside inject_adaptors which the LoraModel does not inherit.

What you suggested makes sense, I can remove Lora specific changes from this function and handle that condition outside. I don't know where to place the call self._check_tied_modules(model, peft_config). Do you have any suggestions?

PS: I have addressed rest of your comments

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should work. LoraModel inherits from BaseTuner, thus it does have LoraModel.inject_adapter. So if we have:

class BaseTuner:
    def inject_adapter(...):
        self.self._check_tied_modules(...)
        ...
    def _check_tied_modules(...): ...

and

class LoraModel(BaseTuner):
    def _check_tied_modules(...): ...

then when calling LoraModel.inject_adapter, it calls BaseTuner.inject_adapter, which in turn will call LoraModel._check_tied_modules.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, so I should initialize _check_tied_modules as an abstract method and implement in LoraModel.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I was about to say that we can have some basic checks for existence of tied weights and modules_to_save here and give a warning for non-LoRA; for LoRA, we would override the method and give more specific warnings. But now I wonder: Why only implement this for LoRA in the first place? It should work for other PEFT methods too, at least I don't see why it wouldn't.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that was my original thought as well, but since I have not worked with other PEFT methods apart from LoRA, I am not sure if I can comment if it would work or not just OOB.

Alternatively, we can expose this for LoRA (or other similar PEFT wherever you are confident) and open it for all PEFT based on feedback?

Copy link
Member

@BenjaminBossan BenjaminBossan Oct 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your proposal makes sense, let's keep the scope small for now, and later, it can be extended to other PEFT methods if there is a need.

This means, let's keep just a generic check here in BaseTuner._check_tied_modules and then in LoraModel._check_tied_modules, do the more specific check.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have updated my implementation by adding an internal method _maybe_add_modules_to_tie. This is the cleanest approach I could find. Let me know your thoughts!

romitjain and others added 2 commits October 11, 2025 10:39
Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the latest updates.

I have updated my implementation by adding an internal method _maybe_add_modules_to_tie. This is the cleanest approach I could find. Let me know your thoughts!

Overall looks good, but I have a comment concerning one of the warning messages, WDYT?

Comment on lines +1220 to +1228
if is_embedding_to_save and tied_weight_keys:
msg = (
"Model has `tie_word_embeddings=True` and a tied layer is part of the adapter, "
"but `ensure_weight_tying` is not set to True. "
"This can lead to complications, for example when merging the adapter "
"or converting your model to formats other than safetensors. "
"Check the discussion here: https://github.com/huggingface/peft/issues/2777"
)
warnings.warn(msg)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this message has the potential to confuse users. ensure_weight_tying is only implemented for LoRA right now but this message can also show up for non-LoRA methods, right?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants