@@ -191,10 +191,11 @@ std::vector<std::string> OnnxConverter::Convert(
191
191
// Please check out "dabnn_*" pases in
192
192
// https://github.com/daquexian/onnx/blob/optimizer_for_bnn/onnx/optimizer/passes
193
193
// for details.
194
- vector<string> optimizers{" eliminate_nop_pad" ,
195
- " extract_constant_to_initializer" ,
196
- " dabnn_convert_gemm_with_reshape_or_flatten_to_conv_and_reshape" ,
197
- " dabnn_bconv_strict" };
194
+ vector<string> optimizers{
195
+ " eliminate_nop_pad" , " extract_constant_to_initializer" ,
196
+ " dabnn_eliminate_dropout" ,
197
+ " dabnn_convert_gemm_with_reshape_or_flatten_to_conv_and_reshape" ,
198
+ " dabnn_bconv_strict" };
198
199
if (level == Level::kModerate || level == Level::kAggressive ) {
199
200
optimizers.push_back (" dabnn_bconv_moderate" );
200
201
}
@@ -231,13 +232,23 @@ std::vector<std::string> OnnxConverter::Convert(
231
232
}
232
233
233
234
Shape shape;
234
- for (const auto &dim : input.type ().tensor_type ().shape ().dim ()) {
235
+ const auto &dims = input.type ().tensor_type ().shape ().dim ();
236
+ FORZ (i, dims.size ()) {
237
+ if (i == 0 ) {
238
+ // We ignore the value of batch dimension since dabnn doesn't
239
+ // support batch input
240
+ shape.push_back (1 );
241
+ continue ;
242
+ }
243
+ const auto &dim = dims.Get (i);
235
244
if (dim.value_case () ==
236
245
ONNX_NAMESPACE::TensorShapeProto_Dimension::kDimValue ) {
237
246
shape.push_back (static_cast <uint32_t >(dim.dim_value ()));
238
247
} else {
239
248
throw std::invalid_argument (
240
- " The input of graph doesn't have dim_value" );
249
+ " Dim " + std::to_string (i) + " of input \" " + input.name () +
250
+ " \" is not static, please re-export your ONNX model with "
251
+ " static input shape" );
241
252
}
242
253
}
243
254
Shape nhwc_shape{shape[0 ], shape[2 ], shape[3 ], shape[1 ]};
@@ -248,17 +259,16 @@ std::vector<std::string> OnnxConverter::Convert(
248
259
}
249
260
250
261
vector<string> binary_conv_outputs;
251
- vector<string> skipped_act;
252
262
bool has_reshape = false ;
253
263
for (const auto &node : model_proto_.graph ().node ()) {
254
- if (has_reshape) {
255
- throw std::invalid_argument (
256
- " Reshape can only be the last layer for now" );
257
- }
258
264
NodeAttrHelper helper (node);
259
265
const auto &op = node.op_type ();
260
266
VLOG (5 ) << " Node " << node.name ();
261
267
if (op == " Conv" ) {
268
+ if (has_reshape) {
269
+ throw std::invalid_argument (" Reshape before " + op +
270
+ " is not supported" );
271
+ }
262
272
VLOG (5 ) << " Start converting Conv" ;
263
273
auto strides = helper.get (" strides" , vector<int >{1 , 1 });
264
274
auto pads = helper.get (" pads" , vector<int >{0 , 0 , 0 , 0 });
@@ -308,6 +318,10 @@ std::vector<std::string> OnnxConverter::Convert(
308
318
VLOG (5 ) << " Converting Conv completed" ;
309
319
} else if (op == " AveragePool" || op == " MaxPool" ||
310
320
op == " GlobalAveragePool" || op == " GlobalMaxPool" ) {
321
+ if (has_reshape) {
322
+ throw std::invalid_argument (" Reshape before " + op +
323
+ " is not supported" );
324
+ }
311
325
VLOG (5 ) << " Start converting Pool" ;
312
326
auto input_name = m (node.input (0 ));
313
327
auto output_name = m (node.output (0 ));
@@ -396,6 +410,10 @@ std::vector<std::string> OnnxConverter::Convert(
396
410
layers_.push_back (layer);
397
411
VLOG (5 ) << " Converting Relu completed" ;
398
412
} else if (op == " Add" ) {
413
+ if (has_reshape) {
414
+ throw std::invalid_argument (" Reshape before " + op +
415
+ " is not supported" );
416
+ }
399
417
VLOG (5 ) << " Start converting Add" ;
400
418
auto input1_name = m (node.input (0 ));
401
419
auto input2_name = m (node.input (1 ));
@@ -409,6 +427,9 @@ std::vector<std::string> OnnxConverter::Convert(
409
427
layers_.push_back (layer);
410
428
VLOG (5 ) << " Converting Add completed" ;
411
429
} else if (op == " Gemm" ) {
430
+ if (has_reshape) {
431
+ has_reshape = false ;
432
+ }
412
433
VLOG (5 ) << " Start converting Gemm" ;
413
434
auto transA = helper.get (" transA" , 0 );
414
435
auto transB = helper.get (" transB" , 0 );
@@ -467,6 +488,10 @@ std::vector<std::string> OnnxConverter::Convert(
467
488
layers_.push_back (layer);
468
489
VLOG (5 ) << " Converting Softmax completed" ;
469
490
} else if (op == " Concat" ) {
491
+ if (has_reshape) {
492
+ throw std::invalid_argument (" Reshape before " + op +
493
+ " is not supported" );
494
+ }
470
495
VLOG (5 ) << " Start converting Concat" ;
471
496
vector<std::string> concat_inputs_str;
472
497
for (const auto &onnx_input : node.input ()) {
@@ -486,11 +511,6 @@ std::vector<std::string> OnnxConverter::Convert(
486
511
0 , 0 , 0 , 0 , 0 , 0 , param);
487
512
layers_.push_back (layer);
488
513
VLOG (5 ) << " Converting Concat completed" ;
489
- } else if (op == " Dropout" ) {
490
- VLOG (5 ) << " Start converting Dropout" ;
491
- // Dropout does nothing, so the output is the same as the input
492
- name_map_[node.output (0 )] = m (node.input (0 ));
493
- VLOG (5 ) << " Converting Dropout completed" ;
494
514
} else if (op == " Reshape" ) {
495
515
VLOG (5 ) << " Start converting Reshape" ;
496
516
has_reshape = true ;
0 commit comments