-
Notifications
You must be signed in to change notification settings - Fork 177
/
Copy pathinvolution2d_wrapper.h
234 lines (197 loc) · 7.87 KB
/
involution2d_wrapper.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
#pragma once
#include <ATen/core/dispatch/Dispatcher.h>
#include <ATen/autocast_mode.h>
#include <torch/csrc/autograd/custom_function.h>
#include "involution2d_cpu.h"
#ifdef USE_CUDA
# include "involution2d_cuda.cuh"
#endif
namespace involution {
at::Tensor involution2d(
const at::Tensor& input,
const at::Tensor& weight,
const std::vector<int64_t>& stride,
const std::vector<int64_t>& padding,
const std::vector<int64_t>& dilation
) {
static auto op = at::Dispatcher::singleton()
.findSchemaOrThrow("involution::involution2d", "")
.typed<decltype(involution2d)>();
return op.call(input, weight, stride, padding, dilation);
}
at::Tensor involution2d_autocast(
const at::Tensor& input,
const at::Tensor& weight,
const std::vector<int64_t>& stride,
const std::vector<int64_t>& padding,
const std::vector<int64_t>& dilation
) {
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
auto exec_type = at::autocast::promote_type(at::kFloat, input, weight);
return involution2d(at::autocast::cached_cast(exec_type, input), at::autocast::cached_cast(exec_type, weight), stride, padding, dilation)
.to(input.scalar_type());
}
at::Tensor _involution2d_backward_grad_input(
const at::Tensor& grad,
const at::Tensor& weight,
const std::vector<int64_t>& input_shape,
const std::vector<int64_t>& stride,
const std::vector<int64_t>& padding,
const std::vector<int64_t>& dilation
) {
static auto op = at::Dispatcher::singleton()
.findSchemaOrThrow("involution2d::_involution2d_backward_grad_input", "")
.typed<decltype(_involution2d_backward_grad_input)>();
return op.call(grad, weight, input_shape, stride, padding, dilation);
}
at::Tensor _involution2d_backward_grad_weight(
const at::Tensor& grad,
const at::Tensor& input,
const std::vector<int64_t>& weight_shape,
const std::vector<int64_t>& stride,
const std::vector<int64_t>& padding,
const std::vector<int64_t>& dilation
) {
static auto op = at::Dispatcher::singleton()
.findSchemaOrThrow("involution2d::_involution2d_backward_grad_weight", "")
.typed<decltype(_involution2d_backward_grad_weight)>();
return op.call(grad, input, weight_shape, stride, padding, dilation);
}
namespace cpu {
class Involution2dFunctionCPU : public torch::autograd::Function<Involution2dFunctionCPU>
{
public:
static torch::autograd::variable_list forward(
torch::autograd::AutogradContext* ctx,
const torch::autograd::Variable& input,
const torch::autograd::Variable& weight,
const std::vector<int64_t>& kernel_size,
const std::vector<int64_t>& stride,
const std::vector<int64_t>& padding,
const std::vector<int64_t>& dilation,
const int64_t groups
) {
ctx->saved_data["kernel_size"] = kernel_size;
ctx->saved_data["stride"] = stride;
ctx->saved_data["padding"] = padding;
ctx->saved_data["dilation"] = dilation;
ctx->saved_data["groups"] = groups;
ctx->save_for_backward({input, weight});
auto output = involution2d_forward(input, weight, kernel_size, stride, padding, dilation, groups);
return {output};
}
static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
const torch::autograd::variable_list grad_output
) {
torch::autograd::variable_list saved = ctx->get_saved_variables();
torch::autograd::Variable input = saved[0];
torch::autograd::Variable weight = saved[1];
auto kernel_size = ctx->saved_data["kernel_size"].toIntVector();
auto stride = ctx->saved_data["stride"].toIntVector();
auto padding = ctx->saved_data["padding"].toIntVector();
auto dilation = ctx->saved_data["dilation"].toIntVector();
auto groups = ctx->saved_data["groups"].toInt();
auto grads = involution2d_backward(grad_output[0], weight, input, kernel_size, stride, padding, dilation, groups);
return {
grads[0],
grads[1],
torch::autograd::Variable(),
torch::autograd::Variable(),
torch::autograd::Variable(),
torch::autograd::Variable(),
torch::autograd::Variable()
};
}
};
at::Tensor involution2d_autograd(
const torch::autograd::Variable& input,
const torch::autograd::Variable& weight,
const std::vector<int64_t>& kernel_size,
const std::vector<int64_t>& stride,
const std::vector<int64_t>& padding,
const std::vector<int64_t>& dilation,
const int64_t groups
) {
return Involution2dFunctionCPU::apply(input, weight, kernel_size, stride, padding, dilation, groups)[0];
}
} // namespace cpu
#ifdef USE_CUDA
namespace cuda {
class Involution2dFunctionCUDA : public torch::autograd::Function<Involution2dFunctionCUDA>
{
public:
static torch::autograd::variable_list forward(
torch::autograd::AutogradContext* ctx,
const torch::autograd::Variable& input,
const torch::autograd::Variable& weight,
const std::vector<int64_t>& kernel_size,
const std::vector<int64_t>& stride,
const std::vector<int64_t>& padding,
const std::vector<int64_t>& dilation,
const int64_t groups
) {
ctx->saved_data["kernel_size"] = kernel_size;
ctx->saved_data["stride"] = stride;
ctx->saved_data["padding"] = padding;
ctx->saved_data["dilation"] = dilation;
ctx->saved_data["groups"] = groups;
ctx->save_for_backward({input, weight});
auto output = involution2d_forward(input, weight, kernel_size, stride, padding, dilation, groups);
return {output};
}
static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
const torch::autograd::variable_list grad_output
) {
torch::autograd::variable_list saved = ctx->get_saved_variables();
torch::autograd::Variable input = saved[0];
torch::autograd::Variable weight = saved[1];
auto kernel_size = ctx->saved_data["kernel_size"].toIntVector();
auto stride = ctx->saved_data["stride"].toIntVector();
auto padding = ctx->saved_data["padding"].toIntVector();
auto dilation = ctx->saved_data["dilation"].toIntVector();
auto groups = ctx->saved_data["groups"].toInt();
auto grads = involution2d_backward(grad_output[0], weight, input, kernel_size, stride, padding, dilation, groups);
return {
grads[0],
grads[1],
torch::autograd::Variable(),
torch::autograd::Variable(),
torch::autograd::Variable(),
torch::autograd::Variable(),
torch::autograd::Variable()
};
}
};
at::Tensor involution2d_autograd(
const torch::autograd::Variable& input,
const torch::autograd::Variable& weight,
const std::vector<int64_t>& kernel_size,
const std::vector<int64_t>& stride,
const std::vector<int64_t>& padding,
const std::vector<int64_t>& dilation,
const int64_t groups
) {
return Involution2dFunctionCUDA::apply(input, weight, kernel_size, stride, padding, dilation, groups)[0];
}
at::Tensor involution2d_autocast(
const torch::autograd::Variable& input,
const torch::autograd::Variable& weight,
const std::vector<int64_t>& kernel_size,
const std::vector<int64_t>& stride,
const std::vector<int64_t>& padding,
const std::vector<int64_t>& dilation,
const int64_t groups
) {
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
auto exec_type = at::autocast::promote_type(at::kFloat, input, weight);
return involution2d_autograd(
at::autocast::cached_cast(exec_type, input),
at::autocast::cached_cast(exec_type, weight),
kernel_size, stride, padding, dilation, groups
);
}
} // namespace cuda
#endif
} // namespace involution