Skip to content

Commit 6f6095e

Browse files
ArthurZuckerCyrilvallezSunMarc
authored
Refactor weight loading (#41580)
* ah actually we don't discard lm head if missing -> needs to be moved to correct device and etc * fix some tests * small fixes * up * up * dik why we tie weights twice but,..,,. * ups * removeunused * fix hunyuan * small fix * nits * ish * up * rev * fix more tie weights keys * small fixes * nit * update * fix and fix * fix a test * glubs * current shitty changes * ship validated ones * more * more update * more * more * more * mllama * more up * fix ernie * fix xopies * up more * more fixes * up * up * fix-copies * fix more * more updates * AI UPDATE * up * hoey * make it fast * fix * lol * fix asjusting * more fixes * _dtype nit * up * nit * update * update * remove semaphores * fix import to avoid jit execution * try to remove custom tiing logic when its stupid * fix more individual models * fix whisper as well * fix? * fox umt5 * improve tqdm bar * cleanup a bit * oupsi * some updates * improve * remove all buffering -> much faster without it * remove some tie_weights custome funcs when not needed * more fixes related to strict matching regex * remove ALL custom tie weights * small update * revert change to init scheme (no need for params) * mixtral init * try less strict source check * tied weight first shot to the fiiiixxxxxx * does this help? * :) * fix some ppolry defined tied_weights_keys for now * subclass nn.Parameters * up * lol * Ouiiii * fix led * fix long cat flash * fix qwen and long cat flash * properly fix qwen init * just push this for now * propnet is dumb * update * push * remove explict sharing of some tied keys. * update decoder.bias * moe case * more changes to untangle old hardcoded ting * fixup * fix big faileurs * fix prophnet * fix resize token embeddings * nits * fix xcodex * asyncio? * fix smart apply * fix data-2-vec * [build-ci-image] * checkout * uupdate * fix hunyuan * update error message * fix deformable detr * fixes * fix init weights for non param gate up projs * shared todo? * update some models * big revert, don't break this behaviour * ty @SunMarc this fixes the buffers Co-authored-by: SunMarc <[email protected]> * mt5 fuck * fix lxmbert * nuke slow test fetcher * fix zamba and deepcopy for now * fix zamba tied weight keys! ~ * fix-copies * update fetch terst * fix gradient for test modeling common! * break "shared" for now I will fix tomorrow changes are properly isoalted now :) * does this fix marian? probably not * fix some vlms * D fine seems to handle this well * glob is fine actually * fix dab detr * small steps * opusy * fix some more models? * yups * better erro * fix? * fix double escape * escape wehere it makes sense * ?? * fix ibert * fix tvp as well * more fxes * try always download ref PR * ONONONO * big fixup * more fixup * small step * small nits * nits * brut force some stuff * fix vilt * make sure special models that always need tie always tie * cleaning up * small nits * fix zamba and bridge tower! * just fixup * potential culprits * revert bark and fix bridgetower * remove now non existant tie_weights * ? * lol reformer actually had nothing tied! * wow these two fucking models were really not well made * fix sam family! * fix bark revision * fix speech2test ? * push this for now.... * upsy * the fuck * fix rtdetr * update * proper * wow that one 's annoying * update * try to find the culprit * get some help on common * nit about general init and cls.padding_idx * revert num workers update * remove old loading func * fix glob * add annotations * fix re * small improvements * clean some stuff * improvements * someone did not understannnnnnd what I tried to dooo or does BNB not support that either? * gluos * fix case when `.` is just not there * remove unused arg * recover orignal parameter/buffer using _original * fix glob issu * this? * deepspeed best-effort * remove unused stuff * Update tie weight keys as they were just wroong Co-authored-by: Benjamin Bossan <[email protected]>" * up * augustuc clauss, a gloubs gloups gloubs * fixup * fixup * there was fucking typo * mrain * nits * fix marian 3 remaining tests * one more * fix some of the copies, not all :) * small cleanup * one propertest * fix core model loadig tes * attempt a new test * fix some of the annoying tests by supporting reading .bin sometimes * push * push more small fixes * remove 1 useless test * up * fix audio flamingo post rebase * fixup * some small updatess * fix sam models * nits * up * updates * onem ore * skip this stupid test * some other fixes * fixup * update * skip more offloaded stuff * oups * ups * update mixtral * skip this one * LET"SGO * fixup * rope delta order * fix csm * small nit --------- Co-authored-by: Cyril Vallez <[email protected]> Co-authored-by: SunMarc <[email protected]> Co-authored-by: Marc Sun <[email protected]>
1 parent c4cfc2e commit 6f6095e

File tree

586 files changed

+8733
-7189
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

586 files changed

+8733
-7189
lines changed

.circleci/config.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@ jobs:
4646
- run: uv pip install -U -e .
4747
- run: echo 'export "GIT_COMMIT_MESSAGE=$(git show -s --format=%s)"' >> "$BASH_ENV" && source "$BASH_ENV"
4848
- run: mkdir -p test_preparation
49-
- run: python utils/tests_fetcher.py | tee tests_fetched_summary.txt
50-
- run: python utils/tests_fetcher.py --filter_tests
49+
- run: python utils/tests_fetcher.py | tee tests_fetched_summary.txt || true
50+
- run: python utils/tests_fetcher.py --filter_tests || true
5151
- run: export "GIT_COMMIT_MESSAGE=$(git show -s --format=%s)" && echo $GIT_COMMIT_MESSAGE && python .circleci/create_circleci_config.py --fetcher_folder test_preparation
5252
- run: |
5353
if [ ! -s test_preparation/generated_config.yml ]; then
@@ -98,8 +98,8 @@ jobs:
9898
- run: uv pip install -U -e .
9999
- run: echo 'export "GIT_COMMIT_MESSAGE=$(git show -s --format=%s)"' >> "$BASH_ENV" && source "$BASH_ENV"
100100
- run: mkdir -p test_preparation
101-
- run: python utils/tests_fetcher.py --fetch_all | tee tests_fetched_summary.txt
102-
- run: python utils/tests_fetcher.py --filter_tests
101+
- run: python utils/tests_fetcher.py --fetch_all | tee tests_fetched_summary.txt || true
102+
- run: python utils/tests_fetcher.py --filter_tests || true
103103
- run: export "GIT_COMMIT_MESSAGE=$(git show -s --format=%s)" && echo $GIT_COMMIT_MESSAGE && python .circleci/create_circleci_config.py --fetcher_folder test_preparation
104104
- run: |
105105
if [ ! -s test_preparation/generated_config.yml ]; then

Makefile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ repo-consistency:
4545
python utils/check_modular_conversion.py
4646
python utils/check_dummies.py
4747
python utils/check_repo.py
48+
python utils/check_init_weights_data.py
4849
python utils/check_inits.py
4950
python utils/check_pipeline_typing.py
5051
python utils/check_config_docstrings.py

docs/source/de/add_new_model.md

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -508,16 +508,16 @@ BERT `_init_weights` Methode:
508508
def _init_weights(self, module):
509509
"""Initialize the weights"""
510510
if isinstance(module, nn.Linear):
511-
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
511+
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
512512
if module.bias is not None:
513-
module.bias.data.zero_()
513+
module.bias.zero_()
514514
elif isinstance(module, nn.Embedding):
515-
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
515+
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
516516
if module.padding_idx is not None:
517517
module.weight.data[module.padding_idx].zero_()
518518
elif isinstance(module, nn.LayerNorm):
519-
module.bias.data.zero_()
520-
module.weight.data.fill_(1.0)
519+
module.bias.zero_()
520+
module.weight.fill_(1.0)
521521
```
522522

523523
Sie können weitere benutzerdefinierte Schemata verwenden, wenn Sie eine spezielle Initialisierung für einige Module benötigen. Zum Beispiel in
@@ -533,9 +533,9 @@ def _init_weights(self, module):
533533
module.project_hid._is_hf_initialized = True
534534
module.project_q._is_hf_initialized = True
535535
elif isinstance(module, nn.Linear):
536-
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
536+
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
537537
if module.bias is not None:
538-
module.bias.data.zero_()
538+
module.bias.zero_()
539539
```
540540

541541
Das Flag `_is_hf_initialized` wird intern verwendet, um sicherzustellen, dass wir ein Submodul nur einmal initialisieren. Wenn Sie es auf

docs/source/en/add_new_model.md

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -314,16 +314,16 @@ Random initialization occurs in the `_init_weights` method of `BrandNewLlamaPreT
314314
def _init_weights(self, module):
315315
"""Initialize the weights"""
316316
if isinstance(module, nn.Linear):
317-
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
317+
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
318318
if module.bias is not None:
319-
module.bias.data.zero_()
319+
module.bias.zero_()
320320
elif isinstance(module, nn.Embedding):
321-
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
321+
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
322322
if module.padding_idx is not None:
323323
module.weight.data[module.padding_idx].zero_()
324324
elif isinstance(module, nn.LayerNorm):
325-
module.bias.data.zero_()
326-
module.weight.data.fill_(1.0)
325+
module.bias.zero_()
326+
module.weight.fill_(1.0)
327327
```
328328

329329
The initialization scheme can look different if you need to adapt it to your model. For example, [`Wav2Vec2ForPreTraining`] initializes [nn.Linear](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html) in its last two linear layers.
@@ -339,9 +339,9 @@ def _init_weights(self, module):
339339
module.project_hid._is_hf_initialized = True
340340
module.project_q._is_hf_initialized = True
341341
elif isinstance(module, nn.Linear):
342-
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
342+
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
343343
if module.bias is not None:
344-
module.bias.data.zero_()
344+
module.bias.zero_()
345345
```
346346

347347
### Convert checkpoints to Transformers

docs/source/en/perf_infer_gpu_multi.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ The example below packs `up_proj` and `gate_proj` into a single `gate_up_proj` m
149149
```python
150150
class Llama4TextExperts(nn.Module):
151151
...
152-
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim))
152+
self.gate_up_proj = nn.Parameter(torch.zeros(self.num_experts, self.hidden_size, 2 * self.expert_dim))
153153
```
154154

155155
Batch matrix multiplication can be used in the `forward` pass to compute the output of the `gate_up_proj` module.

docs/source/it/migration.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ Per quanto riguarda la classe `TrainingArguments`:
170170
- L'argomento `evaluate_during_training` di `TrainingArguments` è deprecato a favore di `eval_strategy`.
171171

172172
Per quanto riguarda il modello Transfo-XL:
173-
- L'attributo di configurazione `tie_weight` di Transfo-XL diventa `tie_words_embeddings`.
173+
- L'attributo di configurazione `tie_weight` di Transfo-XL diventa `tie_word_embeddings`.
174174
- Il metodo di modellazione `reset_length` di Transfo-XL diventa `reset_memory_length`.
175175

176176
Per quanto riguarda le pipeline:

docs/source/ja/add_new_model.md

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -406,16 +406,16 @@ model = BrandNewBertModel(BrandNewBertConfig())
406406
def _init_weights(self, module):
407407
"""Initialize the weights"""
408408
if isinstance(module, nn.Linear):
409-
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
409+
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
410410
if module.bias is not None:
411-
module.bias.data.zero_()
411+
module.bias.zero_()
412412
elif isinstance(module, nn.Embedding):
413-
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
413+
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
414414
if module.padding_idx is not None:
415415
module.weight.data[module.padding_idx].zero_()
416416
elif isinstance(module, nn.LayerNorm):
417-
module.bias.data.zero_()
418-
module.weight.data.fill_(1.0)
417+
module.bias.zero_()
418+
module.weight.fill_(1.0)
419419
```
420420

421421
特定のモジュールに特別な初期化が必要な場合、カスタムスキームをさらに持つことができます。たとえば、
@@ -431,9 +431,9 @@ def _init_weights(self, module):
431431
module.project_hid._is_hf_initialized = True
432432
module.project_q._is_hf_initialized = True
433433
elif isinstance(module, nn.Linear):
434-
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
434+
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
435435
if module.bias is not None:
436-
module.bias.data.zero_()
436+
module.bias.zero_()
437437
```
438438

439439
`_is_hf_initialized`フラグは、サブモジュールを一度だけ初期化することを確実にするために内部で使用されます。

docs/source/ko/add_new_model.md

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -348,16 +348,16 @@ model = BrandNewBertModel(BrandNewBertConfig())
348348
def _init_weights(self, module):
349349
"""Initialize the weights"""
350350
if isinstance(module, nn.Linear):
351-
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
351+
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
352352
if module.bias is not None:
353-
module.bias.data.zero_()
353+
module.bias.zero_()
354354
elif isinstance(module, nn.Embedding):
355-
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
355+
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
356356
if module.padding_idx is not None:
357357
module.weight.data[module.padding_idx].zero_()
358358
elif isinstance(module, nn.LayerNorm):
359-
module.bias.data.zero_()
360-
module.weight.data.fill_(1.0)
359+
module.bias.zero_()
360+
module.weight.fill_(1.0)
361361
```
362362

363363
몇 가지 모듈에 대해 특별한 초기화가 필요한 경우 사용자 정의 방식을 사용할 수도 있습니다. 예를 들어, `Wav2Vec2ForPreTraining`에서 마지막 두 개의 선형 레이어는 일반적인 PyTorch `nn.Linear`의 초기화를 가져야 하지만, 다른 모든 레이어는 위와 같은 초기화를 사용해야 합니다. 이는 다음과 같이 코드화됩니다:
@@ -371,9 +371,9 @@ def _init_weights(self, module):
371371
module.project_hid._is_hf_initialized = True
372372
module.project_q._is_hf_initialized = True
373373
elif isinstance(module, nn.Linear):
374-
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
374+
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
375375
if module.bias is not None:
376-
module.bias.data.zero_()
376+
module.bias.zero_()
377377
```
378378

379379
`_is_hf_initialized` 플래그는 서브모듈을 한 번만 초기화하도록 내부적으로 사용됩니다. `module.project_q``module.project_hid`에 대해 `True`로 설정함으로써, 우리가 수행한 사용자 정의 초기화가 이후에 덮어쓰이지 않도록 합니다. 즉, `_init_weights` 함수가 이들에게 적용되지 않습니다.

docs/source/ko/perf_infer_gpu_multi.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ class ParallelInterface(MutableMapping):
152152
```python
153153
class Llama4TextExperts(nn.Module):
154154
...
155-
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim))
155+
self.gate_up_proj = nn.Parameter(torch.zeros(self.num_experts, self.hidden_size, 2 * self.expert_dim))
156156
```
157157

158158
배치 행렬 곱셈을 `forward` 패스에서 사용하여 `gate_up_proj` 모듈의 출력을 계산할 수 있습니다.

examples/modular-transformers/modeling_dummy_bert.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -502,16 +502,10 @@ def __init__(self, config):
502502

503503
# The output weights are the same as the input embeddings, but there is
504504
# an output-only bias for each token.
505-
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
505+
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
506506

507507
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
508508

509-
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
510-
self.decoder.bias = self.bias
511-
512-
def _tie_weights(self):
513-
self.decoder.bias = self.bias
514-
515509
def forward(self, hidden_states):
516510
hidden_states = self.transform(hidden_states)
517511
hidden_states = self.decoder(hidden_states)
@@ -536,18 +530,18 @@ class DummyBertPreTrainedModel(PreTrainedModel):
536530
def _init_weights(self, module):
537531
"""Initialize the weights"""
538532
if isinstance(module, nn.Linear):
539-
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
533+
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
540534
if module.bias is not None:
541-
module.bias.data.zero_()
535+
module.bias.zero_()
542536
elif isinstance(module, nn.Embedding):
543-
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
537+
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
544538
if module.padding_idx is not None:
545539
module.weight.data[module.padding_idx].zero_()
546540
elif isinstance(module, nn.LayerNorm):
547-
module.bias.data.zero_()
548-
module.weight.data.fill_(1.0)
541+
module.bias.zero_()
542+
module.weight.fill_(1.0)
549543
elif isinstance(module, DummyBertLMPredictionHead):
550-
module.bias.data.zero_()
544+
module.bias.zero_()
551545

552546

553547
@auto_docstring(

0 commit comments

Comments
 (0)