Skip to content

Commit f33c91e

Browse files
committed
more post_init
1 parent c3b5f3e commit f33c91e

File tree

3 files changed

+8
-4
lines changed

3 files changed

+8
-4
lines changed

src/transformers/models/layoutlmv3/modeling_layoutlmv3.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -591,7 +591,7 @@ def __init__(self, config):
591591

592592
self.encoder = LayoutLMv3Encoder(config)
593593

594-
self.init_weights()
594+
self.post_init()
595595

596596
def get_input_embeddings(self):
597597
return self.embeddings.word_embeddings
@@ -881,7 +881,7 @@ def __init__(self, config):
881881
else:
882882
self.classifier = LayoutLMv3ClassificationHead(config, pool_feature=False)
883883

884-
self.init_weights()
884+
self.post_init()
885885

886886
@auto_docstring
887887
def forward(
@@ -980,7 +980,7 @@ def __init__(self, config):
980980
self.layoutlmv3 = LayoutLMv3Model(config)
981981
self.qa_outputs = LayoutLMv3ClassificationHead(config, pool_feature=False)
982982

983-
self.init_weights()
983+
self.post_init()
984984

985985
@auto_docstring
986986
def forward(
@@ -1099,7 +1099,7 @@ def __init__(self, config):
10991099
self.layoutlmv3 = LayoutLMv3Model(config)
11001100
self.classifier = LayoutLMv3ClassificationHead(config, pool_feature=False)
11011101

1102-
self.init_weights()
1102+
self.post_init()
11031103

11041104
@auto_docstring
11051105
def forward(

src/transformers/models/udop/modeling_udop.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1074,6 +1074,7 @@ def __init__(self, config):
10741074

10751075
# get weights from encoder position bias
10761076
self.relative_bias = self._get_relative_bias(config)
1077+
self.post_init()
10771078

10781079
@staticmethod
10791080
def _get_relative_bias(config: UdopConfig) -> RelativePositionBiasAggregated:

tests/trainer/test_trainer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,7 @@ def __init__(self, config):
447447
self.a = nn.Parameter(torch.tensor(config.a).float())
448448
self.b = nn.Parameter(torch.tensor(config.b).float())
449449
self.double_output = config.double_output
450+
self.post_init()
450451

451452
def forward(self, input_x, labels=None, **kwargs):
452453
y = input_x * self.a + self.b
@@ -466,6 +467,7 @@ def __init__(self, config):
466467
self.head = nn.Linear(config.hidden_size, 1)
467468
self.gradient_checkpointing = False
468469
self.double_output = config.double_output
470+
self.post_init()
469471

470472
def forward(self, input_x, labels=None, **kwargs):
471473
y = input_x.unsqueeze(0)
@@ -496,6 +498,7 @@ def __init__(self, config):
496498
self.a = nn.Parameter(torch.tensor(config.a).float())
497499
self.b = nn.Parameter(torch.tensor(config.b).float())
498500
self.random_torch = config.random_torch
501+
self.post_init()
499502

500503
def forward(self, input_x, labels=None, **kwargs):
501504
y = input_x * self.a + self.b

0 commit comments

Comments
 (0)