forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathcross_entropy_op.h
159 lines (135 loc) · 4.32 KB
/
cross_entropy_op.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
#ifndef CAFFE2_OPERATORS_CROSS_ENTROPY_OP_H_
#define CAFFE2_OPERATORS_CROSS_ENTROPY_OP_H_
#include "caffe2/core/context.h"
#include "caffe2/core/logging.h"
#include "caffe2/core/operator.h"
#include "caffe2/utils/math.h"
namespace caffe2 {
template <typename T, class Context>
class LabelCrossEntropyOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(LabelCrossEntropyOp);
USE_OPERATOR_CONTEXT_FUNCTIONS;
bool RunOnDevice() override;
protected:
static constexpr T kLOG_THRESHOLD() {
return static_cast<T>(1e-20);
}
// Input: X, label
// Output: Y
};
template <typename T, class Context>
class LabelCrossEntropyGradientOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(LabelCrossEntropyGradientOp);
USE_OPERATOR_CONTEXT_FUNCTIONS;
bool RunOnDevice() override;
protected:
// Input: X, label, dY
// Ouptut: dX. There is no gradient with respect to the label.
static constexpr T kLOG_THRESHOLD() {
return static_cast<T>(1e-20);
}
};
// Hacky: turns a vector of probabilities into a 2-column matrix with
// complimentary probabilities for binary classification
template <typename T, class Context>
class MakeTwoClassOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(MakeTwoClassOp);
USE_OPERATOR_CONTEXT_FUNCTIONS;
bool RunOnDevice() override;
protected:
// Input: X
// Output: Y = vstack(1-X, X)
};
template <typename T, class Context>
class MakeTwoClassGradientOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(MakeTwoClassGradientOp);
USE_OPERATOR_CONTEXT_FUNCTIONS;
bool RunOnDevice() override;
protected:
// Input: dY
// Ouptut: dX
};
template <typename T, class Context>
class SigmoidCrossEntropyWithLogitsOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
template <class... Args>
explicit SigmoidCrossEntropyWithLogitsOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
log_D_trick_(
this->template GetSingleArgument<bool>("log_D_trick", false)),
unjoined_lr_loss_(
this->template GetSingleArgument<bool>("unjoined_lr_loss", false)) {
CAFFE_ENFORCE(
!(log_D_trick_ && unjoined_lr_loss_),
"log_D_trick_ and unjoined_lr_loss_ cannot be set as True simultaneously");
}
bool RunOnDevice() override;
protected:
bool log_D_trick_;
bool unjoined_lr_loss_;
};
template <typename T, class Context>
class SigmoidCrossEntropyWithLogitsGradientOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
template <class... Args>
explicit SigmoidCrossEntropyWithLogitsGradientOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
log_D_trick_(
this->template GetSingleArgument<bool>("log_D_trick", false)),
unjoined_lr_loss_(
this->template GetSingleArgument<bool>("unjoined_lr_loss", false)) {
}
bool RunOnDevice() override;
protected:
bool log_D_trick_;
bool unjoined_lr_loss_;
};
template <typename T, class Context>
class WeightedSigmoidCrossEntropyWithLogitsOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(WeightedSigmoidCrossEntropyWithLogitsOp);
USE_OPERATOR_CONTEXT_FUNCTIONS;
bool RunOnDevice() override;
};
template <typename T, class Context>
class WeightedSigmoidCrossEntropyWithLogitsGradientOp final
: public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(WeightedSigmoidCrossEntropyWithLogitsGradientOp);
USE_OPERATOR_CONTEXT_FUNCTIONS;
bool RunOnDevice() override;
};
template <typename T, class Context>
class TORCH_API CrossEntropyOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(CrossEntropyOp);
USE_OPERATOR_CONTEXT_FUNCTIONS;
bool RunOnDevice() override;
protected:
// Input: X, label
// Output: Y
static constexpr T kLOG_THRESHOLD() {
return static_cast<T>(1e-20);
}
};
template <typename T, class Context>
class TORCH_API CrossEntropyGradientOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(CrossEntropyGradientOp);
USE_OPERATOR_CONTEXT_FUNCTIONS;
bool RunOnDevice() override;
protected:
// Input: X, label, dY
// Ouptut: dX. There is no gradient with respect to the label.
static constexpr T kLOG_THRESHOLD() {
return static_cast<T>(1e-20);
}
};
} // namespace caffe2
#endif // CAFFE2_OPERATORS_CROSS_ENTROPY_OP_H_