-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Description
System Info
peft: 0.17.1
transformers: 4.56.2
Who can help?
hello, I want to fine-tune some embeddings of Gemma 3. Gemma 3 uses Gemma3TextScaledWordEmbedding
inherited from torch.nn.Embedding
. The forward method of this class scales the output of embedding layer by some constant (it is stored in model's buffer). The problem is that TrainableTokensLayer
relies on isinstance
method which can't distinguish between class and its children and because of it skips the custom forward logics of Gemma3TextScaledWordEmbedding
. This leads to high gradients norm and loss.
Reproduction
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import TrainableTokensModel, TrainableTokensConfig
import torch
device = "cuda:0"
model1 = AutoModelForCausalLM.from_pretrained("google/gemma-3-270m", device_map=device)
model2 = AutoModelForCausalLM.from_pretrained("google/gemma-3-270m", device_map=device)
tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-270m")
config = TrainableTokensConfig(
token_indices=[1, 2],
init_weights=True,
)
model2 = TrainableTokensModel(model2, config, "trainable_tokens")
x = torch.tensor([[1, 2]], device=device, dtype=torch.long)
y1 = model1.model.embed_tokens(x)
y2 = model2.model.model.embed_tokens(x)
print("src embeddings:", y1)
print("embeddings with adapter:", y2)
print("scaled embeddings:", y2 * model1.model.embed_tokens.embed_scale)
Expected behavior
Adapter should distinguish between vanilla torch.nn.Embedding layer and custom variants.
Metadata
Metadata
Assignees
Labels
No labels