Skip to content

Conversation

busycalibrating
Copy link

In the case of untied Embed/LM heads, it seems natural to let the Trainable Tokens implementation also support LM Heads (especially in the case of reserved special tokens). Currently, my understanding is that it only allows you to fine tune the input embeddings (unless I'm horribly mistaken?). Is there any reason to restrict it to just input embeddings? I hacked my local PEFT install to lift this restriction, and it could be broadly useful for others as well.

I didn't do any rigorous testing, so let me know if I'm missing anything obvious. I'd be happy to help out with this if there's interest.

@busycalibrating busycalibrating marked this pull request as draft September 20, 2025 23:22
@githubnemo
Copy link
Collaborator

Thanks for the suggestion!

I think that this is already supported by passing target_modules=["embed_tokens", "lm_head"] in TrainableTokensConfig (or trainable_token_indices={"embed_tokens": [...], "lm_head": [...]}) unless I'm missing something. Can you confirm that?

@busycalibrating
Copy link
Author

busycalibrating commented Sep 22, 2025

No there's a tiny issue (which is what my PR addresses) - passing the lm_head triggers an error when it tries to access the embedding dim; the base implementation does this:

embed_dim = self.get_base_layer().embedding_dim

Since the LMHead (at least for some models like Llama3.2) is Linear, this fails. I changed it to this in my local install:

        base = self.get_base_layer()
        embed_dim = getattr(base, "embedding_dim", None)
        if embed_dim is None:
            embed_dim = getattr(base, "in_features", None)
        if embed_dim is None:
            embed_dim = weight.shape[-1]

Copy link
Collaborator

@githubnemo githubnemo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No there's a tiny issue (which is what my PR addresses) - passing the lm_head triggers an error when it tries to access the embedding dim

Right, makes sense. I think supporting nn.Linear makes sense. Some comments below.

Let's also add some tests for this case, for example in tests/test_trainable_tokens.py.

Comment on lines +238 to +242
bias = getattr(self.base_layer, "bias", None)
result = F.linear(
input=x,
weight=W,
bias=bias,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should use self.get_base_layer() instead to be consistent with update_layer(). If you want, you can update the lines above for the F.embedding call and instance check as well.

Comment on lines +123 to +124
if embed_dim is None:
embed_dim = weight.shape[-1]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you elaborate the purpose of the last case?

Comment on lines +45 to +64
if isinstance(targets, str):
targets = [targets]

# If embeddings are untied, also include the output embedding (lm head) module name
try:
tied_cfg = model_config.get("tie_word_embeddings", False)
tied_keys = getattr(self.model, "_tied_weights_keys", None)
are_tied = bool(tied_cfg and tied_keys is not None)
except Exception:
are_tied = False

if not are_tied and hasattr(self.model, "get_output_embeddings"):
out_emb = self.model.get_output_embeddings()
if out_emb is not None:
for name, module in self.model.named_modules():
if module is out_emb:
targets.append(name)
break

peft_config.target_modules = list(dict.fromkeys(targets))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the idea behind targeting the output embedding automatically in case of untied weights? I don't see the benefit and I think this is also breaking backward compatibility with existing checkpoints.

@busycalibrating
Copy link
Author

busycalibrating commented Sep 24, 2025

Thanks for the feedback, I'll follow up by the end of the week/next week!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants