forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathreverse_packed_segs_op.cc
33 lines (30 loc) · 1.09 KB
/
reverse_packed_segs_op.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
#include "caffe2/operators/reverse_packed_segs_op.h"
namespace caffe2 {
REGISTER_CPU_OPERATOR(ReversePackedSegs, ReversePackedSegsOp<CPUContext>);
OPERATOR_SCHEMA(ReversePackedSegs)
.NumInputs(2)
.NumOutputs(1)
.SetDoc(R"DOC(
Reverse segments in a 3-D tensor (lengths, segments, embeddings,), leaving
paddings unchanged. This operator is used to reverse input of a recurrent neural
network to make it a BRNN.
)DOC")
.Input(0, "data", "a 3-D (lengths, segments, embeddings,) tensor.")
.Input(1, "lengths", "length of each segment.")
.Output(
0,
"reversed data",
"a (lengths, segments, embeddings,) tensor with each segment reversed"
"and paddings unchanged.");
class GetReversePackedSegsGradient : public GradientMakerBase {
using GradientMakerBase::GradientMakerBase;
vector<OperatorDef> GetGradientDefs() override {
return SingleGradientDef(
"ReversePackedSegs",
"",
vector<string>{GO(0), I(1)},
vector<string>{GI(0)});
}
};
REGISTER_GRADIENT(ReversePackedSegs, GetReversePackedSegsGradient);
} // namespace caffe2