|
3 | 3 | import torch
|
4 | 4 | import torch.nn as nn
|
5 | 5 |
|
6 |
| -from timm.models.layers import ConvBnAct, DropPath, AvgPool2dSame, create_attn |
| 6 | +from timm.models.layers import create_conv2d, create_act_layer |
| 7 | +from timm.models.layers import DropPath, AvgPool2dSame, create_attn |
7 | 8 |
|
8 | 9 |
|
9 | 10 | from detectron2.layers import ShapeSpec, FrozenBatchNorm2d
|
|
84 | 85 | )
|
85 | 86 | )
|
86 | 87 |
|
| 88 | +class ConvBnAct(nn.Module): |
| 89 | + def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding='', dilation=1, groups=1, |
| 90 | + bias=False, apply_act=True, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, aa_layer=None, |
| 91 | + drop_block=None): |
| 92 | + super(ConvBnAct, self).__init__() |
| 93 | + use_aa = aa_layer is not None |
| 94 | + |
| 95 | + self.conv = create_conv2d( |
| 96 | + in_channels, out_channels, kernel_size, stride=1 if use_aa else stride, |
| 97 | + padding=padding, dilation=dilation, groups=groups, bias=bias) |
| 98 | + |
| 99 | + # NOTE for backwards compatibility with models that use separate norm and act layer definitions |
| 100 | + self.bn = norm_layer(out_channels) |
| 101 | + self.act = act_layer() |
| 102 | + self.aa = aa_layer( |
| 103 | + channels=out_channels) if stride == 2 and use_aa else None |
| 104 | + |
| 105 | + @property |
| 106 | + def in_channels(self): |
| 107 | + return self.conv.in_channels |
| 108 | + |
| 109 | + @property |
| 110 | + def out_channels(self): |
| 111 | + return self.conv.out_channels |
| 112 | + |
| 113 | + def forward(self, x): |
| 114 | + x = self.conv(x) |
| 115 | + x = self.bn(x) |
| 116 | + x = self.act(x) |
| 117 | + if self.aa is not None: |
| 118 | + x = self.aa(x) |
| 119 | + return x |
| 120 | + |
87 | 121 |
|
88 | 122 | def create_stem(
|
89 | 123 | in_chans=3, out_chs=32, kernel_size=3, stride=2, pool='',
|
@@ -394,8 +428,7 @@ def build_cspnet_backbone(cfg, input_shape=None):
|
394 | 428 | if norm_name == "FrozenBN":
|
395 | 429 | norm = FrozenBatchNorm2d
|
396 | 430 | elif norm_name == "SyncBN":
|
397 |
| - from detectron2.layers import NaiveSyncBatchNorm |
398 |
| - norm = NaiveSyncBatchNorm |
| 431 | + norm = nn.SyncBatchNorm |
399 | 432 | else:
|
400 | 433 | norm = nn.BatchNorm2d
|
401 | 434 |
|
|
0 commit comments