-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Handling embeddings scaling for TrainableTokensModel #2809 #2825
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 3 commits
d6fe7d3
ab63ca0
0d216c7
e4ab8b1
2b0b616
883b3d3
236de8a
b3d8150
8d72a99
eb609b8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -795,3 +795,58 @@ def test_save_pretrained_targeting_lora_to_embedding_layer(self, save_embedding_ | |||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||
assert not contains_embedding | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
def test_lora_embed_scale_is_applied(self): | ||||||||||||||||||||||||||||
"""Test that LoRA correctly handles embeddings with scaling (e.g., Gemma3).""" | ||||||||||||||||||||||||||||
model_id = "hf-internal-testing/tiny-random-Gemma3ForCausalLM" | ||||||||||||||||||||||||||||
with hub_online_once(model_id): | ||||||||||||||||||||||||||||
base_model = AutoModelForCausalLM.from_pretrained(model_id) | ||||||||||||||||||||||||||||
|
base_model = AutoModelForCausalLM.from_pretrained(model_id) | |
base_model = AutoModelForCausalLM.from_pretrained(model_id).to(self.torch_device) |
So that this test can also run on GPU.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
resolved in eb609b8
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's assign x = torch.arange(10).to(self.torch_device)
here and use it as input below for better readability.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
resolved in eb609b8
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, I had another idea how to extend the test:
# sanity check: with the default embed_scale, the embedding output should be reasonably sized | |
peft_embedding = peft_model.base_model.model.get_input_embeddings() | |
max_embedding_output = peft_embedding(torch.arange(10)).abs().max(0)[0] | |
assert (max_embedding_output < 100.0).all() | |
peft_embedding = peft_model.base_model.model.get_input_embeddings() | |
embedding_output = peft_embedding(torch.arange(10)) | |
max_embedding_output = embedding_output.abs().max(0)[0] | |
assert (max_embedding_output < 100.0).all() | |
peft_model.merge_adapter() | |
embedding_merged = peft_embedding(torch.arange(10)) | |
assert torch.allclose(embedding_output, embedding_merged) | |
peft_model.unmerge_adapter() |
The point is that if the embedding scale is not applied correctly, we would expect results to differ between a merged vs non merged output. Thus, by checking the merged output too, we can better ensure that everything works as expected.
The test for trainable tokens can also be updated to check the merged output. test_lora_embed_scale_is_applied_mixed_batch
doesn't need to test this, as this assumes unmerged adapters.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yup , cool . implemented this in eb609b8 .
BenjaminBossan marked this conversation as resolved.
Show resolved
Hide resolved
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -918,3 +918,65 @@ def test_save_pretrained_auto(self, model, resize_embedding, peft_config, tmp_pa | |
assert contains_embedding | ||
else: | ||
assert not contains_embedding | ||
|
||
def test_embed_scale_is_applied(self): | ||
"""Test that TrainableTokens correctly handles embeddings with scaling (e.g., Gemma3).""" | ||
model_id = "hf-internal-testing/tiny-random-Gemma3ForCausalLM" | ||
with hub_online_once(model_id): | ||
base_model = AutoModelForCausalLM.from_pretrained(model_id) | ||
orig_embedding = base_model.get_input_embeddings() | ||
|
||
peft_config = TrainableTokensConfig(target_modules=["embed_tokens"], token_indices=[0, 1, 3]) | ||
peft_model = get_peft_model(base_model, peft_config) | ||
|
||
# sanity check: with the default embed_scale, the embedding output should be reasonably sized | ||
peft_embedding = peft_model.base_model.model.get_input_embeddings() | ||
max_embedding_output = peft_embedding(torch.arange(10)).abs().max(0)[0] | ||
assert (max_embedding_output < 100.0).all() | ||
|
||
# set embed_scale to an absurdly high value, then check that the embedding output is also scaled to a high | ||
# value | ||
orig_embedding.embed_scale.fill_(10000.0) | ||
max_embedding_output = peft_embedding(torch.arange(10)).abs().max(0)[0] | ||
assert (max_embedding_output > 100.0).all() | ||
|
||
# set embed_scale to zero, then check that the embedding output is also zero | ||
orig_embedding.embed_scale.fill_(0) | ||
embedding_output = peft_embedding(torch.arange(10)) | ||
assert (embedding_output == 0.0).all() | ||
|
||
def test_scaled_embedding_with_lora(self): | ||
"""Test that TrainableTokens works with LoRA on scaled embeddings.""" | ||
|
||
|
||
class ScaledEmbedding(torch.nn.Embedding): | ||
def __init__(self, num_embeddings, embedding_dim, embed_scale=2.0): | ||
super().__init__(num_embeddings, embedding_dim) | ||
self.register_buffer("embed_scale", torch.tensor(embed_scale), persistent=False) | ||
|
||
def forward(self, input_ids): | ||
return super().forward(input_ids) * self.embed_scale | ||
|
||
class ModelWithScaledEmb(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.emb = ScaledEmbedding(100, 10, embed_scale=3.0) | ||
self.lin0 = torch.nn.Linear(10, 1) | ||
|
||
def forward(self, x): | ||
return self.lin0(self.emb(x)) | ||
|
||
def get_input_embeddings(self): | ||
return self.emb | ||
|
||
base_model = ModelWithScaledEmb() | ||
x = torch.tensor([[0, 1, 2, 3]]) | ||
|
||
# Get base embeddings before applying PEFT | ||
base_embeddings = base_model.emb(x) | ||
|
||
peft_config = LoraConfig(target_modules=["lin0"], trainable_token_indices={"emb": [0, 1, 2]}) | ||
peft_model = get_peft_model(base_model, peft_config) | ||
|
||
# Verify embed_scale detection works in combined mode | ||
peft_embeddings = peft_model.base_model.model.emb(x) | ||
assert torch.allclose(base_embeddings, peft_embeddings, atol=1e-6) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think otherwise the sentence doesn't quite make sense:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cool . changed that in e4ab8b1 commit .