Skip to content

Commit f61355c

Browse files
authored
Merge pull request #68 from JDAI-CV/ignore_batch_size
Ignore batch size and improve error msg in onnx2daq
2 parents c349f00 + 4898919 commit f61355c

File tree

2 files changed

+37
-17
lines changed

2 files changed

+37
-17
lines changed

third_party/onnx

Submodule onnx updated 833 files

tools/onnx2bnn/OnnxConverter.cpp

Lines changed: 36 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -191,10 +191,11 @@ std::vector<std::string> OnnxConverter::Convert(
191191
// Please check out "dabnn_*" pases in
192192
// https://github.com/daquexian/onnx/blob/optimizer_for_bnn/onnx/optimizer/passes
193193
// for details.
194-
vector<string> optimizers{"eliminate_nop_pad",
195-
"extract_constant_to_initializer",
196-
"dabnn_convert_gemm_with_reshape_or_flatten_to_conv_and_reshape",
197-
"dabnn_bconv_strict"};
194+
vector<string> optimizers{
195+
"eliminate_nop_pad", "extract_constant_to_initializer",
196+
"dabnn_eliminate_dropout",
197+
"dabnn_convert_gemm_with_reshape_or_flatten_to_conv_and_reshape",
198+
"dabnn_bconv_strict"};
198199
if (level == Level::kModerate || level == Level::kAggressive) {
199200
optimizers.push_back("dabnn_bconv_moderate");
200201
}
@@ -231,13 +232,23 @@ std::vector<std::string> OnnxConverter::Convert(
231232
}
232233

233234
Shape shape;
234-
for (const auto &dim : input.type().tensor_type().shape().dim()) {
235+
const auto &dims = input.type().tensor_type().shape().dim();
236+
FORZ(i, dims.size()) {
237+
if (i == 0) {
238+
// We ignore the value of batch dimension since dabnn doesn't
239+
// support batch input
240+
shape.push_back(1);
241+
continue;
242+
}
243+
const auto &dim = dims.Get(i);
235244
if (dim.value_case() ==
236245
ONNX_NAMESPACE::TensorShapeProto_Dimension::kDimValue) {
237246
shape.push_back(static_cast<uint32_t>(dim.dim_value()));
238247
} else {
239248
throw std::invalid_argument(
240-
"The input of graph doesn't have dim_value");
249+
"Dim " + std::to_string(i) + " of input \"" + input.name() +
250+
"\" is not static, please re-export your ONNX model with "
251+
"static input shape");
241252
}
242253
}
243254
Shape nhwc_shape{shape[0], shape[2], shape[3], shape[1]};
@@ -248,17 +259,16 @@ std::vector<std::string> OnnxConverter::Convert(
248259
}
249260

250261
vector<string> binary_conv_outputs;
251-
vector<string> skipped_act;
252262
bool has_reshape = false;
253263
for (const auto &node : model_proto_.graph().node()) {
254-
if (has_reshape) {
255-
throw std::invalid_argument(
256-
"Reshape can only be the last layer for now");
257-
}
258264
NodeAttrHelper helper(node);
259265
const auto &op = node.op_type();
260266
VLOG(5) << "Node " << node.name();
261267
if (op == "Conv") {
268+
if (has_reshape) {
269+
throw std::invalid_argument("Reshape before " + op +
270+
" is not supported");
271+
}
262272
VLOG(5) << "Start converting Conv";
263273
auto strides = helper.get("strides", vector<int>{1, 1});
264274
auto pads = helper.get("pads", vector<int>{0, 0, 0, 0});
@@ -308,6 +318,10 @@ std::vector<std::string> OnnxConverter::Convert(
308318
VLOG(5) << "Converting Conv completed";
309319
} else if (op == "AveragePool" || op == "MaxPool" ||
310320
op == "GlobalAveragePool" || op == "GlobalMaxPool") {
321+
if (has_reshape) {
322+
throw std::invalid_argument("Reshape before " + op +
323+
" is not supported");
324+
}
311325
VLOG(5) << "Start converting Pool";
312326
auto input_name = m(node.input(0));
313327
auto output_name = m(node.output(0));
@@ -396,6 +410,10 @@ std::vector<std::string> OnnxConverter::Convert(
396410
layers_.push_back(layer);
397411
VLOG(5) << "Converting Relu completed";
398412
} else if (op == "Add") {
413+
if (has_reshape) {
414+
throw std::invalid_argument("Reshape before " + op +
415+
" is not supported");
416+
}
399417
VLOG(5) << "Start converting Add";
400418
auto input1_name = m(node.input(0));
401419
auto input2_name = m(node.input(1));
@@ -409,6 +427,9 @@ std::vector<std::string> OnnxConverter::Convert(
409427
layers_.push_back(layer);
410428
VLOG(5) << "Converting Add completed";
411429
} else if (op == "Gemm") {
430+
if (has_reshape) {
431+
has_reshape = false;
432+
}
412433
VLOG(5) << "Start converting Gemm";
413434
auto transA = helper.get("transA", 0);
414435
auto transB = helper.get("transB", 0);
@@ -467,6 +488,10 @@ std::vector<std::string> OnnxConverter::Convert(
467488
layers_.push_back(layer);
468489
VLOG(5) << "Converting Softmax completed";
469490
} else if (op == "Concat") {
491+
if (has_reshape) {
492+
throw std::invalid_argument("Reshape before " + op +
493+
" is not supported");
494+
}
470495
VLOG(5) << "Start converting Concat";
471496
vector<std::string> concat_inputs_str;
472497
for (const auto &onnx_input : node.input()) {
@@ -486,11 +511,6 @@ std::vector<std::string> OnnxConverter::Convert(
486511
0, 0, 0, 0, 0, 0, param);
487512
layers_.push_back(layer);
488513
VLOG(5) << "Converting Concat completed";
489-
} else if (op == "Dropout") {
490-
VLOG(5) << "Start converting Dropout";
491-
// Dropout does nothing, so the output is the same as the input
492-
name_map_[node.output(0)] = m(node.input(0));
493-
VLOG(5) << "Converting Dropout completed";
494514
} else if (op == "Reshape") {
495515
VLOG(5) << "Start converting Reshape";
496516
has_reshape = true;

0 commit comments

Comments
 (0)