Skip to content

Commit 571b20a

Browse files
committed
speed up by reusing the grid
1 parent 03b53ec commit 571b20a

File tree

3 files changed

+36
-20
lines changed

3 files changed

+36
-20
lines changed

tests/test_deform_conv.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,10 @@ def test_th_batch_map_offsets_grad():
5252
offsets = (np.random.random((4, 100, 100, 2)) * 2)
5353

5454
input = Variable(torch.from_numpy(input), requires_grad=True)
55-
offsets = Variable(torch.from_numpy(offsets), requires_grad=False)
55+
offsets = Variable(torch.from_numpy(offsets), requires_grad=True)
5656

5757
th_mapped_vals = th_batch_map_offsets(input, offsets)
5858
e = torch.from_numpy(np.random.random((4, 100, 100)))
5959
th_mapped_vals.backward(e)
60-
grad = input.grad
61-
assert not np.allclose(grad.data.numpy(), 0)
60+
assert not np.allclose(input.grad.data.numpy(), 0)
61+
assert not np.allclose(offsets.grad.data.numpy(), 0)

torch_deform_conv/deform_conv.py

+20-15
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,21 @@ def sp_batch_map_offsets(input, offsets):
138138
return mapped_vals
139139

140140

141-
def th_batch_map_offsets(input, offsets, order=1):
141+
def th_generate_grid(batch_size, input_size, dtype, cuda):
142+
grid = np.meshgrid(
143+
range(input_size), range(input_size), indexing='ij'
144+
)
145+
grid = np.stack(grid, axis=-1)
146+
grid = grid.reshape(-1, 2)
147+
148+
grid = np_repeat_2d(grid, batch_size)
149+
grid = torch.from_numpy(grid).type(dtype)
150+
if cuda:
151+
grid = grid.cuda()
152+
return Variable(grid, requires_grad=False)
153+
154+
155+
def th_batch_map_offsets(input, offsets, grid=None, order=1):
142156
"""Batch map offsets into input
143157
Parameters
144158
---------
@@ -148,23 +162,14 @@ def th_batch_map_offsets(input, offsets, order=1):
148162
-------
149163
torch.Tensor. shape = (b, s, s)
150164
"""
151-
input_shape = input.size()
152-
batch_size = input_shape[0]
153-
input_size = input_shape[1]
165+
batch_size = input.size(0)
166+
input_size = input.size(1)
154167

155168
offsets = offsets.view(batch_size, -1, 2)
156-
grid = np.meshgrid(
157-
range(input_size), range(input_size), indexing='ij'
158-
)
159-
grid = np.stack(grid, axis=-1)
160-
grid = grid.reshape(-1, 2)
169+
if grid is None:
170+
grid = th_generate_grid(batch_size, input_size, offsets.data.type(), offsets.data.is_cuda)
161171

162-
grid = np_repeat_2d(grid, batch_size)
163-
grid = torch.from_numpy(grid).type(offsets.data.type())
164-
if offsets.is_cuda:
165-
grid = grid.cuda()
166-
167-
coords = offsets.add(Variable(grid, requires_grad=False))
172+
coords = offsets + grid
168173

169174
mapped_vals = th_batch_map_coordinates(input, coords)
170175
return mapped_vals

torch_deform_conv/layers.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch.nn as nn
55

66
import numpy as np
7-
from torch_deform_conv.deform_conv import th_batch_map_offsets
7+
from torch_deform_conv.deform_conv import th_batch_map_offsets, th_generate_grid
88

99

1010
class ConvOffset2D(nn.Conv2d):
@@ -29,6 +29,7 @@ def __init__(self, filters, init_normal_stddev=0.01, **kwargs):
2929
Pass to superclass. See Con2d layer in pytorch
3030
"""
3131
self.filters = filters
32+
self._grid_param = None
3233
super(ConvOffset2D, self).__init__(self.filters, self.filters*2, 3, padding=1, bias=False, **kwargs)
3334
self.weight.data.copy_(self._init_weights(self.weight, init_normal_stddev))
3435

@@ -44,13 +45,23 @@ def forward(self, x):
4445
x = self._to_bc_h_w(x, x_shape)
4546

4647
# X_offset: (b*c, h, w)
47-
x_offset = th_batch_map_offsets(x, offsets)
48+
x_offset = th_batch_map_offsets(x, offsets, grid=self._get_grid(self,x))
4849

4950
# x_offset: (b, h, w, c)
5051
x_offset = self._to_b_c_h_w(x_offset, x_shape)
5152

5253
return x_offset
5354

55+
@staticmethod
56+
def _get_grid(self, x):
57+
batch_size, input_size= x.size(0), x.size(1)
58+
dtype, cuda = x.data.type(), x.data.is_cuda
59+
if self._grid_param == (batch_size, input_size, dtype, cuda):
60+
return self._grid
61+
self._grid_param = (batch_size, input_size, dtype, cuda)
62+
self._grid = th_generate_grid(batch_size, input_size, dtype, cuda)
63+
return self._grid
64+
5465
@staticmethod
5566
def _init_weights(weights, std):
5667
fan_out = weights.size(0)

0 commit comments

Comments
 (0)