Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions src/peft/tuners/lora/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,6 +663,17 @@ class LoraConfig(PeftConfig):
arrow_config: Optional[ArrowConfig] = field(
default=None, metadata={"help": "The necessary config to apply arrow routing on the model."}
)
ensure_weight_tying: bool = field(
default=False,
metadata={
"help": (
"Whether to tie weights or not after peft initialization. "
"This will ensure that the adapters added to the tied layers "
"are also tied. This is only applicable for layers passed via "
"`modules_to_save`."
)
},
)

def to_dict(self):
"""
Expand All @@ -681,6 +692,10 @@ def __post_init__(self):
self.exclude_modules = (
set(self.exclude_modules) if isinstance(self.exclude_modules, list) else self.exclude_modules
)

if self.ensure_weight_tying:
self.modules_to_tie = None

if isinstance(self.target_parameters, str):
raise TypeError("`target_parameters` must be a list of strings or None.")

Expand Down
6 changes: 6 additions & 0 deletions src/peft/tuners/lora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,3 +801,9 @@ def subtract_mutated_init(self, output_state_dict: dict[str, torch.Tensor], adap
)

return tensors_lora

def _maybe_add_modules_to_tie(self, peft_config, tied_weight_keys):
modules_to_save = set(getattr(peft_config, "modules_to_save", []) or [])
missing_keys = set(tied_weight_keys) - modules_to_save

peft_config.modules_to_tie = missing_keys
73 changes: 73 additions & 0 deletions src/peft/tuners/tuners_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,6 +690,8 @@ def inject_adapter(
# in a bad (half-initialized) state.
self._check_new_adapter_config(peft_config)

self._check_tied_modules(model, peft_config)

model_config = self.get_model_config(model)

peft_config = self._prepare_adapter_config(peft_config, model_config)
Expand Down Expand Up @@ -1154,6 +1156,77 @@ def _get_tied_target_modules(self, model: nn.Module) -> list[str]:
tied_target_modules.append(target_module)
return tied_target_modules

def _get_tied_weight_keys(self, model: nn.Module, prefix="") -> list[str]:
"""
Get the list of modules that needs to be tied
For example: For models which have `embed_tokens` and `lm_head` as the tied keys this function will return
[`lm_head`]
From: https://github.com/huggingface/transformers/blob/v4.57.0/src/transformers/modeling_utils.py#L563
"""
tied_weight_keys = []
if getattr(model, "_tied_weights_keys", None) is not None:
names = [f"{prefix}.{k}" if prefix else k for k in model._tied_weights_keys]
tied_weight_keys.extend(names)
if getattr(model, "_dynamic_tied_weights_keys", None) is not None:
names = [f"{prefix}.{k}" if prefix else k for k in model._dynamic_tied_weights_keys]
tied_weight_keys.extend(names)
for name, submodule in model.named_children():
local_prefix = f"{prefix}.{name}" if prefix else name
tied_weight_keys.extend(self._get_tied_weight_keys(submodule, prefix=local_prefix))

tied_weight_keys = [".".join(n.split(".")[:-1]) for n in tied_weight_keys]

return tied_weight_keys

def _maybe_add_modules_to_tie(self, peft_config, tied_weight_keys):
"""
This method adds modules to tie to `peft_config` so that those modules can be tied downstream. By default this
method raises a warning, and each tuner class extending `BaseTuner` can choose to implement this.
"""
msg = (
"Model has `tie_word_embeddings=True` and a tied layer is part of the adapter, "
"but no implementation exists to tie the adapters. "
"This can lead to complications, for example when merging the adapter "
"or converting your model to formats other than safetensors. "
"Check the discussion here: https://github.com/huggingface/peft/issues/2777"
)
warnings.warn(msg)

def _check_tied_modules(self, model: nn.Module, peft_config):
"""
Checks if any of the tied layers are targetted via `modules_to_save` Updates the `peft_config.modules_to_tie`
with any layers that needs to be tied
"""
modules_to_save = set(getattr(peft_config, "modules_to_save", []) or [])
is_embedding_to_save = any(m in EMBEDDING_LAYER_NAMES for m in modules_to_save)

tied_weight_keys = self._get_tied_weight_keys(model)

if getattr(peft_config, "ensure_weight_tying", False):
if is_embedding_to_save and tied_weight_keys:
self._maybe_add_modules_to_tie(peft_config, tied_weight_keys)

elif not is_embedding_to_save and tied_weight_keys:
warnings.warn(
"You have requested `ensure_weight_tying`, but no tied modules are added in `modules_to_save`"
)

elif not tied_weight_keys:
warnings.warn("You have requested `ensure_weight_tying`, but no tied modules were found in the model")

else:
if is_embedding_to_save and tied_weight_keys:
msg = (
"Model has `tie_word_embeddings=True` and a tied layer is part of the adapter, "
"but `ensure_weight_tying` is not set to True. "
"This can lead to complications, for example when merging the adapter "
"or converting your model to formats other than safetensors. "
"Check the discussion here: https://github.com/huggingface/peft/issues/2777"
)
warnings.warn(msg)
Comment on lines +1220 to +1228
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this message has the potential to confuse users. ensure_weight_tying is only implemented for LoRA right now but this message can also show up for non-LoRA methods, right?


def __getattr__(self, name: str):
"""Forward missing attributes to the wrapped module."""
try:
Expand Down
30 changes: 25 additions & 5 deletions src/peft/utils/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,10 +508,10 @@ class ModulesToSaveWrapper(AuxiliaryTrainingWrapper):
# All names of layers that may contain adapter (trainable) weights
adapter_layer_names: tuple[str, ...] = ("modules_to_save",)

def __init__(self, module_to_save, adapter_name):
super().__init__(module_to_save, adapter_name)
def __init__(self, module_to_save, adapter_name, tied_module=None):
super().__init__(module_to_save, adapter_name, tied_module=tied_module)

def init_modules(self, adapter_name):
def init_modules(self, adapter_name, **kwargs):
# we treat each adapter separately, so we have multiple adapters, same (copied) module for each
self.modules_to_save = torch.nn.ModuleDict({})

Expand All @@ -535,7 +535,7 @@ def _hasattr_wrapped(self, name, modules):
def _getattr_wrapped(self, name, modules):
return getattr(modules["modules_to_save"][self.active_adapters[0]], name)

def update(self, adapter_name, **kwargs):
def update(self, adapter_name, tied_module=None, **kwargs):
super().update(adapter_name)

context_manager = nullcontext()
Expand All @@ -550,7 +550,13 @@ def update(self, adapter_name, **kwargs):

if adapter_name not in self.modules_to_save:
with context_manager:
self.modules_to_save[adapter_name] = copy.deepcopy(self.original_module)
if tied_module:
new_linear = torch.nn.Linear(*tied_module.weight.shape, bias=False)
new_linear.weight = tied_module.weight

self.modules_to_save[adapter_name] = new_linear
else:
self.modules_to_save[adapter_name] = copy.deepcopy(self.original_module)

if hasattr(self.modules_to_save[adapter_name], "_hf_hook"):
old_hook = self.modules_to_save[adapter_name]._hf_hook
Expand Down Expand Up @@ -1402,6 +1408,20 @@ def set_additional_trainable_modules(model, peft_config, model_config, adapter_n
activate_adapter=activate_adapter,
)

if getattr(peft_config, "modules_to_tie", None) is not None:
# Tie the modules if any tied layer is passed in `modules_to_save`.
# This should always be called after
# `_set_trainable` is called for `modules_to_save`.
tied_module = getattr(model.get_input_embeddings().modules_to_save, adapter_name)
_set_trainable(
model,
adapter_name,
inference_mode=peft_config.inference_mode,
module_names=getattr(peft_config, "modules_to_tie", None),
activate_adapter=activate_adapter,
tied_module=tied_module,
)

if getattr(peft_config, "trainable_token_indices", None) is not None:
if isinstance(peft_config.trainable_token_indices, dict):
target_layers = peft_config.trainable_token_indices
Expand Down
121 changes: 121 additions & 0 deletions tests/test_initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
from peft.tuners.lora.layer import LoraLayer
from peft.utils import infer_device
from peft.utils.hotswap import hotswap_adapter, prepare_model_for_compiled_hotswap
from peft.utils.other import ModulesToSaveWrapper

from .testing_utils import load_dataset_english_quotes, require_deterministic_for_xpu

Expand Down Expand Up @@ -110,6 +111,39 @@ def forward(self, x):

return MyModule().eval().to(self.torch_device)

def get_lm_model(self, bias=True, tie_weights=True):
# Mimicking a LM with embed_tokens and lm_head layers
# to test weight tying of adapters
class MyModule(nn.Module):
def __init__(self):
super().__init__()

self.embed_tokens = nn.Embedding(1000, 1000)
self.linear = nn.Linear(1000, 1000, bias=bias)

class CausalLM(nn.Module):
if tie_weights:
_tied_weights_keys = ["lm_head.weight"]

def __init__(self):
super().__init__()
self.model = MyModule()
self.config = {"tie_word_embeddings": tie_weights}

if tie_weights:
self.lm_head = nn.Linear(1000, 1000, bias=False)
self.lm_head.weight = self.model.embed_tokens.weight
else:
self.lm_head = nn.Linear(1000, 1000, bias=bias)

def prepare_inputs_for_generation(self):
return

def get_input_embeddings(self):
return self.model.embed_tokens

return CausalLM().eval().to(self.torch_device)

@pytest.fixture
def data(self):
return torch.rand(10, 1000).to(self.torch_device)
Expand Down Expand Up @@ -1566,6 +1600,93 @@ def test_multiple_configs_with_bias_raises(self, tmp_path):
config2 = LoraConfig(target_modules=["linear"], bias="none")
model.add_adapter("other", config2) # does not raise

def test_weight_tying_tied_model(self):
# If weight tying is enabled and `embed_tokens`
# is passed as a `modules_to_save`, it needs to be ensured
# that lm_head is tied to the adapter added to `embed_tokens`

model = self.get_lm_model()
embed_token_config = LoraConfig(
modules_to_save=["embed_tokens"],
target_modules=["linear"],
ensure_weight_tying=True,
)
model = get_peft_model(model, embed_token_config)

assert isinstance(model.base_model.model.model.embed_tokens, ModulesToSaveWrapper), (
"Embed tokens is not added in Modules to Save"
)
assert type(model.base_model.model.model.embed_tokens) is type(model.base_model.model.lm_head), (
"Embed tokens and LM head types are not same"
)

# Validating that all model parameters are same
embed_np = dict(model.base_model.model.model.embed_tokens.named_parameters())
lm_head_np = dict(model.base_model.model.lm_head.named_parameters())

for k in embed_np.keys():
assert torch.allclose(embed_np[k], lm_head_np[k])
assert embed_np[k] is lm_head_np[k]

def test_weight_tying_non_tied_model(self):
from peft.utils.other import ModulesToSaveWrapper

model = self.get_lm_model(tie_weights=False)
embed_token_config = LoraConfig(
modules_to_save=["embed_tokens"],
target_modules=["linear"],
ensure_weight_tying=True,
)
with pytest.warns(UserWarning, match="no tied modules were found in the model"):
model = get_peft_model(model, embed_token_config)

assert isinstance(model.base_model.model.model.embed_tokens, ModulesToSaveWrapper), (
"Embed tokens is not added in Modules to Save"
)
assert isinstance(model.base_model.model.lm_head, torch.nn.modules.linear.Linear), (
"LM head is not of type nn.linear"
)

def test_not_weight_tying_tied_model(self):
from peft.utils.other import ModulesToSaveWrapper

model = self.get_lm_model()
embed_token_config = LoraConfig(
modules_to_save=["embed_tokens"],
target_modules=["linear"],
ensure_weight_tying=False,
)
with pytest.warns(UserWarning, match="`ensure_weight_tying` is not set to True"):
model = get_peft_model(model, embed_token_config)

assert isinstance(model.base_model.model.model.embed_tokens, ModulesToSaveWrapper), (
"Embed tokens is not added in Modules to Save"
)
assert isinstance(model.base_model.model.lm_head, torch.nn.modules.linear.Linear), (
"LM head is not of type nn.linear"
)

def test_weight_tying_tied_model_no_embed(self):
model = self.get_lm_model()
embed_token_config = LoraConfig(
target_modules=["linear"],
ensure_weight_tying=True,
)

with pytest.warns(UserWarning, match="no tied modules are added in `modules_to_save`"):
model = get_peft_model(model, embed_token_config)

assert isinstance(model.base_model.model.model.embed_tokens, torch.nn.modules.Embedding)
assert isinstance(model.base_model.model.lm_head, torch.nn.modules.linear.Linear)

# Validating that all model parameters are same
embed_np = dict(model.base_model.model.model.embed_tokens.named_parameters())
lm_head_np = dict(model.base_model.model.lm_head.named_parameters())

for k in embed_np.keys():
assert torch.allclose(embed_np[k], lm_head_np[k])
assert embed_np[k] is lm_head_np[k]


class TestLokrInitialization:
torch_device = infer_device()
Expand Down