From d6fe7d3479e31fc7d7d945526c48b6b452bfd85f Mon Sep 17 00:00:00 2001 From: sambhavnoobcoder Date: Thu, 9 Oct 2025 15:08:10 +0530 Subject: [PATCH 1/9] trainable tokens model embedding scaling is here now . --- src/peft/tuners/lora/layer.py | 61 ++++++++++++- src/peft/tuners/trainable_tokens/layer.py | 41 +++++++++ tests/test_trainable_tokens.py | 100 ++++++++++++++++++++++ 3 files changed, 200 insertions(+), 2 deletions(-) diff --git a/src/peft/tuners/lora/layer.py b/src/peft/tuners/lora/layer.py index b01e87e6a5..834fc27ba0 100644 --- a/src/peft/tuners/lora/layer.py +++ b/src/peft/tuners/lora/layer.py @@ -870,6 +870,42 @@ def resolve_lora_variant(self, *, use_dora: bool, **kwargs) -> Optional[LoraVari return DoraEmbeddingVariant() + def _get_embed_scale(self): + """ + Extract embed_scale from base layer if present and valid. + + Some embedding layers (e.g., Gemma3TextScaledWordEmbedding) apply scaling to embeddings in their forward + method. This method checks for the presence of an `embed_scale` attribute and validates its shape. + + Returns: + torch.Tensor or None: The embed_scale tensor if found and valid, None otherwise. + """ + base_layer = self.get_base_layer() + if not hasattr(base_layer, "embed_scale"): + return None + + embed_scale = base_layer.embed_scale + + # Convert scalar values to tensors + if isinstance(embed_scale, (int, float)): + return torch.tensor(embed_scale, device=base_layer.weight.device, dtype=base_layer.weight.dtype) + + # Validate tensor shape - must be scalar (0-d) or 1-element tensor for proper broadcasting + if isinstance(embed_scale, torch.Tensor): + if embed_scale.numel() == 1: + return embed_scale + else: + # Log warning but don't fail - this maintains backward compatibility + warnings.warn( + f"Found embed_scale attribute with shape {embed_scale.shape}, expected scalar. " + "Embedding scaling will not be applied. If this is unexpected, please open an issue at " + "https://github.com/huggingface/peft/issues", + PeftWarning, + ) + return None + + return None + def update_layer( self, adapter_name, @@ -1035,6 +1071,10 @@ def _mixed_batch_forward( # extra argument that allows mixing different adapters in the same batch at inference time. result = self.base_layer(x, *args, **kwargs) + # Some embedding layers (e.g., Gemma3TextScaledWordEmbedding) apply scaling in their forward method. + # Since base_layer(x) already includes this scaling, we need to apply it to LoRA contributions too. + embed_scale = self._get_embed_scale() + unique_adapters = set(adapter_names) sub_batch_indices_list = [] for adapter in unique_adapters: @@ -1054,7 +1094,13 @@ def _mixed_batch_forward( # layer output sub_batch = x[sub_batch_indices_list[i]] after_A = self._embed(sub_batch, embedding_A) - result[sub_batch_indices_list[i]] += (after_A @ embedding_B) * scaling + adapter_output = (after_A @ embedding_B) * scaling + + # Apply embed_scale to match the base layer's scaling + if embed_scale is not None: + adapter_output = adapter_output * embed_scale.to(adapter_output.dtype) + + result[sub_batch_indices_list[i]] += adapter_output return result @@ -1086,6 +1132,11 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: else: result = self.base_layer(x, *args, **kwargs) torch_result_dtype = result.dtype + + # Some embedding layers (e.g., Gemma3TextScaledWordEmbedding) apply scaling in their forward method. + # Since base_layer(x) already includes this scaling, we need to apply it to LoRA contributions too. + embed_scale = self._get_embed_scale() + for active_adapter in self.active_adapters: if active_adapter not in self.lora_embedding_A: continue @@ -1095,7 +1146,13 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: embedding_B = self.lora_embedding_B[active_adapter].T scaling = self.scaling[active_adapter] after_A = self._embed(x, embedding_A) - result = result + (after_A @ embedding_B) * scaling + adapter_output = (after_A @ embedding_B) * scaling + + # Apply embed_scale to match the base layer's scaling + if embed_scale is not None: + adapter_output = adapter_output * embed_scale.to(adapter_output.dtype) + + result = result + adapter_output else: result = self.lora_variant[active_adapter].forward( self, diff --git a/src/peft/tuners/trainable_tokens/layer.py b/src/peft/tuners/trainable_tokens/layer.py index 0f35462224..b8c0565f9b 100644 --- a/src/peft/tuners/trainable_tokens/layer.py +++ b/src/peft/tuners/trainable_tokens/layer.py @@ -79,6 +79,42 @@ def tied_adapter(self): return self._tied_adapter[0] return None + def _get_embed_scale(self): + """ + Extract embed_scale from base layer if present and valid. + + Some embedding layers (e.g., Gemma3TextScaledWordEmbedding) apply scaling to embeddings in their forward + method. This method checks for the presence of an `embed_scale` attribute and validates its shape. + + Returns: + torch.Tensor or None: The embed_scale tensor if found and valid, None otherwise. + """ + base_layer = self.get_base_layer() + if not hasattr(base_layer, "embed_scale"): + return None + + embed_scale = base_layer.embed_scale + + # Convert scalar values to tensors + if isinstance(embed_scale, (int, float)): + return torch.tensor(embed_scale, device=base_layer.weight.device, dtype=base_layer.weight.dtype) + + # Validate tensor shape - must be scalar (0-d) or 1-element tensor for proper broadcasting + if isinstance(embed_scale, torch.Tensor): + if embed_scale.numel() == 1: + return embed_scale + else: + # Log warning but don't fail - this maintains backward compatibility + warnings.warn( + f"Found embed_scale attribute with shape {embed_scale.shape}, expected scalar. " + "Embedding scaling will not be applied. If this is unexpected, please open an issue at " + "https://github.com/huggingface/peft/issues", + UserWarning, + ) + return None + + return None + def _collect_token_weights(self, weight: torch.Tensor, rows: torch.Tensor, embed_dim: int) -> torch.Tensor: """DeepSpeed zero3 specific code to initialize trainable tokens. @@ -232,6 +268,11 @@ def forward_adapters(self, x: torch.Tensor, active_adapters, *args, **kwargs) -> scale_grad_by_freq=self.base_layer.scale_grad_by_freq, sparse=self.base_layer.sparse, ) + # Some embedding layers (e.g., Gemma3TextScaledWordEmbedding) apply scaling in their forward method. + # Since we're using F.embedding directly, we need to apply this scaling manually. + embed_scale = self._get_embed_scale() + if embed_scale is not None: + result = result * embed_scale.to(result.dtype) elif isinstance(self.base_layer, torch.nn.Linear): # Probably a tied adapter that wraps an LM head. result = F.linear( diff --git a/tests/test_trainable_tokens.py b/tests/test_trainable_tokens.py index 38b32b06ed..2eae92dd86 100644 --- a/tests/test_trainable_tokens.py +++ b/tests/test_trainable_tokens.py @@ -918,3 +918,103 @@ def test_save_pretrained_auto(self, model, resize_embedding, peft_config, tmp_pa assert contains_embedding else: assert not contains_embedding + + def test_scaled_embedding_support(self): + """Test that TrainableTokens correctly handles embeddings with scaling (e.g., Gemma3).""" + + # Create a mock scaled embedding layer similar to Gemma3TextScaledWordEmbedding + class ScaledEmbedding(torch.nn.Embedding): + def __init__(self, num_embeddings, embedding_dim, embed_scale=2.0): + super().__init__(num_embeddings, embedding_dim) + self.register_buffer("embed_scale", torch.tensor(embed_scale), persistent=False) + + def forward(self, input_ids): + return super().forward(input_ids) * self.embed_scale + + # Create a model with scaled embedding + class ModelWithScaledEmb(torch.nn.Module): + def __init__(self): + super().__init__() + self.emb = ScaledEmbedding(100, 10, embed_scale=3.0) + self.lin0 = torch.nn.Linear(10, 1) + + def forward(self, x): + return self.lin0(self.emb(x)) + + def get_input_embeddings(self): + return self.emb + + base_model = ModelWithScaledEmb() + peft_config = TrainableTokensConfig(target_modules=["emb"], token_indices=[0, 1, 2]) + peft_model = get_peft_model(base_model, peft_config) + + # Test input + x = torch.tensor([[0, 1, 2, 3]]) + + # Get outputs from base model and peft model + base_output = base_model(x) + peft_output = peft_model(x) + + # The outputs should be scaled - let's verify the embeddings are scaled + base_embeddings = base_model.emb(x) + peft_embeddings = peft_model.model.emb(x) + + # Check that both apply the same scaling factor + assert hasattr(peft_model.model.emb, "_get_embed_scale") + embed_scale = peft_model.model.emb._get_embed_scale() + assert embed_scale is not None + assert torch.allclose(embed_scale, torch.tensor(3.0)) + + # Before training, outputs should be identical (within numerical precision) + assert torch.allclose(base_embeddings, peft_embeddings, atol=1e-6) + + # Simulate training + self.simulate_training(peft_model.model.emb) + + # After "training", the scaled embeddings for modified tokens should differ + peft_embeddings_trained = peft_model.model.emb(x) + + # Modified tokens (0, 1, 2) should be different from base + base_emb_modified = base_embeddings[0, :3] + peft_emb_modified = peft_embeddings_trained[0, :3] + assert not torch.allclose(base_emb_modified, peft_emb_modified) + + # Unmodified token (3) should be the same + base_emb_unmodified = base_embeddings[0, 3] + peft_emb_unmodified = peft_embeddings_trained[0, 3] + assert torch.allclose(base_emb_unmodified, peft_emb_unmodified, atol=1e-6) + + def test_scaled_embedding_with_lora(self): + """Test that TrainableTokens works with LoRA on scaled embeddings.""" + + class ScaledEmbedding(torch.nn.Embedding): + def __init__(self, num_embeddings, embedding_dim, embed_scale=2.0): + super().__init__(num_embeddings, embedding_dim) + self.register_buffer("embed_scale", torch.tensor(embed_scale), persistent=False) + + def forward(self, input_ids): + return super().forward(input_ids) * self.embed_scale + + class ModelWithScaledEmb(torch.nn.Module): + def __init__(self): + super().__init__() + self.emb = ScaledEmbedding(100, 10, embed_scale=3.0) + self.lin0 = torch.nn.Linear(10, 1) + + def forward(self, x): + return self.lin0(self.emb(x)) + + def get_input_embeddings(self): + return self.emb + + base_model = ModelWithScaledEmb() + peft_config = LoraConfig(target_modules=["lin0"], trainable_token_indices={"emb": [0, 1, 2]}) + peft_model = get_peft_model(base_model, peft_config) + + x = torch.tensor([[0, 1, 2, 3]]) + + # Verify embed_scale detection works in combined mode + assert hasattr(peft_model.model.emb.token_adapter, "_get_embed_scale") + embed_scale = peft_model.model.emb.token_adapter._get_embed_scale() + assert embed_scale is not None + assert torch.allclose(embed_scale, torch.tensor(3.0)) From 0d216c7c0113b115f12275b608233facdaf6a800 Mon Sep 17 00:00:00 2001 From: sambhavnoobcoder Date: Fri, 10 Oct 2025 18:53:23 +0530 Subject: [PATCH 2/9] all tc fixed , testing behaviour verified --- src/peft/tuners/lora/layer.py | 36 --------- src/peft/tuners/trainable_tokens/layer.py | 36 --------- src/peft/tuners/tuners_utils.py | 38 ++++++++++ tests/test_decoder_models.py | 55 ++++++++++++++ tests/test_trainable_tokens.py | 92 +++++++---------------- 5 files changed, 120 insertions(+), 137 deletions(-) diff --git a/src/peft/tuners/lora/layer.py b/src/peft/tuners/lora/layer.py index 834fc27ba0..a338fac0f4 100644 --- a/src/peft/tuners/lora/layer.py +++ b/src/peft/tuners/lora/layer.py @@ -870,42 +870,6 @@ def resolve_lora_variant(self, *, use_dora: bool, **kwargs) -> Optional[LoraVari return DoraEmbeddingVariant() - def _get_embed_scale(self): - """ - Extract embed_scale from base layer if present and valid. - - Some embedding layers (e.g., Gemma3TextScaledWordEmbedding) apply scaling to embeddings in their forward - method. This method checks for the presence of an `embed_scale` attribute and validates its shape. - - Returns: - torch.Tensor or None: The embed_scale tensor if found and valid, None otherwise. - """ - base_layer = self.get_base_layer() - if not hasattr(base_layer, "embed_scale"): - return None - - embed_scale = base_layer.embed_scale - - # Convert scalar values to tensors - if isinstance(embed_scale, (int, float)): - return torch.tensor(embed_scale, device=base_layer.weight.device, dtype=base_layer.weight.dtype) - - # Validate tensor shape - must be scalar (0-d) or 1-element tensor for proper broadcasting - if isinstance(embed_scale, torch.Tensor): - if embed_scale.numel() == 1: - return embed_scale - else: - # Log warning but don't fail - this maintains backward compatibility - warnings.warn( - f"Found embed_scale attribute with shape {embed_scale.shape}, expected scalar. " - "Embedding scaling will not be applied. If this is unexpected, please open an issue at " - "https://github.com/huggingface/peft/issues", - PeftWarning, - ) - return None - - return None - def update_layer( self, adapter_name, diff --git a/src/peft/tuners/trainable_tokens/layer.py b/src/peft/tuners/trainable_tokens/layer.py index b8c0565f9b..da955a6844 100644 --- a/src/peft/tuners/trainable_tokens/layer.py +++ b/src/peft/tuners/trainable_tokens/layer.py @@ -79,42 +79,6 @@ def tied_adapter(self): return self._tied_adapter[0] return None - def _get_embed_scale(self): - """ - Extract embed_scale from base layer if present and valid. - - Some embedding layers (e.g., Gemma3TextScaledWordEmbedding) apply scaling to embeddings in their forward - method. This method checks for the presence of an `embed_scale` attribute and validates its shape. - - Returns: - torch.Tensor or None: The embed_scale tensor if found and valid, None otherwise. - """ - base_layer = self.get_base_layer() - if not hasattr(base_layer, "embed_scale"): - return None - - embed_scale = base_layer.embed_scale - - # Convert scalar values to tensors - if isinstance(embed_scale, (int, float)): - return torch.tensor(embed_scale, device=base_layer.weight.device, dtype=base_layer.weight.dtype) - - # Validate tensor shape - must be scalar (0-d) or 1-element tensor for proper broadcasting - if isinstance(embed_scale, torch.Tensor): - if embed_scale.numel() == 1: - return embed_scale - else: - # Log warning but don't fail - this maintains backward compatibility - warnings.warn( - f"Found embed_scale attribute with shape {embed_scale.shape}, expected scalar. " - "Embedding scaling will not be applied. If this is unexpected, please open an issue at " - "https://github.com/huggingface/peft/issues", - UserWarning, - ) - return None - - return None - def _collect_token_weights(self, weight: torch.Tensor, rows: torch.Tensor, embed_dim: int) -> torch.Tensor: """DeepSpeed zero3 specific code to initialize trainable tokens. diff --git a/src/peft/tuners/tuners_utils.py b/src/peft/tuners/tuners_utils.py index 4fd0d12843..1ed1e383bf 100644 --- a/src/peft/tuners/tuners_utils.py +++ b/src/peft/tuners/tuners_utils.py @@ -49,6 +49,7 @@ set_additional_trainable_modules, ) from peft.utils.peft_types import PeftType, TaskType +from peft.utils.warning import PeftWarning from ..config import PeftConfig from ..utils import _get_submodules @@ -1201,6 +1202,43 @@ def get_base_layer(self) -> nn.Module: base_layer = base_layer.base_layer return base_layer + def _get_embed_scale(self): + """ + Extract embed_scale from base layer if present and valid. + + Some embedding layers (e.g., Gemma3TextScaledWordEmbedding) apply scaling to embeddings in their forward + method. This method checks for the presence of an `embed_scale` attribute. If it exists, it is assumed to be + a scalar and validates its shape accordingly. + + Returns: + torch.Tensor or None: The embed_scale tensor if found and valid, None otherwise. + """ + base_layer = self.get_base_layer() + if not hasattr(base_layer, "embed_scale"): + return None + + embed_scale = base_layer.embed_scale + + # Convert scalar values to tensors + if isinstance(embed_scale, (int, float)): + return torch.tensor(embed_scale, device=base_layer.weight.device, dtype=base_layer.weight.dtype) + + # Validate tensor shape - must be scalar (0-d) or 1-element tensor for proper broadcasting + if isinstance(embed_scale, torch.Tensor): + if embed_scale.numel() == 1: + return embed_scale + else: + # Log warning but don't fail - this maintains backward compatibility + warnings.warn( + f"Found embed_scale attribute with shape {embed_scale.shape}, expected scalar. " + "Embedding scaling will not be applied. If this is unexpected, please open an issue at " + "https://github.com/huggingface/peft/issues", + PeftWarning, + ) + return None + + return None + @property def weight(self) -> torch.Tensor: # This is required for some transformers code, e.g. for T5, weight is accessed as: diff --git a/tests/test_decoder_models.py b/tests/test_decoder_models.py index 12b4a62c3b..fb825378be 100644 --- a/tests/test_decoder_models.py +++ b/tests/test_decoder_models.py @@ -795,3 +795,58 @@ def test_save_pretrained_targeting_lora_to_embedding_layer(self, save_embedding_ ) else: assert not contains_embedding + + def test_lora_embed_scale_is_applied(self): + """Test that LoRA correctly handles embeddings with scaling (e.g., Gemma3).""" + model_id = "hf-internal-testing/tiny-random-Gemma3ForCausalLM" + with hub_online_once(model_id): + base_model = AutoModelForCausalLM.from_pretrained(model_id) + orig_embedding = base_model.get_input_embeddings() + + peft_config = LoraConfig(target_modules=["embed_tokens"], init_lora_weights=False) + peft_model = get_peft_model(base_model, peft_config) + + # sanity check: with the default embed_scale, the embedding output should be reasonably sized + peft_embedding = peft_model.base_model.model.get_input_embeddings() + max_embedding_output = peft_embedding(torch.arange(10)).abs().max(0)[0] + assert (max_embedding_output < 100.0).all() + + # set embed_scale to an absurdly high value, then check that the embedding output is also scaled to a high + # value + orig_embedding.embed_scale.fill_(10000.0) + max_embedding_output = peft_embedding(torch.arange(10)).abs().max(0)[0] + assert (max_embedding_output > 100.0).all() + + # set embed_scale to zero, then check that the embedding output is also zero + orig_embedding.embed_scale.fill_(0) + embedding_output = peft_embedding(torch.arange(10)) + assert (embedding_output == 0.0).all() + + def test_lora_embed_scale_is_applied_mixed_batch(self): + """Test that LoRA correctly handles embeddings with scaling in mixed batch mode.""" + model_id = "hf-internal-testing/tiny-random-Gemma3ForCausalLM" + with hub_online_once(model_id): + base_model = AutoModelForCausalLM.from_pretrained(model_id) + orig_embedding = base_model.get_input_embeddings() + + peft_config = LoraConfig(target_modules=["embed_tokens"], init_lora_weights=False) + peft_model = get_peft_model(base_model, peft_config) + peft_model.add_adapter("adapter2", peft_config) + + # sanity check: with the default embed_scale, the embedding output should be reasonably sized + peft_embedding = peft_model.base_model.model.get_input_embeddings() + input_ids = torch.arange(10).unsqueeze(0).repeat(2, 1) + adapter_names = ["default", "adapter2"] + max_embedding_output = peft_embedding(input_ids, adapter_names=adapter_names).abs().max() + assert max_embedding_output < 100.0 + + # set embed_scale to an absurdly high value, then check that the embedding output is also scaled to a high + # value + orig_embedding.embed_scale.fill_(10000.0) + max_embedding_output = peft_embedding(input_ids, adapter_names=adapter_names).abs().max() + assert max_embedding_output > 100.0 + + # set embed_scale to zero, then check that the embedding output is also zero + orig_embedding.embed_scale.fill_(0) + embedding_output = peft_embedding(input_ids, adapter_names=adapter_names) + assert (embedding_output == 0.0).all() diff --git a/tests/test_trainable_tokens.py b/tests/test_trainable_tokens.py index 2eae92dd86..950636f051 100644 --- a/tests/test_trainable_tokens.py +++ b/tests/test_trainable_tokens.py @@ -919,70 +919,31 @@ def test_save_pretrained_auto(self, model, resize_embedding, peft_config, tmp_pa else: assert not contains_embedding - def test_scaled_embedding_support(self): + def test_embed_scale_is_applied(self): """Test that TrainableTokens correctly handles embeddings with scaling (e.g., Gemma3).""" + model_id = "hf-internal-testing/tiny-random-Gemma3ForCausalLM" + with hub_online_once(model_id): + base_model = AutoModelForCausalLM.from_pretrained(model_id) + orig_embedding = base_model.get_input_embeddings() - # Create a mock scaled embedding layer similar to Gemma3TextScaledWordEmbedding - class ScaledEmbedding(torch.nn.Embedding): - def __init__(self, num_embeddings, embedding_dim, embed_scale=2.0): - super().__init__(num_embeddings, embedding_dim) - self.register_buffer("embed_scale", torch.tensor(embed_scale), persistent=False) - - def forward(self, input_ids): - return super().forward(input_ids) * self.embed_scale - - # Create a model with scaled embedding - class ModelWithScaledEmb(torch.nn.Module): - def __init__(self): - super().__init__() - self.emb = ScaledEmbedding(100, 10, embed_scale=3.0) - self.lin0 = torch.nn.Linear(10, 1) - - def forward(self, x): - return self.lin0(self.emb(x)) - - def get_input_embeddings(self): - return self.emb - - base_model = ModelWithScaledEmb() - peft_config = TrainableTokensConfig(target_modules=["emb"], token_indices=[0, 1, 2]) - peft_model = get_peft_model(base_model, peft_config) - - # Test input - x = torch.tensor([[0, 1, 2, 3]]) - - # Get outputs from base model and peft model - base_output = base_model(x) - peft_output = peft_model(x) - - # The outputs should be scaled - let's verify the embeddings are scaled - base_embeddings = base_model.emb(x) - peft_embeddings = peft_model.model.emb(x) - - # Check that both apply the same scaling factor - assert hasattr(peft_model.model.emb, "_get_embed_scale") - embed_scale = peft_model.model.emb._get_embed_scale() - assert embed_scale is not None - assert torch.allclose(embed_scale, torch.tensor(3.0)) - - # Before training, outputs should be identical (within numerical precision) - assert torch.allclose(base_embeddings, peft_embeddings, atol=1e-6) - - # Simulate training - self.simulate_training(peft_model.model.emb) + peft_config = TrainableTokensConfig(target_modules=["embed_tokens"], token_indices=[0, 1, 3]) + peft_model = get_peft_model(base_model, peft_config) - # After "training", the scaled embeddings for modified tokens should differ - peft_embeddings_trained = peft_model.model.emb(x) + # sanity check: with the default embed_scale, the embedding output should be reasonably sized + peft_embedding = peft_model.base_model.model.get_input_embeddings() + max_embedding_output = peft_embedding(torch.arange(10)).abs().max(0)[0] + assert (max_embedding_output < 100.0).all() - # Modified tokens (0, 1, 2) should be different from base - base_emb_modified = base_embeddings[0, :3] - peft_emb_modified = peft_embeddings_trained[0, :3] - assert not torch.allclose(base_emb_modified, peft_emb_modified) + # set embed_scale to an absurdly high value, then check that the embedding output is also scaled to a high + # value + orig_embedding.embed_scale.fill_(10000.0) + max_embedding_output = peft_embedding(torch.arange(10)).abs().max(0)[0] + assert (max_embedding_output > 100.0).all() - # Unmodified token (3) should be the same - base_emb_unmodified = base_embeddings[0, 3] - peft_emb_unmodified = peft_embeddings_trained[0, 3] - assert torch.allclose(base_emb_unmodified, peft_emb_unmodified, atol=1e-6) + # set embed_scale to zero, then check that the embedding output is also zero + orig_embedding.embed_scale.fill_(0) + embedding_output = peft_embedding(torch.arange(10)) + assert (embedding_output == 0.0).all() def test_scaled_embedding_with_lora(self): """Test that TrainableTokens works with LoRA on scaled embeddings.""" @@ -1008,13 +969,14 @@ def get_input_embeddings(self): return self.emb base_model = ModelWithScaledEmb() + x = torch.tensor([[0, 1, 2, 3]]) + + # Get base embeddings before applying PEFT + base_embeddings = base_model.emb(x) + peft_config = LoraConfig(target_modules=["lin0"], trainable_token_indices={"emb": [0, 1, 2]}) peft_model = get_peft_model(base_model, peft_config) - x = torch.tensor([[0, 1, 2, 3]]) - # Verify embed_scale detection works in combined mode - assert hasattr(peft_model.model.emb.token_adapter, "_get_embed_scale") - embed_scale = peft_model.model.emb.token_adapter._get_embed_scale() - assert embed_scale is not None - assert torch.allclose(embed_scale, torch.tensor(3.0)) + peft_embeddings = peft_model.base_model.model.emb(x) + assert torch.allclose(base_embeddings, peft_embeddings, atol=1e-6) From e4ab8b1a4679fc4e8cf63c016a8f99866def025b Mon Sep 17 00:00:00 2001 From: sambhavnoobcoder Date: Fri, 10 Oct 2025 23:15:23 +0530 Subject: [PATCH 3/9] doc string changed --- src/peft/tuners/tuners_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/peft/tuners/tuners_utils.py b/src/peft/tuners/tuners_utils.py index 1ed1e383bf..a7deeac571 100644 --- a/src/peft/tuners/tuners_utils.py +++ b/src/peft/tuners/tuners_utils.py @@ -1208,7 +1208,7 @@ def _get_embed_scale(self): Some embedding layers (e.g., Gemma3TextScaledWordEmbedding) apply scaling to embeddings in their forward method. This method checks for the presence of an `embed_scale` attribute. If it exists, it is assumed to be - a scalar and validates its shape accordingly. + a scalar. Its shape is validated accordingly. Returns: torch.Tensor or None: The embed_scale tensor if found and valid, None otherwise. From 2b0b6162752774bbf78988603b65547430d8bb90 Mon Sep 17 00:00:00 2001 From: sambhavnoobcoder Date: Fri, 10 Oct 2025 23:16:22 +0530 Subject: [PATCH 4/9] make style removed some unused import --- examples/boft_controlnet/test_controlnet.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/boft_controlnet/test_controlnet.py b/examples/boft_controlnet/test_controlnet.py index 2080deb0a7..9624b7c341 100644 --- a/examples/boft_controlnet/test_controlnet.py +++ b/examples/boft_controlnet/test_controlnet.py @@ -22,7 +22,6 @@ import numpy as np import torch -import torch.utils.checkpoint from accelerate import Accelerator from diffusers import DDIMScheduler from diffusers.utils import check_min_version From 883b3d3a6c07bc14584eec29142790032523543e Mon Sep 17 00:00:00 2001 From: sambhavnoobcoder Date: Mon, 13 Oct 2025 15:10:50 +0530 Subject: [PATCH 5/9] Update test_scaled_embedding_with_lora to use Gemma3 like test_embed_scale_is_applied --- tests/test_trainable_tokens.py | 56 ++++++++++++++++------------------ 1 file changed, 27 insertions(+), 29 deletions(-) diff --git a/tests/test_trainable_tokens.py b/tests/test_trainable_tokens.py index 950636f051..a531d3d2c3 100644 --- a/tests/test_trainable_tokens.py +++ b/tests/test_trainable_tokens.py @@ -946,37 +946,35 @@ def test_embed_scale_is_applied(self): assert (embedding_output == 0.0).all() def test_scaled_embedding_with_lora(self): - """Test that TrainableTokens works with LoRA on scaled embeddings.""" - - class ScaledEmbedding(torch.nn.Embedding): - def __init__(self, num_embeddings, embedding_dim, embed_scale=2.0): - super().__init__(num_embeddings, embedding_dim) - self.register_buffer("embed_scale", torch.tensor(embed_scale), persistent=False) - - def forward(self, input_ids): - return super().forward(input_ids) * self.embed_scale - - class ModelWithScaledEmb(torch.nn.Module): - def __init__(self): - super().__init__() - self.emb = ScaledEmbedding(100, 10, embed_scale=3.0) - self.lin0 = torch.nn.Linear(10, 1) - - def forward(self, x): - return self.lin0(self.emb(x)) + """ + Test that TrainableTokens works with LoRA on scaled embeddings when both are active simultaneously. - def get_input_embeddings(self): - return self.emb + This test uses a real model (Gemma3) and verifies that both TrainableTokens and LoRA correctly apply + embed_scale when used together. + """ + model_id = "hf-internal-testing/tiny-random-Gemma3ForCausalLM" + with hub_online_once(model_id): + base_model = AutoModelForCausalLM.from_pretrained(model_id) + orig_embedding = base_model.get_input_embeddings() - base_model = ModelWithScaledEmb() - x = torch.tensor([[0, 1, 2, 3]]) + # Apply both TrainableTokens and LoRA to the same model + peft_config = LoraConfig( + target_modules=["q_proj"], trainable_token_indices={"embed_tokens": [0, 1, 3]} + ) + peft_model = get_peft_model(base_model, peft_config) - # Get base embeddings before applying PEFT - base_embeddings = base_model.emb(x) + # sanity check: with the default embed_scale, the embedding output should be reasonably sized + peft_embedding = peft_model.base_model.model.get_input_embeddings() + max_embedding_output = peft_embedding(torch.arange(10)).abs().max(0)[0] + assert (max_embedding_output < 100.0).all() - peft_config = LoraConfig(target_modules=["lin0"], trainable_token_indices={"emb": [0, 1, 2]}) - peft_model = get_peft_model(base_model, peft_config) + # set embed_scale to an absurdly high value, then check that the embedding output is also scaled to a high + # value + orig_embedding.embed_scale.fill_(10000.0) + max_embedding_output = peft_embedding(torch.arange(10)).abs().max(0)[0] + assert (max_embedding_output > 100.0).all() - # Verify embed_scale detection works in combined mode - peft_embeddings = peft_model.base_model.model.emb(x) - assert torch.allclose(base_embeddings, peft_embeddings, atol=1e-6) + # set embed_scale to zero, then check that the embedding output is also zero + orig_embedding.embed_scale.fill_(0) + embedding_output = peft_embedding(torch.arange(10)) + assert (embedding_output == 0.0).all() From 236de8a8c9ef43c0b77232568fcbc3943b39c013 Mon Sep 17 00:00:00 2001 From: sambhavnoobcoder Date: Mon, 13 Oct 2025 15:19:19 +0530 Subject: [PATCH 6/9] Run make style - format test_scaled_embedding_with_lora --- tests/test_trainable_tokens.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/test_trainable_tokens.py b/tests/test_trainable_tokens.py index a531d3d2c3..cdd083a451 100644 --- a/tests/test_trainable_tokens.py +++ b/tests/test_trainable_tokens.py @@ -958,9 +958,7 @@ def test_scaled_embedding_with_lora(self): orig_embedding = base_model.get_input_embeddings() # Apply both TrainableTokens and LoRA to the same model - peft_config = LoraConfig( - target_modules=["q_proj"], trainable_token_indices={"embed_tokens": [0, 1, 3]} - ) + peft_config = LoraConfig(target_modules=["q_proj"], trainable_token_indices={"embed_tokens": [0, 1, 3]}) peft_model = get_peft_model(base_model, peft_config) # sanity check: with the default embed_scale, the embedding output should be reasonably sized From b3d815011e55742a7449b4bc1617a44b61c8b694 Mon Sep 17 00:00:00 2001 From: sambhavnoobcoder Date: Mon, 13 Oct 2025 15:36:37 +0530 Subject: [PATCH 7/9] style fix --- tests/test_trainable_tokens.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/tests/test_trainable_tokens.py b/tests/test_trainable_tokens.py index cdd083a451..397bf5495a 100644 --- a/tests/test_trainable_tokens.py +++ b/tests/test_trainable_tokens.py @@ -946,12 +946,7 @@ def test_embed_scale_is_applied(self): assert (embedding_output == 0.0).all() def test_scaled_embedding_with_lora(self): - """ - Test that TrainableTokens works with LoRA on scaled embeddings when both are active simultaneously. - - This test uses a real model (Gemma3) and verifies that both TrainableTokens and LoRA correctly apply - embed_scale when used together. - """ + """Test that TrainableTokens works with LoRA on scaled embeddings when both are active simultaneously.""" model_id = "hf-internal-testing/tiny-random-Gemma3ForCausalLM" with hub_online_once(model_id): base_model = AutoModelForCausalLM.from_pretrained(model_id) From 8d72a99c69aeb63b1a6ccadd7009122aee908779 Mon Sep 17 00:00:00 2001 From: sambhavnoobcoder Date: Mon, 13 Oct 2025 16:00:56 +0530 Subject: [PATCH 8/9] some doc builder changes --- src/peft/tuners/tuners_utils.py | 4 ++-- tests/test_trainable_tokens.py | 4 +++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/peft/tuners/tuners_utils.py b/src/peft/tuners/tuners_utils.py index a7deeac571..bd2d801fb8 100644 --- a/src/peft/tuners/tuners_utils.py +++ b/src/peft/tuners/tuners_utils.py @@ -1207,8 +1207,8 @@ def _get_embed_scale(self): Extract embed_scale from base layer if present and valid. Some embedding layers (e.g., Gemma3TextScaledWordEmbedding) apply scaling to embeddings in their forward - method. This method checks for the presence of an `embed_scale` attribute. If it exists, it is assumed to be - a scalar. Its shape is validated accordingly. + method. This method checks for the presence of an `embed_scale` attribute. If it exists, it is assumed to be a + scalar. Its shape is validated accordingly. Returns: torch.Tensor or None: The embed_scale tensor if found and valid, None otherwise. diff --git a/tests/test_trainable_tokens.py b/tests/test_trainable_tokens.py index 397bf5495a..a197a89f9f 100644 --- a/tests/test_trainable_tokens.py +++ b/tests/test_trainable_tokens.py @@ -946,7 +946,9 @@ def test_embed_scale_is_applied(self): assert (embedding_output == 0.0).all() def test_scaled_embedding_with_lora(self): - """Test that TrainableTokens works with LoRA on scaled embeddings when both are active simultaneously.""" + """ + Test that TrainableTokens works with LoRA on scaled embeddings when both are active simultaneously. + """ model_id = "hf-internal-testing/tiny-random-Gemma3ForCausalLM" with hub_online_once(model_id): base_model = AutoModelForCausalLM.from_pretrained(model_id) From eb609b8db7a440182dc3f8b27e90443b0f5b4adc Mon Sep 17 00:00:00 2001 From: sambhavnoobcoder Date: Mon, 13 Oct 2025 19:16:50 +0530 Subject: [PATCH 9/9] test fixes --- tests/test_decoder_models.py | 15 ++++++++++----- tests/test_trainable_tokens.py | 13 +++++++++---- 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/tests/test_decoder_models.py b/tests/test_decoder_models.py index fb825378be..99fa9cdb8c 100644 --- a/tests/test_decoder_models.py +++ b/tests/test_decoder_models.py @@ -800,26 +800,31 @@ def test_lora_embed_scale_is_applied(self): """Test that LoRA correctly handles embeddings with scaling (e.g., Gemma3).""" model_id = "hf-internal-testing/tiny-random-Gemma3ForCausalLM" with hub_online_once(model_id): - base_model = AutoModelForCausalLM.from_pretrained(model_id) + base_model = AutoModelForCausalLM.from_pretrained(model_id).to(self.torch_device) orig_embedding = base_model.get_input_embeddings() peft_config = LoraConfig(target_modules=["embed_tokens"], init_lora_weights=False) peft_model = get_peft_model(base_model, peft_config) - # sanity check: with the default embed_scale, the embedding output should be reasonably sized + x = torch.arange(10).to(self.torch_device) peft_embedding = peft_model.base_model.model.get_input_embeddings() - max_embedding_output = peft_embedding(torch.arange(10)).abs().max(0)[0] + embedding_output = peft_embedding(x) + max_embedding_output = embedding_output.abs().max(0)[0] assert (max_embedding_output < 100.0).all() + peft_model.merge_adapter() + embedding_merged = peft_embedding(x) + assert torch.allclose(embedding_output, embedding_merged) + peft_model.unmerge_adapter() # set embed_scale to an absurdly high value, then check that the embedding output is also scaled to a high # value orig_embedding.embed_scale.fill_(10000.0) - max_embedding_output = peft_embedding(torch.arange(10)).abs().max(0)[0] + max_embedding_output = peft_embedding(x).abs().max(0)[0] assert (max_embedding_output > 100.0).all() # set embed_scale to zero, then check that the embedding output is also zero orig_embedding.embed_scale.fill_(0) - embedding_output = peft_embedding(torch.arange(10)) + embedding_output = peft_embedding(x) assert (embedding_output == 0.0).all() def test_lora_embed_scale_is_applied_mixed_batch(self): diff --git a/tests/test_trainable_tokens.py b/tests/test_trainable_tokens.py index a197a89f9f..a642fe54c6 100644 --- a/tests/test_trainable_tokens.py +++ b/tests/test_trainable_tokens.py @@ -958,18 +958,23 @@ def test_scaled_embedding_with_lora(self): peft_config = LoraConfig(target_modules=["q_proj"], trainable_token_indices={"embed_tokens": [0, 1, 3]}) peft_model = get_peft_model(base_model, peft_config) - # sanity check: with the default embed_scale, the embedding output should be reasonably sized + x = torch.arange(10) peft_embedding = peft_model.base_model.model.get_input_embeddings() - max_embedding_output = peft_embedding(torch.arange(10)).abs().max(0)[0] + embedding_output = peft_embedding(x) + max_embedding_output = embedding_output.abs().max(0)[0] assert (max_embedding_output < 100.0).all() + peft_model.merge_adapter() + embedding_merged = peft_embedding(x) + assert torch.allclose(embedding_output, embedding_merged) + peft_model.unmerge_adapter() # set embed_scale to an absurdly high value, then check that the embedding output is also scaled to a high # value orig_embedding.embed_scale.fill_(10000.0) - max_embedding_output = peft_embedding(torch.arange(10)).abs().max(0)[0] + max_embedding_output = peft_embedding(x).abs().max(0)[0] assert (max_embedding_output > 100.0).all() # set embed_scale to zero, then check that the embedding output is also zero orig_embedding.embed_scale.fill_(0) - embedding_output = peft_embedding(torch.arange(10)) + embedding_output = peft_embedding(x) assert (embedding_output == 0.0).all()