Skip to content
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -142,4 +142,4 @@ wandb

# method_comparison logs
method_comparison/MetaMathQA/cancelled_results/
method_comparison/MetaMathQA/temporary_results/
method_comparison/MetaMathQA/temporary_results/
26 changes: 26 additions & 0 deletions src/peft/peft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
SAFETENSORS_WEIGHTS_NAME,
TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING,
WEIGHTS_NAME,
ModulesToSaveWrapper,
PeftType,
TaskType,
_get_batch_size,
Expand Down Expand Up @@ -1845,6 +1846,31 @@ def __init__(
super().__init__(model, peft_config, adapter_name, **kwargs)
self.base_model_prepare_inputs_for_generation = self.base_model.prepare_inputs_for_generation

# Condition to check if embedding layer (`embed_tokens`) is added
# in `modules_to_save` and we want to ensure the `lm_head`
# does not diverge from the `embed_tokens` layer
if (
peft_config.task_type == "CAUSAL_LM"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is it important?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updating lm_head makes sense only for CAUSAL_LM tasks. We can extend it to Seq2Seq.
I might be wrong here though.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The proposed code is not really requiring there to be an lm_head, right? In theory, it should also work when there are other tied weights. I'm not sure how relevant that is in practice, but if there are no strict reasons for this check, I'd say let's remove it.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I'll make the changes such that we only worry about weight tying in case one of the tied layers is added in the modules to save

and hasattr(model.get_input_embeddings(), "modules_to_save")
and getattr(peft_config, "ensure_weight_tieing")
):
module_keys = BaseTuner._get_tied_modules_to_save(self, model)

if not module_keys:
warnings.warn("You have requested ensure_weight_tieing, but no tied modules were found")

tied_module = getattr(model.get_input_embeddings().modules_to_save, adapter_name)

_set_trainable(
model,
adapter_name,
inference_mode=peft_config.inference_mode,
module_names=module_keys,
strict_module_check=True,
wrapper_cls=ModulesToSaveWrapper,
tied_module=tied_module,
)

def forward(
self,
input_ids=None,
Expand Down
8 changes: 8 additions & 0 deletions src/peft/tuners/lora/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,6 +663,14 @@ class LoraConfig(PeftConfig):
arrow_config: Optional[ArrowConfig] = field(
default=None, metadata={"help": "The necessary config to apply arrow routing on the model."}
)
ensure_weight_tieing: bool = field(
default=False,
metadata={
"help": (
"Whether to tie weights or not after peft initialization.Only supported for `task_type` == CAUSAL_LM"
)
},
)

def to_dict(self):
"""
Expand Down
22 changes: 21 additions & 1 deletion src/peft/tuners/tuners_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
from peft.utils.peft_types import PeftType, TaskType

from ..config import PeftConfig
from ..utils import _get_submodules
from ..utils import ModulesToSaveWrapper, _get_submodules
from ._buffer_dict import BufferDict


Expand Down Expand Up @@ -1154,6 +1154,26 @@ def _get_tied_target_modules(self, model: nn.Module) -> list[str]:
tied_target_modules.append(target_module)
return tied_target_modules

def _get_tied_modules_to_save(self, model: nn.Module) -> list[str]:
"""
Get the list of modules that needs to be tied

For example: For models which have `embed_tokens` and `lm_head` as the tied keys this function will return
[`lm_head`]
"""
model_config = self.get_model_config(model)
if (
model_config.get("tie_word_embeddings", False)
and model._tied_weights_keys is not None
and isinstance(model.get_input_embeddings(), ModulesToSaveWrapper)
):
# Get the original reference of the `ModulesToSaveWrapper` for the embedding layer
module_keys = [".".join(n.split(".")[:-1]) for n in model._tied_weights_keys]

return module_keys

return []

def __getattr__(self, name: str):
"""Forward missing attributes to the wrapped module."""
try:
Expand Down
16 changes: 12 additions & 4 deletions src/peft/utils/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,10 +508,10 @@ class ModulesToSaveWrapper(AuxiliaryTrainingWrapper):
# All names of layers that may contain adapter (trainable) weights
adapter_layer_names: tuple[str, ...] = ("modules_to_save",)

def __init__(self, module_to_save, adapter_name):
super().__init__(module_to_save, adapter_name)
def __init__(self, module_to_save, adapter_name, tied_module=None):
super().__init__(module_to_save, adapter_name, tied_module=tied_module)

def init_modules(self, adapter_name):
def init_modules(self, adapter_name, **kwargs):
# we treat each adapter separately, so we have multiple adapters, same (copied) module for each
self.modules_to_save = torch.nn.ModuleDict({})

Expand Down Expand Up @@ -548,9 +548,17 @@ def update(self, adapter_name, **kwargs):
context_manager = deepspeed.zero.GatheredParameters(self.original_module.parameters(), modifier_rank=0)
break

tied_module = kwargs.get("tied_module", None)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add tied_module to the signature of init_modules with default None, instead of getting it from kwargs (but you can leave kwargs in the signature too, shouldn't hurt).

Copy link
Author

@romitjain romitjain Oct 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you mean update? Since this peice of code lies in the update method

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added in update


if adapter_name not in self.modules_to_save:
with context_manager:
self.modules_to_save[adapter_name] = copy.deepcopy(self.original_module)
if tied_module:
new_linear = torch.nn.Linear(*tied_module.weight.shape, bias=False)
new_linear.weight = tied_module.weight

self.modules_to_save[adapter_name] = new_linear
else:
self.modules_to_save[adapter_name] = copy.deepcopy(self.original_module)

if hasattr(self.modules_to_save[adapter_name], "_hf_hook"):
old_hook = self.modules_to_save[adapter_name]._hf_hook
Expand Down
110 changes: 110 additions & 0 deletions tests/test_initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,47 @@ def forward(self, x):

return MyModule().eval().to(self.torch_device)

def get_lm_model(self, bias=True, tie_weights=True):
class MyModule(nn.Module):
def __init__(self):
super().__init__()

self.linear = nn.Linear(1000, 1000, bias=bias)
self.embed_tokens = nn.Embedding(1000, 1000)
self.conv2d = nn.Conv2d(100, 100, 3, bias=bias)

def forward(self, x):
x_int = (x * 100).int()
x_4d = x.reshape(1, 100, 10, 10)

return self.linear(x), self.embed(x_int), self.conv2d(x_4d)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we don't need to call forward, let's remove this to focus on what matters for these tests.


class CausalLM(nn.Module):
if tie_weights:
_tied_weights_keys = ["lm_head.weight"]

def __init__(self):
super().__init__()
self.model = MyModule()
self.config = {"tie_word_embeddings": tie_weights}

if tie_weights:
self.lm_head = nn.Linear(1000, 1000, bias=False)
self.lm_head.weight = self.model.embed_tokens.weight
else:
self.lm_head = nn.Linear(1000, 1000, bias=bias)

def forward(self, x):
return self.model(x)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also remove this.


def prepare_inputs_for_generation(self):
return

def get_input_embeddings(self):
return self.model.embed_tokens

return CausalLM().eval().to(self.torch_device)

@pytest.fixture
def data(self):
return torch.rand(10, 1000).to(self.torch_device)
Expand Down Expand Up @@ -1566,6 +1607,75 @@ def test_multiple_configs_with_bias_raises(self, tmp_path):
config2 = LoraConfig(target_modules=["linear"], bias="none")
model.add_adapter("other", config2) # does not raise

def test_weight_tieing_tied_model(self):
# If weight tieing is enabled and `embed_tokens`
# is passed as a `modules_to_save`, it needs to be ensured
# that lm_head is tied to the adapter added to `embed_tokens`

from peft.utils.other import ModulesToSaveWrapper
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can be imported at root level instead of locally on each test.


model = self.get_lm_model()
embed_token_config = LoraConfig(
task_type="CAUSAL_LM",
modules_to_save=["embed_tokens"],
target_modules=["linear"],
ensure_weight_tieing=True,
)
model = get_peft_model(model, embed_token_config)

assert isinstance(model.base_model.model.model.embed_tokens, ModulesToSaveWrapper), (
"Embed tokens is not added in Modules to Save"
)
assert type(model.base_model.model.model.embed_tokens) is type(model.base_model.model.lm_head), (
"Embed tokens and LM head types are not same"
)

# Validating that all model parameters are same
embed_np = dict(model.base_model.model.model.embed_tokens.named_parameters())
lm_head_np = dict(model.base_model.model.lm_head.named_parameters())

for k in embed_np.keys():
assert torch.allclose(embed_np[k], lm_head_np[k])
assert embed_np[k] is lm_head_np[k]

def test_weight_tieing_non_tied_model(self):
from peft.utils.other import ModulesToSaveWrapper

model = self.get_lm_model(tie_weights=False)
embed_token_config = LoraConfig(
task_type="CAUSAL_LM",
modules_to_save=["embed_tokens"],
target_modules=["linear"],
ensure_weight_tieing=True,
)
model = get_peft_model(model, embed_token_config)

assert isinstance(model.base_model.model.model.embed_tokens, ModulesToSaveWrapper), (
"Embed tokens is not added in Modules to Save"
)
assert isinstance(model.base_model.model.lm_head, torch.nn.modules.linear.Linear), (
"LM head is not of type nn.linear"
)

def test_not_weight_tieing_tied_model(self):
from peft.utils.other import ModulesToSaveWrapper

model = self.get_lm_model()
embed_token_config = LoraConfig(
task_type="CAUSAL_LM",
modules_to_save=["embed_tokens"],
target_modules=["linear"],
ensure_weight_tieing=False,
)
model = get_peft_model(model, embed_token_config)

assert isinstance(model.base_model.model.model.embed_tokens, ModulesToSaveWrapper), (
"Embed tokens is not added in Modules to Save"
)
assert isinstance(model.base_model.model.lm_head, torch.nn.modules.linear.Linear), (
"LM head is not of type nn.linear"
)


class TestLokrInitialization:
torch_device = infer_device()
Expand Down