-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Handling embeddings scaling for TrainableTokensModel #2809 #2825
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 7 commits
d6fe7d3
ab63ca0
0d216c7
e4ab8b1
2b0b616
883b3d3
236de8a
b3d8150
8d72a99
eb609b8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -795,3 +795,58 @@ def test_save_pretrained_targeting_lora_to_embedding_layer(self, save_embedding_ | |||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||
assert not contains_embedding | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
def test_lora_embed_scale_is_applied(self): | ||||||||||||||||||||||||||||
"""Test that LoRA correctly handles embeddings with scaling (e.g., Gemma3).""" | ||||||||||||||||||||||||||||
model_id = "hf-internal-testing/tiny-random-Gemma3ForCausalLM" | ||||||||||||||||||||||||||||
with hub_online_once(model_id): | ||||||||||||||||||||||||||||
base_model = AutoModelForCausalLM.from_pretrained(model_id) | ||||||||||||||||||||||||||||
orig_embedding = base_model.get_input_embeddings() | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
peft_config = LoraConfig(target_modules=["embed_tokens"], init_lora_weights=False) | ||||||||||||||||||||||||||||
peft_model = get_peft_model(base_model, peft_config) | ||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's assign There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. resolved in eb609b8 |
||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
# sanity check: with the default embed_scale, the embedding output should be reasonably sized | ||||||||||||||||||||||||||||
peft_embedding = peft_model.base_model.model.get_input_embeddings() | ||||||||||||||||||||||||||||
max_embedding_output = peft_embedding(torch.arange(10)).abs().max(0)[0] | ||||||||||||||||||||||||||||
assert (max_embedding_output < 100.0).all() | ||||||||||||||||||||||||||||
|
# sanity check: with the default embed_scale, the embedding output should be reasonably sized | |
peft_embedding = peft_model.base_model.model.get_input_embeddings() | |
max_embedding_output = peft_embedding(torch.arange(10)).abs().max(0)[0] | |
assert (max_embedding_output < 100.0).all() | |
peft_embedding = peft_model.base_model.model.get_input_embeddings() | |
embedding_output = peft_embedding(torch.arange(10)) | |
max_embedding_output = embedding_output.abs().max(0)[0] | |
assert (max_embedding_output < 100.0).all() | |
peft_model.merge_adapter() | |
embedding_merged = peft_embedding(torch.arange(10)) | |
assert torch.allclose(embedding_output, embedding_merged) | |
peft_model.unmerge_adapter() |
The point is that if the embedding scale is not applied correctly, we would expect results to differ between a merged vs non merged output. Thus, by checking the merged output too, we can better ensure that everything works as expected.
The test for trainable tokens can also be updated to check the merged output. test_lora_embed_scale_is_applied_mixed_batch
doesn't need to test this, as this assumes unmerged adapters.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yup , cool . implemented this in eb609b8 .
BenjaminBossan marked this conversation as resolved.
Show resolved
Hide resolved
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -918,3 +918,61 @@ def test_save_pretrained_auto(self, model, resize_embedding, peft_config, tmp_pa | |||||||
assert contains_embedding | ||||||||
else: | ||||||||
assert not contains_embedding | ||||||||
|
||||||||
def test_embed_scale_is_applied(self): | ||||||||
"""Test that TrainableTokens correctly handles embeddings with scaling (e.g., Gemma3).""" | ||||||||
model_id = "hf-internal-testing/tiny-random-Gemma3ForCausalLM" | ||||||||
with hub_online_once(model_id): | ||||||||
base_model = AutoModelForCausalLM.from_pretrained(model_id) | ||||||||
orig_embedding = base_model.get_input_embeddings() | ||||||||
|
||||||||
peft_config = TrainableTokensConfig(target_modules=["embed_tokens"], token_indices=[0, 1, 3]) | ||||||||
peft_model = get_peft_model(base_model, peft_config) | ||||||||
|
||||||||
# sanity check: with the default embed_scale, the embedding output should be reasonably sized | ||||||||
peft_embedding = peft_model.base_model.model.get_input_embeddings() | ||||||||
max_embedding_output = peft_embedding(torch.arange(10)).abs().max(0)[0] | ||||||||
assert (max_embedding_output < 100.0).all() | ||||||||
|
||||||||
# set embed_scale to an absurdly high value, then check that the embedding output is also scaled to a high | ||||||||
# value | ||||||||
orig_embedding.embed_scale.fill_(10000.0) | ||||||||
max_embedding_output = peft_embedding(torch.arange(10)).abs().max(0)[0] | ||||||||
assert (max_embedding_output > 100.0).all() | ||||||||
|
||||||||
# set embed_scale to zero, then check that the embedding output is also zero | ||||||||
orig_embedding.embed_scale.fill_(0) | ||||||||
embedding_output = peft_embedding(torch.arange(10)) | ||||||||
assert (embedding_output == 0.0).all() | ||||||||
|
||||||||
def test_scaled_embedding_with_lora(self): | ||||||||
""" | ||||||||
Test that TrainableTokens works with LoRA on scaled embeddings when both are active simultaneously. | ||||||||
|
||||||||
This test uses a real model (Gemma3) and verifies that both TrainableTokens and LoRA correctly apply | ||||||||
embed_scale when used together. | ||||||||
|
This test uses a real model (Gemma3) and verifies that both TrainableTokens and LoRA correctly apply | |
embed_scale when used together. |
No need to mention this is a real model, we don't explicitly mention it on other tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done @BenjaminBossan . changed docstring in b3d8150 .
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So that this test can also run on GPU.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
resolved in eb609b8