Ensure weight tying is maintained for embed_tokens and lm_head#2803
Conversation
Signed-off-by: romit <romit@ibm.com>
Signed-off-by: romit <romit@ibm.com>
…ead-module-wrapper
|
@BenjaminBossan Request your review on the approach. If this looks fine, I can handle the other case (target_modules) and add tests. Thanks |
BenjaminBossan
left a comment
There was a problem hiding this comment.
Thanks for working on the PR to deal with targeting tied weights.
I have a suspicion that parts of this PR can be implemented more easily by extending the peft_config.modules_to_save to include the tied module instead of calling _set_trainable explicitly.
However, before checking that, it makes most sense to add one or a few unit tests to establish if everything works as expected. I think the best place for those would be in this test class. LMK if you need help with implementing the tests.
Once we have the tests and confirm that everything works as expected, we can check if my proposal above works.
Signed-off-by: romit <romit@ibm.com>
Signed-off-by: romit <romit@ibm.com>
Signed-off-by: romit <romit@ibm.com>
|
@BenjaminBossan I have addressed your comments. You mentioned that I can extend Essentially, we would need to mark the tied modules as modules to save and then update the weight pointer to point to the source adapter. PTAL. |
BenjaminBossan
left a comment
There was a problem hiding this comment.
Thanks for adding the tests, they LGTM.
You mentioned that I can extend peft_config.modules_to_save to enable the functionality. Can you elaborate on that?
Yes, so my idea was the following: I would like to avoid spreading the _set_trainable calls further, as it makes it harder to reason what is set where. Instead, let's detect early if additional weights need to be added for modules_to_save and extend the peft_config if so. Then, when the PEFT model is created, we will automatically have a ModulesToSaveWrapper everywhere it's needed. Next, let's update _set_trainable such that it ties the weights of ModulesToSaveWrapper (it is a bit wasteful to do that post hoc, but IMO more robust, as we don't rely on the order of the tied modules being created).
For this to work, I had to make a small change to the signature of _set_trainable and move some of the logic you implemented around to different locations. Overall, this is the diff I came up with:
diff --git a/src/peft/mixed_model.py b/src/peft/mixed_model.py
index c3ad517e..3d608642 100644
--- a/src/peft/mixed_model.py
+++ b/src/peft/mixed_model.py
@@ -247,7 +247,7 @@ class PeftMixedModel(PushToHubMixin, torch.nn.Module):
self,
adapter_name,
module_names=getattr(peft_config, "modules_to_save", None),
- inference_mode=peft_config.inference_mode,
+ peft_config=peft_config,
)
def set_adapter(self, adapter_name: Union[str, list[str]], inference_mode: bool = False) -> None:
diff --git a/src/peft/peft_model.py b/src/peft/peft_model.py
index 36eccabb..73596487 100644
--- a/src/peft/peft_model.py
+++ b/src/peft/peft_model.py
@@ -1626,7 +1626,7 @@ class PeftModelForSequenceClassification(PeftModel):
self,
adapter_name,
module_names=getattr(peft_config, "modules_to_save", None),
- inference_mode=peft_config.inference_mode,
+ peft_config=peft_config,
)
def add_adapter(self, adapter_name: str, peft_config: PeftConfig, low_cpu_mem_usage: bool = False) -> None:
@@ -2482,7 +2482,7 @@ class PeftModelForTokenClassification(PeftModel):
self,
adapter_name,
module_names=getattr(peft_config, "modules_to_save", None),
- inference_mode=peft_config.inference_mode,
+ peft_config=peft_config,
)
def add_adapter(self, adapter_name: str, peft_config: PeftConfig, low_cpu_mem_usage: bool = False) -> None:
@@ -2703,7 +2703,7 @@ class PeftModelForQuestionAnswering(PeftModel):
self,
adapter_name,
module_names=getattr(peft_config, "modules_to_save", None),
- inference_mode=peft_config.inference_mode,
+ peft_config=peft_config,
)
def add_adapter(self, adapter_name: str, peft_config: PeftConfig, low_cpu_mem_usage: bool = False) -> None:
diff --git a/src/peft/tuners/lora/config.py b/src/peft/tuners/lora/config.py
index 36ad31ce..493a300b 100644
--- a/src/peft/tuners/lora/config.py
+++ b/src/peft/tuners/lora/config.py
@@ -663,6 +663,14 @@ class LoraConfig(PeftConfig):
arrow_config: Optional[ArrowConfig] = field(
default=None, metadata={"help": "The necessary config to apply arrow routing on the model."}
)
+ ensure_weight_tieing: bool = field(
+ default=False,
+ metadata={
+ "help": (
+ "Whether to tie weights or not after peft initialization.Only supported for `task_type` == CAUSAL_LM"
+ )
+ },
+ )
def to_dict(self):
"""
diff --git a/src/peft/tuners/tuners_utils.py b/src/peft/tuners/tuners_utils.py
index 4fd0d128..8ec6e0af 100644
--- a/src/peft/tuners/tuners_utils.py
+++ b/src/peft/tuners/tuners_utils.py
@@ -647,6 +647,30 @@ class BaseTuner(nn.Module, ABC):
) -> None:
raise NotImplementedError(f"{self.__class__.__name__} does not support targeting nn.Parameter.")
+ def _check_tied_modules(self, peft_config: PeftConfig, model_config: dict, model: nn.Module) -> PeftConfig:
+ """TODO"""
+ if not getattr(peft_config, "ensure_weight_tieing", False):
+ return peft_config
+
+ from transformers.modeling_utils import _get_tied_weight_keys # TODO hmm, maybe just copy that function
+
+ tied_weight_keys = _get_tied_weight_keys(model)
+
+ if not (
+ (peft_config.task_type == TaskType.CAUSAL_LM)
+ and tied_weight_keys
+ and (getattr(peft_config, "modules_to_save", None) is not None)
+ ):
+ warnings.warn("TODO") # user wants to tie weights but we can't, there should be a warning
+ return peft_config
+
+ modules_to_save = set(peft_config.modules_to_save)
+ missing_keys = set(tied_weight_keys) - modules_to_save
+ for key in missing_keys:
+ module_name = key.rpartition(".")[0] # remove the parameter name, e.g. "lm_head.weight" => "lm_head"
+ peft_config.modules_to_save.append(module_name)
+ return peft_config
+
def inject_adapter(
self,
model: nn.Module,
@@ -693,6 +717,7 @@ class BaseTuner(nn.Module, ABC):
model_config = self.get_model_config(model)
peft_config = self._prepare_adapter_config(peft_config, model_config)
+ peft_config = self._check_tied_modules(peft_config, model_config, model)
self._prepare_model(peft_config, model)
diff --git a/src/peft/utils/other.py b/src/peft/utils/other.py
index fe7fb83c..ed7b29a8 100644
--- a/src/peft/utils/other.py
+++ b/src/peft/utils/other.py
@@ -941,7 +941,7 @@ def _set_trainable(
model,
adapter_name,
module_names,
- inference_mode: bool,
+ peft_config,
strict_module_check: bool = False,
wrapper_cls: Optional[AuxiliaryTrainingWrapper] = None,
activate_adapter: bool = True,
@@ -958,7 +958,7 @@ def _set_trainable(
The default is to wrap the module in a `ModulesToSaveWrapper` wrapper.
If `strict_module_check` is set, this method raises an ValueError, similar to BaseTuner.inject_adapter when none of
- the requested modules in `module_names` is not found in the model.
+ the requested modules in `module_names` are found in the model.
The `active_adapter` flag indicates if this new adapter should be activated.
"""
@@ -974,6 +974,7 @@ def _set_trainable(
trainable_modules = []
found_modules = set()
+ inference_mode = peft_config.inference_mode
# disable removal of duplicates to support targeting tied weights
key_list = [key for key, _ in model.named_modules(remove_duplicate=False)]
@@ -1008,6 +1009,20 @@ def _set_trainable(
trainable_modules.append(new_module)
found_modules.add(target_name)
+ # deal with tied weights for modules_to_save
+ if getattr(peft_config, "ensure_weight_tieing", False) and hasattr(model, "get_input_embeddings") and issubclass(wrapper_cls, ModulesToSaveWrapper):
+ from transformers.modeling_utils import _get_tied_weight_keys # TODO hmm, maybe just copy that function
+
+ tied_weight_keys = _get_tied_weight_keys(model)
+ # remove the parameter name, e.g. "lm_head.weight" => "lm_head"
+ tied_module_names = {key.rpartition(".")[0] for key in tied_weight_keys}
+ orig_module = getattr(model.get_input_embeddings().modules_to_save, adapter_name)
+ for module_name in found_modules:
+ if issubclass(wrapper_cls, ModulesToSaveWrapper) and (module_name in tied_module_names):
+ tied_module = model.get_submodule(module_name)
+ del tied_module.modules_to_save[adapter_name].weight
+ tied_module.modules_to_save[adapter_name].register_parameter("weight", orig_module.weight)
+
not_found = set(module_names).difference(found_modules)
if strict_module_check and not found_modules:
raise ValueError(
@@ -1397,7 +1412,7 @@ def set_additional_trainable_modules(model, peft_config, model_config, adapter_n
_set_trainable(
model,
adapter_name,
- inference_mode=peft_config.inference_mode,
+ peft_config=peft_config,
module_names=getattr(peft_config, "modules_to_save", None),
activate_adapter=activate_adapter,
)
@@ -1423,7 +1438,7 @@ def set_additional_trainable_modules(model, peft_config, model_config, adapter_n
_set_trainable(
model,
adapter_name,
- inference_mode=peft_config.inference_mode,
+ peft_config=peft_config,
module_names=[target_layer],
strict_module_check=True,
wrapper_cls=TrainableTokensWrapper,
@@ -1447,7 +1462,7 @@ def set_additional_trainable_modules(model, peft_config, model_config, adapter_n
_set_trainable(
model,
adapter_name,
- inference_mode=peft_config.inference_mode,
+ peft_config=peft_config,
module_names=module_keys,
strict_module_check=True,
wrapper_cls=TrainableTokensWrapper,Testing locally, the tests you added pass. As you can see, I made some small changes to the implementation compared to what you did. Most notably, I changed model._tied_weights_keys to _get_tied_weight_keys(model), where _get_tied_weight_keys is from transformers. Of course, it's risky to use this private function, probably it's better to copy it over to PEFT.
Anyway, LMK what you think of this proposal.
Apart from this, I made some smaller comments and also please not that the correct spelling is "tying", not "tieing".
src/peft/peft_model.py
Outdated
| # in `modules_to_save` and we want to ensure the `lm_head` | ||
| # does not diverge from the `embed_tokens` layer | ||
| if ( | ||
| peft_config.task_type == "CAUSAL_LM" |
There was a problem hiding this comment.
Updating lm_head makes sense only for CAUSAL_LM tasks. We can extend it to Seq2Seq.
I might be wrong here though.
There was a problem hiding this comment.
The proposed code is not really requiring there to be an lm_head, right? In theory, it should also work when there are other tied weights. I'm not sure how relevant that is in practice, but if there are no strict reasons for this check, I'd say let's remove it.
There was a problem hiding this comment.
Yes, I'll make the changes such that we only worry about weight tying in case one of the tied layers is added in the modules to save
|
@BenjaminBossan I think what you suggested is a cleaner approach. I am aligned with that. I realized that the extra weights are being created in both situations, I will go ahead with the solution you proposed. |
BenjaminBossan
left a comment
There was a problem hiding this comment.
I think what you suggested is a cleaner approach. I am aligned with that.
The only concern I have is that we would first register the redundant weights (extra memory) for the tied modules and then replace the pointer post hoc. However, the extra memory might not be too big, so it should be okay? What do you think?I realized that the extra weights are being created in both situations, I will go ahead with the solution you proposed.
We could prevent those weights from being created, e.g. by using init_empty_weights, but I'd be okay without having it.
src/peft/peft_model.py
Outdated
| # in `modules_to_save` and we want to ensure the `lm_head` | ||
| # does not diverge from the `embed_tokens` layer | ||
| if ( | ||
| peft_config.task_type == "CAUSAL_LM" |
There was a problem hiding this comment.
The proposed code is not really requiring there to be an lm_head, right? In theory, it should also work when there are other tied weights. I'm not sure how relevant that is in practice, but if there are no strict reasons for this check, I'd say let's remove it.
Signed-off-by: romit <romit@ibm.com>
Signed-off-by: romit <romit@ibm.com>
|
I have made slight tweaks to your proposed approach. Instead of modifying On tests, I have covered the following cases |
|
Thanks for your continued work on this PR. The idea I proposed was to create the I guess one could argue that this would be a user error and the user needs to take care to define |
I would add them to a separate list called
In my proposed flow, the modules in |
BenjaminBossan
left a comment
There was a problem hiding this comment.
Thanks for explaining further. To summarize how our proposals differ:
In my approach, I extend modules_to_save, which results in ModulesToSaveWrapper being applied to all necessary modules, and then do a second pass outside of the modules to tie weights.
In your approach, you define a new attribute at runtime, modules_to_tie. Then you call _set_trainable on those modules and pass an additional argument, the module to which the new ModulesToSaveWrapper should be tied. Tie tying happens inside of ModulesToSaveWrapper.
I can see some advantages and disadvantages to both approaches and can't confidently say one is better than the other. Thus I'm fine going with your approach for now, as it's already there.
I made a few suggestions on how to improve the PR, please check my comments.
src/peft/tuners/lora/config.py
Outdated
| "help": ( | ||
| "Whether to tie weights or not after peft initialization. " | ||
| "This will ensure that the adapters added to the tied layers " | ||
| "are also tied." |
There was a problem hiding this comment.
Let's mention that right now, this only applied to modules_to_save, not the LoRA weights.
src/peft/tuners/tuners_utils.py
Outdated
| # in a bad (half-initialized) state. | ||
| self._check_new_adapter_config(peft_config) | ||
|
|
||
| modules_to_save = ( |
There was a problem hiding this comment.
Okay, so let's move this whole block into a sub-method on BaseTuner, e.g. def _check_tied_modules. On the BaseTuner, if the peft_config has modules_to_save, you can do a basic check if modules_to_save targets a tied module. This way, all PEFT methods with modules_to_save benefit from having the warnings.
Then, on LoraModel, override _check_tied_modules with the logic below. There, you can remove the PeftType.LORA check.
src/peft/utils/other.py
Outdated
| context_manager = deepspeed.zero.GatheredParameters(self.original_module.parameters(), modifier_rank=0) | ||
| break | ||
|
|
||
| tied_module = kwargs.get("tied_module", None) |
There was a problem hiding this comment.
Let's add tied_module to the signature of init_modules with default None, instead of getting it from kwargs (but you can leave kwargs in the signature too, shouldn't hurt).
There was a problem hiding this comment.
Did you mean update? Since this peice of code lies in the update method
There was a problem hiding this comment.
I have added in update
src/peft/tuners/tuners_utils.py
Outdated
| self._check_new_adapter_config(peft_config) | ||
|
|
||
| modules_to_save = ( | ||
| set(getattr(peft_config, "modules_to_save", [])) if getattr(peft_config, "modules_to_save", []) else set() |
There was a problem hiding this comment.
This is enough, right?
| set(getattr(peft_config, "modules_to_save", [])) if getattr(peft_config, "modules_to_save", []) else set() | |
| set(getattr(peft_config, "modules_to_save", [])) |
There was a problem hiding this comment.
modules_to_save is set as None in src/peft/mixed_model.py which results in an error
tests/test_initialization.py
Outdated
| # is passed as a `modules_to_save`, it needs to be ensured | ||
| # that lm_head is tied to the adapter added to `embed_tokens` | ||
|
|
||
| from peft.utils.other import ModulesToSaveWrapper |
There was a problem hiding this comment.
Can be imported at root level instead of locally on each test.
tests/test_initialization.py
Outdated
| ensure_weight_tying=True, | ||
| ) | ||
|
|
||
| model = get_peft_model(model, embed_token_config) |
There was a problem hiding this comment.
How about checking that there is no warning related to weight tying here?
Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
…n/peft into romit/fix-lmhead-module-wrapper
Signed-off-by: romit <romit@ibm.com>
|
@BenjaminBossan I have resolved the comments. Let me know if I missed something? |
BenjaminBossan
left a comment
There was a problem hiding this comment.
Thanks for the updates, we're getting close. I added a few smaller comments, please check. Once you're done, don't forget to call make style.
Did you mean update? Since this peice of code lies in the update method
Yes, I looked at the wrong line.
|
|
||
| return tied_weight_keys | ||
|
|
||
| def _check_tied_modules(self, model: nn.Module, peft_config): |
There was a problem hiding this comment.
IMO, it makes more sense to just do some basic checks here in this generic method of the BaseTuner, and add a more LoRA-specific implementation to LoraModel. That way, there is no need to check for peft_config.peft_type == PeftType.LORA.
There was a problem hiding this comment.
I don't follow.
The _check_tied_modules function is called inside inject_adaptors which the LoraModel does not inherit.
What you suggested makes sense, I can remove Lora specific changes from this function and handle that condition outside. I don't know where to place the call self._check_tied_modules(model, peft_config). Do you have any suggestions?
PS: I have addressed rest of your comments
There was a problem hiding this comment.
This should work. LoraModel inherits from BaseTuner, thus it does have LoraModel.inject_adapter. So if we have:
class BaseTuner:
def inject_adapter(...):
self.self._check_tied_modules(...)
...
def _check_tied_modules(...): ...and
class LoraModel(BaseTuner):
def _check_tied_modules(...): ...then when calling LoraModel.inject_adapter, it calls BaseTuner.inject_adapter, which in turn will call LoraModel._check_tied_modules.
There was a problem hiding this comment.
Got it, so I should initialize _check_tied_modules as an abstract method and implement in LoraModel.
There was a problem hiding this comment.
Hmm, I was about to say that we can have some basic checks for existence of tied weights and modules_to_save here and give a warning for non-LoRA; for LoRA, we would override the method and give more specific warnings. But now I wonder: Why only implement this for LoRA in the first place? It should work for other PEFT methods too, at least I don't see why it wouldn't.
There was a problem hiding this comment.
Yes, that was my original thought as well, but since I have not worked with other PEFT methods apart from LoRA, I am not sure if I can comment if it would work or not just OOB.
Alternatively, we can expose this for LoRA (or other similar PEFT wherever you are confident) and open it for all PEFT based on feedback?
There was a problem hiding this comment.
Your proposal makes sense, let's keep the scope small for now, and later, it can be extended to other PEFT methods if there is a need.
This means, let's keep just a generic check here in BaseTuner._check_tied_modules and then in LoraModel._check_tied_modules, do the more specific check.
There was a problem hiding this comment.
I have updated my implementation by adding an internal method _maybe_add_modules_to_tie. This is the cleanest approach I could find. Let me know your thoughts!
Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
Signed-off-by: romit <romit@ibm.com>
BenjaminBossan
left a comment
There was a problem hiding this comment.
Thanks for the latest updates.
I have updated my implementation by adding an internal method _maybe_add_modules_to_tie. This is the cleanest approach I could find. Let me know your thoughts!
Overall looks good, but I have a comment concerning one of the warning messages, WDYT?
src/peft/tuners/tuners_utils.py
Outdated
| if is_embedding_to_save and tied_weight_keys: | ||
| msg = ( | ||
| "Model has `tie_word_embeddings=True` and a tied layer is part of the adapter, " | ||
| "but `ensure_weight_tying` is not set to True. " | ||
| "This can lead to complications, for example when merging the adapter " | ||
| "or converting your model to formats other than safetensors. " | ||
| "Check the discussion here: https://github.com/huggingface/peft/issues/2777" | ||
| ) | ||
| warnings.warn(msg) |
There was a problem hiding this comment.
I think this message has the potential to confuse users. ensure_weight_tying is only implemented for LoRA right now but this message can also show up for non-LoRA methods, right?
There was a problem hiding this comment.
Ahh yes, I missed this.
Fixed it now
BenjaminBossan
left a comment
There was a problem hiding this comment.
Thanks for your continued work and for your patience. I didn't quite expect the scope of this PR to grow to what it is now, but I think this will be a nice benefit for PEFT users once it's merged.
I did another review pass and added a few more comments. Please check if they make sense.
tests/test_initialization.py
Outdated
| config2 = LoraConfig(target_modules=["linear"], bias="none") | ||
| model.add_adapter("other", config2) # does not raise | ||
|
|
||
| def test_weight_tying_tied_model(self): |
There was a problem hiding this comment.
I'm rethinking how to best deal with these tests. Right now, they are tied to LoRA, but with the recent changes, other PEFT methods will also benefit (even if it's just a warning about potential issues). Therefore, I propose the following: Let's create a new test class, say TestWeightTying. Move get_lm_model and the test methods (test_weight_tying_tied_model etc.) into that class. Add _lora to the name of these test methods.
Then, let's add tests for a non-LoRA PEFT method that supports modules_to_save, e.g. LoKr. The test can be very similar to the LoRA tests, just that there is no ensure_weight_tying option and that the warning may look different. Add a comment that LoKr is just one example of many PEFT methods that could be chosen for this test.
Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
githubnemo
left a comment
There was a problem hiding this comment.
This change looks cool :)
One question, otherwise LGTM.
| with any layers that needs to be tied | ||
| """ | ||
| modules_to_save = set(getattr(peft_config, "modules_to_save", []) or []) | ||
| is_embedding_to_save = any(m in EMBEDDING_LAYER_NAMES for m in modules_to_save) |
There was a problem hiding this comment.
Should we use model.get_input_embeddings if it is available instead of guessing by name?
There was a problem hiding this comment.
@githubnemo
There was some precedence in using this here and here
There was a problem hiding this comment.
makes sense. I think in the future we shouldn't rely on the embedding names if possible (e.g., custom models wouldn't be possible). but it's ok for me to leave it as is and do this change separately.
BenjaminBossan
left a comment
There was a problem hiding this comment.
Thanks a lot @romitjain for working on this tricky topic. The PR LGTM. Would you also be interested to implement a similar functionality for LoRA targeting the tied weights?
|
Thanks for merging @BenjaminBossan yes, I can work on a similar PR for tied weights. I think now with the base work done, it should be easier to integrate. I am away this week, so I’ll come back to it next week. |
This PR solves: #2777
Essentially, the PR:
LoraConfigfor users to maintain weight tying in the modelModulesToSaveWrapperon modules returned in (2) and a reference to the original layerThis ensures that the weight tying is enabled if the user passes either
embed_tokensorlm_head.Limitations
LoraConfigmodules_to_saveparameter. This can be extended totarget_modulesas well once the approach is approved