fix: auto-untie word embeddings on merge_and_unload when both are adapted #2972
+286
−0
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Summary
This PR fixes the issue where
merge_and_unload()produces broken models when adapters are applied to bothembed_tokensandlm_headon models withtie_word_embeddings=True.Resolves #2777
Problem
When a base model has
tie_word_embeddings=True(e.g., Gemma, Llama):embed_tokensandlm_headshare the same weight tensormodules_to_saveortarget_modules)merge_and_unload()merges both layers with their respective deltastie_word_embeddings=TrueAutoModel.from_pretrained(), thelm_headweights are overwritten withembed_tokensweights due to weight tyinglm_headweights are lost, causing degraded or garbage outputSolution
This PR modifies
_unload_and_optionally_merge()inBaseTunerto:lm_head.weightbefore mergeconfig.tie_word_embeddings = Falsein all relevant config locationsThis ensures that:
Changes
src/peft/tuners/tuners_utils.py:_untie_embedding_weights()helper method_update_tie_word_embeddings_config()helper method_has_adapters_on_both_embeddings()helper method_unload_and_optionally_merge()to auto-handle tied embeddingstests/test_tie_word_embeddings_merge.py:Test Plan
Example
Before this fix:
After this fix: