Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion examples/boft_controlnet/test_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

import numpy as np
import torch
import torch.utils.checkpoint
from accelerate import Accelerator
from diffusers import DDIMScheduler
from diffusers.utils import check_min_version
Expand Down
25 changes: 23 additions & 2 deletions src/peft/tuners/lora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1035,6 +1035,10 @@ def _mixed_batch_forward(
# extra argument that allows mixing different adapters in the same batch at inference time.
result = self.base_layer(x, *args, **kwargs)

# Some embedding layers (e.g., Gemma3TextScaledWordEmbedding) apply scaling in their forward method.
# Since base_layer(x) already includes this scaling, we need to apply it to LoRA contributions too.
embed_scale = self._get_embed_scale()

unique_adapters = set(adapter_names)
sub_batch_indices_list = []
for adapter in unique_adapters:
Expand All @@ -1054,7 +1058,13 @@ def _mixed_batch_forward(
# layer output
sub_batch = x[sub_batch_indices_list[i]]
after_A = self._embed(sub_batch, embedding_A)
result[sub_batch_indices_list[i]] += (after_A @ embedding_B) * scaling
adapter_output = (after_A @ embedding_B) * scaling

# Apply embed_scale to match the base layer's scaling
if embed_scale is not None:
adapter_output = adapter_output * embed_scale.to(adapter_output.dtype)

result[sub_batch_indices_list[i]] += adapter_output

return result

Expand Down Expand Up @@ -1086,6 +1096,11 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
else:
result = self.base_layer(x, *args, **kwargs)
torch_result_dtype = result.dtype

# Some embedding layers (e.g., Gemma3TextScaledWordEmbedding) apply scaling in their forward method.
# Since base_layer(x) already includes this scaling, we need to apply it to LoRA contributions too.
embed_scale = self._get_embed_scale()

for active_adapter in self.active_adapters:
if active_adapter not in self.lora_embedding_A:
continue
Expand All @@ -1095,7 +1110,13 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
embedding_B = self.lora_embedding_B[active_adapter].T
scaling = self.scaling[active_adapter]
after_A = self._embed(x, embedding_A)
result = result + (after_A @ embedding_B) * scaling
adapter_output = (after_A @ embedding_B) * scaling

# Apply embed_scale to match the base layer's scaling
if embed_scale is not None:
adapter_output = adapter_output * embed_scale.to(adapter_output.dtype)

result = result + adapter_output
else:
result = self.lora_variant[active_adapter].forward(
self,
Expand Down
5 changes: 5 additions & 0 deletions src/peft/tuners/trainable_tokens/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,11 @@ def forward_adapters(self, x: torch.Tensor, active_adapters, *args, **kwargs) ->
scale_grad_by_freq=self.base_layer.scale_grad_by_freq,
sparse=self.base_layer.sparse,
)
# Some embedding layers (e.g., Gemma3TextScaledWordEmbedding) apply scaling in their forward method.
# Since we're using F.embedding directly, we need to apply this scaling manually.
embed_scale = self._get_embed_scale()
if embed_scale is not None:
result = result * embed_scale.to(result.dtype)
elif isinstance(self.base_layer, torch.nn.Linear):
# Probably a tied adapter that wraps an LM head.
result = F.linear(
Expand Down
38 changes: 38 additions & 0 deletions src/peft/tuners/tuners_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
set_additional_trainable_modules,
)
from peft.utils.peft_types import PeftType, TaskType
from peft.utils.warning import PeftWarning

from ..config import PeftConfig
from ..utils import _get_submodules
Expand Down Expand Up @@ -1201,6 +1202,43 @@ def get_base_layer(self) -> nn.Module:
base_layer = base_layer.base_layer
return base_layer

def _get_embed_scale(self):
"""
Extract embed_scale from base layer if present and valid.

Some embedding layers (e.g., Gemma3TextScaledWordEmbedding) apply scaling to embeddings in their forward
method. This method checks for the presence of an `embed_scale` attribute. If it exists, it is assumed to be
a scalar. Its shape is validated accordingly.

Returns:
torch.Tensor or None: The embed_scale tensor if found and valid, None otherwise.
"""
base_layer = self.get_base_layer()
if not hasattr(base_layer, "embed_scale"):
return None

embed_scale = base_layer.embed_scale

# Convert scalar values to tensors
if isinstance(embed_scale, (int, float)):
return torch.tensor(embed_scale, device=base_layer.weight.device, dtype=base_layer.weight.dtype)

# Validate tensor shape - must be scalar (0-d) or 1-element tensor for proper broadcasting
if isinstance(embed_scale, torch.Tensor):
if embed_scale.numel() == 1:
return embed_scale
else:
# Log warning but don't fail - this maintains backward compatibility
warnings.warn(
f"Found embed_scale attribute with shape {embed_scale.shape}, expected scalar. "
"Embedding scaling will not be applied. If this is unexpected, please open an issue at "
"https://github.com/huggingface/peft/issues",
PeftWarning,
)
return None

return None

@property
def weight(self) -> torch.Tensor:
# This is required for some transformers code, e.g. for T5, weight is accessed as:
Expand Down
13 changes: 12 additions & 1 deletion src/peft/tuners/xlora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,10 @@ def forward(self, x: Tensor, *args: Any, scalings: Optional[Tensor] = None, **kw

result = self.target.base_layer(x, *args, **kwargs)

# Some embedding layers (e.g., Gemma3TextScaledWordEmbedding) apply scaling in their forward method.
# Since base_layer(x) already includes this scaling, we need to apply it to X-LoRA contributions too.
embed_scale = self.target._get_embed_scale()

# Ignore if disabled. We want to make sure this is always run.
if not self.target.merged:
for adapter_n, active_adapter in enumerate(self.target.active_adapters):
Expand All @@ -171,7 +175,14 @@ def forward(self, x: Tensor, *args: Any, scalings: Optional[Tensor] = None, **kw
else:
after_A_mod = after_A
scaling_weight = 1
result += (after_A_mod @ embedding_B) * scaling * scaling_weight

adapter_output = (after_A_mod @ embedding_B) * scaling * scaling_weight

# Apply embed_scale to match the base layer's scaling
if embed_scale is not None:
adapter_output = adapter_output * embed_scale.to(adapter_output.dtype)

result += adapter_output

return result

Expand Down
55 changes: 55 additions & 0 deletions tests/test_decoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,3 +795,58 @@ def test_save_pretrained_targeting_lora_to_embedding_layer(self, save_embedding_
)
else:
assert not contains_embedding

def test_lora_embed_scale_is_applied(self):
"""Test that LoRA correctly handles embeddings with scaling (e.g., Gemma3)."""
model_id = "hf-internal-testing/tiny-random-Gemma3ForCausalLM"
with hub_online_once(model_id):
base_model = AutoModelForCausalLM.from_pretrained(model_id)
orig_embedding = base_model.get_input_embeddings()

peft_config = LoraConfig(target_modules=["embed_tokens"], init_lora_weights=False)
peft_model = get_peft_model(base_model, peft_config)

# sanity check: with the default embed_scale, the embedding output should be reasonably sized
peft_embedding = peft_model.base_model.model.get_input_embeddings()
max_embedding_output = peft_embedding(torch.arange(10)).abs().max(0)[0]
assert (max_embedding_output < 100.0).all()

# set embed_scale to an absurdly high value, then check that the embedding output is also scaled to a high
# value
orig_embedding.embed_scale.fill_(10000.0)
max_embedding_output = peft_embedding(torch.arange(10)).abs().max(0)[0]
assert (max_embedding_output > 100.0).all()

# set embed_scale to zero, then check that the embedding output is also zero
orig_embedding.embed_scale.fill_(0)
embedding_output = peft_embedding(torch.arange(10))
assert (embedding_output == 0.0).all()

def test_lora_embed_scale_is_applied_mixed_batch(self):
"""Test that LoRA correctly handles embeddings with scaling in mixed batch mode."""
model_id = "hf-internal-testing/tiny-random-Gemma3ForCausalLM"
with hub_online_once(model_id):
base_model = AutoModelForCausalLM.from_pretrained(model_id)
orig_embedding = base_model.get_input_embeddings()

peft_config = LoraConfig(target_modules=["embed_tokens"], init_lora_weights=False)
peft_model = get_peft_model(base_model, peft_config)
peft_model.add_adapter("adapter2", peft_config)

# sanity check: with the default embed_scale, the embedding output should be reasonably sized
peft_embedding = peft_model.base_model.model.get_input_embeddings()
input_ids = torch.arange(10).unsqueeze(0).repeat(2, 1)
adapter_names = ["default", "adapter2"]
max_embedding_output = peft_embedding(input_ids, adapter_names=adapter_names).abs().max()
assert max_embedding_output < 100.0

# set embed_scale to an absurdly high value, then check that the embedding output is also scaled to a high
# value
orig_embedding.embed_scale.fill_(10000.0)
max_embedding_output = peft_embedding(input_ids, adapter_names=adapter_names).abs().max()
assert max_embedding_output > 100.0

# set embed_scale to zero, then check that the embedding output is also zero
orig_embedding.embed_scale.fill_(0)
embedding_output = peft_embedding(input_ids, adapter_names=adapter_names)
assert (embedding_output == 0.0).all()
53 changes: 53 additions & 0 deletions tests/test_trainable_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -918,3 +918,56 @@ def test_save_pretrained_auto(self, model, resize_embedding, peft_config, tmp_pa
assert contains_embedding
else:
assert not contains_embedding

def test_embed_scale_is_applied(self):
"""Test that TrainableTokens correctly handles embeddings with scaling (e.g., Gemma3)."""
model_id = "hf-internal-testing/tiny-random-Gemma3ForCausalLM"
with hub_online_once(model_id):
base_model = AutoModelForCausalLM.from_pretrained(model_id)
orig_embedding = base_model.get_input_embeddings()

peft_config = TrainableTokensConfig(target_modules=["embed_tokens"], token_indices=[0, 1, 3])
peft_model = get_peft_model(base_model, peft_config)

# sanity check: with the default embed_scale, the embedding output should be reasonably sized
peft_embedding = peft_model.base_model.model.get_input_embeddings()
max_embedding_output = peft_embedding(torch.arange(10)).abs().max(0)[0]
assert (max_embedding_output < 100.0).all()

# set embed_scale to an absurdly high value, then check that the embedding output is also scaled to a high
# value
orig_embedding.embed_scale.fill_(10000.0)
max_embedding_output = peft_embedding(torch.arange(10)).abs().max(0)[0]
assert (max_embedding_output > 100.0).all()

# set embed_scale to zero, then check that the embedding output is also zero
orig_embedding.embed_scale.fill_(0)
embedding_output = peft_embedding(torch.arange(10))
assert (embedding_output == 0.0).all()

def test_scaled_embedding_with_lora(self):
"""Test that TrainableTokens works with LoRA on scaled embeddings when both are active simultaneously."""
model_id = "hf-internal-testing/tiny-random-Gemma3ForCausalLM"
with hub_online_once(model_id):
base_model = AutoModelForCausalLM.from_pretrained(model_id)
orig_embedding = base_model.get_input_embeddings()

# Apply both TrainableTokens and LoRA to the same model
peft_config = LoraConfig(target_modules=["q_proj"], trainable_token_indices={"embed_tokens": [0, 1, 3]})
peft_model = get_peft_model(base_model, peft_config)

# sanity check: with the default embed_scale, the embedding output should be reasonably sized
peft_embedding = peft_model.base_model.model.get_input_embeddings()
max_embedding_output = peft_embedding(torch.arange(10)).abs().max(0)[0]
assert (max_embedding_output < 100.0).all()

# set embed_scale to an absurdly high value, then check that the embedding output is also scaled to a high
# value
orig_embedding.embed_scale.fill_(10000.0)
max_embedding_output = peft_embedding(torch.arange(10)).abs().max(0)[0]
assert (max_embedding_output > 100.0).all()

# set embed_scale to zero, then check that the embedding output is also zero
orig_embedding.embed_scale.fill_(0)
embedding_output = peft_embedding(torch.arange(10))
assert (embedding_output == 0.0).all()
47 changes: 47 additions & 0 deletions tests/test_xlora.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
from peft.tuners.xlora.layer import XLoraLayer
from peft.utils import infer_device

from .testing_utils import hub_online_once


def flaky(num_tries: int):
"""Decorator for test functions that are flaky"""
Expand Down Expand Up @@ -424,3 +426,48 @@ def mock_get_maybe_topk_scalings(self, scalings):
assert torch.allclose(weight_sums, torch.ones_like(weight_sums), atol=1e-5), (
"Per-token scaling weights are not normalized to sum to 1."
)

def test_xlora_embed_scale_is_applied(self, tmp_path):
"""Test that X-LoRA correctly handles embeddings with scaling (e.g., Gemma3)."""
model_id = "hf-internal-testing/tiny-random-Gemma3ForCausalLM"
with hub_online_once(model_id):
# Create and save Gemma3-compatible LoRA adapters
adapters = {}
for i in range(2):
torch.manual_seed(i + 1)
lora_config = LoraConfig(
task_type="CAUSAL_LM", init_lora_weights=False, target_modules=["embed_tokens"]
)
model = AutoModelForCausalLM.from_pretrained(model_id)
peft_model = get_peft_model(model, lora_config)
adapter_path = os.path.join(tmp_path, f"checkpoint-{i + 1}")
peft_model.save_pretrained(adapter_path)
adapters[str(i)] = adapter_path

# Load base model and test X-LoRA with embed_scale
base_model = AutoModelForCausalLM.from_pretrained(model_id)
base_model.config.use_cache = False
orig_embedding = base_model.get_input_embeddings()

xlora_config = XLoraConfig(
task_type=TaskType.CAUSAL_LM,
hidden_size=base_model.config.hidden_size,
adapters=adapters,
)
xlora_model = get_peft_model(base_model, xlora_config)

# sanity check: with the default embed_scale, the embedding output should be reasonably sized
xlora_embedding = xlora_model.base_model.model.get_input_embeddings()
max_embedding_output = xlora_embedding(torch.arange(10)).abs().max(0)[0]
assert (max_embedding_output < 100.0).all()

# set embed_scale to an absurdly high value, then check that the embedding output is also scaled to a high
# value
orig_embedding.embed_scale.fill_(10000.0)
max_embedding_output = xlora_embedding(torch.arange(10)).abs().max(0)[0]
assert (max_embedding_output > 100.0).all()

# set embed_scale to zero, then check that the embedding output is also zero
orig_embedding.embed_scale.fill_(0)
embedding_output = xlora_embedding(torch.arange(10))
assert (embedding_output == 0.0).all()