Skip to content

Conversation

@Cyrilvallez
Copy link
Member

@Cyrilvallez Cyrilvallez commented Nov 13, 2025

What does this PR do?

Follow-up to #41580.
This PR is focused on speed and efficiency, as well as clarity. It avoids the following issues/bottlenecks:

  • No longer need to switch dynamically the class of the Parameters (which is hard to understand and may lead to issues with quantization etc, as it's not always clear if the memory stays the same or if we double it)
  • Avoids one more loop over all the modules to switch class
  • Avoids initializing the tied weights (which takes a lonnnngggg time for big embeddings, e.g for "google/gemma-2-2b", the embedding has size 256k * 2304, which is more than 1GB of data in float16 in a single parameter, and it takes more than 3-4s to initialize it with normal_, see profiling below), when we will overwrite them anyway later (with the tied weights)
  • avoids one more loop over all modules for tied weights (get the correct names in advance)
Screenshot 2025-11-14 at 10 43 53

On that trace, the call to Parameter.normal_ takes 3.4s, and is only the lm_head (the only "missing" weight), even though it's not truly missing because it's a tied weight which is being overwritten later!

So basically, the following snippet

from transformers import AutoModelForCausalLM
import torch
import time

model_id = "google/gemma-2-2b"
device = 1 

torch.cuda.synchronize()
t0 = time.time()
model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device, dtype="auto")
torch.cuda.synchronize()
dt = time.time() - t0
print(f"Took {dt:.2f} s")

takes about 7s, when it was taking about 3s on main before #41580 /with the old loading). After this PR, it takes 3s as well, effectively being as performant as before.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very nice, make sure to rebase + models with specific init weight i think are like

elif isinstance(module, RTDetrV2MultiscaleDeformableAttention):
this one!

@Cyrilvallez Cyrilvallez changed the title Much more efficient and clear weight initialization Much more efficient and clear weight initialization and tie weights Nov 14, 2025
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general good, you need to overwrite the get and set for tie_words_embeddings to update tied_weight_keys potentially.

Usually, all models are using the init from `transformers` which are already guarded, but just to make extra sure
and for remote code, we also use this context manager.
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this won't work for any tensor manipulation for any remote code / code outside our scope, but its fine

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I know, this is very unfortunate but we cannot really make it work for remote code 🥲

_prefix = f"{self.base_model_prefix}."
unexpected_keys = {k.removeprefix(_prefix) for k in unexpected_keys}
# Set the flag (very important to avoid initializing them!!)
for tied_param in self._tied_weights_keys.keys():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

only if tie_words_embeddings or tie_encoder_decoder

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah ok they only exist if you have tie_wwords embeddings. But you need in that case a set and get

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes they are set correctly in advance in post_init

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Big bird failing IS related, let's fix.
Also can you update

# Unless required by applicable law or agreed to in writing, software
please.

Also add a big TODO for the dynamic part I think it is important!

@github-actions
Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: aimv2, albert, align, altclip, aria, audio_spectrogram_transformer, audioflamingo3, autoformer, bamba, bark, bart, beit, bert, bert_generation, big_bird, bigbird_pegasus

@Cyrilvallez Cyrilvallez merged commit 8598421 into main Nov 14, 2025
20 of 24 checks passed
@Cyrilvallez Cyrilvallez deleted the better-init-2 branch November 14, 2025 23:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants