Skip to content

Conversation

sambhavnoobcoder
Copy link
Contributor

Fix: Add support for scaled embeddings (Gemma3TextScaledWordEmbedding) in TrainableTokens and LoRA

Fixes #2809
 

Original Issue

 
TrainableTokensModel and LoRA don't handle embedding scaling correctly for models like Gemma3 that use scaled embeddings.
 
The Problem:

  • Gemma3 uses Gemma3TextScaledWordEmbedding which extends nn.Embedding and applies scaling (embed_scale) in its forward method
  • TrainableTokensLayer bypasses the custom forward logic by calling F.embedding() directly, skipping the scaling step
  • LoRA applies adapter modifications but doesn't scale them to match the base embeddings
  • This results in magnitude mismatches, high gradient norms, and loss issues during training
     

 

Problem Identification

 

Root Cause Analysis

 

  1. TrainableTokens Issue:
  • TrainableTokensLayer uses isinstance(layer, nn.Embedding) check which passes for Gemma3TextScaledWordEmbedding (it's a subclass)
  • However, it calls F.embedding() directly instead of base_layer.forward(), bypassing Gemma3's custom scaling logic
  • Result: Embeddings are unscaled
     
  1. LoRA Issue:
  • LoRA calls base_layer(x) which correctly applies Gemma3's scaling to base embeddings
  • However, LoRA adapter contributions (A @ B) * scaling are NOT multiplied by embed_scale
  • Result: Base embeddings are scaled (~33x larger for Gemma3), but adapter deltas are not, creating magnitude mismatch
     
     

 

Implementation Details

 

Changes Made

 

1. TrainableTokensLayer (src/peft/tuners/trainable_tokens/layer.py)

 
Added helper method _get_embed_scale():

  • Detects embed_scale attribute on base layer
  • Validates it's a scalar or 1-element tensor
  • Logs warning if shape is invalid
  • Returns None if not found or invalid
     
    Modified forward_adapters() method:
  • After calling F.embedding(), checks for embed_scale
  • Applies scaling: result = result * embed_scale.to(result.dtype)
  • Only applies when embed_scale is detected and valid
     

2. LoRA Embedding (src/peft/tuners/lora/layer.py)

 
Added helper method _get_embed_scale():

  • Same logic as TrainableTokensLayer for consistency
     
    Modified forward() method:
  • Detects embed_scale once before processing adapters
  • For vanilla LoRA: scales adapter output before adding to result
  • Applies: adapter_output = adapter_output * embed_scale.to(adapter_output.dtype)
     
    Modified _mixed_batch_forward() method:
  • Handles mixed batch inference with multiple adapters
  • Applies same scaling logic to each adapter's contribution
     

3. Tests (tests/test_trainable_tokens.py)

 
Added test_scaled_embedding_support():

  • Creates mock ScaledEmbedding class mimicking Gemma3's behavior
  • Tests TrainableTokens standalone with scaled embeddings
  • Verifies scaling is detected and applied correctly
  • Checks modified and unmodified tokens separately
     
    Added test_scaled_embedding_with_lora():
  • Tests combined LoRA + TrainableTokens scenario
  • Verifies embed_scale detection works in combined mode
  • Ensures no regression when not targeting embeddings directly
     

 

Reproducer

 
The following reproducer demonstrates the fix working correctly with Gemma3-like scaled embeddings:
 

"""
Reproduction script from issue #2809 to verify the fix.
This tests that TrainableTokensModel correctly handles Gemma3's embed_scale.
"""
import torch
import torch.nn as nn
 
# Mock Gemma3TextScaledWordEmbedding since we can't load the actual model easily
class Gemma3TextScaledWordEmbedding(nn.Embedding):
"""Mock implementation matching transformers' Gemma3TextScaledWordEmbedding"""
def __init__(self, num_embeddings, embedding_dim, padding_idx=None, embed_scale=1.0):
super().__init__(num_embeddings, embedding_dim, padding_idx)
self.register_buffer("embed_scale", torch.tensor(embed_scale), persistent=False)
 
def forward(self, input_ids):
return super().forward(input_ids) * self.embed_scale.to(self.weight.dtype)
 
 
class MockGemmaModel(nn.Module):
"""Mock model structure similar to Gemma3"""
def __init__(self, vocab_size=256, hidden_size=64):
super().__init__()
# Use scaled embedding like Gemma3 does
embed_scale = hidden_size ** 0.5  # sqrt(hidden_size)
self.embed_tokens = Gemma3TextScaledWordEmbedding(
vocab_size, hidden_size, embed_scale=embed_scale
)
self.linear = nn.Linear(hidden_size, hidden_size)
 
def forward(self, input_ids):
return self.linear(self.embed_tokens(input_ids))
 
 
def test_trainable_tokens_with_scaled_embedding():
"""Test from issue #2809"""
from peft import TrainableTokensModel, TrainableTokensConfig
 
print("=" * 60)
print("Testing TrainableTokensModel with scaled embeddings")
print("=" * 60)
 
device = "cpu"
 
# Create ONE model then copy it to ensure same initial weights
base_model = MockGemmaModel()
model1 = base_model
 
# Deep copy for model2
import copy
model2 = copy.deepcopy(base_model)
 
# Apply TrainableTokens to model2
config = TrainableTokensConfig(
target_modules=["embed_tokens"],
token_indices=[1, 2],
init_weights=True,
)
 
model2 = TrainableTokensModel(model2, config, "trainable_tokens")
 
x = torch.tensor([[1, 2]], dtype=torch.long)
 
# Check if embed_scale is detected
has_method = hasattr(model2.model.embed_tokens, "_get_embed_scale")
print(f"\nHas _get_embed_scale method: {has_method}")
 
if has_method:
detected_scale = model2.model.embed_tokens._get_embed_scale()
print(f"Detected embed_scale: {detected_scale}")
 
# Get embeddings from both models
y1 = model1.embed_tokens(x)
y2 = model2.model.embed_tokens(x)
 
print(f"\nEmbed scale value: {model1.embed_tokens.embed_scale.item():.4f}")
print(f"\nOriginal embeddings (scaled by Gemma3):\n{y1}")
print(f"\nTrainableTokens embeddings (should also be scaled):\n{y2}")
 
# Before the fix, y2 would not be scaled
# After the fix, y2 should be scaled
print(f"\nManually scaled embeddings (what user had to do before fix):\n{y2 * model1.embed_tokens.embed_scale}")
 
# Verify they match (within numerical precision)
if torch.allclose(y1, y2, atol=1e-5):
print("\n✅ SUCCESS: Embeddings are correctly scaled!")
print("   TrainableTokens now handles embed_scale properly.")
return True
else:
print("\n❌ FAILURE: Embeddings are NOT correctly scaled!")
print(f"   Max difference: {(y1 - y2).abs().max().item():.6f}")
return False
 
 
def test_lora_with_scaled_embedding():
"""Test LoRA with scaled embeddings"""
from peft import LoraConfig, get_peft_model
 
print("\n" + "=" * 60)
print("Testing LoRA with scaled embeddings")
print("=" * 60)
 
model = MockGemmaModel()
 
config = LoraConfig(
r=4,
lora_alpha=16,
target_modules=["linear"],
lora_dropout=0.1,
)
 
# Get base embeddings
x = torch.tensor([[1, 2, 3]], dtype=torch.long)
base_embeddings = model.embed_tokens(x)
 
# Apply LoRA (doesn't target embeddings, but embeddings should still work)
peft_model = get_peft_model(model, config)
peft_embeddings = peft_model.base_model.model.embed_tokens(x)
 
print(f"\nBase embeddings:\n{base_embeddings}")
print(f"\nPEFT embeddings:\n{peft_embeddings}")
 
if torch.allclose(base_embeddings, peft_embeddings, atol=1e-5):
print("\n✅ SUCCESS: LoRA preserves embedding scaling!")
return True
else:
print("\n❌ FAILURE: LoRA doesn't preserve embedding scaling!")
print(f"   Max difference: {(base_embeddings - peft_embeddings).abs().max().item():.6f}")
return False
 
 
if __name__ == "__main__":
import sys
 
# Run tests
results = []
 
try:
results.append(test_trainable_tokens_with_scaled_embedding())
except Exception as e:
print(f"\n❌ ERROR in TrainableTokens test: {e}")
import traceback
traceback.print_exc()
results.append(False)
 
try:
results.append(test_lora_with_scaled_embedding())
except Exception as e:
print(f"\n❌ ERROR in LoRA test: {e}")
import traceback
traceback.print_exc()
results.append(False)
 
# Summary
print("\n" + "=" * 60)
print("SUMMARY")
print("=" * 60)
print(f"TrainableTokens test: {'PASSED ✅' if results[0] else 'FAILED ❌'}")
if len(results) > 1:
print(f"LoRA test: {'PASSED ✅' if results[1] else 'FAILED ❌'}")
 
if all(results):
print("\n🎉 All tests passed! Issue #2809 is fixed.")
sys.exit(0)
else:
print("\n⚠️  Some tests failed. Please review the implementation.")
sys.exit(1)

 

Test Results

 

New Tests Added

  1. test_scaled_embedding_support - Tests TrainableTokens with scaled embeddings ✅
  2. test_scaled_embedding_with_lora - Tests combined LoRA + TrainableTokens ✅
     

Regression Tests

  • All 37 TrainableTokens tests: ✅ PASSED (no regression)
  • LoRA embedding initialization test: ✅ PASSED (no regression)
     

Test Commands

 

# Run new scaled embedding tests
python3 -m pytest tests/test_trainable_tokens.py::TestTrainableTokens::test_scaled_embedding_support -xvs
python3 -m pytest tests/test_trainable_tokens.py::TestTrainableTokens::test_scaled_embedding_with_lora -xvs
 
# Run all TrainableTokens tests (regression check)
python3 -m pytest tests/test_trainable_tokens.py -x
 
# Run LoRA embedding test (regression check)
python3 -m pytest tests/test_initialization.py::TestLoraInitialization::test_lora_embedding_default -xvs

 

 

Screenshots

 
image

image image image

============================================================
--
Testing TrainableTokensModel with scaled embeddings
============================================================
 
Has _get_embed_scale method: True
Detected embed_scale: 8.0
 
Embed scale value: 8.0000
 
Original embeddings (scaled by Gemma3):
tensor([[[  2.3135,  -5.1571,  -9.9790,   9.7044,  12.0080,   8.2798,  -0.3712,
-10.3559,  10.5131,   3.4887,  -0.2534,   2.8968,   6.3321,   0.1154,
-8.9584,   0.0839, -12.6655,   0.2021,  -2.6558,  -1.1641,  -9.0970,
7.7728,  10.2128, -11.9482,  10.5140,  11.8384, -18.5900,  -5.0858,
2.2456,   1.1932,  11.8447,  21.2550,  19.5655,  -3.2657,   2.3217,
-6.0656,  11.8199,  -1.2791,  -1.2594,   4.5504,   9.0352,   1.0631,
5.0689,   4.7381,   9.2891,  10.7626,  -1.9343,  -6.3256,  14.5141,
-2.1795,  -3.5153,  -0.0464, -10.5317,  -3.1889,   7.6621,  -0.2444,
2.0816,  -8.4987, -10.5328,  10.7994,   1.0377,  -1.4904,   6.9351,
8.2466],
[ -6.1600,   0.9357,   1.8082,   0.2661,  -4.0639, -15.9327,  -0.6653,
-4.7879,   7.8478,  -0.7199,  10.3102, -16.4766,  11.9688,  -3.2546,
0.4812,  10.2663,  -1.8529,   8.3749,  -4.3498,  -1.3125,  -5.9023,
-6.0639,   0.1020,  -8.5658, -13.0700, -12.9106,   9.5254,   5.2503,
-5.7898,  -1.9031, -13.6860,   0.3896,  -1.3864,   1.3240,  -4.0550,
7.7994,  -0.2485,  -6.3763,   4.3981,  -9.4430,   3.6809,   4.0087,
7.0423,  -2.6473,   7.4145,   1.4737,  -2.3461,   1.8703,  -5.2413,
9.5990,   2.7086,  -1.5625,   1.7473,  -6.5087,   3.3355,   5.4155,
-2.9019,  -5.7906, -12.2428,   5.5004,  -0.3299,  -0.8121,  -4.0469,
-6.9121]]], grad_fn=<MulBackward0>)
 
TrainableTokens embeddings (should also be scaled):
tensor([[[  2.3135,  -5.1571,  -9.9790,   9.7044,  12.0080,   8.2798,  -0.3712,
-10.3559,  10.5131,   3.4887,  -0.2534,   2.8968,   6.3321,   0.1154,
-8.9584,   0.0839, -12.6655,   0.2021,  -2.6558,  -1.1641,  -9.0970,
7.7728,  10.2128, -11.9482,  10.5140,  11.8384, -18.5900,  -5.0858,
2.2456,   1.1932,  11.8447,  21.2550,  19.5655,  -3.2657,   2.3217,
-6.0656,  11.8199,  -1.2791,  -1.2594,   4.5504,   9.0352,   1.0631,
5.0689,   4.7381,   9.2891,  10.7626,  -1.9343,  -6.3256,  14.5141,
-2.1795,  -3.5153,  -0.0464, -10.5317,  -3.1889,   7.6621,  -0.2444,
2.0816,  -8.4987, -10.5328,  10.7994,   1.0377,  -1.4904,   6.9351,
8.2466],
[ -6.1600,   0.9357,   1.8082,   0.2661,  -4.0639, -15.9327,  -0.6653,
-4.7879,   7.8478,  -0.7199,  10.3102, -16.4766,  11.9688,  -3.2546,
0.4812,  10.2663,  -1.8529,   8.3749,  -4.3498,  -1.3125,  -5.9023,
-6.0639,   0.1020,  -8.5658, -13.0700, -12.9106,   9.5254,   5.2503,
-5.7898,  -1.9031, -13.6860,   0.3896,  -1.3864,   1.3240,  -4.0550,
7.7994,  -0.2485,  -6.3763,   4.3981,  -9.4430,   3.6809,   4.0087,
7.0423,  -2.6473,   7.4145,   1.4737,  -2.3461,   1.8703,  -5.2413,
9.5990,   2.7086,  -1.5625,   1.7473,  -6.5087,   3.3355,   5.4155,
-2.9019,  -5.7906, -12.2428,   5.5004,  -0.3299,  -0.8121,  -4.0469,
-6.9121]]], grad_fn=<MulBackward0>)
 
Manually scaled embeddings (what user had to do before fix):
tensor([[[  18.5078,  -41.2570,  -79.8316,   77.6354,   96.0641,   66.2384,
-2.9693,  -82.8473,   84.1046,   27.9093,   -2.0269,   23.1747,
50.6568,    0.9228,  -71.6672,    0.6711, -101.3236,    1.6166,
-21.2464,   -9.3126,  -72.7759,   62.1827,   81.7021,  -95.5859,
84.1120,   94.7074, -148.7200,  -40.6862,   17.9646,    9.5455,
94.7580,  170.0399,  156.5238,  -26.1256,   18.5733,  -48.5247,
94.5593,  -10.2331,  -10.0755,   36.4033,   72.2820,    8.5049,
40.5513,   37.9049,   74.3127,   86.1005,  -15.4744,  -50.6045,
116.1132,  -17.4360,  -28.1225,   -0.3715,  -84.2536,  -25.5113,
61.2966,   -1.9551,   16.6531,  -67.9900,  -84.2626,   86.3953,
8.3019,  -11.9230,   55.4806,   65.9728],
[ -49.2799,    7.4854,   14.4654,    2.1285,  -32.5112, -127.4619,
-5.3222,  -38.3028,   62.7824,   -5.7596,   82.4814, -131.8128,
95.7508,  -26.0372,    3.8498,   82.1308,  -14.8231,   66.9995,
-34.7985,  -10.5001,  -47.2185,  -48.5110,    0.8158,  -68.5266,
-104.5601, -103.2851,   76.2029,   42.0023,  -46.3183,  -15.2250,
-109.4877,    3.1167,  -11.0914,   10.5923,  -32.4401,   62.3951,
-1.9878,  -51.0107,   35.1850,  -75.5438,   29.4473,   32.0696,
56.3382,  -21.1781,   59.3164,   11.7894,  -18.7688,   14.9623,
-41.9306,   76.7918,   21.6684,  -12.5001,   13.9783,  -52.0695,
26.6843,   43.3242,  -23.2156,  -46.3251,  -97.9424,   44.0033,
-2.6389,   -6.4971,  -32.3748,  -55.2965]]],
grad_fn=<MulBackward0>)
 
✅ SUCCESS: Embeddings are correctly scaled!
TrainableTokens now handles embed_scale properly.
 
============================================================
Testing LoRA with scaled embeddings
============================================================
 
Base embeddings:
tensor([[[-1.1619e+01, -7.6770e+00,  2.4871e+00,  1.5184e+00,  9.3255e+00,
-9.1187e+00, -4.6751e+00, -1.8480e+00, -1.5874e+00,  2.3816e+00,
7.1693e-01, -1.0290e+01,  6.4172e+00,  2.7438e+00,  7.0511e+00,
-2.6599e-01, -5.1046e+00,  1.8711e+00, -1.7994e+01, -9.0639e-01,
-3.9938e+00,  1.6113e+00,  5.1933e+00, -7.3137e-02, -1.0448e+01,
8.8850e+00, -1.9851e+01, -1.4340e+01, -7.5584e+00, -6.1629e-01,
-5.3136e+00, -1.2265e+00,  3.7959e+00, -2.2346e+00,  4.2535e+00,
-1.0668e+00,  6.8172e+00, -8.1645e+00,  2.3845e+00,  1.4588e+00,
5.3160e+00, -4.0073e+00,  7.9125e+00,  3.3760e+00, -4.8277e+00,
1.3364e+01,  1.4134e+00, -1.1019e+01, -1.7110e+01, -1.6476e+00,
-4.2495e+00,  2.0266e-02,  9.8023e+00, -7.1404e-01,  7.7452e+00,
-8.8765e+00,  6.1417e+00,  8.2545e+00, -6.2551e+00, -1.2192e+01,
-5.9086e+00, -6.8317e+00, -8.5168e+00,  1.2764e+01],
[-2.2004e+00, -6.8917e+00, -1.4325e+01, -1.3367e+00, -1.0980e+00,
1.8232e+00, -1.2623e+00, -8.9399e-01,  1.0697e+01, -1.8785e+00,
2.1892e+00, -4.1479e-01,  1.2225e+01, -1.6505e+01,  2.1135e+00,
-1.3094e+01,  4.9711e+00, -2.8507e+00,  1.5153e+00, -1.0873e+01,
8.1325e+00, -6.3725e+00, -7.1367e+00, -3.0327e+00, -2.9200e+00,
1.1772e+00,  3.0027e+00,  7.7454e-01,  1.5383e+01,  6.9788e+00,
8.5045e+00, -1.0107e+01, -8.3515e+00, -1.3216e+01, -1.4397e+01,
4.8454e+00, -4.9102e+00,  5.0223e-01,  2.8494e+00,  9.1550e+00,
-4.1965e+00,  3.7793e+00,  2.2246e+00, -4.4636e+00, -8.7864e+00,
9.4401e+00, -8.6393e+00,  2.9253e+00, -2.1723e+00,  2.3158e+01,
7.6899e+00,  5.9682e+00,  1.1054e+01, -6.6642e-01, -1.0394e+01,
-9.2163e+00,  9.9434e+00, -4.9281e+00,  9.2706e+00,  7.7051e+00,
1.3008e+01, -1.2389e+01, -6.6812e+00,  1.4681e+01],
[-2.6244e+01,  6.3272e+00, -1.4973e+01,  4.7564e+00,  2.8202e+00,
-6.7871e-02, -1.1528e+00,  5.5271e-01, -5.5602e+00, -7.2202e+00,
-5.5272e+00,  3.4697e+00,  5.9348e+00,  4.2910e+00,  5.1053e+00,
6.8088e+00,  8.2540e+00, -7.9782e+00,  3.0512e+00, -4.9857e+00,
-1.3644e+00,  1.1233e+00,  5.1235e+00,  4.6313e+00, -5.8353e+00,
8.4664e+00, -6.6549e+00,  5.7793e+00, -7.8023e-01, -7.4949e-01,
-5.2423e+00,  9.3710e-01, -1.1767e+01, -1.1166e+01, -8.9155e+00,
2.7016e+00, -1.1784e+00, -8.3448e+00,  4.1471e+00,  3.8121e+00,
-4.4141e+00, -5.1789e+00, -2.8566e+00,  6.1742e+00,  1.8608e+01,
-4.5141e+00, -3.5987e+00, -1.7266e+00, -1.7816e+00,  8.5050e+00,
3.8157e+00,  9.4791e+00,  1.5180e+01,  8.5696e+00, -9.7263e+00,
2.5824e+00,  5.4181e+00, -4.5006e-01,  7.9345e+00, -4.2691e+00,
-1.4233e+00,  1.4307e+01, -1.2801e+00, -8.2278e+00]]],
grad_fn=<MulBackward0>)
 
PEFT embeddings:
tensor([[[-1.1619e+01, -7.6770e+00,  2.4871e+00,  1.5184e+00,  9.3255e+00,
-9.1187e+00, -4.6751e+00, -1.8480e+00, -1.5874e+00,  2.3816e+00,
7.1693e-01, -1.0290e+01,  6.4172e+00,  2.7438e+00,  7.0511e+00,
-2.6599e-01, -5.1046e+00,  1.8711e+00, -1.7994e+01, -9.0639e-01,
-3.9938e+00,  1.6113e+00,  5.1933e+00, -7.3137e-02, -1.0448e+01,
8.8850e+00, -1.9851e+01, -1.4340e+01, -7.5584e+00, -6.1629e-01,
-5.3136e+00, -1.2265e+00,  3.7959e+00, -2.2346e+00,  4.2535e+00,
-1.0668e+00,  6.8172e+00, -8.1645e+00,  2.3845e+00,  1.4588e+00,
5.3160e+00, -4.0073e+00,  7.9125e+00,  3.3760e+00, -4.8277e+00,
1.3364e+01,  1.4134e+00, -1.1019e+01, -1.7110e+01, -1.6476e+00,
-4.2495e+00,  2.0266e-02,  9.8023e+00, -7.1404e-01,  7.7452e+00,
-8.8765e+00,  6.1417e+00,  8.2545e+00, -6.2551e+00, -1.2192e+01,
-5.9086e+00, -6.8317e+00, -8.5168e+00,  1.2764e+01],
[-2.2004e+00, -6.8917e+00, -1.4325e+01, -1.3367e+00, -1.0980e+00,
1.8232e+00, -1.2623e+00, -8.9399e-01,  1.0697e+01, -1.8785e+00,
2.1892e+00, -4.1479e-01,  1.2225e+01, -1.6505e+01,  2.1135e+00,
-1.3094e+01,  4.9711e+00, -2.8507e+00,  1.5153e+00, -1.0873e+01,
8.1325e+00, -6.3725e+00, -7.1367e+00, -3.0327e+00, -2.9200e+00,
1.1772e+00,  3.0027e+00,  7.7454e-01,  1.5383e+01,  6.9788e+00,
8.5045e+00, -1.0107e+01, -8.3515e+00, -1.3216e+01, -1.4397e+01,
4.8454e+00, -4.9102e+00,  5.0223e-01,  2.8494e+00,  9.1550e+00,
-4.1965e+00,  3.7793e+00,  2.2246e+00, -4.4636e+00, -8.7864e+00,
9.4401e+00, -8.6393e+00,  2.9253e+00, -2.1723e+00,  2.3158e+01,
7.6899e+00,  5.9682e+00,  1.1054e+01, -6.6642e-01, -1.0394e+01,
-9.2163e+00,  9.9434e+00, -4.9281e+00,  9.2706e+00,  7.7051e+00,
1.3008e+01, -1.2389e+01, -6.6812e+00,  1.4681e+01],
[-2.6244e+01,  6.3272e+00, -1.4973e+01,  4.7564e+00,  2.8202e+00,
-6.7871e-02, -1.1528e+00,  5.5271e-01, -5.5602e+00, -7.2202e+00,
-5.5272e+00,  3.4697e+00,  5.9348e+00,  4.2910e+00,  5.1053e+00,
6.8088e+00,  8.2540e+00, -7.9782e+00,  3.0512e+00, -4.9857e+00,
-1.3644e+00,  1.1233e+00,  5.1235e+00,  4.6313e+00, -5.8353e+00,
8.4664e+00, -6.6549e+00,  5.7793e+00, -7.8023e-01, -7.4949e-01,
-5.2423e+00,  9.3710e-01, -1.1767e+01, -1.1166e+01, -8.9155e+00,
2.7016e+00, -1.1784e+00, -8.3448e+00,  4.1471e+00,  3.8121e+00,
-4.4141e+00, -5.1789e+00, -2.8566e+00,  6.1742e+00,  1.8608e+01,
-4.5141e+00, -3.5987e+00, -1.7266e+00, -1.7816e+00,  8.5050e+00,
3.8157e+00,  9.4791e+00,  1.5180e+01,  8.5696e+00, -9.7263e+00,
2.5824e+00,  5.4181e+00, -4.5006e-01,  7.9345e+00, -4.2691e+00,
-1.4233e+00,  1.4307e+01, -1.2801e+00, -8.2278e+00]]])
 
✅ SUCCESS: LoRA preserves embedding scaling!
 
============================================================
SUMMARY
============================================================
TrainableTokens test: PASSED ✅
LoRA test: PASSED ✅
 
🎉 All tests passed! Issue #2809 is fixed.



 

Checklist

 

  • Added _get_embed_scale() helper method to TrainableTokensLayer
  • Modified TrainableTokensLayer to apply scaling in forward pass
  • Added _get_embed_scale() helper method to LoRA Embedding
  • Modified LoRA Embedding to apply scaling in forward pass
  • Modified LoRA Embedding to apply scaling in mixed batch forward
  • Added comprehensive tests for scaled embeddings
  • All existing tests pass (no regression)
  • Reproducer from issue TrainableTokensModel doesn't handle embedding scaling #2809 works correctly
  • Shape validation with helpful warnings
  • Non-breaking: code without embed_scale works unchanged
  • Documentation: clear comments explaining the implementation
     

 

Related Issues

 

@sambhavnoobcoder sambhavnoobcoder changed the title Handling embeddings scaling for TrainableTokensModel Handling embeddings scaling for TrainableTokensModel #2809 Oct 9, 2025
Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for taking this issue and providing a PR to fix it. I especially like the detailed descriptions of the steps you took and how you checked the results.

Regarding the implementation, it already looks quite good, I only added a few small comments. Regarding the testing, I proposed a different approach, please check if it makes sense. Furthermore, please add tests for LoRA as well, not only for trainable tokens. These tests can be added to test_decoder_models.py (only for LoRA and Gemma3, no need to parametrize the test). There should be one test for the normal forward and one for the mixed batch forward (good job spotting that it needs to be updated too).

Searching through the PEFT code base, I think XLoraEmbeddingLayer and DoraEmbeddingLayer need to be updated too, but feel free to skip those for this PR. Especially for DoRA, I'm not even quite sure how (or if) the scale should be applied when calculating the weight norm.

Extract embed_scale from base layer if present and valid.
Some embedding layers (e.g., Gemma3TextScaledWordEmbedding) apply scaling to embeddings in their forward
method. This method checks for the presence of an `embed_scale` attribute and validates its shape.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's mention that if it exists, it is assumed to be a scalar. From my search through the transformers code base, this is always the case right now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resolved in 0d216c7


return DoraEmbeddingVariant()

def _get_embed_scale(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's factor this function out so that it can be reused between the two PEFT methods. It should be fine to put it on the BaseTunerLayer class. This also eliminates the small inconsistency between the two implementations you have.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resolved in 0d216c7

f"Found embed_scale attribute with shape {embed_scale.shape}, expected scalar. "
"Embedding scaling will not be applied. If this is unexpected, please open an issue at "
"https://github.com/huggingface/peft/issues",
UserWarning,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For LoRA, you use PeftWarning. I think that's the better choice.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resolved in 0d216c7

x = torch.tensor([[0, 1, 2, 3]])

# Get outputs from base model and peft model
base_output = base_model(x)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The base_output needs to be determined before calling get_peft_model, since this function will modify the base model in-place. In this particular case, it doesn't matter, but as a general precaution, it's better to do it correctly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resolved in 0d216c7

peft_embeddings = peft_model.model.emb(x)

# Check that both apply the same scaling factor
assert hasattr(peft_model.model.emb, "_get_embed_scale")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can be removed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resolved in 0d216c7

else:
assert not contains_embedding

def test_scaled_embedding_support(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for including tests for the correct application of embed_scale. However, I think the test is not focused enough to test its goal. To prove it, you can change _get_embed_scale to return a fixed scalar of 1.0 and the test would still pass, except for the torch.allclose(embed_scale, torch.tensor(3.0)) assert, which by itself doesn't check if the scale is correctly applied.

Let me propose a test that I think will be more focused. It uses a real (but small) transformers model checks the outputs of the embedding layer. By carefully varying the embed_scale, we can ensure that it's being applied. Let me know what you think.

    def test_embed_scale_is_applied(self):
        """Test that TrainableTokens correctly handles embeddings with scaling (e.g., Gemma3)."""
        model_id = "hf-internal-testing/tiny-random-Gemma3ForCausalLM"
        with hub_online_once(model_id):
            base_model = AutoModelForCausalLM.from_pretrained(model_id)
            orig_embedding = base_model.get_input_embeddings()

            peft_config = TrainableTokensConfig(target_modules=["embed_tokens"], token_indices=[0, 1, 3])
            peft_model = get_peft_model(base_model, peft_config)

            # sanity check: with the default embed_scale, the embedding output should be reasonably sized
            peft_embedding = peft_model.base_model.model.get_input_embeddings()
            max_embedding_output = peft_embedding(torch.arange(10)).abs().max(0)[0]
            assert (max_embedding_output < 100.0).all()

            # set embed_scale to an absurdly high value, then check that the embedding output is also scaled to a high
            # value
            orig_embedding.embed_scale.fill_(10000.0)
            max_embedding_output = peft_embedding(torch.arange(10)).abs().max(0)[0]
            assert (max_embedding_output > 100.0).all()

            # set embed_scale to zero, then check that the embedding output is also zero
            orig_embedding.embed_scale.fill_(0)
            embedding_output = peft_embedding(torch.arange(10))
            assert (embedding_output == 0.0).all()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resolved in 0d216c7 . implemented this with a mini version of gemma3 for more realistic testing . all tc pass .

@sambhavnoobcoder
Copy link
Contributor Author

Hi @BenjaminBossan . i went through all the changes you suggested and implemented them as well . refactored some existing tests , wrote some new ones , tested them as well , all pass successfully .

About the changes to XLoraEmbeddingLayer and DoraEmbeddingLayer i had come across this issue once when i was planning for this PR , and upon your suggestion , i had a look into it again . i may have some ideas about how both of those could be updated , less so about dora , but i suppose i will create separate issues for them both , tackling them one at a time , and if you agree we can move forward with this PR keeping it independent of the two .

Kindly do have a look at my present changes and tell me if i need to make any more amends , i'll make those as well asap .

Some embedding layers (e.g., Gemma3TextScaledWordEmbedding) apply scaling to embeddings in their forward
method. This method checks for the presence of an `embed_scale` attribute. If it exists, it is assumed to be
a scalar and validates its shape accordingly.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think otherwise the sentence doesn't quite make sense:

Suggested change
a scalar and validates its shape accordingly.
a scalar. Its shape is validated accordingly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cool . changed that in e4ab8b1 commit .

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update, we're almost done. Just some small comments, please check. Also, don't forget to run make style.

About the changes to XLoraEmbeddingLayer and DoraEmbeddingLayer i had come across this issue once when i was planning for this PR , and upon your suggestion , i had a look into it again . i may have some ideas about how both of those could be updated , less so about dora , but i suppose i will create separate issues for them both , tackling them one at a time , and if you agree we can move forward with this PR keeping it independent of the two .

Yes, having those in a separate PR works for me.

assert (embedding_output == 0.0).all()

def test_scaled_embedding_with_lora(self):
"""Test that TrainableTokens works with LoRA on scaled embeddings."""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please explain in a bit more detail what this tests? Would it not be possible to have a similar test setup to the one above here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test validates the combined scenario where both TrainableTokens AND LoRA are active simultaneously ( line 977: trainable_token_indices + target_modules both set). so on one hand in test_embed_scale_is_applied we test TrainableTokens alone on Gemma3 while intest_scaled_embedding_with_lora: TrainableTokens + LoRA together on mock model .

This ensures embed_scale is correctly applied when both PEFT methods interact on the same model.

would you prefer gemma3 instead of the mock model for this test as well ? or update the test_embed_scale_is_applied to use this combined approach instead ? or the third alternative is we keep the tests as they are now with no change ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally, we should have a test that is similar to test_embed_scale_is_applied, using a real model architecture and following similar steps. It would be great if you could rewrite the test accordingly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cool . rewritten in 883b3d3 .
Screenshot 2025-10-13 at 3 13 35 PM

@sambhavnoobcoder
Copy link
Contributor Author

sambhavnoobcoder commented Oct 10, 2025

@BenjaminBossan I’ve gone through the comments and applied the requested fixes. We can discuss and decide how to handle the test_scaled_embedding_with_lora test.

I also ran make style, which removed an unused import from examples/boft_controlnet/test_controlnet.py in commit 2b0b616 . I double-checked, and it seems like the correct change from the ruff check, so I included it here even though the file isn’t directly related to this PR’s logic.

Lastly, I’ll raise a separate issue for XLoraEmbeddingLayer and DoraEmbeddingLayer and address those there.

@sambhavnoobcoder
Copy link
Contributor Author

@BenjaminBossan made that test change in 883b3d3 . ran make style as well , so if all changes are now handled , we can merge this in as well . i'll fix if any other fixes need to go in here .

Comment on lines 951 to 953
This test uses a real model (Gemma3) and verifies that both TrainableTokens and LoRA correctly apply
embed_scale when used together.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
This test uses a real model (Gemma3) and verifies that both TrainableTokens and LoRA correctly apply
embed_scale when used together.

No need to mention this is a real model, we don't explicitly mention it on other tests.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done @BenjaminBossan . changed docstring in b3d8150 .

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did a final pass and found a few smaller changes that would improve the tests. Could you please take a look? The rest looks ready to go.

Comment on lines 809 to 812
# sanity check: with the default embed_scale, the embedding output should be reasonably sized
peft_embedding = peft_model.base_model.model.get_input_embeddings()
max_embedding_output = peft_embedding(torch.arange(10)).abs().max(0)[0]
assert (max_embedding_output < 100.0).all()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I had another idea how to extend the test:

Suggested change
# sanity check: with the default embed_scale, the embedding output should be reasonably sized
peft_embedding = peft_model.base_model.model.get_input_embeddings()
max_embedding_output = peft_embedding(torch.arange(10)).abs().max(0)[0]
assert (max_embedding_output < 100.0).all()
peft_embedding = peft_model.base_model.model.get_input_embeddings()
embedding_output = peft_embedding(torch.arange(10))
max_embedding_output = embedding_output.abs().max(0)[0]
assert (max_embedding_output < 100.0).all()
peft_model.merge_adapter()
embedding_merged = peft_embedding(torch.arange(10))
assert torch.allclose(embedding_output, embedding_merged)
peft_model.unmerge_adapter()

The point is that if the embedding scale is not applied correctly, we would expect results to differ between a merged vs non merged output. Thus, by checking the merged output too, we can better ensure that everything works as expected.

The test for trainable tokens can also be updated to check the merged output. test_lora_embed_scale_is_applied_mixed_batch doesn't need to test this, as this assumes unmerged adapters.

Copy link
Contributor Author

@sambhavnoobcoder sambhavnoobcoder Oct 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yup , cool . implemented this in eb609b8 .

orig_embedding = base_model.get_input_embeddings()

peft_config = LoraConfig(target_modules=["embed_tokens"], init_lora_weights=False)
peft_model = get_peft_model(base_model, peft_config)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's assign x = torch.arange(10).to(self.torch_device) here and use it as input below for better readability.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resolved in eb609b8

"""Test that LoRA correctly handles embeddings with scaling (e.g., Gemma3)."""
model_id = "hf-internal-testing/tiny-random-Gemma3ForCausalLM"
with hub_online_once(model_id):
base_model = AutoModelForCausalLM.from_pretrained(model_id)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
base_model = AutoModelForCausalLM.from_pretrained(model_id)
base_model = AutoModelForCausalLM.from_pretrained(model_id).to(self.torch_device)

So that this test can also run on GPU.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resolved in eb609b8

@sambhavnoobcoder
Copy link
Contributor Author

Great . I had a look at all of the recommended changes , implemented and pushed them in the PR as well . i think this is ready to go now .

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

TrainableTokensModel doesn't handle embedding scaling

3 participants