Skip to content

TrainableTokensModel doesn't handle embedding scaling #2809

@kimihailv

Description

@kimihailv

System Info

peft: 0.17.1
transformers: 4.56.2

Who can help?

@BenjaminBossan

hello, I want to fine-tune some embeddings of Gemma 3. Gemma 3 uses Gemma3TextScaledWordEmbedding inherited from torch.nn.Embedding. The forward method of this class scales the output of embedding layer by some constant (it is stored in model's buffer). The problem is that TrainableTokensLayer relies on isinstance method which can't distinguish between class and its children and because of it skips the custom forward logics of Gemma3TextScaledWordEmbedding. This leads to high gradients norm and loss.

Reproduction

from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import TrainableTokensModel, TrainableTokensConfig
import torch

device = "cuda:0"
model1 = AutoModelForCausalLM.from_pretrained("google/gemma-3-270m", device_map=device)
model2 = AutoModelForCausalLM.from_pretrained("google/gemma-3-270m", device_map=device)
tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-270m")

config = TrainableTokensConfig(
        token_indices=[1, 2],
        init_weights=True,
    )

model2 = TrainableTokensModel(model2, config, "trainable_tokens")

x = torch.tensor([[1, 2]], device=device, dtype=torch.long)

y1 = model1.model.embed_tokens(x)
y2 = model2.model.model.embed_tokens(x)

print("src embeddings:", y1)
print("embeddings with adapter:", y2)

print("scaled embeddings:", y2 * model1.model.embed_tokens.embed_scale)

Expected behavior

Adapter should distinguish between vanilla torch.nn.Embedding layer and custom variants.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions