Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions src/peft/tuners/trainable_tokens/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,18 @@ def update_layer(self, adapter_name, **kwargs):
# onto the new values, we would get undefined behavior. By replacing the specific token values we always
# get defined behavior.
weight = self.get_base_layer().weight
embed_dim = self.get_base_layer().embedding_dim
base = self.get_base_layer()
embed_dim = getattr(base, "embedding_dim", None)
if embed_dim is None:
embed_dim = getattr(base, "in_features", None)
if embed_dim is None:
embed_dim = weight.shape[-1]
Comment on lines +123 to +124
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you elaborate the purpose of the last case?


if init_weights:
if check_deepspeed_zero3_enabled():
values = self._collect_token_weights(weight, self.token_indices[adapter_name], embed_dim)
else:
values = self.weight[self.token_indices[adapter_name]]
values = weight[self.token_indices[adapter_name]]
else:
# random init with matching dtype/device
values = torch.randn(
Expand Down Expand Up @@ -230,9 +235,11 @@ def forward_adapters(self, x: torch.Tensor, active_adapters, *args, **kwargs) ->
)
elif isinstance(self.base_layer, torch.nn.Linear):
# Probably a tied adapter that wraps an LM head.
bias = getattr(self.base_layer, "bias", None)
result = F.linear(
input=x,
weight=W,
bias=bias,
Comment on lines +238 to +242
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should use self.get_base_layer() instead to be consistent with update_layer(). If you want, you can update the lines above for the F.embedding call and instance check as well.

)
else:
raise ValueError(
Expand Down
22 changes: 21 additions & 1 deletion src/peft/tuners/trainable_tokens/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,27 @@ def __getattr__(self, name: str):
def _prepare_adapter_config(self, peft_config, model_config):
# target_modules can be none which prompts us to infer the embedding layer name ourselves.
if peft_config.target_modules is None:
peft_config.target_modules = _get_input_embeddings_name(self.model, "embed_tokens")
targets = _get_input_embeddings_name(self.model, "embed_tokens")
if isinstance(targets, str):
targets = [targets]

# If embeddings are untied, also include the output embedding (lm head) module name
try:
tied_cfg = model_config.get("tie_word_embeddings", False)
tied_keys = getattr(self.model, "_tied_weights_keys", None)
are_tied = bool(tied_cfg and tied_keys is not None)
except Exception:
are_tied = False

if not are_tied and hasattr(self.model, "get_output_embeddings"):
out_emb = self.model.get_output_embeddings()
if out_emb is not None:
for name, module in self.model.named_modules():
if module is out_emb:
targets.append(name)
break

peft_config.target_modules = list(dict.fromkeys(targets))
Comment on lines +45 to +64
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the idea behind targeting the output embedding automatically in case of untied weights? I don't see the benefit and I think this is also breaking backward compatibility with existing checkpoints.


return peft_config

Expand Down