forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathfc_inference.cc
135 lines (116 loc) · 4.86 KB
/
fc_inference.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
#include "caffe2/operators/fc_inference.h"
#include "caffe2/core/types.h"
namespace caffe2 {
std::vector<TensorShape> FCShapeInference(
const OperatorDef& def,
const vector<TensorShape>& in,
bool pretransposed_weight) {
vector<TensorShape> out(1);
if (in[0].unknown_shape() || in[1].unknown_shape()) {
out[0].set_unknown_shape(true);
return out;
}
ArgumentHelper helper(def);
auto axis = helper.GetSingleArgument<int32_t>("axis", 1);
const auto canonical_axis = canonical_axis_index_(axis, in[0].dims().size());
auto axis_w = helper.GetSingleArgument<int32_t>("axis_w", 1);
const int canonical_axis_w =
canonical_axis_index_(axis_w, in[1].dims().size());
const int64_t N = pretransposed_weight
? size_from_dim_(canonical_axis_w, GetDimsVector(in[1]))
: size_to_dim_(canonical_axis_w, GetDimsVector(in[1]));
vector<int64_t> y_shape(in[0].dims().begin(), in[0].dims().end());
CAFFE_ENFORCE_LE(canonical_axis + 1, y_shape.size());
y_shape.resize(canonical_axis + 1);
y_shape[canonical_axis] = N;
out[0] = CreateTensorShape(y_shape, in[0].data_type());
return out;
}
OpSchema::Cost CostInferenceForFC(
const OperatorDef& def,
const vector<TensorShape>& in,
bool pretransposed_weight) {
CAFFE_ENFORCE_GE(in.size(), 3, "FC requires at least three inputs");
struct OpSchema::Cost c;
ArgumentHelper helper(def);
auto axis = helper.GetSingleArgument<int32_t>("axis", 1);
const auto canonical_axis = canonical_axis_index_(axis, in[0].dims().size());
const uint64_t M = size_to_dim_(canonical_axis, GetDimsVector(in[0]));
const uint64_t K = size_from_dim_(canonical_axis, GetDimsVector(in[0]));
auto axis_w = helper.GetSingleArgument<int32_t>("axis_w", 1);
const int canonical_axis_w =
canonical_axis_index_(axis_w, in[1].dims().size());
const uint64_t N = pretransposed_weight
? size_from_dim_(canonical_axis_w, GetDimsVector(in[1]))
: size_to_dim_(canonical_axis_w, GetDimsVector(in[1]));
auto const& X_element_size_byte =
DataTypeToTypeMeta(in[0].data_type()).itemsize();
c.flops = M * N * (2 * K + 1);
c.bytes_read = (K * (M + N) + N) * X_element_size_byte;
c.bytes_written = M * N * X_element_size_byte;
c.params_bytes = (K * N + N) * X_element_size_byte;
return c;
}
std::vector<TensorShape> FCGradientShapeInference(
const OperatorDef& def,
const vector<TensorShape>& in,
bool pretransposed_weight) {
vector<TensorShape> out(2);
ArgumentHelper helper(def);
auto axis_w = helper.GetSingleArgument<int32_t>("axis_w", 1);
const int canonical_axis_w =
canonical_axis_index_(axis_w, in[1].dims().size());
const int N = pretransposed_weight
? size_from_dim_(canonical_axis_w, GetDimsVector(in[1]))
: size_to_dim_(canonical_axis_w, GetDimsVector(in[1]));
vector<int> dW_shape(in[1].dims().begin(), in[1].dims().end());
out[0] = CreateTensorShape(dW_shape, in[1].data_type());
out[1] = CreateTensorShape(vector<int>{N}, in[1].data_type()); // db
if (def.output_size() == 3) {
vector<int> dX_shape(in[0].dims().begin(), in[0].dims().end());
out.push_back(CreateTensorShape(dX_shape, in[0].data_type()));
}
return out;
}
OpSchema::Cost CostInferenceForFCGradient(
const OperatorDef& def,
const vector<TensorShape>& in,
bool pretransposed_weight) {
struct OpSchema::Cost c;
ArgumentHelper helper(def);
std::vector<TensorShape> out =
FCGradientShapeInference(def, in, pretransposed_weight);
CAFFE_ENFORCE_LT(0, out.size());
const TensorShape dW = out[0];
auto const& dW_element_size_byte =
DataTypeToTypeMeta(dW.data_type()).itemsize();
const TensorShape db = out[1];
auto const& db_element_size_byte =
DataTypeToTypeMeta(db.data_type()).itemsize();
auto axis = helper.GetSingleArgument<int32_t>("axis", 1);
const auto canonical_axis = canonical_axis_index_(axis, in[0].dims().size());
const uint64_t M = size_to_dim_(canonical_axis, GetDimsVector(in[0]));
const uint64_t K = size_from_dim_(canonical_axis, GetDimsVector(in[0]));
auto axis_w = helper.GetSingleArgument<int32_t>("axis_w", 1);
const int canonical_axis_w =
canonical_axis_index_(axis_w, in[1].dims().size());
const uint64_t N = pretransposed_weight
? size_from_dim_(canonical_axis_w, GetDimsVector(in[1]))
: size_to_dim_(canonical_axis_w, GetDimsVector(in[1]));
uint64_t size_dW = nElemFromDim(dW);
uint64_t size_db = nElemFromDim(db);
c.flops = M * N * (2 * K + 1);
c.bytes_written =
size_dW * dW_element_size_byte + size_db * db_element_size_byte;
c.params_bytes = (K * N + N) * sizeof(float);
if (out.size() == 3) {
const TensorShape dX = out[2];
uint64_t size_dX = nElemFromDim(dX);
auto const& dX_element_size_byte =
DataTypeToTypeMeta(dX.data_type()).itemsize();
c.flops += 2 * M * N * K;
c.bytes_written += size_dX * dX_element_size_byte;
}
return c;
}
} // namespace caffe2