forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathpiecewise_linear_transform_op.h
256 lines (225 loc) · 8.21 KB
/
piecewise_linear_transform_op.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
#ifndef CAFFE2_OPERATORS_PIECEWISE_LINEAR_TRANSFORM_OP_H_
#define CAFFE2_OPERATORS_PIECEWISE_LINEAR_TRANSFORM_OP_H_
#include "caffe2/core/context.h"
#include "caffe2/core/export_caffe2_op_to_c10.h"
#include <c10/util/irange.h>
#include "caffe2/core/operator.h"
C10_DECLARE_EXPORT_CAFFE2_OP_TO_C10(PiecewiseLinearTransform);
namespace caffe2 {
template <typename T, class Context>
class PiecewiseLinearTransformOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
template <class... Args>
explicit PiecewiseLinearTransformOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...) {
binary_ = this->template GetSingleArgument<bool>("binary", false);
// Retrieve transform params (i.e., the linear functions).
bounds_from_arg_ = this->template GetRepeatedArgument<T>("bounds");
slopes_from_arg_ = this->template GetRepeatedArgument<T>("slopes");
intercepts_from_arg_ = this->template GetRepeatedArgument<T>("intercepts");
transform_param_from_arg_ = CheckTransParamFromArg();
}
bool RunOnDevice() override {
return binary_ ? TransformBinary() : TransformGeneral();
}
private:
// num_func_per_group is the number of pieces of linear functions of
// each group.
// num_group: The number of groups of linear functions. Each group is for
// transforming one column of predictions.
void InferNumFunctionsPerGroup(
const int64_t num_bounds,
const int64_t num_slopes,
const int64_t num_intercepts,
int64_t* num_func_per_group,
int64_t* num_group) {
CAFFE_ENFORCE_EQ(num_slopes, num_intercepts);
// This is based on the facts:
// 1. in each group, the num of bounds minus the num of slopes is 1;
// 2. each group has the same number of pieces.
*num_group = num_bounds - num_slopes;
CAFFE_ENFORCE_GT(*num_group, 0);
if (binary_) {
CAFFE_ENFORCE_EQ(*num_group, 1);
}
*num_func_per_group = num_slopes / *num_group;
CAFFE_ENFORCE_GT(*num_func_per_group, 0);
CAFFE_ENFORCE_EQ(num_slopes % *num_group, 0);
}
bool CheckBoundsSorted(
const T* bounds,
const int64_t num_bounds_per_group,
const int64_t num_group) {
const T* start = bounds;
for (const auto i : c10::irange(num_group)) {
(void)i; // CUDA-10.2 on Windows crashes when C10_UNUSED macro is used
if (!std::is_sorted(start, start + num_bounds_per_group)) {
return false;
}
start += num_bounds_per_group;
}
return true;
}
// Returns true if the transform params from arg are valid.
// Otherwise, we will assume the transform params will pass from Input blobs.
bool CheckTransParamFromArg() {
int good_param = 0;
good_param += bounds_from_arg_.size() > 0;
good_param += slopes_from_arg_.size() > 0;
good_param += intercepts_from_arg_.size() > 0;
CAFFE_ENFORCE(
good_param == 0 || good_param == 3,
"bounds, slopes, intercepts must be all set or all not set");
if (good_param == 3) {
int64_t num_func_per_group;
int64_t num_group;
InferNumFunctionsPerGroup(
bounds_from_arg_.size(),
slopes_from_arg_.size(),
intercepts_from_arg_.size(),
&num_func_per_group,
&num_group);
CAFFE_ENFORCE(
CheckBoundsSorted(
bounds_from_arg_.data(), num_func_per_group + 1, num_group),
"bounds must be sorted for each group");
}
return good_param == 3;
}
void setUpTensors(int64_t& num_func_per_group, int64_t& num_group, int64_t M);
void GetTransParamData(
const T** bounds,
const T** slopes,
const T** intercepts,
int64_t* num_func_per_group,
int64_t* num_group) {
int64_t num_bounds;
int64_t num_slopes;
int64_t num_intercepts;
if (transform_param_from_arg_) {
CAFFE_ENFORCE_EQ(InputSize(), 1);
*bounds = bounds_from_arg_.data();
*slopes = slopes_from_arg_.data();
*intercepts = intercepts_from_arg_.data();
num_bounds = bounds_from_arg_.size();
num_slopes = slopes_from_arg_.size();
num_intercepts = intercepts_from_arg_.size();
} else {
CAFFE_ENFORCE_EQ(InputSize(), 4);
auto& bounds_input = Input(BOUNDS);
auto& slopes_input = Input(SLOPES);
auto& intercepts_input = Input(INTERCEPTS);
*bounds = bounds_input.template data<T>();
*slopes = slopes_input.template data<T>();
*intercepts = intercepts_input.template data<T>();
num_bounds = bounds_input.numel();
num_slopes = slopes_input.numel();
num_intercepts = intercepts_input.numel();
}
InferNumFunctionsPerGroup(
num_bounds, num_slopes, num_intercepts, num_func_per_group, num_group);
}
bool TransformGeneral() {
auto& X = Input(0);
CAFFE_ENFORCE_EQ(X.dim(), 2);
int64_t N = X.dim32(0);
int64_t M = X.dim32(1);
auto* Y = Output(0, X.sizes(), at::dtype<T>());
const auto* Xdata = X.template data<T>();
T* Ydata = Y->template mutable_data<T>();
const T* bounds;
const T* slopes;
const T* intercepts;
int64_t num_func_per_group;
int64_t num_group;
GetTransParamData(
&bounds, &slopes, &intercepts, &num_func_per_group, &num_group);
CAFFE_ENFORCE_EQ(num_group, M);
for (const auto j : c10::irange(M)) {
const T* bounds_group = bounds + j * (num_func_per_group + 1);
const T* slopes_group = slopes + j * num_func_per_group;
const T* intercepts_group = intercepts + j * num_func_per_group;
for (const auto i : c10::irange(N)) {
Ydata[i * M + j] = PiecewiseLinearTransform(
Xdata[i * M + j],
bounds_group,
slopes_group,
intercepts_group,
num_func_per_group);
}
}
return true;
}
bool TransformBinary() {
auto& X = Input(PREDICTIONS);
CAFFE_ENFORCE(X.dim() == 1 || X.dim() == 2);
int64_t N = X.dim32(0);
int64_t M = X.dim() == 2 ? X.dim32(1) : 1;
CAFFE_ENFORCE(
M == 1 || M == 2,
"If binary is set to true, the input must be Nx2 or Nx1 tensor");
auto* Y = Output(0, X.sizes(), at::dtype<T>());
const auto* Xdata = X.template data<T>();
T* Ydata = Y->template mutable_data<T>();
const T* bounds;
const T* slopes;
const T* intercepts;
int64_t num_func_per_group;
int64_t num_group;
GetTransParamData(
&bounds, &slopes, &intercepts, &num_func_per_group, &num_group);
CAFFE_ENFORCE_EQ(num_group, 1);
if (M == 1) {
for (const auto i : c10::irange(N)) {
Ydata[i] = PiecewiseLinearTransform(
Xdata[i], bounds, slopes, intercepts, num_func_per_group);
}
} else {
for (const auto i : c10::irange(N)) {
Ydata[i * M + 1] = PiecewiseLinearTransform(
Xdata[i * M + 1], bounds, slopes, intercepts, num_func_per_group);
Ydata[i * M] = 1.0f - Ydata[i * M + 1];
}
}
return true;
}
T PiecewiseLinearTransform(
const T x,
const T* bounds,
const T* slopes,
const T* intercepts,
const int64_t num_func_per_group) {
T y = 0;
// deal with samples out of bounds
// make it the same as the upper/lower bound value
if (x <= bounds[0]) {
y = slopes[0] * bounds[0] + intercepts[0];
} else if (x >= bounds[num_func_per_group]) {
y = slopes[num_func_per_group - 1] * bounds[num_func_per_group] +
intercepts[num_func_per_group - 1];
} else {
auto low_bound =
std::lower_bound(bounds, bounds + num_func_per_group + 1, x);
int bounds_idx = low_bound - bounds - 1;
// compute the piecewise linear transformation as Y
y = slopes[bounds_idx] * x + intercepts[bounds_idx];
}
return y;
}
private:
bool binary_;
vector<T> bounds_from_arg_;
vector<T> slopes_from_arg_;
vector<T> intercepts_from_arg_;
Tensor bounds_device_{Context::GetDeviceType()};
Tensor intercepts_device_{Context::GetDeviceType()};
Tensor slopes_device_{Context::GetDeviceType()};
bool gpu_copied_ = false;
// If true, the piecewise linear functions are passed through args,
// otherwise, they are passed through Input blobs.
bool transform_param_from_arg_;
INPUT_TAGS(PREDICTIONS, BOUNDS, SLOPES, INTERCEPTS);
};
} // namespace caffe2
#endif // CAFFE2_OPERATORS_PIECEWISE_LINEAR_TRANSFORM_OP_H_