Skip to content
25 changes: 23 additions & 2 deletions src/peft/tuners/lora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1035,6 +1035,10 @@ def _mixed_batch_forward(
# extra argument that allows mixing different adapters in the same batch at inference time.
result = self.base_layer(x, *args, **kwargs)

# Some embedding layers (e.g., Gemma3TextScaledWordEmbedding) apply scaling in their forward method.
# Since base_layer(x) already includes this scaling, we need to apply it to LoRA contributions too.
embed_scale = self._get_embed_scale()

unique_adapters = set(adapter_names)
sub_batch_indices_list = []
for adapter in unique_adapters:
Expand All @@ -1054,7 +1058,13 @@ def _mixed_batch_forward(
# layer output
sub_batch = x[sub_batch_indices_list[i]]
after_A = self._embed(sub_batch, embedding_A)
result[sub_batch_indices_list[i]] += (after_A @ embedding_B) * scaling
adapter_output = (after_A @ embedding_B) * scaling

# Apply embed_scale to match the base layer's scaling
if embed_scale is not None:
adapter_output = adapter_output * embed_scale.to(adapter_output.dtype)

result[sub_batch_indices_list[i]] += adapter_output

return result

Expand Down Expand Up @@ -1086,6 +1096,11 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
else:
result = self.base_layer(x, *args, **kwargs)
torch_result_dtype = result.dtype

# Some embedding layers (e.g., Gemma3TextScaledWordEmbedding) apply scaling in their forward method.
# Since base_layer(x) already includes this scaling, we need to apply it to LoRA contributions too.
embed_scale = self._get_embed_scale()

for active_adapter in self.active_adapters:
if active_adapter not in self.lora_embedding_A:
continue
Expand All @@ -1095,7 +1110,13 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
embedding_B = self.lora_embedding_B[active_adapter].T
scaling = self.scaling[active_adapter]
after_A = self._embed(x, embedding_A)
result = result + (after_A @ embedding_B) * scaling
adapter_output = (after_A @ embedding_B) * scaling

# Apply embed_scale to match the base layer's scaling
if embed_scale is not None:
adapter_output = adapter_output * embed_scale.to(adapter_output.dtype)

result = result + adapter_output
else:
result = self.lora_variant[active_adapter].forward(
self,
Expand Down
5 changes: 5 additions & 0 deletions src/peft/tuners/trainable_tokens/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,11 @@ def forward_adapters(self, x: torch.Tensor, active_adapters, *args, **kwargs) ->
scale_grad_by_freq=self.base_layer.scale_grad_by_freq,
sparse=self.base_layer.sparse,
)
# Some embedding layers (e.g., Gemma3TextScaledWordEmbedding) apply scaling in their forward method.
# Since we're using F.embedding directly, we need to apply this scaling manually.
embed_scale = self._get_embed_scale()
if embed_scale is not None:
result = result * embed_scale.to(result.dtype)
elif isinstance(self.base_layer, torch.nn.Linear):
# Probably a tied adapter that wraps an LM head.
result = F.linear(
Expand Down
38 changes: 38 additions & 0 deletions src/peft/tuners/tuners_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
set_additional_trainable_modules,
)
from peft.utils.peft_types import PeftType, TaskType
from peft.utils.warning import PeftWarning

from ..config import PeftConfig
from ..utils import _get_submodules
Expand Down Expand Up @@ -1201,6 +1202,43 @@ def get_base_layer(self) -> nn.Module:
base_layer = base_layer.base_layer
return base_layer

def _get_embed_scale(self):
"""
Extract embed_scale from base layer if present and valid.

Some embedding layers (e.g., Gemma3TextScaledWordEmbedding) apply scaling to embeddings in their forward
method. This method checks for the presence of an `embed_scale` attribute. If it exists, it is assumed to be
a scalar and validates its shape accordingly.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think otherwise the sentence doesn't quite make sense:

Suggested change
a scalar and validates its shape accordingly.
a scalar. Its shape is validated accordingly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cool . changed that in e4ab8b1 commit .


Returns:
torch.Tensor or None: The embed_scale tensor if found and valid, None otherwise.
"""
base_layer = self.get_base_layer()
if not hasattr(base_layer, "embed_scale"):
return None

embed_scale = base_layer.embed_scale

# Convert scalar values to tensors
if isinstance(embed_scale, (int, float)):
return torch.tensor(embed_scale, device=base_layer.weight.device, dtype=base_layer.weight.dtype)

# Validate tensor shape - must be scalar (0-d) or 1-element tensor for proper broadcasting
if isinstance(embed_scale, torch.Tensor):
if embed_scale.numel() == 1:
return embed_scale
else:
# Log warning but don't fail - this maintains backward compatibility
warnings.warn(
f"Found embed_scale attribute with shape {embed_scale.shape}, expected scalar. "
"Embedding scaling will not be applied. If this is unexpected, please open an issue at "
"https://github.com/huggingface/peft/issues",
PeftWarning,
)
return None

return None

@property
def weight(self) -> torch.Tensor:
# This is required for some transformers code, e.g. for T5, weight is accessed as:
Expand Down
55 changes: 55 additions & 0 deletions tests/test_decoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,3 +795,58 @@ def test_save_pretrained_targeting_lora_to_embedding_layer(self, save_embedding_
)
else:
assert not contains_embedding

def test_lora_embed_scale_is_applied(self):
"""Test that LoRA correctly handles embeddings with scaling (e.g., Gemma3)."""
model_id = "hf-internal-testing/tiny-random-Gemma3ForCausalLM"
with hub_online_once(model_id):
base_model = AutoModelForCausalLM.from_pretrained(model_id)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
base_model = AutoModelForCausalLM.from_pretrained(model_id)
base_model = AutoModelForCausalLM.from_pretrained(model_id).to(self.torch_device)

So that this test can also run on GPU.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resolved in eb609b8

orig_embedding = base_model.get_input_embeddings()

peft_config = LoraConfig(target_modules=["embed_tokens"], init_lora_weights=False)
peft_model = get_peft_model(base_model, peft_config)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's assign x = torch.arange(10).to(self.torch_device) here and use it as input below for better readability.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resolved in eb609b8


# sanity check: with the default embed_scale, the embedding output should be reasonably sized
peft_embedding = peft_model.base_model.model.get_input_embeddings()
max_embedding_output = peft_embedding(torch.arange(10)).abs().max(0)[0]
assert (max_embedding_output < 100.0).all()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I had another idea how to extend the test:

Suggested change
# sanity check: with the default embed_scale, the embedding output should be reasonably sized
peft_embedding = peft_model.base_model.model.get_input_embeddings()
max_embedding_output = peft_embedding(torch.arange(10)).abs().max(0)[0]
assert (max_embedding_output < 100.0).all()
peft_embedding = peft_model.base_model.model.get_input_embeddings()
embedding_output = peft_embedding(torch.arange(10))
max_embedding_output = embedding_output.abs().max(0)[0]
assert (max_embedding_output < 100.0).all()
peft_model.merge_adapter()
embedding_merged = peft_embedding(torch.arange(10))
assert torch.allclose(embedding_output, embedding_merged)
peft_model.unmerge_adapter()

The point is that if the embedding scale is not applied correctly, we would expect results to differ between a merged vs non merged output. Thus, by checking the merged output too, we can better ensure that everything works as expected.

The test for trainable tokens can also be updated to check the merged output. test_lora_embed_scale_is_applied_mixed_batch doesn't need to test this, as this assumes unmerged adapters.

Copy link
Contributor Author

@sambhavnoobcoder sambhavnoobcoder Oct 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yup , cool . implemented this in eb609b8 .


# set embed_scale to an absurdly high value, then check that the embedding output is also scaled to a high
# value
orig_embedding.embed_scale.fill_(10000.0)
max_embedding_output = peft_embedding(torch.arange(10)).abs().max(0)[0]
assert (max_embedding_output > 100.0).all()

# set embed_scale to zero, then check that the embedding output is also zero
orig_embedding.embed_scale.fill_(0)
embedding_output = peft_embedding(torch.arange(10))
assert (embedding_output == 0.0).all()

def test_lora_embed_scale_is_applied_mixed_batch(self):
"""Test that LoRA correctly handles embeddings with scaling in mixed batch mode."""
model_id = "hf-internal-testing/tiny-random-Gemma3ForCausalLM"
with hub_online_once(model_id):
base_model = AutoModelForCausalLM.from_pretrained(model_id)
orig_embedding = base_model.get_input_embeddings()

peft_config = LoraConfig(target_modules=["embed_tokens"], init_lora_weights=False)
peft_model = get_peft_model(base_model, peft_config)
peft_model.add_adapter("adapter2", peft_config)

# sanity check: with the default embed_scale, the embedding output should be reasonably sized
peft_embedding = peft_model.base_model.model.get_input_embeddings()
input_ids = torch.arange(10).unsqueeze(0).repeat(2, 1)
adapter_names = ["default", "adapter2"]
max_embedding_output = peft_embedding(input_ids, adapter_names=adapter_names).abs().max()
assert max_embedding_output < 100.0

# set embed_scale to an absurdly high value, then check that the embedding output is also scaled to a high
# value
orig_embedding.embed_scale.fill_(10000.0)
max_embedding_output = peft_embedding(input_ids, adapter_names=adapter_names).abs().max()
assert max_embedding_output > 100.0

# set embed_scale to zero, then check that the embedding output is also zero
orig_embedding.embed_scale.fill_(0)
embedding_output = peft_embedding(input_ids, adapter_names=adapter_names)
assert (embedding_output == 0.0).all()
62 changes: 62 additions & 0 deletions tests/test_trainable_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -918,3 +918,65 @@ def test_save_pretrained_auto(self, model, resize_embedding, peft_config, tmp_pa
assert contains_embedding
else:
assert not contains_embedding

def test_embed_scale_is_applied(self):
"""Test that TrainableTokens correctly handles embeddings with scaling (e.g., Gemma3)."""
model_id = "hf-internal-testing/tiny-random-Gemma3ForCausalLM"
with hub_online_once(model_id):
base_model = AutoModelForCausalLM.from_pretrained(model_id)
orig_embedding = base_model.get_input_embeddings()

peft_config = TrainableTokensConfig(target_modules=["embed_tokens"], token_indices=[0, 1, 3])
peft_model = get_peft_model(base_model, peft_config)

# sanity check: with the default embed_scale, the embedding output should be reasonably sized
peft_embedding = peft_model.base_model.model.get_input_embeddings()
max_embedding_output = peft_embedding(torch.arange(10)).abs().max(0)[0]
assert (max_embedding_output < 100.0).all()

# set embed_scale to an absurdly high value, then check that the embedding output is also scaled to a high
# value
orig_embedding.embed_scale.fill_(10000.0)
max_embedding_output = peft_embedding(torch.arange(10)).abs().max(0)[0]
assert (max_embedding_output > 100.0).all()

# set embed_scale to zero, then check that the embedding output is also zero
orig_embedding.embed_scale.fill_(0)
embedding_output = peft_embedding(torch.arange(10))
assert (embedding_output == 0.0).all()

def test_scaled_embedding_with_lora(self):
"""Test that TrainableTokens works with LoRA on scaled embeddings."""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please explain in a bit more detail what this tests? Would it not be possible to have a similar test setup to the one above here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test validates the combined scenario where both TrainableTokens AND LoRA are active simultaneously ( line 977: trainable_token_indices + target_modules both set). so on one hand in test_embed_scale_is_applied we test TrainableTokens alone on Gemma3 while intest_scaled_embedding_with_lora: TrainableTokens + LoRA together on mock model .

This ensures embed_scale is correctly applied when both PEFT methods interact on the same model.

would you prefer gemma3 instead of the mock model for this test as well ? or update the test_embed_scale_is_applied to use this combined approach instead ? or the third alternative is we keep the tests as they are now with no change ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally, we should have a test that is similar to test_embed_scale_is_applied, using a real model architecture and following similar steps. It would be great if you could rewrite the test accordingly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cool . rewritten in 883b3d3 .
Screenshot 2025-10-13 at 3 13 35 PM


class ScaledEmbedding(torch.nn.Embedding):
def __init__(self, num_embeddings, embedding_dim, embed_scale=2.0):
super().__init__(num_embeddings, embedding_dim)
self.register_buffer("embed_scale", torch.tensor(embed_scale), persistent=False)

def forward(self, input_ids):
return super().forward(input_ids) * self.embed_scale

class ModelWithScaledEmb(torch.nn.Module):
def __init__(self):
super().__init__()
self.emb = ScaledEmbedding(100, 10, embed_scale=3.0)
self.lin0 = torch.nn.Linear(10, 1)

def forward(self, x):
return self.lin0(self.emb(x))

def get_input_embeddings(self):
return self.emb

base_model = ModelWithScaledEmb()
x = torch.tensor([[0, 1, 2, 3]])

# Get base embeddings before applying PEFT
base_embeddings = base_model.emb(x)

peft_config = LoraConfig(target_modules=["lin0"], trainable_token_indices={"emb": [0, 1, 2]})
peft_model = get_peft_model(base_model, peft_config)

# Verify embed_scale detection works in combined mode
peft_embeddings = peft_model.base_model.model.emb(x)
assert torch.allclose(base_embeddings, peft_embeddings, atol=1e-6)