-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Open
Description
Feature request
I wanted to describe an unexpected (or unusual) behaviour:
If we pass either of the embedding layers or the LM head in modules to save for a model whose embeddings are tied (tie_word_embeddings = True), peft does the following:
- It (correctly) adds adaptors to mentioned layers (in case both layers are added, then separately), which essentially results in an untied layer (divergence from the original model architecture)
- Trains them separately, and the resultant model has different weights for both the embedding layer and the LM head. While storing the state dict, it stores different weight parameters
- But most critically, it does not update the model config to set tie_word_embeddings to be False. It leads to unexpected downstream behaviour in case the model is loaded and merged.
3.1. If we load the model again via PeftModel.from_pretrained, it works as expected
3.2. If we merge it and load again using AutoModel.from_pretrained, it ties the embedding layer and the LM head. This results in nonsensical generations.
There have been several reported issues and solutions in the past:
- With tied embeddings adapter merged to tied layers #2018
- Lm_head layer Problem in gemma2 : 2b-it #2244
- How to finetune embeddings and LM head as a single layer when they are tied? #1750
One of the solves (that I am aware of)
- Warn if using tied target module with
tie_word_embeddings
#2025 - But this just adds the warning
Your contribution
There are 2 possible solutions I see (applicable only for models with tie_word_embeddings = True
- If either of embedding layer or the lm_head is added as a module to save, add an adaptor to both the modules to respect the original model's config and tie them explicitly. (or tie the adaptor)
- Update the model config and mark tie_word_embeddings = False. This will solve the issue, but all LoRA finetunes will diverge in the model config from the base model
If we can agree on the right solution, I can raise a PR to solve this.
Metadata
Metadata
Assignees
Labels
No labels