-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Handling embeddings scaling for TrainableTokensModel #2809 #2825
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 2 commits
d6fe7d3
ab63ca0
0d216c7
e4ab8b1
2b0b616
883b3d3
236de8a
b3d8150
8d72a99
eb609b8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -870,6 +870,42 @@ def resolve_lora_variant(self, *, use_dora: bool, **kwargs) -> Optional[LoraVari | |
|
||
return DoraEmbeddingVariant() | ||
|
||
def _get_embed_scale(self): | ||
""" | ||
Extract embed_scale from base layer if present and valid. | ||
|
||
Some embedding layers (e.g., Gemma3TextScaledWordEmbedding) apply scaling to embeddings in their forward | ||
method. This method checks for the presence of an `embed_scale` attribute and validates its shape. | ||
|
||
|
||
Returns: | ||
torch.Tensor or None: The embed_scale tensor if found and valid, None otherwise. | ||
""" | ||
base_layer = self.get_base_layer() | ||
if not hasattr(base_layer, "embed_scale"): | ||
return None | ||
|
||
embed_scale = base_layer.embed_scale | ||
|
||
# Convert scalar values to tensors | ||
if isinstance(embed_scale, (int, float)): | ||
return torch.tensor(embed_scale, device=base_layer.weight.device, dtype=base_layer.weight.dtype) | ||
|
||
# Validate tensor shape - must be scalar (0-d) or 1-element tensor for proper broadcasting | ||
if isinstance(embed_scale, torch.Tensor): | ||
if embed_scale.numel() == 1: | ||
return embed_scale | ||
else: | ||
# Log warning but don't fail - this maintains backward compatibility | ||
warnings.warn( | ||
f"Found embed_scale attribute with shape {embed_scale.shape}, expected scalar. " | ||
"Embedding scaling will not be applied. If this is unexpected, please open an issue at " | ||
"https://github.com/huggingface/peft/issues", | ||
PeftWarning, | ||
) | ||
return None | ||
|
||
return None | ||
|
||
def update_layer( | ||
self, | ||
adapter_name, | ||
|
@@ -1035,6 +1071,10 @@ def _mixed_batch_forward( | |
# extra argument that allows mixing different adapters in the same batch at inference time. | ||
result = self.base_layer(x, *args, **kwargs) | ||
|
||
# Some embedding layers (e.g., Gemma3TextScaledWordEmbedding) apply scaling in their forward method. | ||
# Since base_layer(x) already includes this scaling, we need to apply it to LoRA contributions too. | ||
embed_scale = self._get_embed_scale() | ||
|
||
unique_adapters = set(adapter_names) | ||
sub_batch_indices_list = [] | ||
for adapter in unique_adapters: | ||
|
@@ -1054,7 +1094,13 @@ def _mixed_batch_forward( | |
# layer output | ||
sub_batch = x[sub_batch_indices_list[i]] | ||
after_A = self._embed(sub_batch, embedding_A) | ||
result[sub_batch_indices_list[i]] += (after_A @ embedding_B) * scaling | ||
adapter_output = (after_A @ embedding_B) * scaling | ||
|
||
# Apply embed_scale to match the base layer's scaling | ||
if embed_scale is not None: | ||
adapter_output = adapter_output * embed_scale.to(adapter_output.dtype) | ||
|
||
result[sub_batch_indices_list[i]] += adapter_output | ||
|
||
return result | ||
|
||
|
@@ -1086,6 +1132,11 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: | |
else: | ||
result = self.base_layer(x, *args, **kwargs) | ||
torch_result_dtype = result.dtype | ||
|
||
# Some embedding layers (e.g., Gemma3TextScaledWordEmbedding) apply scaling in their forward method. | ||
# Since base_layer(x) already includes this scaling, we need to apply it to LoRA contributions too. | ||
embed_scale = self._get_embed_scale() | ||
|
||
for active_adapter in self.active_adapters: | ||
if active_adapter not in self.lora_embedding_A: | ||
continue | ||
|
@@ -1095,7 +1146,13 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: | |
embedding_B = self.lora_embedding_B[active_adapter].T | ||
scaling = self.scaling[active_adapter] | ||
after_A = self._embed(x, embedding_A) | ||
result = result + (after_A @ embedding_B) * scaling | ||
adapter_output = (after_A @ embedding_B) * scaling | ||
|
||
# Apply embed_scale to match the base layer's scaling | ||
if embed_scale is not None: | ||
adapter_output = adapter_output * embed_scale.to(adapter_output.dtype) | ||
|
||
result = result + adapter_output | ||
else: | ||
result = self.lora_variant[active_adapter].forward( | ||
self, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -79,6 +79,42 @@ def tied_adapter(self): | |
return self._tied_adapter[0] | ||
return None | ||
|
||
def _get_embed_scale(self): | ||
""" | ||
Extract embed_scale from base layer if present and valid. | ||
|
||
Some embedding layers (e.g., Gemma3TextScaledWordEmbedding) apply scaling to embeddings in their forward | ||
method. This method checks for the presence of an `embed_scale` attribute and validates its shape. | ||
|
||
Returns: | ||
torch.Tensor or None: The embed_scale tensor if found and valid, None otherwise. | ||
""" | ||
base_layer = self.get_base_layer() | ||
if not hasattr(base_layer, "embed_scale"): | ||
return None | ||
|
||
embed_scale = base_layer.embed_scale | ||
|
||
# Convert scalar values to tensors | ||
if isinstance(embed_scale, (int, float)): | ||
return torch.tensor(embed_scale, device=base_layer.weight.device, dtype=base_layer.weight.dtype) | ||
|
||
# Validate tensor shape - must be scalar (0-d) or 1-element tensor for proper broadcasting | ||
if isinstance(embed_scale, torch.Tensor): | ||
if embed_scale.numel() == 1: | ||
return embed_scale | ||
else: | ||
# Log warning but don't fail - this maintains backward compatibility | ||
warnings.warn( | ||
f"Found embed_scale attribute with shape {embed_scale.shape}, expected scalar. " | ||
"Embedding scaling will not be applied. If this is unexpected, please open an issue at " | ||
"https://github.com/huggingface/peft/issues", | ||
UserWarning, | ||
|
||
) | ||
return None | ||
|
||
return None | ||
|
||
def _collect_token_weights(self, weight: torch.Tensor, rows: torch.Tensor, embed_dim: int) -> torch.Tensor: | ||
"""DeepSpeed zero3 specific code to initialize trainable tokens. | ||
|
||
|
@@ -232,6 +268,11 @@ def forward_adapters(self, x: torch.Tensor, active_adapters, *args, **kwargs) -> | |
scale_grad_by_freq=self.base_layer.scale_grad_by_freq, | ||
sparse=self.base_layer.sparse, | ||
) | ||
# Some embedding layers (e.g., Gemma3TextScaledWordEmbedding) apply scaling in their forward method. | ||
# Since we're using F.embedding directly, we need to apply this scaling manually. | ||
embed_scale = self._get_embed_scale() | ||
if embed_scale is not None: | ||
result = result * embed_scale.to(result.dtype) | ||
elif isinstance(self.base_layer, torch.nn.Linear): | ||
# Probably a tied adapter that wraps an LM head. | ||
result = F.linear( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -918,3 +918,103 @@ def test_save_pretrained_auto(self, model, resize_embedding, peft_config, tmp_pa | |
assert contains_embedding | ||
else: | ||
assert not contains_embedding | ||
|
||
def test_scaled_embedding_support(self): | ||
|
||
"""Test that TrainableTokens correctly handles embeddings with scaling (e.g., Gemma3).""" | ||
|
||
# Create a mock scaled embedding layer similar to Gemma3TextScaledWordEmbedding | ||
class ScaledEmbedding(torch.nn.Embedding): | ||
def __init__(self, num_embeddings, embedding_dim, embed_scale=2.0): | ||
super().__init__(num_embeddings, embedding_dim) | ||
self.register_buffer("embed_scale", torch.tensor(embed_scale), persistent=False) | ||
|
||
def forward(self, input_ids): | ||
return super().forward(input_ids) * self.embed_scale | ||
|
||
# Create a model with scaled embedding | ||
class ModelWithScaledEmb(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.emb = ScaledEmbedding(100, 10, embed_scale=3.0) | ||
self.lin0 = torch.nn.Linear(10, 1) | ||
|
||
def forward(self, x): | ||
return self.lin0(self.emb(x)) | ||
|
||
def get_input_embeddings(self): | ||
return self.emb | ||
|
||
base_model = ModelWithScaledEmb() | ||
peft_config = TrainableTokensConfig(target_modules=["emb"], token_indices=[0, 1, 2]) | ||
peft_model = get_peft_model(base_model, peft_config) | ||
|
||
# Test input | ||
x = torch.tensor([[0, 1, 2, 3]]) | ||
|
||
# Get outputs from base model and peft model | ||
base_output = base_model(x) | ||
|
||
peft_output = peft_model(x) | ||
|
||
# The outputs should be scaled - let's verify the embeddings are scaled | ||
base_embeddings = base_model.emb(x) | ||
peft_embeddings = peft_model.model.emb(x) | ||
|
||
# Check that both apply the same scaling factor | ||
assert hasattr(peft_model.model.emb, "_get_embed_scale") | ||
|
||
embed_scale = peft_model.model.emb._get_embed_scale() | ||
assert embed_scale is not None | ||
assert torch.allclose(embed_scale, torch.tensor(3.0)) | ||
|
||
# Before training, outputs should be identical (within numerical precision) | ||
assert torch.allclose(base_embeddings, peft_embeddings, atol=1e-6) | ||
|
||
# Simulate training | ||
self.simulate_training(peft_model.model.emb) | ||
|
||
# After "training", the scaled embeddings for modified tokens should differ | ||
peft_embeddings_trained = peft_model.model.emb(x) | ||
|
||
# Modified tokens (0, 1, 2) should be different from base | ||
base_emb_modified = base_embeddings[0, :3] | ||
peft_emb_modified = peft_embeddings_trained[0, :3] | ||
assert not torch.allclose(base_emb_modified, peft_emb_modified) | ||
|
||
# Unmodified token (3) should be the same | ||
base_emb_unmodified = base_embeddings[0, 3] | ||
peft_emb_unmodified = peft_embeddings_trained[0, 3] | ||
assert torch.allclose(base_emb_unmodified, peft_emb_unmodified, atol=1e-6) | ||
|
||
def test_scaled_embedding_with_lora(self): | ||
"""Test that TrainableTokens works with LoRA on scaled embeddings.""" | ||
|
||
|
||
class ScaledEmbedding(torch.nn.Embedding): | ||
def __init__(self, num_embeddings, embedding_dim, embed_scale=2.0): | ||
super().__init__(num_embeddings, embedding_dim) | ||
self.register_buffer("embed_scale", torch.tensor(embed_scale), persistent=False) | ||
|
||
def forward(self, input_ids): | ||
return super().forward(input_ids) * self.embed_scale | ||
|
||
class ModelWithScaledEmb(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.emb = ScaledEmbedding(100, 10, embed_scale=3.0) | ||
self.lin0 = torch.nn.Linear(10, 1) | ||
|
||
def forward(self, x): | ||
return self.lin0(self.emb(x)) | ||
|
||
def get_input_embeddings(self): | ||
return self.emb | ||
|
||
base_model = ModelWithScaledEmb() | ||
peft_config = LoraConfig(target_modules=["lin0"], trainable_token_indices={"emb": [0, 1, 2]}) | ||
peft_model = get_peft_model(base_model, peft_config) | ||
|
||
x = torch.tensor([[0, 1, 2, 3]]) | ||
|
||
# Verify embed_scale detection works in combined mode | ||
assert hasattr(peft_model.model.emb.token_adapter, "_get_embed_scale") | ||
embed_scale = peft_model.model.emb.token_adapter._get_embed_scale() | ||
assert embed_scale is not None | ||
assert torch.allclose(embed_scale, torch.tensor(3.0)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's factor this function out so that it can be reused between the two PEFT methods. It should be fine to put it on the
BaseTunerLayer
class. This also eliminates the small inconsistency between the two implementations you have.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
resolved in 0d216c7