-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Ensure weight tying is maintained for embed_tokens and lm_head #2803
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Ensure weight tying is maintained for embed_tokens and lm_head #2803
Conversation
Signed-off-by: romit <[email protected]>
Signed-off-by: romit <[email protected]>
…ead-module-wrapper
@BenjaminBossan Request your review on the approach. If this looks fine, I can handle the other case (target_modules) and add tests. Thanks |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for working on the PR to deal with targeting tied weights.
I have a suspicion that parts of this PR can be implemented more easily by extending the peft_config.modules_to_save
to include the tied module instead of calling _set_trainable
explicitly.
However, before checking that, it makes most sense to add one or a few unit tests to establish if everything works as expected. I think the best place for those would be in this test class. LMK if you need help with implementing the tests.
Once we have the tests and confirm that everything works as expected, we can check if my proposal above works.
Signed-off-by: romit <[email protected]>
Signed-off-by: romit <[email protected]>
Signed-off-by: romit <[email protected]>
@BenjaminBossan I have addressed your comments. You mentioned that I can extend Essentially, we would need to mark the tied modules as modules to save and then update the weight pointer to point to the source adapter. PTAL. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding the tests, they LGTM.
You mentioned that I can extend peft_config.modules_to_save to enable the functionality. Can you elaborate on that?
Yes, so my idea was the following: I would like to avoid spreading the _set_trainable
calls further, as it makes it harder to reason what is set where. Instead, let's detect early if additional weights need to be added for modules_to_save
and extend the peft_config
if so. Then, when the PEFT model is created, we will automatically have a ModulesToSaveWrapper
everywhere it's needed. Next, let's update _set_trainable
such that it ties the weights of ModulesToSaveWrapper
(it is a bit wasteful to do that post hoc, but IMO more robust, as we don't rely on the order of the tied modules being created).
For this to work, I had to make a small change to the signature of _set_trainable
and move some of the logic you implemented around to different locations. Overall, this is the diff I came up with:
diff --git a/src/peft/mixed_model.py b/src/peft/mixed_model.py
index c3ad517e..3d608642 100644
--- a/src/peft/mixed_model.py
+++ b/src/peft/mixed_model.py
@@ -247,7 +247,7 @@ class PeftMixedModel(PushToHubMixin, torch.nn.Module):
self,
adapter_name,
module_names=getattr(peft_config, "modules_to_save", None),
- inference_mode=peft_config.inference_mode,
+ peft_config=peft_config,
)
def set_adapter(self, adapter_name: Union[str, list[str]], inference_mode: bool = False) -> None:
diff --git a/src/peft/peft_model.py b/src/peft/peft_model.py
index 36eccabb..73596487 100644
--- a/src/peft/peft_model.py
+++ b/src/peft/peft_model.py
@@ -1626,7 +1626,7 @@ class PeftModelForSequenceClassification(PeftModel):
self,
adapter_name,
module_names=getattr(peft_config, "modules_to_save", None),
- inference_mode=peft_config.inference_mode,
+ peft_config=peft_config,
)
def add_adapter(self, adapter_name: str, peft_config: PeftConfig, low_cpu_mem_usage: bool = False) -> None:
@@ -2482,7 +2482,7 @@ class PeftModelForTokenClassification(PeftModel):
self,
adapter_name,
module_names=getattr(peft_config, "modules_to_save", None),
- inference_mode=peft_config.inference_mode,
+ peft_config=peft_config,
)
def add_adapter(self, adapter_name: str, peft_config: PeftConfig, low_cpu_mem_usage: bool = False) -> None:
@@ -2703,7 +2703,7 @@ class PeftModelForQuestionAnswering(PeftModel):
self,
adapter_name,
module_names=getattr(peft_config, "modules_to_save", None),
- inference_mode=peft_config.inference_mode,
+ peft_config=peft_config,
)
def add_adapter(self, adapter_name: str, peft_config: PeftConfig, low_cpu_mem_usage: bool = False) -> None:
diff --git a/src/peft/tuners/lora/config.py b/src/peft/tuners/lora/config.py
index 36ad31ce..493a300b 100644
--- a/src/peft/tuners/lora/config.py
+++ b/src/peft/tuners/lora/config.py
@@ -663,6 +663,14 @@ class LoraConfig(PeftConfig):
arrow_config: Optional[ArrowConfig] = field(
default=None, metadata={"help": "The necessary config to apply arrow routing on the model."}
)
+ ensure_weight_tieing: bool = field(
+ default=False,
+ metadata={
+ "help": (
+ "Whether to tie weights or not after peft initialization.Only supported for `task_type` == CAUSAL_LM"
+ )
+ },
+ )
def to_dict(self):
"""
diff --git a/src/peft/tuners/tuners_utils.py b/src/peft/tuners/tuners_utils.py
index 4fd0d128..8ec6e0af 100644
--- a/src/peft/tuners/tuners_utils.py
+++ b/src/peft/tuners/tuners_utils.py
@@ -647,6 +647,30 @@ class BaseTuner(nn.Module, ABC):
) -> None:
raise NotImplementedError(f"{self.__class__.__name__} does not support targeting nn.Parameter.")
+ def _check_tied_modules(self, peft_config: PeftConfig, model_config: dict, model: nn.Module) -> PeftConfig:
+ """TODO"""
+ if not getattr(peft_config, "ensure_weight_tieing", False):
+ return peft_config
+
+ from transformers.modeling_utils import _get_tied_weight_keys # TODO hmm, maybe just copy that function
+
+ tied_weight_keys = _get_tied_weight_keys(model)
+
+ if not (
+ (peft_config.task_type == TaskType.CAUSAL_LM)
+ and tied_weight_keys
+ and (getattr(peft_config, "modules_to_save", None) is not None)
+ ):
+ warnings.warn("TODO") # user wants to tie weights but we can't, there should be a warning
+ return peft_config
+
+ modules_to_save = set(peft_config.modules_to_save)
+ missing_keys = set(tied_weight_keys) - modules_to_save
+ for key in missing_keys:
+ module_name = key.rpartition(".")[0] # remove the parameter name, e.g. "lm_head.weight" => "lm_head"
+ peft_config.modules_to_save.append(module_name)
+ return peft_config
+
def inject_adapter(
self,
model: nn.Module,
@@ -693,6 +717,7 @@ class BaseTuner(nn.Module, ABC):
model_config = self.get_model_config(model)
peft_config = self._prepare_adapter_config(peft_config, model_config)
+ peft_config = self._check_tied_modules(peft_config, model_config, model)
self._prepare_model(peft_config, model)
diff --git a/src/peft/utils/other.py b/src/peft/utils/other.py
index fe7fb83c..ed7b29a8 100644
--- a/src/peft/utils/other.py
+++ b/src/peft/utils/other.py
@@ -941,7 +941,7 @@ def _set_trainable(
model,
adapter_name,
module_names,
- inference_mode: bool,
+ peft_config,
strict_module_check: bool = False,
wrapper_cls: Optional[AuxiliaryTrainingWrapper] = None,
activate_adapter: bool = True,
@@ -958,7 +958,7 @@ def _set_trainable(
The default is to wrap the module in a `ModulesToSaveWrapper` wrapper.
If `strict_module_check` is set, this method raises an ValueError, similar to BaseTuner.inject_adapter when none of
- the requested modules in `module_names` is not found in the model.
+ the requested modules in `module_names` are found in the model.
The `active_adapter` flag indicates if this new adapter should be activated.
"""
@@ -974,6 +974,7 @@ def _set_trainable(
trainable_modules = []
found_modules = set()
+ inference_mode = peft_config.inference_mode
# disable removal of duplicates to support targeting tied weights
key_list = [key for key, _ in model.named_modules(remove_duplicate=False)]
@@ -1008,6 +1009,20 @@ def _set_trainable(
trainable_modules.append(new_module)
found_modules.add(target_name)
+ # deal with tied weights for modules_to_save
+ if getattr(peft_config, "ensure_weight_tieing", False) and hasattr(model, "get_input_embeddings") and issubclass(wrapper_cls, ModulesToSaveWrapper):
+ from transformers.modeling_utils import _get_tied_weight_keys # TODO hmm, maybe just copy that function
+
+ tied_weight_keys = _get_tied_weight_keys(model)
+ # remove the parameter name, e.g. "lm_head.weight" => "lm_head"
+ tied_module_names = {key.rpartition(".")[0] for key in tied_weight_keys}
+ orig_module = getattr(model.get_input_embeddings().modules_to_save, adapter_name)
+ for module_name in found_modules:
+ if issubclass(wrapper_cls, ModulesToSaveWrapper) and (module_name in tied_module_names):
+ tied_module = model.get_submodule(module_name)
+ del tied_module.modules_to_save[adapter_name].weight
+ tied_module.modules_to_save[adapter_name].register_parameter("weight", orig_module.weight)
+
not_found = set(module_names).difference(found_modules)
if strict_module_check and not found_modules:
raise ValueError(
@@ -1397,7 +1412,7 @@ def set_additional_trainable_modules(model, peft_config, model_config, adapter_n
_set_trainable(
model,
adapter_name,
- inference_mode=peft_config.inference_mode,
+ peft_config=peft_config,
module_names=getattr(peft_config, "modules_to_save", None),
activate_adapter=activate_adapter,
)
@@ -1423,7 +1438,7 @@ def set_additional_trainable_modules(model, peft_config, model_config, adapter_n
_set_trainable(
model,
adapter_name,
- inference_mode=peft_config.inference_mode,
+ peft_config=peft_config,
module_names=[target_layer],
strict_module_check=True,
wrapper_cls=TrainableTokensWrapper,
@@ -1447,7 +1462,7 @@ def set_additional_trainable_modules(model, peft_config, model_config, adapter_n
_set_trainable(
model,
adapter_name,
- inference_mode=peft_config.inference_mode,
+ peft_config=peft_config,
module_names=module_keys,
strict_module_check=True,
wrapper_cls=TrainableTokensWrapper,
Testing locally, the tests you added pass. As you can see, I made some small changes to the implementation compared to what you did. Most notably, I changed model._tied_weights_keys
to _get_tied_weight_keys(model)
, where _get_tied_weight_keys
is from transformers. Of course, it's risky to use this private function, probably it's better to copy it over to PEFT.
Anyway, LMK what you think of this proposal.
Apart from this, I made some smaller comments and also please not that the correct spelling is "tying", not "tieing".
src/peft/peft_model.py
Outdated
# in `modules_to_save` and we want to ensure the `lm_head` | ||
# does not diverge from the `embed_tokens` layer | ||
if ( | ||
peft_config.task_type == "CAUSAL_LM" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is it important?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updating lm_head
makes sense only for CAUSAL_LM
tasks. We can extend it to Seq2Seq.
I might be wrong here though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The proposed code is not really requiring there to be an lm_head
, right? In theory, it should also work when there are other tied weights. I'm not sure how relevant that is in practice, but if there are no strict reasons for this check, I'd say let's remove it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I'll make the changes such that we only worry about weight tying in case one of the tied layers is added in the modules to save
@BenjaminBossan I think what you suggested is a cleaner approach. I am aligned with that. I realized that the extra weights are being created in both situations, I will go ahead with the solution you proposed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think what you suggested is a cleaner approach. I am aligned with that.
The only concern I have is that we would first register the redundant weights (extra memory) for the tied modules and then replace the pointer post hoc. However, the extra memory might not be too big, so it should be okay? What do you think?I realized that the extra weights are being created in both situations, I will go ahead with the solution you proposed.
We could prevent those weights from being created, e.g. by using init_empty_weights
, but I'd be okay without having it.
src/peft/peft_model.py
Outdated
# in `modules_to_save` and we want to ensure the `lm_head` | ||
# does not diverge from the `embed_tokens` layer | ||
if ( | ||
peft_config.task_type == "CAUSAL_LM" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The proposed code is not really requiring there to be an lm_head
, right? In theory, it should also work when there are other tied weights. I'm not sure how relevant that is in practice, but if there are no strict reasons for this check, I'd say let's remove it.
Signed-off-by: romit <[email protected]>
Signed-off-by: romit <[email protected]>
I have made slight tweaks to your proposed approach. Instead of modifying On tests, I have covered the following cases
|
Thanks for your continued work on this PR. The idea I proposed was to create the I guess one could argue that this would be a user error and the user needs to take care to define |
I would add them to a separate list called
In my proposed flow, the modules in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for explaining further. To summarize how our proposals differ:
In my approach, I extend modules_to_save
, which results in ModulesToSaveWrapper
being applied to all necessary modules, and then do a second pass outside of the modules to tie weights.
In your approach, you define a new attribute at runtime, modules_to_tie
. Then you call _set_trainable
on those modules and pass an additional argument, the module to which the new ModulesToSaveWrapper
should be tied. Tie tying happens inside of ModulesToSaveWrapper
.
I can see some advantages and disadvantages to both approaches and can't confidently say one is better than the other. Thus I'm fine going with your approach for now, as it's already there.
I made a few suggestions on how to improve the PR, please check my comments.
src/peft/tuners/lora/config.py
Outdated
"help": ( | ||
"Whether to tie weights or not after peft initialization. " | ||
"This will ensure that the adapters added to the tied layers " | ||
"are also tied." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's mention that right now, this only applied to modules_to_save
, not the LoRA weights.
src/peft/tuners/tuners_utils.py
Outdated
# in a bad (half-initialized) state. | ||
self._check_new_adapter_config(peft_config) | ||
|
||
modules_to_save = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, so let's move this whole block into a sub-method on BaseTuner
, e.g. def _check_tied_modules
. On the BaseTuner
, if the peft_config
has modules_to_save
, you can do a basic check if modules_to_save
targets a tied module. This way, all PEFT methods with modules_to_save
benefit from having the warnings.
Then, on LoraModel
, override _check_tied_modules
with the logic below. There, you can remove the PeftType.LORA
check.
src/peft/utils/other.py
Outdated
context_manager = deepspeed.zero.GatheredParameters(self.original_module.parameters(), modifier_rank=0) | ||
break | ||
|
||
tied_module = kwargs.get("tied_module", None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's add tied_module
to the signature of init_modules
with default None
, instead of getting it from kwargs
(but you can leave kwargs
in the signature too, shouldn't hurt).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you mean update
? Since this peice of code lies in the update
method
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have added in update
src/peft/tuners/tuners_utils.py
Outdated
self._check_new_adapter_config(peft_config) | ||
|
||
modules_to_save = ( | ||
set(getattr(peft_config, "modules_to_save", [])) if getattr(peft_config, "modules_to_save", []) else set() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is enough, right?
set(getattr(peft_config, "modules_to_save", [])) if getattr(peft_config, "modules_to_save", []) else set() | |
set(getattr(peft_config, "modules_to_save", [])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
modules_to_save
is set as None
in src/peft/mixed_model.py
which results in an error
tests/test_initialization.py
Outdated
# is passed as a `modules_to_save`, it needs to be ensured | ||
# that lm_head is tied to the adapter added to `embed_tokens` | ||
|
||
from peft.utils.other import ModulesToSaveWrapper |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can be imported at root level instead of locally on each test.
tests/test_initialization.py
Outdated
ensure_weight_tying=True, | ||
) | ||
|
||
model = get_peft_model(model, embed_token_config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about checking that there is no warning related to weight tying here?
Co-authored-by: Benjamin Bossan <[email protected]>
Co-authored-by: Benjamin Bossan <[email protected]>
Co-authored-by: Benjamin Bossan <[email protected]>
…n/peft into romit/fix-lmhead-module-wrapper
Signed-off-by: romit <[email protected]>
@BenjaminBossan I have resolved the comments. Let me know if I missed something? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the updates, we're getting close. I added a few smaller comments, please check. Once you're done, don't forget to call make style
.
Did you mean update? Since this peice of code lies in the update method
Yes, I looked at the wrong line.
|
||
return tied_weight_keys | ||
|
||
def _check_tied_modules(self, model: nn.Module, peft_config): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO, it makes more sense to just do some basic checks here in this generic method of the BaseTuner
, and add a more LoRA-specific implementation to LoraModel
. That way, there is no need to check for peft_config.peft_type == PeftType.LORA
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't follow.
The _check_tied_modules
function is called inside inject_adaptors
which the LoraModel
does not inherit.
What you suggested makes sense, I can remove Lora specific changes from this function and handle that condition outside. I don't know where to place the call self._check_tied_modules(model, peft_config)
. Do you have any suggestions?
PS: I have addressed rest of your comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should work. LoraModel
inherits from BaseTuner
, thus it does have LoraModel.inject_adapter
. So if we have:
class BaseTuner:
def inject_adapter(...):
self.self._check_tied_modules(...)
...
def _check_tied_modules(...): ...
and
class LoraModel(BaseTuner):
def _check_tied_modules(...): ...
then when calling LoraModel.inject_adapter
, it calls BaseTuner.inject_adapter
, which in turn will call LoraModel._check_tied_modules
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it, so I should initialize _check_tied_modules
as an abstract method and implement in LoraModel
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, I was about to say that we can have some basic checks for existence of tied weights and modules_to_save
here and give a warning for non-LoRA; for LoRA, we would override the method and give more specific warnings. But now I wonder: Why only implement this for LoRA in the first place? It should work for other PEFT methods too, at least I don't see why it wouldn't.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that was my original thought as well, but since I have not worked with other PEFT methods apart from LoRA, I am not sure if I can comment if it would work or not just OOB.
Alternatively, we can expose this for LoRA (or other similar PEFT wherever you are confident) and open it for all PEFT based on feedback?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Your proposal makes sense, let's keep the scope small for now, and later, it can be extended to other PEFT methods if there is a need.
This means, let's keep just a generic check here in BaseTuner._check_tied_modules
and then in LoraModel._check_tied_modules
, do the more specific check.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have updated my implementation by adding an internal method _maybe_add_modules_to_tie
. This is the cleanest approach I could find. Let me know your thoughts!
Co-authored-by: Benjamin Bossan <[email protected]>
Signed-off-by: romit <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the latest updates.
I have updated my implementation by adding an internal method _maybe_add_modules_to_tie. This is the cleanest approach I could find. Let me know your thoughts!
Overall looks good, but I have a comment concerning one of the warning messages, WDYT?
if is_embedding_to_save and tied_weight_keys: | ||
msg = ( | ||
"Model has `tie_word_embeddings=True` and a tied layer is part of the adapter, " | ||
"but `ensure_weight_tying` is not set to True. " | ||
"This can lead to complications, for example when merging the adapter " | ||
"or converting your model to formats other than safetensors. " | ||
"Check the discussion here: https://github.com/huggingface/peft/issues/2777" | ||
) | ||
warnings.warn(msg) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this message has the potential to confuse users. ensure_weight_tying
is only implemented for LoRA right now but this message can also show up for non-LoRA methods, right?
This PR solves: #2777
Essentially, the PR:
LoraConfig
for users to maintain weight tying in the modelModulesToSaveWrapper
on modules returned in (2) and a reference to the original layerThis ensures that the weight tying is enabled if the user passes either
embed_tokens
orlm_head
.Limitations
LoraConfig
modules_to_save
parameter. This can be extended totarget_modules
as well once the approach is approved