From 9740efab3e62425835f923a1d9784e10998d3e87 Mon Sep 17 00:00:00 2001 From: Jayden9912 <85933053+Jayden9912@users.noreply.github.com> Date: Mon, 16 Jan 2023 14:56:24 +0800 Subject: [PATCH 1/2] Adding .contiguous() after transpose or permutation --- mmseg/models/backbones/mix_transformer.py | 10 +++++----- mmseg/models/decode_heads/segformer_head.py | 10 +++++----- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/mmseg/models/backbones/mix_transformer.py b/mmseg/models/backbones/mix_transformer.py index e4b7ea2..661efa5 100644 --- a/mmseg/models/backbones/mix_transformer.py +++ b/mmseg/models/backbones/mix_transformer.py @@ -106,11 +106,11 @@ def forward(self, x, H, W): kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) k, v = kv[0], kv[1] - attn = (q @ k.transpose(-2, -1)) * self.scale + attn = ((q @ k.transpose(-2, -1)) * self.scale).contiguous() attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) - x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = (attn @ v).transpose(1, 2).reshape(B, N, C).contiguous() x = self.proj(x) x = self.proj_drop(x) @@ -194,7 +194,7 @@ def _init_weights(self, m): def forward(self, x): x = self.proj(x) _, _, H, W = x.shape - x = x.flatten(2).transpose(1, 2) + x = x.flatten(2).transpose(1, 2).contiguous() x = self.norm(x) return x, H, W @@ -362,9 +362,9 @@ def __init__(self, dim=768): def forward(self, x, H, W): B, N, C = x.shape - x = x.transpose(1, 2).view(B, C, H, W) + x = x.transpose(1, 2).view(B, C, H, W).contiguous() x = self.dwconv(x) - x = x.flatten(2).transpose(1, 2) + x = x.flatten(2).transpose(1, 2).contiguous() return x diff --git a/mmseg/models/decode_heads/segformer_head.py b/mmseg/models/decode_heads/segformer_head.py index 8ada6d7..be6f587 100644 --- a/mmseg/models/decode_heads/segformer_head.py +++ b/mmseg/models/decode_heads/segformer_head.py @@ -26,7 +26,7 @@ def __init__(self, input_dim=2048, embed_dim=768): self.proj = nn.Linear(input_dim, embed_dim) def forward(self, x): - x = x.flatten(2).transpose(1, 2) + x = x.flatten(2).transpose(1, 2).contiguous() x = self.proj(x) return x @@ -68,16 +68,16 @@ def forward(self, inputs): ############## MLP decoder on C1-C4 ########### n, _, h, w = c4.shape - _c4 = self.linear_c4(c4).permute(0,2,1).reshape(n, -1, c4.shape[2], c4.shape[3]) + _c4 = (self.linear_c4(c4).permute(0,2,1).reshape(n, -1, c4.shape[2], c4.shape[3])).contiguous() _c4 = resize(_c4, size=c1.size()[2:],mode='bilinear',align_corners=False) - _c3 = self.linear_c3(c3).permute(0,2,1).reshape(n, -1, c3.shape[2], c3.shape[3]) + _c3 = (self.linear_c3(c3).permute(0,2,1).reshape(n, -1, c3.shape[2], c3.shape[3])).contiguous() _c3 = resize(_c3, size=c1.size()[2:],mode='bilinear',align_corners=False) - _c2 = self.linear_c2(c2).permute(0,2,1).reshape(n, -1, c2.shape[2], c2.shape[3]) + _c2 = (self.linear_c2(c2).permute(0,2,1).reshape(n, -1, c2.shape[2], c2.shape[3])).contiguous() _c2 = resize(_c2, size=c1.size()[2:],mode='bilinear',align_corners=False) - _c1 = self.linear_c1(c1).permute(0,2,1).reshape(n, -1, c1.shape[2], c1.shape[3]) + _c1 = (self.linear_c1(c1).permute(0,2,1).reshape(n, -1, c1.shape[2], c1.shape[3])).contiguous() _c = self.linear_fuse(torch.cat([_c4, _c3, _c2, _c1], dim=1)) From 4c733d0948e2681f9b218adc6ae35dff0092f491 Mon Sep 17 00:00:00 2001 From: Jayden9912 <85933053+Jayden9912@users.noreply.github.com> Date: Tue, 7 Mar 2023 15:47:15 +0800 Subject: [PATCH 2/2] pushing current work --- mmseg/__init__.py | 4 +- mmseg/models/backbones/mix_transformer.py | 364 +++++++++++++++----- mmseg/models/decode_heads/segformer_head.py | 42 ++- tools/get_flops.py | 1 + 4 files changed, 312 insertions(+), 99 deletions(-) diff --git a/mmseg/__init__.py b/mmseg/__init__.py index f301a5d..dd8f4c8 100755 --- a/mmseg/__init__.py +++ b/mmseg/__init__.py @@ -3,8 +3,8 @@ from .version import __version__, version_info MMCV_MIN = '1.1.4' -MMCV_MAX = '1.3.0' - +# MMCV_MAX = '1.3.0' +MMCV_MAX = '1.6.2' def digit_version(version_str): digit_version = [] diff --git a/mmseg/models/backbones/mix_transformer.py b/mmseg/models/backbones/mix_transformer.py index e4b7ea2..4f6c382 100644 --- a/mmseg/models/backbones/mix_transformer.py +++ b/mmseg/models/backbones/mix_transformer.py @@ -18,7 +18,14 @@ class Mlp(nn.Module): - def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.0, + ): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features @@ -32,7 +39,7 @@ def __init__(self, in_features, hidden_features=None, out_features=None, act_lay def _init_weights(self, m): if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=.02) + trunc_normal_(m.weight, std=0.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): @@ -56,14 +63,25 @@ def forward(self, x, H, W): class Attention(nn.Module): - def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1): + def __init__( + self, + dim, + num_heads=8, + qkv_bias=False, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + sr_ratio=1, + ): super().__init__() - assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + assert ( + dim % num_heads == 0 + ), f"dim {dim} should be divided by num_heads {num_heads}." self.dim = dim self.num_heads = num_heads head_dim = dim // num_heads - self.scale = qk_scale or head_dim ** -0.5 + self.scale = qk_scale or head_dim**-0.5 self.q = nn.Linear(dim, dim, bias=qkv_bias) self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) @@ -80,7 +98,7 @@ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0. def _init_weights(self, m): if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=.02) + trunc_normal_(m.weight, std=0.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): @@ -95,22 +113,34 @@ def _init_weights(self, m): def forward(self, x, H, W): B, N, C = x.shape - q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + q = ( + self.q(x) + .reshape(B, N, self.num_heads, C // self.num_heads) + .permute(0, 2, 1, 3) + ) if self.sr_ratio > 1: x_ = x.permute(0, 2, 1).reshape(B, C, H, W) x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) x_ = self.norm(x_) - kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + kv = ( + self.kv(x_) + .reshape(B, -1, 2, self.num_heads, C // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) else: - kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + kv = ( + self.kv(x) + .reshape(B, -1, 2, self.num_heads, C // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) k, v = kv[0], kv[1] - attn = (q @ k.transpose(-2, -1)) * self.scale + attn = ((q @ k.transpose(-2, -1)) * self.scale).contiguous() attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) - x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = ((attn @ v).transpose(1, 2).reshape(B, N, C)).contiguous() x = self.proj(x) x = self.proj_drop(x) @@ -118,26 +148,47 @@ def forward(self, x, H, W): class Block(nn.Module): - - def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., - drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1): + def __init__( + self, + dim, + num_heads, + mlp_ratio=4.0, + qkv_bias=False, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + sr_ratio=1, + ): super().__init__() self.norm1 = norm_layer(dim) self.attn = Attention( dim, - num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, - attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio) + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + sr_ratio=sr_ratio, + ) # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + ) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=.02) + trunc_normal_(m.weight, std=0.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): @@ -158,8 +209,7 @@ def forward(self, x, H, W): class OverlapPatchEmbed(nn.Module): - """ Image to Patch Embedding - """ + """Image to Patch Embedding""" def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768): super().__init__() @@ -170,15 +220,20 @@ def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=7 self.patch_size = patch_size self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] self.num_patches = self.H * self.W - self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, - padding=(patch_size[0] // 2, patch_size[1] // 2)) + self.proj = nn.Conv2d( + in_chans, + embed_dim, + kernel_size=patch_size, + stride=stride, + padding=(patch_size[0] // 2, patch_size[1] // 2), + ) self.norm = nn.LayerNorm(embed_dim) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=.02) + trunc_normal_(m.weight, std=0.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): @@ -194,63 +249,147 @@ def _init_weights(self, m): def forward(self, x): x = self.proj(x) _, _, H, W = x.shape - x = x.flatten(2).transpose(1, 2) + x = x.flatten(2).transpose(1, 2).contiguous() x = self.norm(x) return x, H, W class MixVisionTransformer(nn.Module): - def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512], - num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0., - attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, - depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1]): + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + num_classes=1000, + embed_dims=[64, 128, 256, 512], + num_heads=[1, 2, 4, 8], + mlp_ratios=[4, 4, 4, 4], + qkv_bias=False, + qk_scale=None, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.0, + norm_layer=nn.LayerNorm, + depths=[3, 4, 6, 3], + sr_ratios=[8, 4, 2, 1], + ): super().__init__() self.num_classes = num_classes self.depths = depths # patch_embed - self.patch_embed1 = OverlapPatchEmbed(img_size=img_size, patch_size=7, stride=4, in_chans=in_chans, - embed_dim=embed_dims[0]) - self.patch_embed2 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0], - embed_dim=embed_dims[1]) - self.patch_embed3 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1], - embed_dim=embed_dims[2]) - self.patch_embed4 = OverlapPatchEmbed(img_size=img_size // 16, patch_size=3, stride=2, in_chans=embed_dims[2], - embed_dim=embed_dims[3]) + self.patch_embed1 = OverlapPatchEmbed( + img_size=img_size, + patch_size=7, + stride=4, + in_chans=in_chans, + embed_dim=embed_dims[0], + ) + self.patch_embed2 = OverlapPatchEmbed( + img_size=img_size // 4, + patch_size=3, + stride=2, + in_chans=embed_dims[0], + embed_dim=embed_dims[1], + ) + self.patch_embed3 = OverlapPatchEmbed( + img_size=img_size // 8, + patch_size=3, + stride=2, + in_chans=embed_dims[1], + embed_dim=embed_dims[2], + ) + self.patch_embed4 = OverlapPatchEmbed( + img_size=img_size // 16, + patch_size=3, + stride=2, + in_chans=embed_dims[2], + embed_dim=embed_dims[3], + ) # transformer encoder - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) + ] # stochastic depth decay rule cur = 0 - self.block1 = nn.ModuleList([Block( - dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, - sr_ratio=sr_ratios[0]) - for i in range(depths[0])]) + self.block1 = nn.ModuleList( + [ + Block( + dim=embed_dims[0], + num_heads=num_heads[0], + mlp_ratio=mlp_ratios[0], + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[cur + i], + norm_layer=norm_layer, + sr_ratio=sr_ratios[0], + ) + for i in range(depths[0]) + ] + ) self.norm1 = norm_layer(embed_dims[0]) cur += depths[0] - self.block2 = nn.ModuleList([Block( - dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, - sr_ratio=sr_ratios[1]) - for i in range(depths[1])]) + self.block2 = nn.ModuleList( + [ + Block( + dim=embed_dims[1], + num_heads=num_heads[1], + mlp_ratio=mlp_ratios[1], + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[cur + i], + norm_layer=norm_layer, + sr_ratio=sr_ratios[1], + ) + for i in range(depths[1]) + ] + ) self.norm2 = norm_layer(embed_dims[1]) cur += depths[1] - self.block3 = nn.ModuleList([Block( - dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, - sr_ratio=sr_ratios[2]) - for i in range(depths[2])]) + self.block3 = nn.ModuleList( + [ + Block( + dim=embed_dims[2], + num_heads=num_heads[2], + mlp_ratio=mlp_ratios[2], + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[cur + i], + norm_layer=norm_layer, + sr_ratio=sr_ratios[2], + ) + for i in range(depths[2]) + ] + ) self.norm3 = norm_layer(embed_dims[2]) cur += depths[2] - self.block4 = nn.ModuleList([Block( - dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, - sr_ratio=sr_ratios[3]) - for i in range(depths[3])]) + self.block4 = nn.ModuleList( + [ + Block( + dim=embed_dims[3], + num_heads=num_heads[3], + mlp_ratio=mlp_ratios[3], + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[cur + i], + norm_layer=norm_layer, + sr_ratio=sr_ratios[3], + ) + for i in range(depths[3]) + ] + ) self.norm4 = norm_layer(embed_dims[3]) # classification head @@ -260,7 +399,7 @@ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, em def _init_weights(self, m): if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=.02) + trunc_normal_(m.weight, std=0.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): @@ -276,7 +415,9 @@ def _init_weights(self, m): def init_weights(self, pretrained=None): if isinstance(pretrained, str): logger = get_root_logger() - load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger) + load_checkpoint( + self, pretrained, map_location="cpu", strict=False, logger=logger + ) def reset_drop_path(self, drop_path_rate): dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))] @@ -301,14 +442,22 @@ def freeze_patch_emb(self): @torch.jit.ignore def no_weight_decay(self): - return {'pos_embed1', 'pos_embed2', 'pos_embed3', 'pos_embed4', 'cls_token'} # has pos_embed may be better + return { + "pos_embed1", + "pos_embed2", + "pos_embed3", + "pos_embed4", + "cls_token", + } # has pos_embed may be better def get_classifier(self): return self.head - def reset_classifier(self, num_classes, global_pool=''): + def reset_classifier(self, num_classes, global_pool=""): self.num_classes = num_classes - self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + self.head = ( + nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + ) def forward_features(self, x): B = x.shape[0] @@ -362,63 +511,110 @@ def __init__(self, dim=768): def forward(self, x, H, W): B, N, C = x.shape - x = x.transpose(1, 2).view(B, C, H, W) + x = x.transpose(1, 2).view(B, C, H, W).contiguous() x = self.dwconv(x) - x = x.flatten(2).transpose(1, 2) + x = x.flatten(2).transpose(1, 2).contiguous() return x - @BACKBONES.register_module() class mit_b0(MixVisionTransformer): def __init__(self, **kwargs): super(mit_b0, self).__init__( - patch_size=4, embed_dims=[32, 64, 160, 256], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], - qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], - drop_rate=0.0, drop_path_rate=0.1) + patch_size=4, + embed_dims=[32, 64, 160, 256], + num_heads=[1, 2, 5, 8], + mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + depths=[2, 2, 2, 2], + sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, + drop_path_rate=0.1, + ) @BACKBONES.register_module() class mit_b1(MixVisionTransformer): def __init__(self, **kwargs): super(mit_b1, self).__init__( - patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], - qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], - drop_rate=0.0, drop_path_rate=0.1) + patch_size=4, + embed_dims=[64, 128, 320, 512], + num_heads=[1, 2, 5, 8], + mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + depths=[2, 2, 2, 2], + sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, + drop_path_rate=0.1, + ) @BACKBONES.register_module() class mit_b2(MixVisionTransformer): def __init__(self, **kwargs): super(mit_b2, self).__init__( - patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], - qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], - drop_rate=0.0, drop_path_rate=0.1) + patch_size=4, + embed_dims=[64, 128, 320, 512], + num_heads=[1, 2, 5, 8], + mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + depths=[3, 4, 6, 3], + sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, + drop_path_rate=0.1, + ) @BACKBONES.register_module() class mit_b3(MixVisionTransformer): def __init__(self, **kwargs): super(mit_b3, self).__init__( - patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], - qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1], - drop_rate=0.0, drop_path_rate=0.1) + patch_size=4, + embed_dims=[64, 128, 320, 512], + num_heads=[1, 2, 5, 8], + mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + depths=[3, 4, 18, 3], + sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, + drop_path_rate=0.1, + ) @BACKBONES.register_module() class mit_b4(MixVisionTransformer): def __init__(self, **kwargs): super(mit_b4, self).__init__( - patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], - qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1], - drop_rate=0.0, drop_path_rate=0.1) + patch_size=4, + embed_dims=[64, 128, 320, 512], + num_heads=[1, 2, 5, 8], + mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + depths=[3, 8, 27, 3], + sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, + drop_path_rate=0.1, + ) @BACKBONES.register_module() class mit_b5(MixVisionTransformer): def __init__(self, **kwargs): super(mit_b5, self).__init__( - patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], - qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1], - drop_rate=0.0, drop_path_rate=0.1) \ No newline at end of file + patch_size=4, + embed_dims=[64, 128, 320, 512], + num_heads=[1, 2, 5, 8], + mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + depths=[3, 6, 40, 3], + sr_ratios=[8, 4, 2, 1], + drop_rate=0.0, + drop_path_rate=0.1, + ) diff --git a/mmseg/models/decode_heads/segformer_head.py b/mmseg/models/decode_heads/segformer_head.py index 8ada6d7..8ef7b6d 100644 --- a/mmseg/models/decode_heads/segformer_head.py +++ b/mmseg/models/decode_heads/segformer_head.py @@ -17,10 +17,12 @@ from IPython import embed + class MLP(nn.Module): """ Linear Embedding """ + def __init__(self, input_dim=2048, embed_dim=768): super().__init__() self.proj = nn.Linear(input_dim, embed_dim) @@ -36,16 +38,22 @@ class SegFormerHead(BaseDecodeHead): """ SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers """ + def __init__(self, feature_strides, **kwargs): - super(SegFormerHead, self).__init__(input_transform='multiple_select', **kwargs) + super(SegFormerHead, self).__init__(input_transform="multiple_select", **kwargs) assert len(feature_strides) == len(self.in_channels) assert min(feature_strides) == feature_strides[0] self.feature_strides = feature_strides - c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = self.in_channels + ( + c1_in_channels, + c2_in_channels, + c3_in_channels, + c4_in_channels, + ) = self.in_channels - decoder_params = kwargs['decoder_params'] - embedding_dim = decoder_params['embed_dim'] + decoder_params = kwargs["decoder_params"] + embedding_dim = decoder_params["embed_dim"] self.linear_c4 = MLP(input_dim=c4_in_channels, embed_dim=embedding_dim) self.linear_c3 = MLP(input_dim=c3_in_channels, embed_dim=embedding_dim) @@ -53,10 +61,10 @@ def __init__(self, feature_strides, **kwargs): self.linear_c1 = MLP(input_dim=c1_in_channels, embed_dim=embedding_dim) self.linear_fuse = ConvModule( - in_channels=embedding_dim*4, + in_channels=embedding_dim * 4, out_channels=embedding_dim, kernel_size=1, - norm_cfg=dict(type='SyncBN', requires_grad=True) + norm_cfg=dict(type="BN", requires_grad=True), ) self.linear_pred = nn.Conv2d(embedding_dim, self.num_classes, kernel_size=1) @@ -68,16 +76,24 @@ def forward(self, inputs): ############## MLP decoder on C1-C4 ########### n, _, h, w = c4.shape - _c4 = self.linear_c4(c4).permute(0,2,1).reshape(n, -1, c4.shape[2], c4.shape[3]) - _c4 = resize(_c4, size=c1.size()[2:],mode='bilinear',align_corners=False) + _c4 = ( + self.linear_c4(c4).permute(0, 2, 1).reshape(n, -1, c4.shape[2], c4.shape[3]) + ).contiguous() + _c4 = resize(_c4, size=c1.size()[2:], mode="bilinear", align_corners=False) - _c3 = self.linear_c3(c3).permute(0,2,1).reshape(n, -1, c3.shape[2], c3.shape[3]) - _c3 = resize(_c3, size=c1.size()[2:],mode='bilinear',align_corners=False) + _c3 = ( + self.linear_c3(c3).permute(0, 2, 1).reshape(n, -1, c3.shape[2], c3.shape[3]) + ).contiguous() + _c3 = resize(_c3, size=c1.size()[2:], mode="bilinear", align_corners=False) - _c2 = self.linear_c2(c2).permute(0,2,1).reshape(n, -1, c2.shape[2], c2.shape[3]) - _c2 = resize(_c2, size=c1.size()[2:],mode='bilinear',align_corners=False) + _c2 = ( + self.linear_c2(c2).permute(0, 2, 1).reshape(n, -1, c2.shape[2], c2.shape[3]) + ).contiguous() + _c2 = resize(_c2, size=c1.size()[2:], mode="bilinear", align_corners=False) - _c1 = self.linear_c1(c1).permute(0,2,1).reshape(n, -1, c1.shape[2], c1.shape[3]) + _c1 = ( + self.linear_c1(c1).permute(0, 2, 1).reshape(n, -1, c1.shape[2], c1.shape[3]) + ).contiguous() _c = self.linear_fuse(torch.cat([_c4, _c3, _c2, _c1], dim=1)) diff --git a/tools/get_flops.py b/tools/get_flops.py index 2643369..b892682 100755 --- a/tools/get_flops.py +++ b/tools/get_flops.py @@ -15,6 +15,7 @@ def parse_args(): type=int, nargs='+', default=[2048, 1024], + # default=[512, 288], help='input image size') args = parser.parse_args() return args