-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Ensure weight tying is maintained for embed_tokens and lm_head #2803
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
romitjain
wants to merge
20
commits into
huggingface:main
Choose a base branch
from
romitjain:romit/fix-lmhead-module-wrapper
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 17 commits
Commits
Show all changes
20 commits
Select commit
Hold shift + click to select a range
2ddc006
Initial working version with weight sharing
romitjain 5ba62f7
Removed print
romitjain aa67b3a
Added get tied modules to save to BaseTuner
romitjain a17d8cd
Updated linting
romitjain 1de6c5f
Merge branch 'main' of github.com:huggingface/peft into romit/fix-lmh…
romitjain 6108db1
Reversed gitignore
romitjain 0e2f966
Added test and warning
romitjain 32c393c
Style changes
romitjain bae029f
Merge branch 'main' of github.com:romitjain/peft into romit/fix-lmhea…
romitjain 7f9ce15
Updated the flow of updates
romitjain 7b80354
Added test cases
romitjain 2a1fa42
Apply suggestion from @BenjaminBossan
romitjain 68cf10c
Apply suggestion from @BenjaminBossan
romitjain 4696569
Apply suggestion from @BenjaminBossan
romitjain 43098ae
Resolving PR comments
romitjain 15f2949
Merge branch 'romit/fix-lmhead-module-wrapper' of github.com:romitjai…
romitjain acf4ce0
Resolved PR comments
romitjain c1d08c7
Update src/peft/tuners/tuners_utils.py
romitjain bc2d233
Removed redundant forward
romitjain b3e29b5
Moved core logic to LoraModel
romitjain File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -68,6 +68,7 @@ | |
from peft.tuners.lora.layer import LoraLayer | ||
from peft.utils import infer_device | ||
from peft.utils.hotswap import hotswap_adapter, prepare_model_for_compiled_hotswap | ||
from peft.utils.other import ModulesToSaveWrapper | ||
|
||
from .testing_utils import load_dataset_english_quotes, require_deterministic_for_xpu | ||
|
||
|
@@ -110,6 +111,45 @@ def forward(self, x): | |
|
||
return MyModule().eval().to(self.torch_device) | ||
|
||
def get_lm_model(self, bias=True, tie_weights=True): | ||
BenjaminBossan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# Mimicking a LM with embed_tokens and lm_head layers | ||
# to test weight tying of adapters | ||
class MyModule(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
self.embed_tokens = nn.Embedding(1000, 1000) | ||
self.linear = nn.Linear(1000, 1000, bias=bias) | ||
|
||
def forward(self, x): | ||
return | ||
romitjain marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
||
class CausalLM(nn.Module): | ||
if tie_weights: | ||
_tied_weights_keys = ["lm_head.weight"] | ||
|
||
def __init__(self): | ||
super().__init__() | ||
self.model = MyModule() | ||
self.config = {"tie_word_embeddings": tie_weights} | ||
|
||
if tie_weights: | ||
self.lm_head = nn.Linear(1000, 1000, bias=False) | ||
self.lm_head.weight = self.model.embed_tokens.weight | ||
else: | ||
self.lm_head = nn.Linear(1000, 1000, bias=bias) | ||
|
||
def forward(self, x): | ||
return | ||
|
||
def prepare_inputs_for_generation(self): | ||
return | ||
|
||
def get_input_embeddings(self): | ||
return self.model.embed_tokens | ||
|
||
return CausalLM().eval().to(self.torch_device) | ||
|
||
@pytest.fixture | ||
def data(self): | ||
return torch.rand(10, 1000).to(self.torch_device) | ||
|
@@ -1566,6 +1606,92 @@ def test_multiple_configs_with_bias_raises(self, tmp_path): | |
config2 = LoraConfig(target_modules=["linear"], bias="none") | ||
model.add_adapter("other", config2) # does not raise | ||
|
||
def test_weight_tying_tied_model(self): | ||
# If weight tying is enabled and `embed_tokens` | ||
# is passed as a `modules_to_save`, it needs to be ensured | ||
# that lm_head is tied to the adapter added to `embed_tokens` | ||
|
||
model = self.get_lm_model() | ||
embed_token_config = LoraConfig( | ||
modules_to_save=["embed_tokens"], | ||
target_modules=["linear"], | ||
ensure_weight_tying=True, | ||
) | ||
model = get_peft_model(model, embed_token_config) | ||
|
||
assert isinstance(model.base_model.model.model.embed_tokens, ModulesToSaveWrapper), ( | ||
"Embed tokens is not added in Modules to Save" | ||
) | ||
assert type(model.base_model.model.model.embed_tokens) is type(model.base_model.model.lm_head), ( | ||
"Embed tokens and LM head types are not same" | ||
) | ||
|
||
# Validating that all model parameters are same | ||
embed_np = dict(model.base_model.model.model.embed_tokens.named_parameters()) | ||
lm_head_np = dict(model.base_model.model.lm_head.named_parameters()) | ||
|
||
for k in embed_np.keys(): | ||
assert torch.allclose(embed_np[k], lm_head_np[k]) | ||
assert embed_np[k] is lm_head_np[k] | ||
|
||
def test_weight_tying_non_tied_model(self): | ||
from peft.utils.other import ModulesToSaveWrapper | ||
|
||
model = self.get_lm_model(tie_weights=False) | ||
embed_token_config = LoraConfig( | ||
modules_to_save=["embed_tokens"], | ||
target_modules=["linear"], | ||
ensure_weight_tying=True, | ||
) | ||
with pytest.warns(UserWarning, match="no tied modules were found in the model"): | ||
model = get_peft_model(model, embed_token_config) | ||
|
||
assert isinstance(model.base_model.model.model.embed_tokens, ModulesToSaveWrapper), ( | ||
"Embed tokens is not added in Modules to Save" | ||
) | ||
assert isinstance(model.base_model.model.lm_head, torch.nn.modules.linear.Linear), ( | ||
"LM head is not of type nn.linear" | ||
) | ||
|
||
def test_not_weight_tying_tied_model(self): | ||
from peft.utils.other import ModulesToSaveWrapper | ||
|
||
model = self.get_lm_model() | ||
embed_token_config = LoraConfig( | ||
modules_to_save=["embed_tokens"], | ||
target_modules=["linear"], | ||
ensure_weight_tying=False, | ||
) | ||
with pytest.warns(UserWarning, match="`ensure_weight_tying` is not set to True"): | ||
model = get_peft_model(model, embed_token_config) | ||
|
||
assert isinstance(model.base_model.model.model.embed_tokens, ModulesToSaveWrapper), ( | ||
"Embed tokens is not added in Modules to Save" | ||
) | ||
assert isinstance(model.base_model.model.lm_head, torch.nn.modules.linear.Linear), ( | ||
"LM head is not of type nn.linear" | ||
) | ||
|
||
def test_weight_tying_tied_model_no_embed(self): | ||
model = self.get_lm_model() | ||
embed_token_config = LoraConfig( | ||
target_modules=["linear"], | ||
ensure_weight_tying=True, | ||
) | ||
|
||
model = get_peft_model(model, embed_token_config) | ||
|
||
|
||
assert isinstance(model.base_model.model.model.embed_tokens, torch.nn.modules.Embedding) | ||
assert isinstance(model.base_model.model.lm_head, torch.nn.modules.linear.Linear) | ||
|
||
# Validating that all model parameters are same | ||
embed_np = dict(model.base_model.model.model.embed_tokens.named_parameters()) | ||
lm_head_np = dict(model.base_model.model.lm_head.named_parameters()) | ||
|
||
for k in embed_np.keys(): | ||
assert torch.allclose(embed_np[k], lm_head_np[k]) | ||
assert embed_np[k] is lm_head_np[k] | ||
|
||
|
||
class TestLokrInitialization: | ||
torch_device = infer_device() | ||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO, it makes more sense to just do some basic checks here in this generic method of the
BaseTuner
, and add a more LoRA-specific implementation toLoraModel
. That way, there is no need to check forpeft_config.peft_type == PeftType.LORA
.Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't follow.
The
_check_tied_modules
function is called insideinject_adaptors
which theLoraModel
does not inherit.What you suggested makes sense, I can remove Lora specific changes from this function and handle that condition outside. I don't know where to place the call
self._check_tied_modules(model, peft_config)
. Do you have any suggestions?PS: I have addressed rest of your comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should work.
LoraModel
inherits fromBaseTuner
, thus it does haveLoraModel.inject_adapter
. So if we have:and
then when calling
LoraModel.inject_adapter
, it callsBaseTuner.inject_adapter
, which in turn will callLoraModel._check_tied_modules
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it, so I should initialize
_check_tied_modules
as an abstract method and implement inLoraModel
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, I was about to say that we can have some basic checks for existence of tied weights and
modules_to_save
here and give a warning for non-LoRA; for LoRA, we would override the method and give more specific warnings. But now I wonder: Why only implement this for LoRA in the first place? It should work for other PEFT methods too, at least I don't see why it wouldn't.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that was my original thought as well, but since I have not worked with other PEFT methods apart from LoRA, I am not sure if I can comment if it would work or not just OOB.
Alternatively, we can expose this for LoRA (or other similar PEFT wherever you are confident) and open it for all PEFT based on feedback?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Your proposal makes sense, let's keep the scope small for now, and later, it can be extended to other PEFT methods if there is a need.
This means, let's keep just a generic check here in
BaseTuner._check_tied_modules
and then inLoraModel._check_tied_modules
, do the more specific check.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have updated my implementation by adding an internal method
_maybe_add_modules_to_tie
. This is the cleanest approach I could find. Let me know your thoughts!