forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathinference_lstm_op.h
311 lines (277 loc) · 9.69 KB
/
inference_lstm_op.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
#ifndef LSTM_OP_H_
#define LSTM_OP_H_
#include <algorithm>
#include <sstream>
#include <unordered_map>
#include <vector>
#include "caffe2/core/blob_serialization.h"
#include "caffe2/core/export_caffe2_op_to_c10.h"
#include <c10/util/irange.h>
#include "caffe2/core/operator.h"
#include "caffe2/core/tensor.h"
#include "caffe2/utils/eigen_utils.h"
#include "caffe2/utils/math.h"
#include "lstm_utils.h"
C10_DECLARE_EXPORT_CAFFE2_OP_TO_C10(LSTMOp);
namespace caffe2 {
namespace {
using t_tuple = std::tuple<Tensor, Tensor>;
struct CellParams {
CellParams(
const Tensor& _w_ih,
const Tensor& _w_hh,
const Tensor& _b_ih,
const Tensor& _b_hh,
CPUContext* _context) {
initParams(_w_ih, _w_hh, _b_ih, _b_hh, _context);
}
CellParams(const CellParams& rhs) {
initParams(rhs.w_ih, rhs.w_hh, rhs.b_ih, rhs.b_hh, rhs.context);
}
CellParams& operator=(const CellParams& rhs) {
initParams(rhs.w_ih, rhs.w_hh, rhs.b_ih, rhs.b_hh, rhs.context);
return *this;
}
void initParams(
const Tensor& _w_ih,
const Tensor& _w_hh,
const Tensor& _b_ih,
const Tensor& _b_hh,
CPUContext* _context) {
w_ih = copy_ctor(_w_ih);
w_hh = copy_ctor(_w_hh);
b_ih = copy_ctor(_b_ih);
b_hh = copy_ctor(_b_hh);
context = _context;
}
Tensor w_ih;
Tensor w_hh;
Tensor b_ih; /* optional */
Tensor b_hh; /* optional */
CPUContext* context;
Tensor linear_ih(const Tensor& input) const {
return linear(input, w_ih, b_ih, context);
}
Tensor linear_hh(const Tensor& h) const {
return linear(h, w_hh, b_hh, context);
}
};
struct LSTMCell {
explicit LSTMCell(CPUContext* context) : context_(context) {}
t_tuple operator()(
const Tensor& input,
const t_tuple& hidden,
const CellParams& params) const {
const auto& hx = std::get<0>(hidden);
const auto& cx = std::get<1>(hidden);
auto linear_ih = params.linear_ih(input);
auto linear_hh = params.linear_hh(hx);
auto gates = add(linear_ih, linear_hh, context_);
auto chunked_gates = chunk(gates, 4, 1, context_);
auto ingate = sigmoid(chunked_gates[0]);
auto forgetgate = sigmoid(chunked_gates[1]);
auto cellgate = tanh(chunked_gates[2], context_);
auto outgate = sigmoid(chunked_gates[3]);
auto cy =
add(mul(forgetgate, cx, context_),
mul(ingate, cellgate, context_),
context_);
auto hy = mul(outgate, tanh(cy, context_), context_);
return std::make_tuple(std::move(hy), std::move(cy));
}
CPUContext* context_;
};
template <typename output_type, typename hidden_type>
struct LayerOutput {
output_type outputs;
hidden_type final_hidden;
LayerOutput(const output_type& _outputs, const hidden_type& _hidden) {
outputs = copy_ctor(_outputs);
final_hidden = copy_ctor(_hidden);
}
};
template <typename hidden_type, typename param_type>
struct Layer {
using output_type = LayerOutput<Tensor, hidden_type>;
virtual ~Layer() {}
virtual output_type operator()(
const Tensor& input,
const hidden_type& input_hidden,
const param_type& params) const = 0;
};
struct FullLSTMLayer : Layer<t_tuple, CellParams> {
FullLSTMLayer(LSTMCell& cell, CPUContext* context)
: cell_(cell), context_(context) {}
LayerOutput<std::vector<Tensor>, t_tuple> operator()(
const std::vector<Tensor>& step_inputs,
const std::tuple<Tensor, Tensor>& input_hidden,
const CellParams& params) const {
std::vector<Tensor> step_outputs;
auto hidden = copy_ctor(input_hidden);
for (const auto i : c10::irange(step_inputs.size())) {
hidden = cell_(step_inputs[i], hidden, params);
step_outputs.push_back(copy_ctor(std::get<0>(hidden)));
}
return {step_outputs, hidden};
}
LayerOutput<Tensor, t_tuple> operator()(
const Tensor& inputs,
const std::tuple<Tensor, Tensor>& input_hidden,
const CellParams& params) const override {
auto unstacked_output =
(*this)(unbind(inputs, 0, context_), input_hidden, params);
return {stack(unstacked_output.outputs, 0, context_),
unstacked_output.final_hidden};
}
LSTMCell cell_;
CPUContext* context_;
};
struct FullBidirectionalLSTMLayer
: Layer<std::pair<t_tuple, t_tuple>, std::pair<CellParams, CellParams>> {
using bidir_hidden_type = std::pair<t_tuple, t_tuple>;
using param_type = std::pair<CellParams, CellParams>;
using output_type = LayerOutput<Tensor, bidir_hidden_type>;
FullBidirectionalLSTMLayer(LSTMCell& cell, CPUContext* context)
: layer_(cell, context), context_(context) {}
output_type operator()(
const Tensor& input,
const bidir_hidden_type& input_hidden,
const param_type& params) const override {
std::vector<Tensor> outputs;
auto step_inputs = unbind(input, 0, context_);
auto fw_result = layer_(step_inputs, input_hidden.first, params.first);
auto fw_output = stack(fw_result.outputs, 0, context_);
outputs.push_back(copy_ctor(fw_output));
auto rev_step_inputs = reverse(std::move(step_inputs));
auto rev_result =
layer_(rev_step_inputs, input_hidden.second, params.second);
std::reverse(rev_result.outputs.begin(), rev_result.outputs.end());
auto rev_output = stack(rev_result.outputs, 0, context_);
outputs.push_back(copy_ctor(rev_output));
return {cat(outputs, fw_output.dim() - 1, context_),
std::make_pair(
std::move(fw_result.final_hidden),
std::move(rev_result.final_hidden))};
}
inline std::vector<Tensor> reverse(std::vector<Tensor>&& x) const {
std::reverse(x.begin(), x.end());
return std::move(x);
}
private:
FullLSTMLayer layer_;
CPUContext* context_;
};
template <typename hidden_type, typename weight_type>
LayerOutput<Tensor, std::vector<hidden_type>> apply_layer_stack(
const Layer<hidden_type, weight_type>& layer,
const Tensor& input,
const std::vector<hidden_type>& hiddens,
const std::vector<weight_type>& weights,
int64_t num_layers) {
CAFFE_ENFORCE(
num_layers == hiddens.size(),
"Expected more hidden states in stacked_rnn");
CAFFE_ENFORCE(
num_layers == weights.size(), "Expected more weights in stacked_rnn");
auto layer_input = input.UnsafeSharedInstance();
auto hidden_it = hiddens.begin();
auto weight_it = weights.begin();
std::vector<hidden_type> final_hiddens(num_layers);
for (const auto l : c10::irange(num_layers)) {
auto layer_output = layer(layer_input, *(hidden_it++), *(weight_it++));
final_hiddens.at(l) = std::move(layer_output.final_hidden);
layer_input = std::move(layer_output.outputs);
}
return {layer_input, final_hiddens};
}
std::tuple<Tensor, Tensor, Tensor> _lstm_impl(
const Tensor& input,
const std::vector<CellParams>& params,
const Tensor& hx,
const Tensor& cx,
int64_t num_layers,
bool bidirectional,
CPUContext* context) {
using stack_output = LayerOutput<Tensor, std::vector<t_tuple>>;
auto layer_hx = unbind(hx, 0, context);
auto layer_cx = unbind(cx, 0, context);
int64_t total_layers = layer_hx.size();
std::vector<std::tuple<Tensor, Tensor>> hiddens;
hiddens.reserve(total_layers);
for (const auto i : c10::irange(total_layers)) {
hiddens.emplace_back(std::move(layer_hx[i]), std::move(layer_cx[i]));
}
LSTMCell cell(context);
std::shared_ptr<stack_output> stack_output_ptr;
if (bidirectional) {
auto bidir_result = apply_layer_stack(
FullBidirectionalLSTMLayer{cell, context},
input,
pair_vec(hiddens),
pair_vec(params),
num_layers);
stack_output_ptr.reset(new stack_output(
bidir_result.outputs,
unpair_vec(std::move(bidir_result.final_hidden))));
} else {
auto result = apply_layer_stack(
FullLSTMLayer{cell, context}, input, hiddens, params, num_layers);
stack_output_ptr = std::make_shared<stack_output>(std::move(result));
}
std::vector<Tensor> hy, cy;
hy.reserve(total_layers);
cy.reserve(total_layers);
for (auto& hidden : stack_output_ptr->final_hidden) {
hy.push_back(std::move(std::get<0>(hidden)));
cy.push_back(std::move(std::get<1>(hidden)));
}
return std::make_tuple(
std::move(stack_output_ptr->outputs),
stack(hy, 0, context),
stack(cy, 0, context));
}
// Parses a flat list of parameter tensors into a list of CellParams
std::vector<CellParams> gather_params(
const std::vector<Tensor>& params,
bool has_biases,
CPUContext* context) {
Tensor undefined;
std::vector<CellParams> result;
if (has_biases) {
CAFFE_ENFORCE_EQ(
params.size() % 4, 0, "got an incorrect number of LSTM parameters");
for (size_t i = 0; i < params.size(); i += 4) {
result.emplace_back(
params[i], params[i + 1], params[i + 2], params[i + 3], context);
}
} else {
CAFFE_ENFORCE_EQ(
params.size() % 2, 0, "got an incorrect number of LSTM parameters");
for (size_t i = 0; i < params.size(); i += 2) {
result.emplace_back(
params[i], params[i + 1], undefined, undefined, context);
}
}
return result;
}
class InferenceLSTMOp : public Operator<CPUContext> {
public:
template <class... Args>
explicit InferenceLSTMOp(Args&&... args)
: Operator(std::forward<Args>(args)...),
num_layers_(this->template GetSingleArgument<int64_t>("num_layers", 1)),
bidirectional_(
this->template GetSingleArgument<bool>("bidirectional", false)),
has_biases_(this->template GetSingleArgument<bool>("has_biases", true)),
batch_first_(
this->template GetSingleArgument<bool>("batch_first", false)) {}
bool RunOnDevice() override;
protected:
int64_t num_layers_;
bool bidirectional_;
bool has_biases_;
bool batch_first_;
};
} // namespace
} // namespace caffe2
#endif // LSTM_OP_H_