forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathprepend_dim_op.cc
46 lines (38 loc) · 1.31 KB
/
prepend_dim_op.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
#include "caffe2/operators/prepend_dim_op.h"
namespace caffe2 {
REGISTER_CPU_OPERATOR(PrependDim, PrependDimOp<CPUContext>);
REGISTER_CPU_OPERATOR(MergeDim, MergeDimOp<CPUContext>);
OPERATOR_SCHEMA(PrependDim)
.NumInputs(1)
.NumOutputs(1)
.AllowInplace({{0, 0}})
.SetDoc(R"DOC(
Reshape the tensor by prepending a dimension of fixed size and dividing the
size of the next dimension by that amount.
)DOC")
.Arg("dim_size", "Size of the dimension to prepend.")
.Input(0, "data", "An input tensor.")
.Output(0, "reshaped", "Reshaped tensor.");
OPERATOR_SCHEMA(MergeDim)
.NumInputs(1)
.NumOutputs(1)
.AllowInplace({{0, 0}})
.SetDoc(R"DOC(
Merge first two dimensions in a single dimension with size dim(0) * dim(1).
)DOC")
.Input(0, "data", "An input tensor.")
.Output(0, "reshaped", "Reshaped tensor.")
.InheritOnnxSchema("Reshape");
class GetPrependDimGradient : public GradientMakerBase {
using GradientMakerBase::GradientMakerBase;
vector<OperatorDef> GetGradientDefs() override {
return SingleGradientDef(
"MergeDim", "", vector<string>{GO(0)}, vector<string>{GI(0)});
}
// Arguments are no longer needed in backprop.
bool CopyArguments() const override {
return false;
}
};
REGISTER_GRADIENT(PrependDim, GetPrependDimGradient);
} // namespace caffe2