Skip to content
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions src/peft/tuners/lora/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,6 +663,16 @@ class LoraConfig(PeftConfig):
arrow_config: Optional[ArrowConfig] = field(
default=None, metadata={"help": "The necessary config to apply arrow routing on the model."}
)
ensure_weight_tying: bool = field(
default=False,
metadata={
"help": (
"Whether to tie weights or not after peft initialization. "
"This will ensure that the adapters added to the tied layers "
"are also tied."
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's mention that right now, this only applied to modules_to_save, not the LoRA weights.

)
},
)

def to_dict(self):
"""
Expand All @@ -681,6 +691,10 @@ def __post_init__(self):
self.exclude_modules = (
set(self.exclude_modules) if isinstance(self.exclude_modules, list) else self.exclude_modules
)

if self.ensure_weight_tying:
self.modules_to_tie = None

if isinstance(self.target_parameters, str):
raise TypeError("`target_parameters` must be a list of strings or None.")

Expand Down
54 changes: 54 additions & 0 deletions src/peft/tuners/tuners_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,6 +690,36 @@ def inject_adapter(
# in a bad (half-initialized) state.
self._check_new_adapter_config(peft_config)

modules_to_save = (
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, so let's move this whole block into a sub-method on BaseTuner, e.g. def _check_tied_modules. On the BaseTuner, if the peft_config has modules_to_save, you can do a basic check if modules_to_save targets a tied module. This way, all PEFT methods with modules_to_save benefit from having the warnings.

Then, on LoraModel, override _check_tied_modules with the logic below. There, you can remove the PeftType.LORA check.

set(getattr(peft_config, "modules_to_save", [])) if getattr(peft_config, "modules_to_save", []) else set()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is enough, right?

Suggested change
set(getattr(peft_config, "modules_to_save", [])) if getattr(peft_config, "modules_to_save", []) else set()
set(getattr(peft_config, "modules_to_save", []))

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

modules_to_save is set as None in src/peft/mixed_model.py which results in an error

)
is_embedding_to_save = [m for m in modules_to_save if m in EMBEDDING_LAYER_NAMES]

tied_weight_keys = self._get_tied_weight_keys(model)

# Condition to check if embedding layer is added
# in `modules_to_save` and the model has tied keys
if is_embedding_to_save and tied_weight_keys and peft_config.peft_type == PeftType.LORA:
missing_keys = set(tied_weight_keys) - modules_to_save

if getattr(peft_config, "ensure_weight_tying", False):
peft_config.modules_to_tie = missing_keys
elif not getattr(peft_config, "ensure_weight_tying", False):
msg = (
"Model has `tie_word_embeddings=True` and the tied layer is part of the adapter, "
"but `ensure_weight_tying` is not set to True. "
"This can lead to complications, for example when merging the adapter "
"or converting your model to formats other than safetensors. "
"See for discussion: https://github.com/huggingface/peft/issues/2777"
)
warnings.warn(msg)
elif (
getattr(peft_config, "ensure_weight_tying", False)
and not tied_weight_keys
and peft_config.peft_type == PeftType.LORA
):
warnings.warn("You have requested ensure_weight_tying, but no tied modules were found in the model")

model_config = self.get_model_config(model)

peft_config = self._prepare_adapter_config(peft_config, model_config)
Expand Down Expand Up @@ -1154,6 +1184,30 @@ def _get_tied_target_modules(self, model: nn.Module) -> list[str]:
tied_target_modules.append(target_module)
return tied_target_modules

def _get_tied_weight_keys(self, model: nn.Module, prefix="") -> list[str]:
"""
Get the list of modules that needs to be tied

For example: For models which have `embed_tokens` and `lm_head` as the tied keys this function will return
[`lm_head`]

From: https://github.com/huggingface/transformers/blob/v4.57.0/src/transformers/modeling_utils.py#L563
"""
tied_weight_keys = []
if getattr(model, "_tied_weights_keys", None) is not None:
names = [f"{prefix}.{k}" if prefix else k for k in model._tied_weights_keys]
tied_weight_keys.extend(names)
if getattr(model, "_dynamic_tied_weights_keys", None) is not None:
names = [f"{prefix}.{k}" if prefix else k for k in model._dynamic_tied_weights_keys]
tied_weight_keys.extend(names)
for name, submodule in model.named_children():
local_prefix = f"{prefix}.{name}" if prefix else name
tied_weight_keys.extend(self._get_tied_weight_keys(submodule, prefix=local_prefix))

tied_weight_keys = [".".join(n.split(".")[:-1]) for n in tied_weight_keys]

return tied_weight_keys

def __getattr__(self, name: str):
"""Forward missing attributes to the wrapped module."""
try:
Expand Down
27 changes: 23 additions & 4 deletions src/peft/utils/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,10 +508,10 @@ class ModulesToSaveWrapper(AuxiliaryTrainingWrapper):
# All names of layers that may contain adapter (trainable) weights
adapter_layer_names: tuple[str, ...] = ("modules_to_save",)

def __init__(self, module_to_save, adapter_name):
super().__init__(module_to_save, adapter_name)
def __init__(self, module_to_save, adapter_name, tied_module=None):
super().__init__(module_to_save, adapter_name, tied_module=tied_module)

def init_modules(self, adapter_name):
def init_modules(self, adapter_name, **kwargs):
# we treat each adapter separately, so we have multiple adapters, same (copied) module for each
self.modules_to_save = torch.nn.ModuleDict({})

Expand Down Expand Up @@ -548,9 +548,17 @@ def update(self, adapter_name, **kwargs):
context_manager = deepspeed.zero.GatheredParameters(self.original_module.parameters(), modifier_rank=0)
break

tied_module = kwargs.get("tied_module", None)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add tied_module to the signature of init_modules with default None, instead of getting it from kwargs (but you can leave kwargs in the signature too, shouldn't hurt).

Copy link
Author

@romitjain romitjain Oct 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you mean update? Since this peice of code lies in the update method

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added in update


if adapter_name not in self.modules_to_save:
with context_manager:
self.modules_to_save[adapter_name] = copy.deepcopy(self.original_module)
if tied_module:
new_linear = torch.nn.Linear(*tied_module.weight.shape, bias=False)
new_linear.weight = tied_module.weight

self.modules_to_save[adapter_name] = new_linear
else:
self.modules_to_save[adapter_name] = copy.deepcopy(self.original_module)

if hasattr(self.modules_to_save[adapter_name], "_hf_hook"):
old_hook = self.modules_to_save[adapter_name]._hf_hook
Expand Down Expand Up @@ -1402,6 +1410,17 @@ def set_additional_trainable_modules(model, peft_config, model_config, adapter_n
activate_adapter=activate_adapter,
)

if getattr(peft_config, "modules_to_tie", None) is not None:
tied_module = getattr(model.get_input_embeddings().modules_to_save, adapter_name)
_set_trainable(
model,
adapter_name,
inference_mode=peft_config.inference_mode,
module_names=getattr(peft_config, "modules_to_tie", None),
activate_adapter=activate_adapter,
tied_module=tied_module,
)

if getattr(peft_config, "trainable_token_indices", None) is not None:
if isinstance(peft_config.trainable_token_indices, dict):
target_layers = peft_config.trainable_token_indices
Expand Down
133 changes: 133 additions & 0 deletions tests/test_initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,47 @@ def forward(self, x):

return MyModule().eval().to(self.torch_device)

def get_lm_model(self, bias=True, tie_weights=True):
class MyModule(nn.Module):
def __init__(self):
super().__init__()

self.linear = nn.Linear(1000, 1000, bias=bias)
self.embed_tokens = nn.Embedding(1000, 1000)
self.conv2d = nn.Conv2d(100, 100, 3, bias=bias)

def forward(self, x):
x_int = (x * 100).int()
x_4d = x.reshape(1, 100, 10, 10)

return self.linear(x), self.embed(x_int), self.conv2d(x_4d)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we don't need to call forward, let's remove this to focus on what matters for these tests.


class CausalLM(nn.Module):
if tie_weights:
_tied_weights_keys = ["lm_head.weight"]

def __init__(self):
super().__init__()
self.model = MyModule()
self.config = {"tie_word_embeddings": tie_weights}

if tie_weights:
self.lm_head = nn.Linear(1000, 1000, bias=False)
self.lm_head.weight = self.model.embed_tokens.weight
else:
self.lm_head = nn.Linear(1000, 1000, bias=bias)

def forward(self, x):
return self.model(x)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also remove this.


def prepare_inputs_for_generation(self):
return

def get_input_embeddings(self):
return self.model.embed_tokens

return CausalLM().eval().to(self.torch_device)

@pytest.fixture
def data(self):
return torch.rand(10, 1000).to(self.torch_device)
Expand Down Expand Up @@ -1566,6 +1607,98 @@ def test_multiple_configs_with_bias_raises(self, tmp_path):
config2 = LoraConfig(target_modules=["linear"], bias="none")
model.add_adapter("other", config2) # does not raise

def test_weight_tying_tied_model(self):
# If weight tying is enabled and `embed_tokens`
# is passed as a `modules_to_save`, it needs to be ensured
# that lm_head is tied to the adapter added to `embed_tokens`

from peft.utils.other import ModulesToSaveWrapper
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can be imported at root level instead of locally on each test.


model = self.get_lm_model()
embed_token_config = LoraConfig(
modules_to_save=["embed_tokens"],
target_modules=["linear"],
ensure_weight_tying=True,
)
model = get_peft_model(model, embed_token_config)

assert isinstance(model.base_model.model.model.embed_tokens, ModulesToSaveWrapper), (
"Embed tokens is not added in Modules to Save"
)
assert type(model.base_model.model.model.embed_tokens) is type(model.base_model.model.lm_head), (
"Embed tokens and LM head types are not same"
)

# Validating that all model parameters are same
embed_np = dict(model.base_model.model.model.embed_tokens.named_parameters())
lm_head_np = dict(model.base_model.model.lm_head.named_parameters())

for k in embed_np.keys():
assert torch.allclose(embed_np[k], lm_head_np[k])
assert embed_np[k] is lm_head_np[k]

def test_weight_tying_non_tied_model(self):
from peft.utils.other import ModulesToSaveWrapper

model = self.get_lm_model(tie_weights=False)
embed_token_config = LoraConfig(
modules_to_save=["embed_tokens"],
target_modules=["linear"],
ensure_weight_tying=True,
)
with pytest.warns(UserWarning, match="no tied modules were found in the model"):
model = get_peft_model(model, embed_token_config)

assert isinstance(model.base_model.model.model.embed_tokens, ModulesToSaveWrapper), (
"Embed tokens is not added in Modules to Save"
)
assert isinstance(model.base_model.model.lm_head, torch.nn.modules.linear.Linear), (
"LM head is not of type nn.linear"
)

def test_not_weight_tying_tied_model(self):
from peft.utils.other import ModulesToSaveWrapper

model = self.get_lm_model()
embed_token_config = LoraConfig(
modules_to_save=["embed_tokens"],
target_modules=["linear"],
ensure_weight_tying=False,
)
with pytest.warns(UserWarning, match="`ensure_weight_tying` is not set to True"):
model = get_peft_model(model, embed_token_config)

assert isinstance(model.base_model.model.model.embed_tokens, ModulesToSaveWrapper), (
"Embed tokens is not added in Modules to Save"
)
assert isinstance(model.base_model.model.lm_head, torch.nn.modules.linear.Linear), (
"LM head is not of type nn.linear"
)

def test_weight_tying_tied_model_no_embed(self):
# If weight tying is enabled and `embed_tokens`
# is passed as a `modules_to_save`, it needs to be ensured
# that lm_head is tied to the adapter added to `embed_tokens`

model = self.get_lm_model()
embed_token_config = LoraConfig(
target_modules=["linear"],
ensure_weight_tying=True,
)

model = get_peft_model(model, embed_token_config)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about checking that there is no warning related to weight tying here?


assert isinstance(model.base_model.model.model.embed_tokens, torch.nn.modules.Embedding)
assert isinstance(model.base_model.model.lm_head, torch.nn.modules.linear.Linear)

# Validating that all model parameters are same
embed_np = dict(model.base_model.model.model.embed_tokens.named_parameters())
lm_head_np = dict(model.base_model.model.lm_head.named_parameters())

for k in embed_np.keys():
assert torch.allclose(embed_np[k], lm_head_np[k])
assert embed_np[k] is lm_head_np[k]


class TestLokrInitialization:
torch_device = infer_device()
Expand Down