Skip to content
Open
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions src/peft/tuners/lora/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,6 +663,17 @@ class LoraConfig(PeftConfig):
arrow_config: Optional[ArrowConfig] = field(
default=None, metadata={"help": "The necessary config to apply arrow routing on the model."}
)
ensure_weight_tying: bool = field(
default=False,
metadata={
"help": (
"Whether to tie weights or not after peft initialization. "
"This will ensure that the adapters added to the tied layers "
"are also tied. This is only applicable for layers passed via "
"`modules_to_save`."
)
},
)

def to_dict(self):
"""
Expand All @@ -681,6 +692,10 @@ def __post_init__(self):
self.exclude_modules = (
set(self.exclude_modules) if isinstance(self.exclude_modules, list) else self.exclude_modules
)

if self.ensure_weight_tying:
self.modules_to_tie = None

if isinstance(self.target_parameters, str):
raise TypeError("`target_parameters` must be a list of strings or None.")

Expand Down
60 changes: 60 additions & 0 deletions src/peft/tuners/tuners_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,6 +690,8 @@ def inject_adapter(
# in a bad (half-initialized) state.
self._check_new_adapter_config(peft_config)

self._check_tied_modules(model, peft_config)

model_config = self.get_model_config(model)

peft_config = self._prepare_adapter_config(peft_config, model_config)
Expand Down Expand Up @@ -1154,6 +1156,64 @@ def _get_tied_target_modules(self, model: nn.Module) -> list[str]:
tied_target_modules.append(target_module)
return tied_target_modules

def _get_tied_weight_keys(self, model: nn.Module, prefix="") -> list[str]:
"""
Get the list of modules that needs to be tied

For example: For models which have `embed_tokens` and `lm_head` as the tied keys this function will return
[`lm_head`]

From: https://github.com/huggingface/transformers/blob/v4.57.0/src/transformers/modeling_utils.py#L563
"""
tied_weight_keys = []
if getattr(model, "_tied_weights_keys", None) is not None:
names = [f"{prefix}.{k}" if prefix else k for k in model._tied_weights_keys]
tied_weight_keys.extend(names)
if getattr(model, "_dynamic_tied_weights_keys", None) is not None:
names = [f"{prefix}.{k}" if prefix else k for k in model._dynamic_tied_weights_keys]
tied_weight_keys.extend(names)
for name, submodule in model.named_children():
local_prefix = f"{prefix}.{name}" if prefix else name
tied_weight_keys.extend(self._get_tied_weight_keys(submodule, prefix=local_prefix))

tied_weight_keys = [".".join(n.split(".")[:-1]) for n in tied_weight_keys]

return tied_weight_keys

def _check_tied_modules(self, model: nn.Module, peft_config):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO, it makes more sense to just do some basic checks here in this generic method of the BaseTuner, and add a more LoRA-specific implementation to LoraModel. That way, there is no need to check for peft_config.peft_type == PeftType.LORA.

Copy link
Author

@romitjain romitjain Oct 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't follow.

The _check_tied_modules function is called inside inject_adaptors which the LoraModel does not inherit.

What you suggested makes sense, I can remove Lora specific changes from this function and handle that condition outside. I don't know where to place the call self._check_tied_modules(model, peft_config). Do you have any suggestions?

PS: I have addressed rest of your comments

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should work. LoraModel inherits from BaseTuner, thus it does have LoraModel.inject_adapter. So if we have:

class BaseTuner:
    def inject_adapter(...):
        self.self._check_tied_modules(...)
        ...
    def _check_tied_modules(...): ...

and

class LoraModel(BaseTuner):
    def _check_tied_modules(...): ...

then when calling LoraModel.inject_adapter, it calls BaseTuner.inject_adapter, which in turn will call LoraModel._check_tied_modules.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, so I should initialize _check_tied_modules as an abstract method and implement in LoraModel.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I was about to say that we can have some basic checks for existence of tied weights and modules_to_save here and give a warning for non-LoRA; for LoRA, we would override the method and give more specific warnings. But now I wonder: Why only implement this for LoRA in the first place? It should work for other PEFT methods too, at least I don't see why it wouldn't.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that was my original thought as well, but since I have not worked with other PEFT methods apart from LoRA, I am not sure if I can comment if it would work or not just OOB.

Alternatively, we can expose this for LoRA (or other similar PEFT wherever you are confident) and open it for all PEFT based on feedback?

Copy link
Member

@BenjaminBossan BenjaminBossan Oct 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your proposal makes sense, let's keep the scope small for now, and later, it can be extended to other PEFT methods if there is a need.

This means, let's keep just a generic check here in BaseTuner._check_tied_modules and then in LoraModel._check_tied_modules, do the more specific check.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have updated my implementation by adding an internal method _maybe_add_modules_to_tie. This is the cleanest approach I could find. Let me know your thoughts!

"""
Checks if any of the tied layers are targetted via `modules_to_save`
Updates the `peft_config.modules_to_tie` with any layers that needs to be tied
"""
modules_to_save = (
set(getattr(peft_config, "modules_to_save", [])) if getattr(peft_config, "modules_to_save", []) else set()
)
is_embedding_to_save = any(m in EMBEDDING_LAYER_NAMES for m in modules_to_save)

tied_weight_keys = self._get_tied_weight_keys(model)

if is_embedding_to_save and tied_weight_keys and peft_config.peft_type == PeftType.LORA:
missing_keys = set(tied_weight_keys) - modules_to_save

if getattr(peft_config, "ensure_weight_tying", False):
peft_config.modules_to_tie = missing_keys
elif not getattr(peft_config, "ensure_weight_tying", False):
msg = (
"Model has `tie_word_embeddings=True` and the tied layer is part of the adapter, "
"but `ensure_weight_tying` is not set to True. "
"This can lead to complications, for example when merging the adapter "
"or converting your model to formats other than safetensors. "
"Check the discussion here: https://github.com/huggingface/peft/issues/2777"
)
warnings.warn(msg)
elif (
getattr(peft_config, "ensure_weight_tying", False)
and not tied_weight_keys
and peft_config.peft_type == PeftType.LORA
):
warnings.warn("You have requested ensure_weight_tying, but no tied modules were found in the model")


def __getattr__(self, name: str):
"""Forward missing attributes to the wrapped module."""
try:
Expand Down
30 changes: 25 additions & 5 deletions src/peft/utils/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,10 +508,10 @@ class ModulesToSaveWrapper(AuxiliaryTrainingWrapper):
# All names of layers that may contain adapter (trainable) weights
adapter_layer_names: tuple[str, ...] = ("modules_to_save",)

def __init__(self, module_to_save, adapter_name):
super().__init__(module_to_save, adapter_name)
def __init__(self, module_to_save, adapter_name, tied_module=None):
super().__init__(module_to_save, adapter_name, tied_module=tied_module)

def init_modules(self, adapter_name):
def init_modules(self, adapter_name, **kwargs):
# we treat each adapter separately, so we have multiple adapters, same (copied) module for each
self.modules_to_save = torch.nn.ModuleDict({})

Expand All @@ -535,7 +535,7 @@ def _hasattr_wrapped(self, name, modules):
def _getattr_wrapped(self, name, modules):
return getattr(modules["modules_to_save"][self.active_adapters[0]], name)

def update(self, adapter_name, **kwargs):
def update(self, adapter_name, tied_module=None, **kwargs):
super().update(adapter_name)

context_manager = nullcontext()
Expand All @@ -550,7 +550,13 @@ def update(self, adapter_name, **kwargs):

if adapter_name not in self.modules_to_save:
with context_manager:
self.modules_to_save[adapter_name] = copy.deepcopy(self.original_module)
if tied_module:
new_linear = torch.nn.Linear(*tied_module.weight.shape, bias=False)
new_linear.weight = tied_module.weight

self.modules_to_save[adapter_name] = new_linear
else:
self.modules_to_save[adapter_name] = copy.deepcopy(self.original_module)

if hasattr(self.modules_to_save[adapter_name], "_hf_hook"):
old_hook = self.modules_to_save[adapter_name]._hf_hook
Expand Down Expand Up @@ -1402,6 +1408,20 @@ def set_additional_trainable_modules(model, peft_config, model_config, adapter_n
activate_adapter=activate_adapter,
)

if getattr(peft_config, "modules_to_tie", None) is not None:
# Tie the modules if any tied layer is passed in `modules_to_save`.
# This should always be called after
# `_set_trainable` is called for `modules_to_save`.
tied_module = getattr(model.get_input_embeddings().modules_to_save, adapter_name)
_set_trainable(
model,
adapter_name,
inference_mode=peft_config.inference_mode,
module_names=getattr(peft_config, "modules_to_tie", None),
activate_adapter=activate_adapter,
tied_module=tied_module,
)

if getattr(peft_config, "trainable_token_indices", None) is not None:
if isinstance(peft_config.trainable_token_indices, dict):
target_layers = peft_config.trainable_token_indices
Expand Down
126 changes: 126 additions & 0 deletions tests/test_initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
from peft.tuners.lora.layer import LoraLayer
from peft.utils import infer_device
from peft.utils.hotswap import hotswap_adapter, prepare_model_for_compiled_hotswap
from peft.utils.other import ModulesToSaveWrapper

from .testing_utils import load_dataset_english_quotes, require_deterministic_for_xpu

Expand Down Expand Up @@ -110,6 +111,45 @@ def forward(self, x):

return MyModule().eval().to(self.torch_device)

def get_lm_model(self, bias=True, tie_weights=True):
# Mimicking a LM with embed_tokens and lm_head layers
# to test weight tying of adapters
class MyModule(nn.Module):
def __init__(self):
super().__init__()

self.embed_tokens = nn.Embedding(1000, 1000)
self.linear = nn.Linear(1000, 1000, bias=bias)

def forward(self, x):
return

class CausalLM(nn.Module):
if tie_weights:
_tied_weights_keys = ["lm_head.weight"]

def __init__(self):
super().__init__()
self.model = MyModule()
self.config = {"tie_word_embeddings": tie_weights}

if tie_weights:
self.lm_head = nn.Linear(1000, 1000, bias=False)
self.lm_head.weight = self.model.embed_tokens.weight
else:
self.lm_head = nn.Linear(1000, 1000, bias=bias)

def forward(self, x):
return

def prepare_inputs_for_generation(self):
return

def get_input_embeddings(self):
return self.model.embed_tokens

return CausalLM().eval().to(self.torch_device)

@pytest.fixture
def data(self):
return torch.rand(10, 1000).to(self.torch_device)
Expand Down Expand Up @@ -1566,6 +1606,92 @@ def test_multiple_configs_with_bias_raises(self, tmp_path):
config2 = LoraConfig(target_modules=["linear"], bias="none")
model.add_adapter("other", config2) # does not raise

def test_weight_tying_tied_model(self):
# If weight tying is enabled and `embed_tokens`
# is passed as a `modules_to_save`, it needs to be ensured
# that lm_head is tied to the adapter added to `embed_tokens`

model = self.get_lm_model()
embed_token_config = LoraConfig(
modules_to_save=["embed_tokens"],
target_modules=["linear"],
ensure_weight_tying=True,
)
model = get_peft_model(model, embed_token_config)

assert isinstance(model.base_model.model.model.embed_tokens, ModulesToSaveWrapper), (
"Embed tokens is not added in Modules to Save"
)
assert type(model.base_model.model.model.embed_tokens) is type(model.base_model.model.lm_head), (
"Embed tokens and LM head types are not same"
)

# Validating that all model parameters are same
embed_np = dict(model.base_model.model.model.embed_tokens.named_parameters())
lm_head_np = dict(model.base_model.model.lm_head.named_parameters())

for k in embed_np.keys():
assert torch.allclose(embed_np[k], lm_head_np[k])
assert embed_np[k] is lm_head_np[k]

def test_weight_tying_non_tied_model(self):
from peft.utils.other import ModulesToSaveWrapper

model = self.get_lm_model(tie_weights=False)
embed_token_config = LoraConfig(
modules_to_save=["embed_tokens"],
target_modules=["linear"],
ensure_weight_tying=True,
)
with pytest.warns(UserWarning, match="no tied modules were found in the model"):
model = get_peft_model(model, embed_token_config)

assert isinstance(model.base_model.model.model.embed_tokens, ModulesToSaveWrapper), (
"Embed tokens is not added in Modules to Save"
)
assert isinstance(model.base_model.model.lm_head, torch.nn.modules.linear.Linear), (
"LM head is not of type nn.linear"
)

def test_not_weight_tying_tied_model(self):
from peft.utils.other import ModulesToSaveWrapper

model = self.get_lm_model()
embed_token_config = LoraConfig(
modules_to_save=["embed_tokens"],
target_modules=["linear"],
ensure_weight_tying=False,
)
with pytest.warns(UserWarning, match="`ensure_weight_tying` is not set to True"):
model = get_peft_model(model, embed_token_config)

assert isinstance(model.base_model.model.model.embed_tokens, ModulesToSaveWrapper), (
"Embed tokens is not added in Modules to Save"
)
assert isinstance(model.base_model.model.lm_head, torch.nn.modules.linear.Linear), (
"LM head is not of type nn.linear"
)

def test_weight_tying_tied_model_no_embed(self):
model = self.get_lm_model()
embed_token_config = LoraConfig(
target_modules=["linear"],
ensure_weight_tying=True,
)

model = get_peft_model(model, embed_token_config)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about checking that there is no warning related to weight tying here?


assert isinstance(model.base_model.model.model.embed_tokens, torch.nn.modules.Embedding)
assert isinstance(model.base_model.model.lm_head, torch.nn.modules.linear.Linear)

# Validating that all model parameters are same
embed_np = dict(model.base_model.model.model.embed_tokens.named_parameters())
lm_head_np = dict(model.base_model.model.lm_head.named_parameters())

for k in embed_np.keys():
assert torch.allclose(embed_np[k], lm_head_np[k])
assert embed_np[k] is lm_head_np[k]


class TestLokrInitialization:
torch_device = infer_device()
Expand Down