Skip to content
1 change: 0 additions & 1 deletion examples/boft_controlnet/test_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

import numpy as np
import torch
import torch.utils.checkpoint
from accelerate import Accelerator
from diffusers import DDIMScheduler
from diffusers.utils import check_min_version
Expand Down
25 changes: 23 additions & 2 deletions src/peft/tuners/lora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1035,6 +1035,10 @@ def _mixed_batch_forward(
# extra argument that allows mixing different adapters in the same batch at inference time.
result = self.base_layer(x, *args, **kwargs)

# Some embedding layers (e.g., Gemma3TextScaledWordEmbedding) apply scaling in their forward method.
# Since base_layer(x) already includes this scaling, we need to apply it to LoRA contributions too.
embed_scale = self._get_embed_scale()

unique_adapters = set(adapter_names)
sub_batch_indices_list = []
for adapter in unique_adapters:
Expand All @@ -1054,7 +1058,13 @@ def _mixed_batch_forward(
# layer output
sub_batch = x[sub_batch_indices_list[i]]
after_A = self._embed(sub_batch, embedding_A)
result[sub_batch_indices_list[i]] += (after_A @ embedding_B) * scaling
adapter_output = (after_A @ embedding_B) * scaling

# Apply embed_scale to match the base layer's scaling
if embed_scale is not None:
adapter_output = adapter_output * embed_scale.to(adapter_output.dtype)

result[sub_batch_indices_list[i]] += adapter_output

return result

Expand Down Expand Up @@ -1086,6 +1096,11 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
else:
result = self.base_layer(x, *args, **kwargs)
torch_result_dtype = result.dtype

# Some embedding layers (e.g., Gemma3TextScaledWordEmbedding) apply scaling in their forward method.
# Since base_layer(x) already includes this scaling, we need to apply it to LoRA contributions too.
embed_scale = self._get_embed_scale()

for active_adapter in self.active_adapters:
if active_adapter not in self.lora_embedding_A:
continue
Expand All @@ -1095,7 +1110,13 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
embedding_B = self.lora_embedding_B[active_adapter].T
scaling = self.scaling[active_adapter]
after_A = self._embed(x, embedding_A)
result = result + (after_A @ embedding_B) * scaling
adapter_output = (after_A @ embedding_B) * scaling

# Apply embed_scale to match the base layer's scaling
if embed_scale is not None:
adapter_output = adapter_output * embed_scale.to(adapter_output.dtype)

result = result + adapter_output
else:
result = self.lora_variant[active_adapter].forward(
self,
Expand Down
5 changes: 5 additions & 0 deletions src/peft/tuners/trainable_tokens/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,11 @@ def forward_adapters(self, x: torch.Tensor, active_adapters, *args, **kwargs) ->
scale_grad_by_freq=self.base_layer.scale_grad_by_freq,
sparse=self.base_layer.sparse,
)
# Some embedding layers (e.g., Gemma3TextScaledWordEmbedding) apply scaling in their forward method.
# Since we're using F.embedding directly, we need to apply this scaling manually.
embed_scale = self._get_embed_scale()
if embed_scale is not None:
result = result * embed_scale.to(result.dtype)
elif isinstance(self.base_layer, torch.nn.Linear):
# Probably a tied adapter that wraps an LM head.
result = F.linear(
Expand Down
38 changes: 38 additions & 0 deletions src/peft/tuners/tuners_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
set_additional_trainable_modules,
)
from peft.utils.peft_types import PeftType, TaskType
from peft.utils.warning import PeftWarning

from ..config import PeftConfig
from ..utils import _get_submodules
Expand Down Expand Up @@ -1201,6 +1202,43 @@ def get_base_layer(self) -> nn.Module:
base_layer = base_layer.base_layer
return base_layer

def _get_embed_scale(self):
"""
Extract embed_scale from base layer if present and valid.

Some embedding layers (e.g., Gemma3TextScaledWordEmbedding) apply scaling to embeddings in their forward
method. This method checks for the presence of an `embed_scale` attribute. If it exists, it is assumed to be a
scalar. Its shape is validated accordingly.

Returns:
torch.Tensor or None: The embed_scale tensor if found and valid, None otherwise.
"""
base_layer = self.get_base_layer()
if not hasattr(base_layer, "embed_scale"):
return None

embed_scale = base_layer.embed_scale

# Convert scalar values to tensors
if isinstance(embed_scale, (int, float)):
return torch.tensor(embed_scale, device=base_layer.weight.device, dtype=base_layer.weight.dtype)

# Validate tensor shape - must be scalar (0-d) or 1-element tensor for proper broadcasting
if isinstance(embed_scale, torch.Tensor):
if embed_scale.numel() == 1:
return embed_scale
else:
# Log warning but don't fail - this maintains backward compatibility
warnings.warn(
f"Found embed_scale attribute with shape {embed_scale.shape}, expected scalar. "
"Embedding scaling will not be applied. If this is unexpected, please open an issue at "
"https://github.com/huggingface/peft/issues",
PeftWarning,
)
return None

return None

@property
def weight(self) -> torch.Tensor:
# This is required for some transformers code, e.g. for T5, weight is accessed as:
Expand Down
60 changes: 60 additions & 0 deletions tests/test_decoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,3 +795,63 @@ def test_save_pretrained_targeting_lora_to_embedding_layer(self, save_embedding_
)
else:
assert not contains_embedding

def test_lora_embed_scale_is_applied(self):
"""Test that LoRA correctly handles embeddings with scaling (e.g., Gemma3)."""
model_id = "hf-internal-testing/tiny-random-Gemma3ForCausalLM"
with hub_online_once(model_id):
base_model = AutoModelForCausalLM.from_pretrained(model_id).to(self.torch_device)
orig_embedding = base_model.get_input_embeddings()

peft_config = LoraConfig(target_modules=["embed_tokens"], init_lora_weights=False)
peft_model = get_peft_model(base_model, peft_config)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's assign x = torch.arange(10).to(self.torch_device) here and use it as input below for better readability.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resolved in eb609b8


x = torch.arange(10).to(self.torch_device)
peft_embedding = peft_model.base_model.model.get_input_embeddings()
embedding_output = peft_embedding(x)
max_embedding_output = embedding_output.abs().max(0)[0]
assert (max_embedding_output < 100.0).all()
peft_model.merge_adapter()
embedding_merged = peft_embedding(x)
assert torch.allclose(embedding_output, embedding_merged)
peft_model.unmerge_adapter()

# set embed_scale to an absurdly high value, then check that the embedding output is also scaled to a high
# value
orig_embedding.embed_scale.fill_(10000.0)
max_embedding_output = peft_embedding(x).abs().max(0)[0]
assert (max_embedding_output > 100.0).all()

# set embed_scale to zero, then check that the embedding output is also zero
orig_embedding.embed_scale.fill_(0)
embedding_output = peft_embedding(x)
assert (embedding_output == 0.0).all()

def test_lora_embed_scale_is_applied_mixed_batch(self):
"""Test that LoRA correctly handles embeddings with scaling in mixed batch mode."""
model_id = "hf-internal-testing/tiny-random-Gemma3ForCausalLM"
with hub_online_once(model_id):
base_model = AutoModelForCausalLM.from_pretrained(model_id)
orig_embedding = base_model.get_input_embeddings()

peft_config = LoraConfig(target_modules=["embed_tokens"], init_lora_weights=False)
peft_model = get_peft_model(base_model, peft_config)
peft_model.add_adapter("adapter2", peft_config)

# sanity check: with the default embed_scale, the embedding output should be reasonably sized
peft_embedding = peft_model.base_model.model.get_input_embeddings()
input_ids = torch.arange(10).unsqueeze(0).repeat(2, 1)
adapter_names = ["default", "adapter2"]
max_embedding_output = peft_embedding(input_ids, adapter_names=adapter_names).abs().max()
assert max_embedding_output < 100.0

# set embed_scale to an absurdly high value, then check that the embedding output is also scaled to a high
# value
orig_embedding.embed_scale.fill_(10000.0)
max_embedding_output = peft_embedding(input_ids, adapter_names=adapter_names).abs().max()
assert max_embedding_output > 100.0

# set embed_scale to zero, then check that the embedding output is also zero
orig_embedding.embed_scale.fill_(0)
embedding_output = peft_embedding(input_ids, adapter_names=adapter_names)
assert (embedding_output == 0.0).all()
60 changes: 60 additions & 0 deletions tests/test_trainable_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -918,3 +918,63 @@ def test_save_pretrained_auto(self, model, resize_embedding, peft_config, tmp_pa
assert contains_embedding
else:
assert not contains_embedding

def test_embed_scale_is_applied(self):
"""Test that TrainableTokens correctly handles embeddings with scaling (e.g., Gemma3)."""
model_id = "hf-internal-testing/tiny-random-Gemma3ForCausalLM"
with hub_online_once(model_id):
base_model = AutoModelForCausalLM.from_pretrained(model_id)
orig_embedding = base_model.get_input_embeddings()

peft_config = TrainableTokensConfig(target_modules=["embed_tokens"], token_indices=[0, 1, 3])
peft_model = get_peft_model(base_model, peft_config)

# sanity check: with the default embed_scale, the embedding output should be reasonably sized
peft_embedding = peft_model.base_model.model.get_input_embeddings()
max_embedding_output = peft_embedding(torch.arange(10)).abs().max(0)[0]
assert (max_embedding_output < 100.0).all()

# set embed_scale to an absurdly high value, then check that the embedding output is also scaled to a high
# value
orig_embedding.embed_scale.fill_(10000.0)
max_embedding_output = peft_embedding(torch.arange(10)).abs().max(0)[0]
assert (max_embedding_output > 100.0).all()

# set embed_scale to zero, then check that the embedding output is also zero
orig_embedding.embed_scale.fill_(0)
embedding_output = peft_embedding(torch.arange(10))
assert (embedding_output == 0.0).all()

def test_scaled_embedding_with_lora(self):
"""
Test that TrainableTokens works with LoRA on scaled embeddings when both are active simultaneously.
"""
model_id = "hf-internal-testing/tiny-random-Gemma3ForCausalLM"
with hub_online_once(model_id):
base_model = AutoModelForCausalLM.from_pretrained(model_id)
orig_embedding = base_model.get_input_embeddings()

# Apply both TrainableTokens and LoRA to the same model
peft_config = LoraConfig(target_modules=["q_proj"], trainable_token_indices={"embed_tokens": [0, 1, 3]})
peft_model = get_peft_model(base_model, peft_config)

x = torch.arange(10)
peft_embedding = peft_model.base_model.model.get_input_embeddings()
embedding_output = peft_embedding(x)
max_embedding_output = embedding_output.abs().max(0)[0]
assert (max_embedding_output < 100.0).all()
peft_model.merge_adapter()
embedding_merged = peft_embedding(x)
assert torch.allclose(embedding_output, embedding_merged)
peft_model.unmerge_adapter()

# set embed_scale to an absurdly high value, then check that the embedding output is also scaled to a high
# value
orig_embedding.embed_scale.fill_(10000.0)
max_embedding_output = peft_embedding(x).abs().max(0)[0]
assert (max_embedding_output > 100.0).all()

# set embed_scale to zero, then check that the embedding output is also zero
orig_embedding.embed_scale.fill_(0)
embedding_output = peft_embedding(x)
assert (embedding_output == 0.0).all()