-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Handling embeddings scaling for TrainableTokensModel #2809 #2825
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Handling embeddings scaling for TrainableTokensModel #2809 #2825
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for taking this issue and providing a PR to fix it. I especially like the detailed descriptions of the steps you took and how you checked the results.
Regarding the implementation, it already looks quite good, I only added a few small comments. Regarding the testing, I proposed a different approach, please check if it makes sense. Furthermore, please add tests for LoRA as well, not only for trainable tokens. These tests can be added to test_decoder_models.py
(only for LoRA and Gemma3, no need to parametrize the test). There should be one test for the normal forward and one for the mixed batch forward (good job spotting that it needs to be updated too).
Searching through the PEFT code base, I think XLoraEmbeddingLayer
and DoraEmbeddingLayer
need to be updated too, but feel free to skip those for this PR. Especially for DoRA, I'm not even quite sure how (or if) the scale should be applied when calculating the weight norm.
src/peft/tuners/lora/layer.py
Outdated
Extract embed_scale from base layer if present and valid. | ||
Some embedding layers (e.g., Gemma3TextScaledWordEmbedding) apply scaling to embeddings in their forward | ||
method. This method checks for the presence of an `embed_scale` attribute and validates its shape. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's mention that if it exists, it is assumed to be a scalar. From my search through the transformers code base, this is always the case right now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
resolved in 0d216c7
src/peft/tuners/lora/layer.py
Outdated
|
||
return DoraEmbeddingVariant() | ||
|
||
def _get_embed_scale(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's factor this function out so that it can be reused between the two PEFT methods. It should be fine to put it on the BaseTunerLayer
class. This also eliminates the small inconsistency between the two implementations you have.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
resolved in 0d216c7
f"Found embed_scale attribute with shape {embed_scale.shape}, expected scalar. " | ||
"Embedding scaling will not be applied. If this is unexpected, please open an issue at " | ||
"https://github.com/huggingface/peft/issues", | ||
UserWarning, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For LoRA, you use PeftWarning
. I think that's the better choice.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
resolved in 0d216c7
tests/test_trainable_tokens.py
Outdated
x = torch.tensor([[0, 1, 2, 3]]) | ||
|
||
# Get outputs from base model and peft model | ||
base_output = base_model(x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The base_output
needs to be determined before calling get_peft_model
, since this function will modify the base model in-place. In this particular case, it doesn't matter, but as a general precaution, it's better to do it correctly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
resolved in 0d216c7
tests/test_trainable_tokens.py
Outdated
peft_embeddings = peft_model.model.emb(x) | ||
|
||
# Check that both apply the same scaling factor | ||
assert hasattr(peft_model.model.emb, "_get_embed_scale") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can be removed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
resolved in 0d216c7
tests/test_trainable_tokens.py
Outdated
else: | ||
assert not contains_embedding | ||
|
||
def test_scaled_embedding_support(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for including tests for the correct application of embed_scale
. However, I think the test is not focused enough to test its goal. To prove it, you can change _get_embed_scale
to return a fixed scalar of 1.0 and the test would still pass, except for the torch.allclose(embed_scale, torch.tensor(3.0))
assert, which by itself doesn't check if the scale is correctly applied.
Let me propose a test that I think will be more focused. It uses a real (but small) transformers model checks the outputs of the embedding layer. By carefully varying the embed_scale
, we can ensure that it's being applied. Let me know what you think.
def test_embed_scale_is_applied(self):
"""Test that TrainableTokens correctly handles embeddings with scaling (e.g., Gemma3)."""
model_id = "hf-internal-testing/tiny-random-Gemma3ForCausalLM"
with hub_online_once(model_id):
base_model = AutoModelForCausalLM.from_pretrained(model_id)
orig_embedding = base_model.get_input_embeddings()
peft_config = TrainableTokensConfig(target_modules=["embed_tokens"], token_indices=[0, 1, 3])
peft_model = get_peft_model(base_model, peft_config)
# sanity check: with the default embed_scale, the embedding output should be reasonably sized
peft_embedding = peft_model.base_model.model.get_input_embeddings()
max_embedding_output = peft_embedding(torch.arange(10)).abs().max(0)[0]
assert (max_embedding_output < 100.0).all()
# set embed_scale to an absurdly high value, then check that the embedding output is also scaled to a high
# value
orig_embedding.embed_scale.fill_(10000.0)
max_embedding_output = peft_embedding(torch.arange(10)).abs().max(0)[0]
assert (max_embedding_output > 100.0).all()
# set embed_scale to zero, then check that the embedding output is also zero
orig_embedding.embed_scale.fill_(0)
embedding_output = peft_embedding(torch.arange(10))
assert (embedding_output == 0.0).all()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
resolved in 0d216c7 . implemented this with a mini version of gemma3 for more realistic testing . all tc pass .
Hi @BenjaminBossan . i went through all the changes you suggested and implemented them as well . refactored some existing tests , wrote some new ones , tested them as well , all pass successfully . About the changes to XLoraEmbeddingLayer and DoraEmbeddingLayer i had come across this issue once when i was planning for this PR , and upon your suggestion , i had a look into it again . i may have some ideas about how both of those could be updated , less so about dora , but i suppose i will create separate issues for them both , tackling them one at a time , and if you agree we can move forward with this PR keeping it independent of the two . Kindly do have a look at my present changes and tell me if i need to make any more amends , i'll make those as well asap . |
src/peft/tuners/tuners_utils.py
Outdated
Some embedding layers (e.g., Gemma3TextScaledWordEmbedding) apply scaling to embeddings in their forward | ||
method. This method checks for the presence of an `embed_scale` attribute. If it exists, it is assumed to be | ||
a scalar and validates its shape accordingly. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think otherwise the sentence doesn't quite make sense:
a scalar and validates its shape accordingly. | |
a scalar. Its shape is validated accordingly. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cool . changed that in e4ab8b1 commit .
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the update, we're almost done. Just some small comments, please check. Also, don't forget to run make style
.
About the changes to XLoraEmbeddingLayer and DoraEmbeddingLayer i had come across this issue once when i was planning for this PR , and upon your suggestion , i had a look into it again . i may have some ideas about how both of those could be updated , less so about dora , but i suppose i will create separate issues for them both , tackling them one at a time , and if you agree we can move forward with this PR keeping it independent of the two .
Yes, having those in a separate PR works for me.
tests/test_trainable_tokens.py
Outdated
assert (embedding_output == 0.0).all() | ||
|
||
def test_scaled_embedding_with_lora(self): | ||
"""Test that TrainableTokens works with LoRA on scaled embeddings.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please explain in a bit more detail what this tests? Would it not be possible to have a similar test setup to the one above here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test validates the combined scenario where both TrainableTokens AND LoRA are active simultaneously ( line 977: trainable_token_indices
+ target_modules
both set). so on one hand in test_embed_scale_is_applied
we test TrainableTokens alone on Gemma3 while intest_scaled_embedding_with_lora
: TrainableTokens + LoRA together on mock model .
This ensures embed_scale is correctly applied when both PEFT methods interact on the same model.
would you prefer gemma3 instead of the mock model for this test as well ? or update the test_embed_scale_is_applied
to use this combined approach instead ? or the third alternative is we keep the tests as they are now with no change ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ideally, we should have a test that is similar to test_embed_scale_is_applied
, using a real model architecture and following similar steps. It would be great if you could rewrite the test accordingly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cool . rewritten in 883b3d3 .
@BenjaminBossan I’ve gone through the comments and applied the requested fixes. We can discuss and decide how to handle the I also ran Lastly, I’ll raise a separate issue for |
@BenjaminBossan made that test change in 883b3d3 . ran |
tests/test_trainable_tokens.py
Outdated
This test uses a real model (Gemma3) and verifies that both TrainableTokens and LoRA correctly apply | ||
embed_scale when used together. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test uses a real model (Gemma3) and verifies that both TrainableTokens and LoRA correctly apply | |
embed_scale when used together. |
No need to mention this is a real model, we don't explicitly mention it on other tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done @BenjaminBossan . changed docstring in b3d8150 .
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did a final pass and found a few smaller changes that would improve the tests. Could you please take a look? The rest looks ready to go.
tests/test_decoder_models.py
Outdated
# sanity check: with the default embed_scale, the embedding output should be reasonably sized | ||
peft_embedding = peft_model.base_model.model.get_input_embeddings() | ||
max_embedding_output = peft_embedding(torch.arange(10)).abs().max(0)[0] | ||
assert (max_embedding_output < 100.0).all() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, I had another idea how to extend the test:
# sanity check: with the default embed_scale, the embedding output should be reasonably sized | |
peft_embedding = peft_model.base_model.model.get_input_embeddings() | |
max_embedding_output = peft_embedding(torch.arange(10)).abs().max(0)[0] | |
assert (max_embedding_output < 100.0).all() | |
peft_embedding = peft_model.base_model.model.get_input_embeddings() | |
embedding_output = peft_embedding(torch.arange(10)) | |
max_embedding_output = embedding_output.abs().max(0)[0] | |
assert (max_embedding_output < 100.0).all() | |
peft_model.merge_adapter() | |
embedding_merged = peft_embedding(torch.arange(10)) | |
assert torch.allclose(embedding_output, embedding_merged) | |
peft_model.unmerge_adapter() |
The point is that if the embedding scale is not applied correctly, we would expect results to differ between a merged vs non merged output. Thus, by checking the merged output too, we can better ensure that everything works as expected.
The test for trainable tokens can also be updated to check the merged output. test_lora_embed_scale_is_applied_mixed_batch
doesn't need to test this, as this assumes unmerged adapters.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yup , cool . implemented this in eb609b8 .
orig_embedding = base_model.get_input_embeddings() | ||
|
||
peft_config = LoraConfig(target_modules=["embed_tokens"], init_lora_weights=False) | ||
peft_model = get_peft_model(base_model, peft_config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's assign x = torch.arange(10).to(self.torch_device)
here and use it as input below for better readability.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
resolved in eb609b8
tests/test_decoder_models.py
Outdated
"""Test that LoRA correctly handles embeddings with scaling (e.g., Gemma3).""" | ||
model_id = "hf-internal-testing/tiny-random-Gemma3ForCausalLM" | ||
with hub_online_once(model_id): | ||
base_model = AutoModelForCausalLM.from_pretrained(model_id) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
base_model = AutoModelForCausalLM.from_pretrained(model_id) | |
base_model = AutoModelForCausalLM.from_pretrained(model_id).to(self.torch_device) |
So that this test can also run on GPU.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
resolved in eb609b8
Great . I had a look at all of the recommended changes , implemented and pushed them in the PR as well . i think this is ready to go now . |
Fix: Add support for scaled embeddings (Gemma3TextScaledWordEmbedding) in TrainableTokens and LoRA
Fixes #2809
Original Issue
TrainableTokensModel and LoRA don't handle embedding scaling correctly for models like Gemma3 that use scaled embeddings.
The Problem:
Gemma3TextScaledWordEmbedding
which extendsnn.Embedding
and applies scaling (embed_scale
) in its forward methodF.embedding()
directly, skipping the scaling stepProblem Identification
Root Cause Analysis
isinstance(layer, nn.Embedding)
check which passes forGemma3TextScaledWordEmbedding
(it's a subclass)F.embedding()
directly instead ofbase_layer.forward()
, bypassing Gemma3's custom scaling logicbase_layer(x)
which correctly applies Gemma3's scaling to base embeddings(A @ B) * scaling
are NOT multiplied byembed_scale
Implementation Details
Changes Made
1. TrainableTokensLayer (
src/peft/tuners/trainable_tokens/layer.py
)Added helper method
_get_embed_scale()
:embed_scale
attribute on base layerNone
if not found or invalidModified
forward_adapters()
method:F.embedding()
, checks forembed_scale
result = result * embed_scale.to(result.dtype)
embed_scale
is detected and valid2. LoRA Embedding (
src/peft/tuners/lora/layer.py
)Added helper method
_get_embed_scale()
:Modified
forward()
method:embed_scale
once before processing adaptersadapter_output = adapter_output * embed_scale.to(adapter_output.dtype)
Modified
_mixed_batch_forward()
method:3. Tests (
tests/test_trainable_tokens.py
)Added
test_scaled_embedding_support()
:ScaledEmbedding
class mimicking Gemma3's behaviorAdded
test_scaled_embedding_with_lora()
:embed_scale
detection works in combined modeReproducer
The following reproducer demonstrates the fix working correctly with Gemma3-like scaled embeddings:
Test Results
New Tests Added
test_scaled_embedding_support
- Tests TrainableTokens with scaled embeddings ✅test_scaled_embedding_with_lora
- Tests combined LoRA + TrainableTokens ✅Regression Tests
Test Commands
Screenshots
Checklist
_get_embed_scale()
helper method to TrainableTokensLayer_get_embed_scale()
helper method to LoRA Embeddingembed_scale
works unchangedRelated Issues