Skip to content

Commit 557ef75

Browse files
committed
fixes
1 parent 99961fc commit 557ef75

File tree

6 files changed

+12
-11
lines changed

6 files changed

+12
-11
lines changed

src/transformers/models/electra/modeling_electra.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1302,7 +1302,7 @@ def __init__(self, config):
13021302
self.generator_predictions = ElectraGeneratorPredictions(config)
13031303
self.generator_lm_head = nn.Linear(config.embedding_size, config.vocab_size)
13041304

1305-
self.init_weights()
1305+
self.post_init()
13061306

13071307
def get_output_embeddings(self):
13081308
return self.generator_lm_head

src/transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1008,7 +1008,7 @@ def _init_weights(self, module):
10081008
init.zeros_(module.bias)
10091009
init.ones_(module.weight)
10101010
elif isinstance(module, nn.Embedding):
1011-
module.weight.normal_()
1011+
init.normal_(module.weight)
10121012
if module.padding_idx is not None:
10131013
init.zeros_(module.weight[module.padding_idx])
10141014
elif isinstance(module, FastSpeech2ConformerAttention):

src/transformers/models/groupvit/modeling_groupvit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -755,7 +755,7 @@ def _init_weights(self, module):
755755

756756
init_range = self.config.initializer_range
757757
if isinstance(module, (nn.Linear, nn.Conv2d)):
758-
module.weight.normal_(mean=0.0, std=init_range)
758+
init.normal_(module.weight, mean=0.0, std=init_range)
759759
if module.bias is not None:
760760
init.zeros_(module.bias)
761761
elif isinstance(module, nn.LayerNorm):

src/transformers/models/plbart/modular_plbart.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def __init__(self, config: PLBartConfig):
8282
self.encoder = PLBartEncoder(config)
8383
self.decoder = PLBartDecoder(config)
8484

85-
self.init_weights()
85+
self.post_init()
8686

8787
def get_input_embeddings(self):
8888
return self.shared
@@ -211,7 +211,7 @@ def __init__(self, config: PLBartConfig):
211211
self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
212212
self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)
213213

214-
self.init_weights()
214+
self.post_init()
215215

216216
def get_encoder(self):
217217
return self.model.get_encoder()

src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,9 @@ def _init_weights(self, module):
8282
super()._init_weights(module)
8383
std = self.config.initializer_range
8484
if isinstance(module, Qwen3OmniMoeThinkerTextSparseMoeBlock):
85-
module.experts.gate_up_proj.normal_(mean=0.0, std=std)
86-
module.experts.down_proj.normal_(mean=0.0, std=std)
87-
module.router.weight.normal_(mean=0.0, std=std)
85+
init.normal_(module.experts.gate_up_proj, mean=0.0, std=std)
86+
init.normal_(module.experts.down_proj, mean=0.0, std=std)
87+
init.normal_(module.router.weight, mean=0.0, std=std)
8888

8989

9090
def _get_feat_extract_output_lengths(input_lengths):

src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from torch import nn
2626
from torch.nn import functional as F
2727

28+
from ... import initialization as init
2829
from ...activations import ACT2FN
2930
from ...audio_utils import AudioInput
3031
from ...cache_utils import Cache, DynamicCache
@@ -796,9 +797,9 @@ def _init_weights(self, module):
796797
PreTrainedModel._init_weights(self, module)
797798
std = self.config.initializer_range
798799
if isinstance(module, Qwen3OmniMoeThinkerTextSparseMoeBlock):
799-
module.experts.gate_up_proj.normal_(mean=0.0, std=std)
800-
module.experts.down_proj.normal_(mean=0.0, std=std)
801-
module.router.weight.normal_(mean=0.0, std=std)
800+
init.normal_(module.experts.gate_up_proj, mean=0.0, std=std)
801+
init.normal_(module.experts.down_proj, mean=0.0, std=std)
802+
init.normal_(module.router.weight, mean=0.0, std=std)
802803

803804

804805
class Qwen3OmniMoePreTrainedModelForConditionalGeneration(Qwen2_5OmniPreTrainedModelForConditionalGeneration):

0 commit comments

Comments
 (0)