Skip to content

Commit 5a95ed5

Browse files
🚨🚨 Fix initialization of Mask2Former (#38864)
* Correctly fix init Co-authored-by: BUI Van Tuan <[email protected]> * add back the block, breaking BC but this is correct author's code * override the test for params needing it --------- Co-authored-by: BUI Van Tuan <[email protected]>
1 parent 309e8c9 commit 5a95ed5

File tree

4 files changed

+68
-32
lines changed

4 files changed

+68
-32
lines changed

src/transformers/models/mask2former/modeling_mask2former.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2127,30 +2127,20 @@ def _init_weights(self, module: nn.Module):
21272127
for p in module.parameters():
21282128
if p.dim() > 1:
21292129
nn.init.xavier_uniform_(p, gain=xavier_std)
2130-
2131-
elif isinstance(module, Mask2FormerPixelLevelModule):
2132-
for submodule in module.modules():
2133-
if isinstance(submodule, (nn.Conv2d, nn.Linear)):
2134-
submodule.weight.data.normal_(mean=0.0, std=std)
2135-
if submodule.bias is not None:
2136-
submodule.bias.data.zero_()
2130+
module.cross_attn.in_proj_bias.data.zero_()
21372131

21382132
elif isinstance(module, Mask2FormerPixelDecoder):
2139-
for p in module.parameters():
2140-
if p.dim() > 1:
2141-
nn.init.xavier_uniform_(p)
21422133
nn.init.normal_(module.level_embed, std=0)
21432134

2144-
elif isinstance(module, Mask2FormerPixelDecoderEncoderOnly):
2145-
for p in module.parameters():
2146-
if p.dim() > 1:
2147-
nn.init.xavier_uniform_(p)
2148-
21492135
elif isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)):
21502136
module.weight.data.normal_(mean=0.0, std=std)
21512137
if module.bias is not None:
21522138
module.bias.data.zero_()
21532139

2140+
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
2141+
module.weight.data.fill_(1.0)
2142+
module.bias.data.zero_()
2143+
21542144
elif isinstance(module, nn.Embedding):
21552145
module.weight.data.normal_(mean=0.0, std=std)
21562146
if module.padding_idx is not None:

src/transformers/utils/backbone_utils.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -324,12 +324,7 @@ def load_backbone(config):
324324
raise ValueError("Cannot specify both config.backbone_config and config.backbone")
325325

326326
# If any of thhe following are set, then the config passed in is from a model which contains a backbone.
327-
if (
328-
backbone_config is None
329-
and use_timm_backbone is None
330-
and backbone_checkpoint is None
331-
and backbone_checkpoint is None
332-
):
327+
if backbone_config is None and use_timm_backbone is None and backbone_checkpoint is None:
333328
return AutoBackbone.from_config(config=config, **backbone_kwargs)
334329

335330
# config from the parent model that has a backbone

tests/models/deformable_detr/test_modeling_deformable_detr.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -590,15 +590,14 @@ def test_initialization(self):
590590
model = model_class(config=configs_no_init)
591591
for name, param in model.named_parameters():
592592
if param.requires_grad:
593-
if param.requires_grad:
594-
if (
595-
"level_embed" in name
596-
or "sampling_offsets.bias" in name
597-
or "value_proj" in name
598-
or "output_proj" in name
599-
or "reference_points" in name
600-
):
601-
continue
593+
if (
594+
"level_embed" in name
595+
or "sampling_offsets.bias" in name
596+
or "value_proj" in name
597+
or "output_proj" in name
598+
or "reference_points" in name
599+
):
600+
continue
602601
self.assertIn(
603602
((param.data.mean() * 1e9).round() / 1e9).item(),
604603
[0.0, 1.0],

tests/models/mask2former/test_modeling_mask2former.py

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import numpy as np
1919

2020
from tests.test_modeling_common import floats_tensor
21-
from transformers import Mask2FormerConfig, is_torch_available, is_vision_available
21+
from transformers import AutoModelForImageClassification, Mask2FormerConfig, is_torch_available, is_vision_available
2222
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_4
2323
from transformers.testing_utils import (
2424
require_timm,
@@ -33,7 +33,7 @@
3333
from transformers.utils import cached_property
3434

3535
from ...test_configuration_common import ConfigTester
36-
from ...test_modeling_common import ModelTesterMixin
36+
from ...test_modeling_common import ModelTesterMixin, _config_zero_init
3737
from ...test_pipeline_mixin import PipelineTesterMixin
3838

3939

@@ -350,6 +350,58 @@ def test_backbone_selection(self):
350350
elif model.__class__.__name__ == "Mask2FormerForUniversalSegmentation":
351351
self.assertEqual(model.model.pixel_level_module.encoder.out_indices, [1, 2, 3])
352352

353+
def test_initialization(self):
354+
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
355+
356+
configs_no_init = _config_zero_init(config)
357+
for model_class in self.all_model_classes:
358+
model = model_class(config=configs_no_init)
359+
for name, param in model.named_parameters():
360+
if param.requires_grad:
361+
if (
362+
"self_attn.sampling_offsets.bias" in name
363+
or "self_attn.value_proj.weight" in name
364+
or "self_attn.output_proj.weight" in name
365+
):
366+
continue
367+
self.assertIn(
368+
((param.data.mean() * 1e9).round() / 1e9).item(),
369+
[0.0, 1.0],
370+
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
371+
)
372+
373+
def test_initialization_pretrained_backbone(self):
374+
backbone_name = "microsoft/resnet-18"
375+
376+
# load Mask2Former config with a pretrained backbone
377+
config = Mask2FormerConfig(
378+
backbone=backbone_name,
379+
use_pretrained_backbone=True,
380+
)
381+
382+
# load pretrained backbone
383+
backbone_model = AutoModelForImageClassification.from_pretrained(backbone_name, device_map=torch_device)
384+
385+
def params_match(params1, params2):
386+
return all((p1 == p2).all() for p1, p2 in zip(params1, params2))
387+
388+
for model_class in self.all_model_classes:
389+
model = model_class(config).to(torch_device).eval()
390+
if model.__class__.__name__ == "Mask2FormerModel":
391+
self.assertTrue(
392+
params_match(
393+
backbone_model.base_model.encoder.parameters(),
394+
model.pixel_level_module.encoder.encoder.parameters(),
395+
)
396+
)
397+
elif model.__class__.__name__ == "Mask2FormerForUniversalSegmentation":
398+
self.assertTrue(
399+
params_match(
400+
backbone_model.base_model.encoder.parameters(),
401+
model.model.pixel_level_module.encoder.encoder.parameters(),
402+
)
403+
)
404+
353405

354406
TOLERANCE = 1e-4
355407

0 commit comments

Comments
 (0)