Skip to content

Commit 41e0c92

Browse files
simonreisequbvel
andauthored
Improve auxiliary_in_channels default behavior in UperNet (#37540)
Improve auxiliary_in_channels behavior in UperNet Co-authored-by: Pavel Iakubovskii <[email protected]>
1 parent c61ca64 commit 41e0c92

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

src/transformers/models/upernet/configuration_upernet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def __init__(
9696
pool_scales=[1, 2, 3, 6],
9797
use_auxiliary_head=True,
9898
auxiliary_loss_weight=0.4,
99-
auxiliary_in_channels=384,
99+
auxiliary_in_channels=None,
100100
auxiliary_channels=256,
101101
auxiliary_num_convs=1,
102102
auxiliary_concat_input=False,

src/transformers/models/upernet/modeling_upernet.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -218,12 +218,14 @@ class UperNetFCNHead(nn.Module):
218218
"""
219219

220220
def __init__(
221-
self, config, in_index: int = 2, kernel_size: int = 3, dilation: Union[int, Tuple[int, int]] = 1
221+
self, config, in_channels, in_index: int = 2, kernel_size: int = 3, dilation: Union[int, Tuple[int, int]] = 1
222222
) -> None:
223223
super().__init__()
224224

225225
self.config = config
226-
self.in_channels = config.auxiliary_in_channels
226+
self.in_channels = (
227+
in_channels[in_index] if config.auxiliary_in_channels is None else config.auxiliary_in_channels
228+
)
227229
self.channels = config.auxiliary_channels
228230
self.num_convs = config.auxiliary_num_convs
229231
self.concat_input = config.auxiliary_concat_input
@@ -292,7 +294,9 @@ def __init__(self, config):
292294

293295
# Semantic segmentation head(s)
294296
self.decode_head = UperNetHead(config, in_channels=self.backbone.channels)
295-
self.auxiliary_head = UperNetFCNHead(config) if config.use_auxiliary_head else None
297+
self.auxiliary_head = (
298+
UperNetFCNHead(config, in_channels=self.backbone.channels) if config.use_auxiliary_head else None
299+
)
296300

297301
# Initialize weights and apply final processing
298302
self.post_init()

0 commit comments

Comments
 (0)