Skip to content
61 changes: 59 additions & 2 deletions src/peft/tuners/lora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -870,6 +870,42 @@ def resolve_lora_variant(self, *, use_dora: bool, **kwargs) -> Optional[LoraVari

return DoraEmbeddingVariant()

def _get_embed_scale(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's factor this function out so that it can be reused between the two PEFT methods. It should be fine to put it on the BaseTunerLayer class. This also eliminates the small inconsistency between the two implementations you have.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resolved in 0d216c7

"""
Extract embed_scale from base layer if present and valid.

Some embedding layers (e.g., Gemma3TextScaledWordEmbedding) apply scaling to embeddings in their forward
method. This method checks for the presence of an `embed_scale` attribute and validates its shape.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's mention that if it exists, it is assumed to be a scalar. From my search through the transformers code base, this is always the case right now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resolved in 0d216c7


Returns:
torch.Tensor or None: The embed_scale tensor if found and valid, None otherwise.
"""
base_layer = self.get_base_layer()
if not hasattr(base_layer, "embed_scale"):
return None

embed_scale = base_layer.embed_scale

# Convert scalar values to tensors
if isinstance(embed_scale, (int, float)):
return torch.tensor(embed_scale, device=base_layer.weight.device, dtype=base_layer.weight.dtype)

# Validate tensor shape - must be scalar (0-d) or 1-element tensor for proper broadcasting
if isinstance(embed_scale, torch.Tensor):
if embed_scale.numel() == 1:
return embed_scale
else:
# Log warning but don't fail - this maintains backward compatibility
warnings.warn(
f"Found embed_scale attribute with shape {embed_scale.shape}, expected scalar. "
"Embedding scaling will not be applied. If this is unexpected, please open an issue at "
"https://github.com/huggingface/peft/issues",
PeftWarning,
)
return None

return None

def update_layer(
self,
adapter_name,
Expand Down Expand Up @@ -1035,6 +1071,10 @@ def _mixed_batch_forward(
# extra argument that allows mixing different adapters in the same batch at inference time.
result = self.base_layer(x, *args, **kwargs)

# Some embedding layers (e.g., Gemma3TextScaledWordEmbedding) apply scaling in their forward method.
# Since base_layer(x) already includes this scaling, we need to apply it to LoRA contributions too.
embed_scale = self._get_embed_scale()

unique_adapters = set(adapter_names)
sub_batch_indices_list = []
for adapter in unique_adapters:
Expand All @@ -1054,7 +1094,13 @@ def _mixed_batch_forward(
# layer output
sub_batch = x[sub_batch_indices_list[i]]
after_A = self._embed(sub_batch, embedding_A)
result[sub_batch_indices_list[i]] += (after_A @ embedding_B) * scaling
adapter_output = (after_A @ embedding_B) * scaling

# Apply embed_scale to match the base layer's scaling
if embed_scale is not None:
adapter_output = adapter_output * embed_scale.to(adapter_output.dtype)

result[sub_batch_indices_list[i]] += adapter_output

return result

Expand Down Expand Up @@ -1086,6 +1132,11 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
else:
result = self.base_layer(x, *args, **kwargs)
torch_result_dtype = result.dtype

# Some embedding layers (e.g., Gemma3TextScaledWordEmbedding) apply scaling in their forward method.
# Since base_layer(x) already includes this scaling, we need to apply it to LoRA contributions too.
embed_scale = self._get_embed_scale()

for active_adapter in self.active_adapters:
if active_adapter not in self.lora_embedding_A:
continue
Expand All @@ -1095,7 +1146,13 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
embedding_B = self.lora_embedding_B[active_adapter].T
scaling = self.scaling[active_adapter]
after_A = self._embed(x, embedding_A)
result = result + (after_A @ embedding_B) * scaling
adapter_output = (after_A @ embedding_B) * scaling

# Apply embed_scale to match the base layer's scaling
if embed_scale is not None:
adapter_output = adapter_output * embed_scale.to(adapter_output.dtype)

result = result + adapter_output
else:
result = self.lora_variant[active_adapter].forward(
self,
Expand Down
41 changes: 41 additions & 0 deletions src/peft/tuners/trainable_tokens/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,42 @@ def tied_adapter(self):
return self._tied_adapter[0]
return None

def _get_embed_scale(self):
"""
Extract embed_scale from base layer if present and valid.

Some embedding layers (e.g., Gemma3TextScaledWordEmbedding) apply scaling to embeddings in their forward
method. This method checks for the presence of an `embed_scale` attribute and validates its shape.

Returns:
torch.Tensor or None: The embed_scale tensor if found and valid, None otherwise.
"""
base_layer = self.get_base_layer()
if not hasattr(base_layer, "embed_scale"):
return None

embed_scale = base_layer.embed_scale

# Convert scalar values to tensors
if isinstance(embed_scale, (int, float)):
return torch.tensor(embed_scale, device=base_layer.weight.device, dtype=base_layer.weight.dtype)

# Validate tensor shape - must be scalar (0-d) or 1-element tensor for proper broadcasting
if isinstance(embed_scale, torch.Tensor):
if embed_scale.numel() == 1:
return embed_scale
else:
# Log warning but don't fail - this maintains backward compatibility
warnings.warn(
f"Found embed_scale attribute with shape {embed_scale.shape}, expected scalar. "
"Embedding scaling will not be applied. If this is unexpected, please open an issue at "
"https://github.com/huggingface/peft/issues",
UserWarning,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For LoRA, you use PeftWarning. I think that's the better choice.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resolved in 0d216c7

)
return None

return None

def _collect_token_weights(self, weight: torch.Tensor, rows: torch.Tensor, embed_dim: int) -> torch.Tensor:
"""DeepSpeed zero3 specific code to initialize trainable tokens.

Expand Down Expand Up @@ -232,6 +268,11 @@ def forward_adapters(self, x: torch.Tensor, active_adapters, *args, **kwargs) ->
scale_grad_by_freq=self.base_layer.scale_grad_by_freq,
sparse=self.base_layer.sparse,
)
# Some embedding layers (e.g., Gemma3TextScaledWordEmbedding) apply scaling in their forward method.
# Since we're using F.embedding directly, we need to apply this scaling manually.
embed_scale = self._get_embed_scale()
if embed_scale is not None:
result = result * embed_scale.to(result.dtype)
elif isinstance(self.base_layer, torch.nn.Linear):
# Probably a tied adapter that wraps an LM head.
result = F.linear(
Expand Down
100 changes: 100 additions & 0 deletions tests/test_trainable_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -918,3 +918,103 @@ def test_save_pretrained_auto(self, model, resize_embedding, peft_config, tmp_pa
assert contains_embedding
else:
assert not contains_embedding

def test_scaled_embedding_support(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for including tests for the correct application of embed_scale. However, I think the test is not focused enough to test its goal. To prove it, you can change _get_embed_scale to return a fixed scalar of 1.0 and the test would still pass, except for the torch.allclose(embed_scale, torch.tensor(3.0)) assert, which by itself doesn't check if the scale is correctly applied.

Let me propose a test that I think will be more focused. It uses a real (but small) transformers model checks the outputs of the embedding layer. By carefully varying the embed_scale, we can ensure that it's being applied. Let me know what you think.

    def test_embed_scale_is_applied(self):
        """Test that TrainableTokens correctly handles embeddings with scaling (e.g., Gemma3)."""
        model_id = "hf-internal-testing/tiny-random-Gemma3ForCausalLM"
        with hub_online_once(model_id):
            base_model = AutoModelForCausalLM.from_pretrained(model_id)
            orig_embedding = base_model.get_input_embeddings()

            peft_config = TrainableTokensConfig(target_modules=["embed_tokens"], token_indices=[0, 1, 3])
            peft_model = get_peft_model(base_model, peft_config)

            # sanity check: with the default embed_scale, the embedding output should be reasonably sized
            peft_embedding = peft_model.base_model.model.get_input_embeddings()
            max_embedding_output = peft_embedding(torch.arange(10)).abs().max(0)[0]
            assert (max_embedding_output < 100.0).all()

            # set embed_scale to an absurdly high value, then check that the embedding output is also scaled to a high
            # value
            orig_embedding.embed_scale.fill_(10000.0)
            max_embedding_output = peft_embedding(torch.arange(10)).abs().max(0)[0]
            assert (max_embedding_output > 100.0).all()

            # set embed_scale to zero, then check that the embedding output is also zero
            orig_embedding.embed_scale.fill_(0)
            embedding_output = peft_embedding(torch.arange(10))
            assert (embedding_output == 0.0).all()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resolved in 0d216c7 . implemented this with a mini version of gemma3 for more realistic testing . all tc pass .

"""Test that TrainableTokens correctly handles embeddings with scaling (e.g., Gemma3)."""

# Create a mock scaled embedding layer similar to Gemma3TextScaledWordEmbedding
class ScaledEmbedding(torch.nn.Embedding):
def __init__(self, num_embeddings, embedding_dim, embed_scale=2.0):
super().__init__(num_embeddings, embedding_dim)
self.register_buffer("embed_scale", torch.tensor(embed_scale), persistent=False)

def forward(self, input_ids):
return super().forward(input_ids) * self.embed_scale

# Create a model with scaled embedding
class ModelWithScaledEmb(torch.nn.Module):
def __init__(self):
super().__init__()
self.emb = ScaledEmbedding(100, 10, embed_scale=3.0)
self.lin0 = torch.nn.Linear(10, 1)

def forward(self, x):
return self.lin0(self.emb(x))

def get_input_embeddings(self):
return self.emb

base_model = ModelWithScaledEmb()
peft_config = TrainableTokensConfig(target_modules=["emb"], token_indices=[0, 1, 2])
peft_model = get_peft_model(base_model, peft_config)

# Test input
x = torch.tensor([[0, 1, 2, 3]])

# Get outputs from base model and peft model
base_output = base_model(x)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The base_output needs to be determined before calling get_peft_model, since this function will modify the base model in-place. In this particular case, it doesn't matter, but as a general precaution, it's better to do it correctly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resolved in 0d216c7

peft_output = peft_model(x)

# The outputs should be scaled - let's verify the embeddings are scaled
base_embeddings = base_model.emb(x)
peft_embeddings = peft_model.model.emb(x)

# Check that both apply the same scaling factor
assert hasattr(peft_model.model.emb, "_get_embed_scale")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can be removed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resolved in 0d216c7

embed_scale = peft_model.model.emb._get_embed_scale()
assert embed_scale is not None
assert torch.allclose(embed_scale, torch.tensor(3.0))

# Before training, outputs should be identical (within numerical precision)
assert torch.allclose(base_embeddings, peft_embeddings, atol=1e-6)

# Simulate training
self.simulate_training(peft_model.model.emb)

# After "training", the scaled embeddings for modified tokens should differ
peft_embeddings_trained = peft_model.model.emb(x)

# Modified tokens (0, 1, 2) should be different from base
base_emb_modified = base_embeddings[0, :3]
peft_emb_modified = peft_embeddings_trained[0, :3]
assert not torch.allclose(base_emb_modified, peft_emb_modified)

# Unmodified token (3) should be the same
base_emb_unmodified = base_embeddings[0, 3]
peft_emb_unmodified = peft_embeddings_trained[0, 3]
assert torch.allclose(base_emb_unmodified, peft_emb_unmodified, atol=1e-6)

def test_scaled_embedding_with_lora(self):
"""Test that TrainableTokens works with LoRA on scaled embeddings."""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please explain in a bit more detail what this tests? Would it not be possible to have a similar test setup to the one above here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test validates the combined scenario where both TrainableTokens AND LoRA are active simultaneously ( line 977: trainable_token_indices + target_modules both set). so on one hand in test_embed_scale_is_applied we test TrainableTokens alone on Gemma3 while intest_scaled_embedding_with_lora: TrainableTokens + LoRA together on mock model .

This ensures embed_scale is correctly applied when both PEFT methods interact on the same model.

would you prefer gemma3 instead of the mock model for this test as well ? or update the test_embed_scale_is_applied to use this combined approach instead ? or the third alternative is we keep the tests as they are now with no change ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally, we should have a test that is similar to test_embed_scale_is_applied, using a real model architecture and following similar steps. It would be great if you could rewrite the test accordingly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cool . rewritten in 883b3d3 .
Screenshot 2025-10-13 at 3 13 35 PM


class ScaledEmbedding(torch.nn.Embedding):
def __init__(self, num_embeddings, embedding_dim, embed_scale=2.0):
super().__init__(num_embeddings, embedding_dim)
self.register_buffer("embed_scale", torch.tensor(embed_scale), persistent=False)

def forward(self, input_ids):
return super().forward(input_ids) * self.embed_scale

class ModelWithScaledEmb(torch.nn.Module):
def __init__(self):
super().__init__()
self.emb = ScaledEmbedding(100, 10, embed_scale=3.0)
self.lin0 = torch.nn.Linear(10, 1)

def forward(self, x):
return self.lin0(self.emb(x))

def get_input_embeddings(self):
return self.emb

base_model = ModelWithScaledEmb()
peft_config = LoraConfig(target_modules=["lin0"], trainable_token_indices={"emb": [0, 1, 2]})
peft_model = get_peft_model(base_model, peft_config)

x = torch.tensor([[0, 1, 2, 3]])

# Verify embed_scale detection works in combined mode
assert hasattr(peft_model.model.emb.token_adapter, "_get_embed_scale")
embed_scale = peft_model.model.emb.token_adapter._get_embed_scale()
assert embed_scale is not None
assert torch.allclose(embed_scale, torch.tensor(3.0))