Skip to content
1 change: 0 additions & 1 deletion examples/boft_controlnet/test_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

import numpy as np
import torch
import torch.utils.checkpoint
from accelerate import Accelerator
from diffusers import DDIMScheduler
from diffusers.utils import check_min_version
Expand Down
25 changes: 23 additions & 2 deletions src/peft/tuners/lora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1035,6 +1035,10 @@ def _mixed_batch_forward(
# extra argument that allows mixing different adapters in the same batch at inference time.
result = self.base_layer(x, *args, **kwargs)

# Some embedding layers (e.g., Gemma3TextScaledWordEmbedding) apply scaling in their forward method.
# Since base_layer(x) already includes this scaling, we need to apply it to LoRA contributions too.
embed_scale = self._get_embed_scale()

unique_adapters = set(adapter_names)
sub_batch_indices_list = []
for adapter in unique_adapters:
Expand All @@ -1054,7 +1058,13 @@ def _mixed_batch_forward(
# layer output
sub_batch = x[sub_batch_indices_list[i]]
after_A = self._embed(sub_batch, embedding_A)
result[sub_batch_indices_list[i]] += (after_A @ embedding_B) * scaling
adapter_output = (after_A @ embedding_B) * scaling

# Apply embed_scale to match the base layer's scaling
if embed_scale is not None:
adapter_output = adapter_output * embed_scale.to(adapter_output.dtype)

result[sub_batch_indices_list[i]] += adapter_output

return result

Expand Down Expand Up @@ -1086,6 +1096,11 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
else:
result = self.base_layer(x, *args, **kwargs)
torch_result_dtype = result.dtype

# Some embedding layers (e.g., Gemma3TextScaledWordEmbedding) apply scaling in their forward method.
# Since base_layer(x) already includes this scaling, we need to apply it to LoRA contributions too.
embed_scale = self._get_embed_scale()

for active_adapter in self.active_adapters:
if active_adapter not in self.lora_embedding_A:
continue
Expand All @@ -1095,7 +1110,13 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
embedding_B = self.lora_embedding_B[active_adapter].T
scaling = self.scaling[active_adapter]
after_A = self._embed(x, embedding_A)
result = result + (after_A @ embedding_B) * scaling
adapter_output = (after_A @ embedding_B) * scaling

# Apply embed_scale to match the base layer's scaling
if embed_scale is not None:
adapter_output = adapter_output * embed_scale.to(adapter_output.dtype)

result = result + adapter_output
else:
result = self.lora_variant[active_adapter].forward(
self,
Expand Down
5 changes: 5 additions & 0 deletions src/peft/tuners/trainable_tokens/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,11 @@ def forward_adapters(self, x: torch.Tensor, active_adapters, *args, **kwargs) ->
scale_grad_by_freq=self.base_layer.scale_grad_by_freq,
sparse=self.base_layer.sparse,
)
# Some embedding layers (e.g., Gemma3TextScaledWordEmbedding) apply scaling in their forward method.
# Since we're using F.embedding directly, we need to apply this scaling manually.
embed_scale = self._get_embed_scale()
if embed_scale is not None:
result = result * embed_scale.to(result.dtype)
elif isinstance(self.base_layer, torch.nn.Linear):
# Probably a tied adapter that wraps an LM head.
result = F.linear(
Expand Down
38 changes: 38 additions & 0 deletions src/peft/tuners/tuners_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
set_additional_trainable_modules,
)
from peft.utils.peft_types import PeftType, TaskType
from peft.utils.warning import PeftWarning

from ..config import PeftConfig
from ..utils import _get_submodules
Expand Down Expand Up @@ -1201,6 +1202,43 @@ def get_base_layer(self) -> nn.Module:
base_layer = base_layer.base_layer
return base_layer

def _get_embed_scale(self):
"""
Extract embed_scale from base layer if present and valid.

Some embedding layers (e.g., Gemma3TextScaledWordEmbedding) apply scaling to embeddings in their forward
method. This method checks for the presence of an `embed_scale` attribute. If it exists, it is assumed to be a
scalar. Its shape is validated accordingly.

Returns:
torch.Tensor or None: The embed_scale tensor if found and valid, None otherwise.
"""
base_layer = self.get_base_layer()
if not hasattr(base_layer, "embed_scale"):
return None

embed_scale = base_layer.embed_scale

# Convert scalar values to tensors
if isinstance(embed_scale, (int, float)):
return torch.tensor(embed_scale, device=base_layer.weight.device, dtype=base_layer.weight.dtype)

# Validate tensor shape - must be scalar (0-d) or 1-element tensor for proper broadcasting
if isinstance(embed_scale, torch.Tensor):
if embed_scale.numel() == 1:
return embed_scale
else:
# Log warning but don't fail - this maintains backward compatibility
warnings.warn(
f"Found embed_scale attribute with shape {embed_scale.shape}, expected scalar. "
"Embedding scaling will not be applied. If this is unexpected, please open an issue at "
"https://github.com/huggingface/peft/issues",
PeftWarning,
)
return None

return None

@property
def weight(self) -> torch.Tensor:
# This is required for some transformers code, e.g. for T5, weight is accessed as:
Expand Down
55 changes: 55 additions & 0 deletions tests/test_decoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,3 +795,58 @@ def test_save_pretrained_targeting_lora_to_embedding_layer(self, save_embedding_
)
else:
assert not contains_embedding

def test_lora_embed_scale_is_applied(self):
"""Test that LoRA correctly handles embeddings with scaling (e.g., Gemma3)."""
model_id = "hf-internal-testing/tiny-random-Gemma3ForCausalLM"
with hub_online_once(model_id):
base_model = AutoModelForCausalLM.from_pretrained(model_id)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
base_model = AutoModelForCausalLM.from_pretrained(model_id)
base_model = AutoModelForCausalLM.from_pretrained(model_id).to(self.torch_device)

So that this test can also run on GPU.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resolved in eb609b8

orig_embedding = base_model.get_input_embeddings()

peft_config = LoraConfig(target_modules=["embed_tokens"], init_lora_weights=False)
peft_model = get_peft_model(base_model, peft_config)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's assign x = torch.arange(10).to(self.torch_device) here and use it as input below for better readability.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resolved in eb609b8


# sanity check: with the default embed_scale, the embedding output should be reasonably sized
peft_embedding = peft_model.base_model.model.get_input_embeddings()
max_embedding_output = peft_embedding(torch.arange(10)).abs().max(0)[0]
assert (max_embedding_output < 100.0).all()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I had another idea how to extend the test:

Suggested change
# sanity check: with the default embed_scale, the embedding output should be reasonably sized
peft_embedding = peft_model.base_model.model.get_input_embeddings()
max_embedding_output = peft_embedding(torch.arange(10)).abs().max(0)[0]
assert (max_embedding_output < 100.0).all()
peft_embedding = peft_model.base_model.model.get_input_embeddings()
embedding_output = peft_embedding(torch.arange(10))
max_embedding_output = embedding_output.abs().max(0)[0]
assert (max_embedding_output < 100.0).all()
peft_model.merge_adapter()
embedding_merged = peft_embedding(torch.arange(10))
assert torch.allclose(embedding_output, embedding_merged)
peft_model.unmerge_adapter()

The point is that if the embedding scale is not applied correctly, we would expect results to differ between a merged vs non merged output. Thus, by checking the merged output too, we can better ensure that everything works as expected.

The test for trainable tokens can also be updated to check the merged output. test_lora_embed_scale_is_applied_mixed_batch doesn't need to test this, as this assumes unmerged adapters.

Copy link
Contributor Author

@sambhavnoobcoder sambhavnoobcoder Oct 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yup , cool . implemented this in eb609b8 .


# set embed_scale to an absurdly high value, then check that the embedding output is also scaled to a high
# value
orig_embedding.embed_scale.fill_(10000.0)
max_embedding_output = peft_embedding(torch.arange(10)).abs().max(0)[0]
assert (max_embedding_output > 100.0).all()

# set embed_scale to zero, then check that the embedding output is also zero
orig_embedding.embed_scale.fill_(0)
embedding_output = peft_embedding(torch.arange(10))
assert (embedding_output == 0.0).all()

def test_lora_embed_scale_is_applied_mixed_batch(self):
"""Test that LoRA correctly handles embeddings with scaling in mixed batch mode."""
model_id = "hf-internal-testing/tiny-random-Gemma3ForCausalLM"
with hub_online_once(model_id):
base_model = AutoModelForCausalLM.from_pretrained(model_id)
orig_embedding = base_model.get_input_embeddings()

peft_config = LoraConfig(target_modules=["embed_tokens"], init_lora_weights=False)
peft_model = get_peft_model(base_model, peft_config)
peft_model.add_adapter("adapter2", peft_config)

# sanity check: with the default embed_scale, the embedding output should be reasonably sized
peft_embedding = peft_model.base_model.model.get_input_embeddings()
input_ids = torch.arange(10).unsqueeze(0).repeat(2, 1)
adapter_names = ["default", "adapter2"]
max_embedding_output = peft_embedding(input_ids, adapter_names=adapter_names).abs().max()
assert max_embedding_output < 100.0

# set embed_scale to an absurdly high value, then check that the embedding output is also scaled to a high
# value
orig_embedding.embed_scale.fill_(10000.0)
max_embedding_output = peft_embedding(input_ids, adapter_names=adapter_names).abs().max()
assert max_embedding_output > 100.0

# set embed_scale to zero, then check that the embedding output is also zero
orig_embedding.embed_scale.fill_(0)
embedding_output = peft_embedding(input_ids, adapter_names=adapter_names)
assert (embedding_output == 0.0).all()
55 changes: 55 additions & 0 deletions tests/test_trainable_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -918,3 +918,58 @@ def test_save_pretrained_auto(self, model, resize_embedding, peft_config, tmp_pa
assert contains_embedding
else:
assert not contains_embedding

def test_embed_scale_is_applied(self):
"""Test that TrainableTokens correctly handles embeddings with scaling (e.g., Gemma3)."""
model_id = "hf-internal-testing/tiny-random-Gemma3ForCausalLM"
with hub_online_once(model_id):
base_model = AutoModelForCausalLM.from_pretrained(model_id)
orig_embedding = base_model.get_input_embeddings()

peft_config = TrainableTokensConfig(target_modules=["embed_tokens"], token_indices=[0, 1, 3])
peft_model = get_peft_model(base_model, peft_config)

# sanity check: with the default embed_scale, the embedding output should be reasonably sized
peft_embedding = peft_model.base_model.model.get_input_embeddings()
max_embedding_output = peft_embedding(torch.arange(10)).abs().max(0)[0]
assert (max_embedding_output < 100.0).all()

# set embed_scale to an absurdly high value, then check that the embedding output is also scaled to a high
# value
orig_embedding.embed_scale.fill_(10000.0)
max_embedding_output = peft_embedding(torch.arange(10)).abs().max(0)[0]
assert (max_embedding_output > 100.0).all()

# set embed_scale to zero, then check that the embedding output is also zero
orig_embedding.embed_scale.fill_(0)
embedding_output = peft_embedding(torch.arange(10))
assert (embedding_output == 0.0).all()

def test_scaled_embedding_with_lora(self):
"""
Test that TrainableTokens works with LoRA on scaled embeddings when both are active simultaneously.
"""
model_id = "hf-internal-testing/tiny-random-Gemma3ForCausalLM"
with hub_online_once(model_id):
base_model = AutoModelForCausalLM.from_pretrained(model_id)
orig_embedding = base_model.get_input_embeddings()

# Apply both TrainableTokens and LoRA to the same model
peft_config = LoraConfig(target_modules=["q_proj"], trainable_token_indices={"embed_tokens": [0, 1, 3]})
peft_model = get_peft_model(base_model, peft_config)

# sanity check: with the default embed_scale, the embedding output should be reasonably sized
peft_embedding = peft_model.base_model.model.get_input_embeddings()
max_embedding_output = peft_embedding(torch.arange(10)).abs().max(0)[0]
assert (max_embedding_output < 100.0).all()

# set embed_scale to an absurdly high value, then check that the embedding output is also scaled to a high
# value
orig_embedding.embed_scale.fill_(10000.0)
max_embedding_output = peft_embedding(torch.arange(10)).abs().max(0)[0]
assert (max_embedding_output > 100.0).all()

# set embed_scale to zero, then check that the embedding output is also zero
orig_embedding.embed_scale.fill_(0)
embedding_output = peft_embedding(torch.arange(10))
assert (embedding_output == 0.0).all()
Loading