forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathminmax_gradient_ops.cc
66 lines (54 loc) · 2.06 KB
/
minmax_gradient_ops.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
#include "caffe2/operators/minmax_ops.h"
#include <string>
#include <vector>
#include "caffe2/utils/eigen_utils.h"
namespace caffe2 {
template <typename T, class Context>
bool SelectGradientOpBase<T, Context>::RunOnDevice() {
const auto& Y = Input(0);
const auto& dY = Input(1);
const int N = Y.numel();
ConstEigenVectorArrayMap<T> Y_arr(Y.template data<T>(), N);
ConstEigenVectorArrayMap<T> dY_arr(dY.template data<T>(), N);
for (int i = 0; i < OutputSize(); i++) {
const auto& Xi = Input(i + 2);
auto* dXi = Output(i, Xi.sizes(), at::dtype<T>());
ConstEigenVectorArrayMap<T> Xi_arr(Xi.template data<T>(), N);
EigenVectorArrayMap<T> dXi_arr(dXi->template mutable_data<T>(), N);
dXi_arr = (Xi_arr == Y_arr).template cast<T>() * dY_arr;
}
return true;
}
REGISTER_CPU_OPERATOR(MaxGradient, MaxGradientOp<float, CPUContext>);
REGISTER_CPU_OPERATOR(MinGradient, MinGradientOp<float, CPUContext>);
OPERATOR_SCHEMA(MaxGradient).NumInputs(3, INT_MAX).NumOutputs(1, INT_MAX);
OPERATOR_SCHEMA(MinGradient).NumInputs(3, INT_MAX).NumOutputs(1, INT_MAX);
namespace {
class GetMaxGradient : public GradientMakerBase {
using GradientMakerBase::GradientMakerBase;
std::vector<OperatorDef> GetGradientDefs() override {
std::vector<std::string> inputs = {O(0), GO(0)};
std::vector<std::string> grad_inputs;
for (int i = 0; i < def_.input_size(); ++i) {
inputs.push_back(I(i));
grad_inputs.push_back(GI(i));
}
return SingleGradientDef("MaxGradient", "", inputs, grad_inputs);
}
};
class GetMinGradient : public GradientMakerBase {
using GradientMakerBase::GradientMakerBase;
vector<OperatorDef> GetGradientDefs() override {
std::vector<std::string> inputs = {O(0), GO(0)};
std::vector<std::string> grad_inputs;
for (int i = 0; i < def_.input_size(); ++i) {
inputs.push_back(I(i));
grad_inputs.push_back(GI(i));
}
return SingleGradientDef("MinGradient", "", inputs, grad_inputs);
}
};
} // namespace
REGISTER_GRADIENT(Max, GetMaxGradient);
REGISTER_GRADIENT(Min, GetMinGradient);
} // namespace caffe2