forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathjsd_op.cc
98 lines (88 loc) · 2.87 KB
/
jsd_op.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
#include "caffe2/operators/jsd_op.h"
namespace caffe2 {
namespace {
static constexpr float kLOG_THRESHOLD() {
return 1e-20;
}
inline float logit(float p) {
// it computes log(p / (1-p))
// to avoid numeric issue, hard code p log(p) when p approaches 0
float x = std::min(std::max(p, kLOG_THRESHOLD()), 1 - kLOG_THRESHOLD());
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
return -log(1. / x - 1.);
}
inline float entropy(float p) {
if (p < kLOG_THRESHOLD() || 1 - p < kLOG_THRESHOLD()) {
return 0.;
} else {
float q = 1 - p;
return -p * log(p) - q * log(q);
}
}
} // namespace
template <>
bool BernoulliJSDOp<float, CPUContext>::RunOnDevice() {
auto& X = Input(0); // predicted probabilities
auto& T = Input(1); // target probabilities
int N = X.numel();
CAFFE_ENFORCE_EQ(T.numel(), N);
auto* L = Output(0, X.sizes(), at::dtype<float>()); // JSD loss output
auto* x_data = X.data<float>();
auto* t_data = T.data<float>();
auto* l_data = L->template mutable_data<float>();
for (int i = 0; i < N; i++) {
auto p_mdl = x_data[i];
auto p_emp = t_data[i];
auto p_avg = (p_mdl + p_emp) / 2.;
auto jsd = entropy(p_avg) - (entropy(p_mdl) + entropy(p_emp)) / 2.;
l_data[i] = jsd;
}
return true;
}
template <>
bool BernoulliJSDGradientOp<float, CPUContext>::RunOnDevice() {
auto& go = Input(0);
auto& X = Input(1);
auto& T = Input(2);
int N = X.numel();
auto* gi = Output(0, X.sizes(), at::dtype<float>());
auto* go_data = go.data<float>();
auto* x_data = X.data<float>();
auto* t_data = T.data<float>();
auto* gi_data = gi->template mutable_data<float>();
for (int i = 0; i < N; i++) {
auto p_mdl = x_data[i];
auto p_emp = t_data[i];
auto p_avg = (p_mdl + p_emp) / 2.;
auto g_jsd = (logit(p_mdl) - logit(p_avg)) / 2.;
gi_data[i] = go_data[i] * g_jsd;
}
return true;
}
REGISTER_CPU_OPERATOR(BernoulliJSD, BernoulliJSDOp<float, CPUContext>);
REGISTER_CPU_OPERATOR(
BernoulliJSDGradient,
BernoulliJSDGradientOp<float, CPUContext>);
OPERATOR_SCHEMA(BernoulliJSD)
.NumInputs(2)
.NumOutputs(1)
.SetDoc(R"DOC(
Computes the Jensen-Shannon divergence (JSD) between two Bernoulli distributions
where each is parametrized by a single probability.
)DOC")
.Input(0, "X", "array of probabilities for prediction")
.Input(0, "T", "array of probabilities for target")
.Output(0, "L", "array of JSD losses");
OPERATOR_SCHEMA(BernoulliJSDGradient).NumInputs(3).NumOutputs(1);
class GetBernoulliJSDGradient : public GradientMakerBase {
using GradientMakerBase::GradientMakerBase;
vector<OperatorDef> GetGradientDefs() override {
return SingleGradientDef(
"BernoulliJSDGradient",
"",
vector<string>{GO(0), I(0), I(1)},
vector<string>{GI(0)});
}
};
REGISTER_GRADIENT(BernoulliJSD, GetBernoulliJSDGradient);
} // namespace caffe2