Skip to content

Commit 2bb3513

Browse files
committed
Fuse locations when fusing convs
Signed-off-by: Jonas Rickert <[email protected]>
1 parent ad65f0e commit 2bb3513

File tree

2 files changed

+51
-0
lines changed

2 files changed

+51
-0
lines changed

src/Dialect/ONNX/Transforms/Recompose.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2828
#include "llvm/Support/Debug.h"
2929

30+
3031
#include "src/Dialect/ONNX/DialectBuilder.hpp"
3132
#include "src/Dialect/ONNX/ONNXOps.hpp"
3233
#include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp"
@@ -846,7 +847,11 @@ struct CombineParallelConv2DPattern : public OpRewritePattern<ONNXConvOp> {
846847
SmallVector<ONNXConvOp> parallelConvs = candidateConvs;
847848

848849
bool allHaveBias = !mlir::isa<NoneType>(parallelConvs[0].getB().getType());
850+
849851
Location loc = convOp1.getLoc();
852+
for (auto conv : parallelConvs) {
853+
loc = rewriter.getFusedLoc({loc, conv.getLoc()});
854+
}
850855
auto inputType = mlir::cast<ShapedType>(input.getType());
851856
Type elementType = inputType.getElementType();
852857
onnx_mlir::MultiDialectBuilder<onnx_mlir::OnnxBuilder> create(

test/mlir/onnx/onnx_recompose_locations.mlir

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,49 @@ func.func @layernorm_without_bias(%x: tensor<1x384x768xf32>, %scale: tensor<768x
2525
// CHECK-DAG: [[LOC_LN_MUL:#.+]] = loc("lnMul")
2626
// CHECK: [[LOC_FUSED]] = loc(fused[[[LOC_M_REDUCE]], [[LOC_SUB]], [[LOC_DD_MUL]], [[LOC_V_REDUCE]], [[LOC_ADD]], [[LOC_SQRT]], [[LOC_DIV]], [[LOC_LN_MUL]]])
2727
}
28+
29+
30+
// -----
31+
32+
func.func @test_combine_conv_split(%arg0: tensor<1x1x512x512xf32>) -> tensor<1x96x512x512xf32> {
33+
%0 = onnx.Constant dense<0.00999999976> : tensor<32x1x3x3xf32>
34+
%1 = onnx.Constant dense<0.00999999976> : tensor<32xf32>
35+
%2 = onnx.Constant dense<0.00999999976> : tensor<32x1x3x3xf32>
36+
%3 = onnx.Constant dense<0.00999999976> : tensor<32xf32>
37+
%4 = onnx.Constant dense<0.00999999976> : tensor<32x1x3x3xf32>
38+
%5 = onnx.Constant dense<0.00999999976> : tensor<32xf32>
39+
%6 = "onnx.Conv"(%arg0, %0, %1) {auto_pad = "NOTSET", group = 1 : si64, pads = [1, 1, 1, 1]} : (tensor<1x1x512x512xf32>, tensor<32x1x3x3xf32>, tensor<32xf32>) -> tensor<1x32x512x512xf32> loc("conv1")
40+
%7 = "onnx.Conv"(%arg0, %2, %3) {auto_pad = "NOTSET", group = 1 : si64, pads = [1, 1, 1, 1]} : (tensor<1x1x512x512xf32>, tensor<32x1x3x3xf32>, tensor<32xf32>) -> tensor<1x32x512x512xf32> loc("conv2")
41+
%8 = "onnx.Conv"(%arg0, %4, %5) {auto_pad = "NOTSET", group = 1 : si64, pads = [1, 1, 1, 1]} : (tensor<1x1x512x512xf32>, tensor<32x1x3x3xf32>, tensor<32xf32>) -> tensor<1x32x512x512xf32> loc("conv3")
42+
%9 = "onnx.Relu"(%6) : (tensor<1x32x512x512xf32>) -> tensor<1x32x512x512xf32> loc("relu")
43+
%10 = "onnx.Sigmoid"(%7) : (tensor<1x32x512x512xf32>) -> tensor<1x32x512x512xf32> loc("sigmoid")
44+
%11 = "onnx.Tanh"(%8) : (tensor<1x32x512x512xf32>) -> tensor<1x32x512x512xf32> loc("tanh")
45+
%12 = "onnx.Concat"(%9, %10, %11) {axis = 1 : si64} : (tensor<1x32x512x512xf32>, tensor<1x32x512x512xf32>, tensor<1x32x512x512xf32>) -> tensor<1x96x512x512xf32> loc("concat")
46+
return %12 : tensor<1x96x512x512xf32>
47+
48+
// CHECK-LABEL: func.func @test_combine_conv_split
49+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x1x512x512xf32>
50+
// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<32> : tensor<3xi64>
51+
// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<0.00999999977> : tensor<32x1x3x3xf32>
52+
// CHECK-DAG: [[VAR_2_:%.+]] = onnx.Constant dense<0.00999999977> : tensor<32xf32>
53+
// CHECK-NOT: separator of consecutive DAGs
54+
// CHECK-DAG: [[VAR_3_:%.+]] = "onnx.Concat"([[VAR_1_]], [[VAR_1_]], [[VAR_1_]]) {axis = 0 : si64} : (tensor<32x1x3x3xf32>, tensor<32x1x3x3xf32>, tensor<32x1x3x3xf32>) -> tensor<96x1x3x3xf32> loc([[LOC_FUSED:#.+]])
55+
// CHECK-DAG: [[VAR_4_:%.+]] = "onnx.Concat"([[VAR_2_]], [[VAR_2_]], [[VAR_2_]]) {axis = 0 : si64} : (tensor<32xf32>, tensor<32xf32>, tensor<32xf32>) -> tensor<96xf32> loc([[LOC_FUSED:#.+]])
56+
// CHECK: [[VAR_5_:%.+]] = "onnx.Conv"([[PARAM_0_]], [[VAR_3_]], [[VAR_4_]]) {auto_pad = "NOTSET", group = 1 : si64, pads = [1, 1, 1, 1]} : (tensor<1x1x512x512xf32>, tensor<96x1x3x3xf32>, tensor<96xf32>) -> tensor<1x96x512x512xf32> loc([[LOC_FUSED:#.+]])
57+
// CHECK: [[VAR_6_:%.+]]:3 = "onnx.Split"([[VAR_5_]], [[VAR_0_]]) {axis = 1 : si64} : (tensor<1x96x512x512xf32>, tensor<3xi64>) -> (tensor<1x32x512x512xf32>, tensor<1x32x512x512xf32>, tensor<1x32x512x512xf32>) loc([[LOC_FUSED:#.+]])
58+
// CHECK-DAG: [[VAR_7_:%.+]] = "onnx.Relu"([[VAR_6_]]#2) : (tensor<1x32x512x512xf32>) -> tensor<1x32x512x512xf32> loc([[LOC_RELU:#.+]])
59+
// CHECK-DAG: [[VAR_8_:%.+]] = "onnx.Sigmoid"([[VAR_6_]]#1) : (tensor<1x32x512x512xf32>) -> tensor<1x32x512x512xf32> loc([[LOC_SIGMOID:#.+]])
60+
// CHECK-DAG: [[VAR_9_:%.+]] = "onnx.Tanh"([[VAR_6_]]#0) : (tensor<1x32x512x512xf32>) -> tensor<1x32x512x512xf32> loc([[LOC_TANH:#.+]])
61+
// CHECK: [[VAR_10_:%.+]] = "onnx.Concat"([[VAR_7_]], [[VAR_8_]], [[VAR_9_]]) {axis = 1 : si64} : (tensor<1x32x512x512xf32>, tensor<1x32x512x512xf32>, tensor<1x32x512x512xf32>) -> tensor<1x96x512x512xf32> loc([[LOC_ORIGINAL_CONCAT:#.+]])
62+
// CHECK: return [[VAR_10_]] : tensor<1x96x512x512xf32>
63+
// CHECK: }
64+
65+
// CHECK-DAG: [[LOC_RELU:#.+]] = loc("relu")
66+
// CHECK-DAG: [[LOC_SIGMOID:#.+]] = loc("sigmoid")
67+
// CHECK-DAG: [[LOC_TANH:#.+]] = loc("tanh")
68+
// CHECK-DAG: [[LOC_ORIGINAL_CONCAT:#.+]] = loc("concat")
69+
// CHECK-DAG: [[LOC_CONV1:#.+]] = loc("conv1")
70+
// CHECK-DAG: [[LOC_CONV2:#.+]] = loc("conv2")
71+
// CHECK-DAG: [[LOC_CONV3:#.+]] = loc("conv3")
72+
// CHECK-DAG: [[LOC_FUSED]] = loc(fused[[[LOC_CONV1]], [[LOC_CONV3]], [[LOC_CONV2]]])
73+
}

0 commit comments

Comments
 (0)