@@ -25,3 +25,49 @@ func.func @layernorm_without_bias(%x: tensor<1x384x768xf32>, %scale: tensor<768x
25
25
// CHECK-DAG: [[LOC_LN_MUL:#.+]] = loc("lnMul")
26
26
// CHECK: [[LOC_FUSED]] = loc(fused[[[LOC_M_REDUCE]], [[LOC_SUB]], [[LOC_DD_MUL]], [[LOC_V_REDUCE]], [[LOC_ADD]], [[LOC_SQRT]], [[LOC_DIV]], [[LOC_LN_MUL]]])
27
27
}
28
+
29
+
30
+ // -----
31
+
32
+ func.func @test_combine_conv_split (%arg0: tensor <1 x1 x512 x512 xf32 >) -> tensor <1 x96 x512 x512 xf32 > {
33
+ %0 = onnx.Constant dense <0.00999999976 > : tensor <32 x1 x3 x3 xf32 >
34
+ %1 = onnx.Constant dense <0.00999999976 > : tensor <32 xf32 >
35
+ %2 = onnx.Constant dense <0.00999999976 > : tensor <32 x1 x3 x3 xf32 >
36
+ %3 = onnx.Constant dense <0.00999999976 > : tensor <32 xf32 >
37
+ %4 = onnx.Constant dense <0.00999999976 > : tensor <32 x1 x3 x3 xf32 >
38
+ %5 = onnx.Constant dense <0.00999999976 > : tensor <32 xf32 >
39
+ %6 = " onnx.Conv" (%arg0 , %0 , %1 ) {auto_pad = " NOTSET" , group = 1 : si64 , pads = [1 , 1 , 1 , 1 ]} : (tensor <1 x1 x512 x512 xf32 >, tensor <32 x1 x3 x3 xf32 >, tensor <32 xf32 >) -> tensor <1 x32 x512 x512 xf32 > loc (" conv1" )
40
+ %7 = " onnx.Conv" (%arg0 , %2 , %3 ) {auto_pad = " NOTSET" , group = 1 : si64 , pads = [1 , 1 , 1 , 1 ]} : (tensor <1 x1 x512 x512 xf32 >, tensor <32 x1 x3 x3 xf32 >, tensor <32 xf32 >) -> tensor <1 x32 x512 x512 xf32 > loc (" conv2" )
41
+ %8 = " onnx.Conv" (%arg0 , %4 , %5 ) {auto_pad = " NOTSET" , group = 1 : si64 , pads = [1 , 1 , 1 , 1 ]} : (tensor <1 x1 x512 x512 xf32 >, tensor <32 x1 x3 x3 xf32 >, tensor <32 xf32 >) -> tensor <1 x32 x512 x512 xf32 > loc (" conv3" )
42
+ %9 = " onnx.Relu" (%6 ) : (tensor <1 x32 x512 x512 xf32 >) -> tensor <1 x32 x512 x512 xf32 > loc (" relu" )
43
+ %10 = " onnx.Sigmoid" (%7 ) : (tensor <1 x32 x512 x512 xf32 >) -> tensor <1 x32 x512 x512 xf32 > loc (" sigmoid" )
44
+ %11 = " onnx.Tanh" (%8 ) : (tensor <1 x32 x512 x512 xf32 >) -> tensor <1 x32 x512 x512 xf32 > loc (" tanh" )
45
+ %12 = " onnx.Concat" (%9 , %10 , %11 ) {axis = 1 : si64 } : (tensor <1 x32 x512 x512 xf32 >, tensor <1 x32 x512 x512 xf32 >, tensor <1 x32 x512 x512 xf32 >) -> tensor <1 x96 x512 x512 xf32 > loc (" concat" )
46
+ return %12 : tensor <1 x96 x512 x512 xf32 >
47
+
48
+ // CHECK-LABEL: func.func @test_combine_conv_split
49
+ // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x1x512x512xf32>
50
+ // CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<32> : tensor<3xi64>
51
+ // CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<0.00999999977> : tensor<32x1x3x3xf32>
52
+ // CHECK-DAG: [[VAR_2_:%.+]] = onnx.Constant dense<0.00999999977> : tensor<32xf32>
53
+ // CHECK-NOT: separator of consecutive DAGs
54
+ // CHECK-DAG: [[VAR_3_:%.+]] = "onnx.Concat"([[VAR_1_]], [[VAR_1_]], [[VAR_1_]]) {axis = 0 : si64} : (tensor<32x1x3x3xf32>, tensor<32x1x3x3xf32>, tensor<32x1x3x3xf32>) -> tensor<96x1x3x3xf32> loc([[LOC_FUSED:#.+]])
55
+ // CHECK-DAG: [[VAR_4_:%.+]] = "onnx.Concat"([[VAR_2_]], [[VAR_2_]], [[VAR_2_]]) {axis = 0 : si64} : (tensor<32xf32>, tensor<32xf32>, tensor<32xf32>) -> tensor<96xf32> loc([[LOC_FUSED:#.+]])
56
+ // CHECK: [[VAR_5_:%.+]] = "onnx.Conv"([[PARAM_0_]], [[VAR_3_]], [[VAR_4_]]) {auto_pad = "NOTSET", group = 1 : si64, pads = [1, 1, 1, 1]} : (tensor<1x1x512x512xf32>, tensor<96x1x3x3xf32>, tensor<96xf32>) -> tensor<1x96x512x512xf32> loc([[LOC_FUSED:#.+]])
57
+ // CHECK: [[VAR_6_:%.+]]:3 = "onnx.Split"([[VAR_5_]], [[VAR_0_]]) {axis = 1 : si64} : (tensor<1x96x512x512xf32>, tensor<3xi64>) -> (tensor<1x32x512x512xf32>, tensor<1x32x512x512xf32>, tensor<1x32x512x512xf32>) loc([[LOC_FUSED:#.+]])
58
+ // CHECK-DAG: [[VAR_7_:%.+]] = "onnx.Relu"([[VAR_6_]]#2) : (tensor<1x32x512x512xf32>) -> tensor<1x32x512x512xf32> loc([[LOC_RELU:#.+]])
59
+ // CHECK-DAG: [[VAR_8_:%.+]] = "onnx.Sigmoid"([[VAR_6_]]#1) : (tensor<1x32x512x512xf32>) -> tensor<1x32x512x512xf32> loc([[LOC_SIGMOID:#.+]])
60
+ // CHECK-DAG: [[VAR_9_:%.+]] = "onnx.Tanh"([[VAR_6_]]#0) : (tensor<1x32x512x512xf32>) -> tensor<1x32x512x512xf32> loc([[LOC_TANH:#.+]])
61
+ // CHECK: [[VAR_10_:%.+]] = "onnx.Concat"([[VAR_7_]], [[VAR_8_]], [[VAR_9_]]) {axis = 1 : si64} : (tensor<1x32x512x512xf32>, tensor<1x32x512x512xf32>, tensor<1x32x512x512xf32>) -> tensor<1x96x512x512xf32> loc([[LOC_ORIGINAL_CONCAT:#.+]])
62
+ // CHECK: return [[VAR_10_]] : tensor<1x96x512x512xf32>
63
+ // CHECK: }
64
+
65
+ // CHECK-DAG: [[LOC_RELU:#.+]] = loc("relu")
66
+ // CHECK-DAG: [[LOC_SIGMOID:#.+]] = loc("sigmoid")
67
+ // CHECK-DAG: [[LOC_TANH:#.+]] = loc("tanh")
68
+ // CHECK-DAG: [[LOC_ORIGINAL_CONCAT:#.+]] = loc("concat")
69
+ // CHECK-DAG: [[LOC_CONV1:#.+]] = loc("conv1")
70
+ // CHECK-DAG: [[LOC_CONV2:#.+]] = loc("conv2")
71
+ // CHECK-DAG: [[LOC_CONV3:#.+]] = loc("conv3")
72
+ // CHECK-DAG: [[LOC_FUSED]] = loc(fused[[[LOC_CONV1]], [[LOC_CONV3]], [[LOC_CONV2]]])
73
+ }
0 commit comments