-
Notifications
You must be signed in to change notification settings - Fork 31.2k
Much more efficient and clear weight initialization and tie weights #42191
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
ArthurZucker
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
very nice, make sure to rebase + models with specific init weight i think are like
| elif isinstance(module, RTDetrV2MultiscaleDeformableAttention): |
f33c91e to
3ede287
Compare
ArthurZucker
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general good, you need to overwrite the get and set for tie_words_embeddings to update tied_weight_keys potentially.
| Usually, all models are using the init from `transformers` which are already guarded, but just to make extra sure | ||
| and for remote code, we also use this context manager. | ||
| """ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this won't work for any tensor manipulation for any remote code / code outside our scope, but its fine
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes I know, this is very unfortunate but we cannot really make it work for remote code 🥲
src/transformers/modeling_utils.py
Outdated
| _prefix = f"{self.base_model_prefix}." | ||
| unexpected_keys = {k.removeprefix(_prefix) for k in unexpected_keys} | ||
| # Set the flag (very important to avoid initializing them!!) | ||
| for tied_param in self._tied_weights_keys.keys(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
only if tie_words_embeddings or tie_encoder_decoder
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah ok they only exist if you have tie_wwords embeddings. But you need in that case a set and get
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes they are set correctly in advance in post_init
79e84f9 to
557ef75
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Big bird failing IS related, let's fix.
Also can you update
| # Unless required by applicable law or agreed to in writing, software |
Also add a big TODO for the dynamic part I think it is important!
|
[For maintainers] Suggested jobs to run (before merge) run-slow: aimv2, albert, align, altclip, aria, audio_spectrogram_transformer, audioflamingo3, autoformer, bamba, bark, bart, beit, bert, bert_generation, big_bird, bigbird_pegasus |
What does this PR do?
Follow-up to #41580.
This PR is focused on speed and efficiency, as well as clarity. It avoids the following issues/bottlenecks:
"google/gemma-2-2b", the embedding has size256k * 2304, which is more than 1GB of data in float16 in a single parameter, and it takes more than 3-4s to initialize it withnormal_, see profiling below), when we will overwrite them anyway later (with the tied weights)On that trace, the call to
Parameter.normal_takes 3.4s, and is only thelm_head(the only "missing" weight), even though it's not truly missing because it's a tied weight which is being overwritten later!So basically, the following snippet
takes about 7s, when it was taking about 3s on main before #41580 /with the old loading). After this PR, it takes 3s as well, effectively being as performant as before.