-
Notifications
You must be signed in to change notification settings - Fork 400
/
dcn_v2.py
303 lines (261 loc) · 11.8 KB
/
dcn_v2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
#!/usr/bin/env python
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import math
import torch
from torch import nn
from torch.autograd import Function
from torch.nn.modules.utils import _pair
from torch.autograd.function import once_differentiable
import _ext as _backend
class _DCNv2(Function):
@staticmethod
def forward(ctx, input, offset, mask, weight, bias,
stride, padding, dilation, deformable_groups):
ctx.stride = _pair(stride)
ctx.padding = _pair(padding)
ctx.dilation = _pair(dilation)
ctx.kernel_size = _pair(weight.shape[2:4])
ctx.deformable_groups = deformable_groups
output = _backend.dcn_v2_forward(input, weight, bias,
offset, mask,
ctx.kernel_size[0], ctx.kernel_size[1],
ctx.stride[0], ctx.stride[1],
ctx.padding[0], ctx.padding[1],
ctx.dilation[0], ctx.dilation[1],
ctx.deformable_groups)
ctx.save_for_backward(input, offset, mask, weight, bias)
return output
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
input, offset, mask, weight, bias = ctx.saved_tensors
grad_input, grad_offset, grad_mask, grad_weight, grad_bias = \
_backend.dcn_v2_backward(input, weight,
bias,
offset, mask,
grad_output,
ctx.kernel_size[0], ctx.kernel_size[1],
ctx.stride[0], ctx.stride[1],
ctx.padding[0], ctx.padding[1],
ctx.dilation[0], ctx.dilation[1],
ctx.deformable_groups)
return grad_input, grad_offset, grad_mask, grad_weight, grad_bias,\
None, None, None, None,
dcn_v2_conv = _DCNv2.apply
class DCNv2(nn.Module):
def __init__(self, in_channels, out_channels,
kernel_size, stride, padding, dilation=1, deformable_groups=1):
super(DCNv2, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = _pair(kernel_size)
self.stride = _pair(stride)
self.padding = _pair(padding)
self.dilation = _pair(dilation)
self.deformable_groups = deformable_groups
self.weight = nn.Parameter(torch.Tensor(
out_channels, in_channels, *self.kernel_size))
self.bias = nn.Parameter(torch.Tensor(out_channels))
self.reset_parameters()
def reset_parameters(self):
n = self.in_channels
for k in self.kernel_size:
n *= k
stdv = 1. / math.sqrt(n)
self.weight.data.uniform_(-stdv, stdv)
self.bias.data.zero_()
def forward(self, input, offset, mask):
assert 2 * self.deformable_groups * self.kernel_size[0] * self.kernel_size[1] == \
offset.shape[1]
assert self.deformable_groups * self.kernel_size[0] * self.kernel_size[1] == \
mask.shape[1]
return dcn_v2_conv(input, offset, mask,
self.weight,
self.bias,
self.stride,
self.padding,
self.dilation,
self.deformable_groups)
class DCN(DCNv2):
def __init__(self, in_channels, out_channels,
kernel_size, stride, padding,
dilation=1, deformable_groups=1):
super(DCN, self).__init__(in_channels, out_channels,
kernel_size, stride, padding, dilation, deformable_groups)
channels_ = self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1]
self.conv_offset_mask = nn.Conv2d(self.in_channels,
channels_,
kernel_size=self.kernel_size,
stride=self.stride,
padding=self.padding,
bias=True)
self.init_offset()
def init_offset(self):
self.conv_offset_mask.weight.data.zero_()
self.conv_offset_mask.bias.data.zero_()
def forward(self, input):
out = self.conv_offset_mask(input)
o1, o2, mask = torch.chunk(out, 3, dim=1)
offset = torch.cat((o1, o2), dim=1)
mask = torch.sigmoid(mask)
return dcn_v2_conv(input, offset, mask,
self.weight, self.bias,
self.stride,
self.padding,
self.dilation,
self.deformable_groups)
class _DCNv2Pooling(Function):
@staticmethod
def forward(ctx, input, rois, offset,
spatial_scale,
pooled_size,
output_dim,
no_trans,
group_size=1,
part_size=None,
sample_per_part=4,
trans_std=.0):
ctx.spatial_scale = spatial_scale
ctx.no_trans = int(no_trans)
ctx.output_dim = output_dim
ctx.group_size = group_size
ctx.pooled_size = pooled_size
ctx.part_size = pooled_size if part_size is None else part_size
ctx.sample_per_part = sample_per_part
ctx.trans_std = trans_std
output, output_count = \
_backend.dcn_v2_psroi_pooling_forward(input, rois, offset,
ctx.no_trans, ctx.spatial_scale,
ctx.output_dim, ctx.group_size,
ctx.pooled_size, ctx.part_size,
ctx.sample_per_part, ctx.trans_std)
ctx.save_for_backward(input, rois, offset, output_count)
return output
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
input, rois, offset, output_count = ctx.saved_tensors
grad_input, grad_offset = \
_backend.dcn_v2_psroi_pooling_backward(grad_output,
input,
rois,
offset,
output_count,
ctx.no_trans,
ctx.spatial_scale,
ctx.output_dim,
ctx.group_size,
ctx.pooled_size,
ctx.part_size,
ctx.sample_per_part,
ctx.trans_std)
return grad_input, None, grad_offset, \
None, None, None, None, None, None, None, None
dcn_v2_pooling = _DCNv2Pooling.apply
class DCNv2Pooling(nn.Module):
def __init__(self,
spatial_scale,
pooled_size,
output_dim,
no_trans,
group_size=1,
part_size=None,
sample_per_part=4,
trans_std=.0):
super(DCNv2Pooling, self).__init__()
self.spatial_scale = spatial_scale
self.pooled_size = pooled_size
self.output_dim = output_dim
self.no_trans = no_trans
self.group_size = group_size
self.part_size = pooled_size if part_size is None else part_size
self.sample_per_part = sample_per_part
self.trans_std = trans_std
def forward(self, input, rois, offset):
assert input.shape[1] == self.output_dim
if self.no_trans:
offset = input.new()
return dcn_v2_pooling(input, rois, offset,
self.spatial_scale,
self.pooled_size,
self.output_dim,
self.no_trans,
self.group_size,
self.part_size,
self.sample_per_part,
self.trans_std)
class DCNPooling(DCNv2Pooling):
def __init__(self,
spatial_scale,
pooled_size,
output_dim,
no_trans,
group_size=1,
part_size=None,
sample_per_part=4,
trans_std=.0,
deform_fc_dim=1024):
super(DCNPooling, self).__init__(spatial_scale,
pooled_size,
output_dim,
no_trans,
group_size,
part_size,
sample_per_part,
trans_std)
self.deform_fc_dim = deform_fc_dim
if not no_trans:
self.offset_mask_fc = nn.Sequential(
nn.Linear(self.pooled_size * self.pooled_size *
self.output_dim, self.deform_fc_dim),
nn.ReLU(inplace=True),
nn.Linear(self.deform_fc_dim, self.deform_fc_dim),
nn.ReLU(inplace=True),
nn.Linear(self.deform_fc_dim, self.pooled_size *
self.pooled_size * 3)
)
self.offset_mask_fc[4].weight.data.zero_()
self.offset_mask_fc[4].bias.data.zero_()
def forward(self, input, rois):
offset = input.new()
if not self.no_trans:
# do roi_align first
n = rois.shape[0]
roi = dcn_v2_pooling(input, rois, offset,
self.spatial_scale,
self.pooled_size,
self.output_dim,
True, # no trans
self.group_size,
self.part_size,
self.sample_per_part,
self.trans_std)
# build mask and offset
offset_mask = self.offset_mask_fc(roi.view(n, -1))
offset_mask = offset_mask.view(
n, 3, self.pooled_size, self.pooled_size)
o1, o2, mask = torch.chunk(offset_mask, 3, dim=1)
offset = torch.cat((o1, o2), dim=1)
mask = torch.sigmoid(mask)
# do pooling with offset and mask
return dcn_v2_pooling(input, rois, offset,
self.spatial_scale,
self.pooled_size,
self.output_dim,
self.no_trans,
self.group_size,
self.part_size,
self.sample_per_part,
self.trans_std) * mask
# only roi_align
return dcn_v2_pooling(input, rois, offset,
self.spatial_scale,
self.pooled_size,
self.output_dim,
self.no_trans,
self.group_size,
self.part_size,
self.sample_per_part,
self.trans_std)