Skip to content

Commit 66e82a9

Browse files
Conv add const where the constant is a scalar (#3145)
Signed-off-by: Alexandre Eichenberger <[email protected]>
1 parent 30602ba commit 66e82a9

File tree

2 files changed

+39
-11
lines changed

2 files changed

+39
-11
lines changed

src/Dialect/ONNX/ONNXOps/Canonicalize.td

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -62,19 +62,19 @@ def subtractOrNeg: NativeCodeCall<
6262
def getRankOf :
6363
NativeCodeCall<"mlir::cast<ShapedType>($0.getType()).getRank()">;
6464

65-
// Create an ArrayAttr of IntergerAttr(s) of [$0].
65+
// Create an ArrayAttr of IntegerAttr(s) of [$0].
6666
def createDenseElementsAttrOf : NativeCodeCall<
6767
"onnx_mlir::createDenseElementsAttrOfNToM($_builder, $0, $0)">;
6868

69-
// Create an ArrayAttr of IntergerAttr(s) of values in [1, N-1].
69+
// Create an ArrayAttr of IntegerAttr(s) of values in [1, N-1].
7070
def createDenseElementsAttrOfOneToRankOf : NativeCodeCall<
7171
"onnx_mlir::createDenseElementsAttrOfNToM($_builder, 1, mlir::cast<ShapedType>($0.getType()).getRank() - 1)">;
7272

73-
// Create an ArrayAttr of IntergerAttr(s) of values in [1, N-2].
73+
// Create an ArrayAttr of IntegerAttr(s) of values in [1, N-2].
7474
def createDenseElementsAttrOfOneToRankOfExclusive : NativeCodeCall<
7575
"onnx_mlir::createDenseElementsAttrOfNToM($_builder, 1, mlir::cast<ShapedType>($0.getType()).getRank() - 2)">;
7676

77-
// Create an ArrayAttr of IntergerAttr(s) of values in [2, rank - 1].
77+
// Create an ArrayAttr of IntegerAttr(s) of values in [2, rank - 1].
7878
def createArrayAttrOfTwoToRankOf : NativeCodeCall<
7979
"onnx_mlir::createArrayAttrOfNToM($_builder, 2, mlir::cast<ShapedType>($0.getType()).getRank() - 1)">;
8080

@@ -167,7 +167,7 @@ def HaveSameElementType : Constraint<
167167
def HaveSameElementTypeBitWidth: Constraint<
168168
CPred<"(mlir::dyn_cast<ShapedType>($0.getType()).getElementTypeBitWidth() == "
169169
"mlir::dyn_cast<ShapedType>($1.getType()).getElementTypeBitWidth())">,
170-
"has same element type bitwidth">;
170+
"has same element type bit-width">;
171171

172172
def ElementTypeIsNotUnsigned: Constraint<
173173
CPred<"!mlir::dyn_cast<ShapedType>($_self.getType()).getElementType().isUnsignedInteger()">,
@@ -334,8 +334,10 @@ def FuseAddConvNullBiasPattern: Pat<
334334
[(HasShapeAndRank:$res),
335335
(HasNoneType $b),
336336
(AttributeIsNotNull:$denseAttr),
337+
(RankXMinusRankYIs<1> $res, $y),
338+
(HasRankGT<0> $y),
337339
(AllDimsFromAxisToEndAre<1, 1>:$y),
338-
(RankXMinusRankYIs<1> $res, $y)]
340+
]
339341
>;
340342

341343
def FuseAddConvPattern: Pat<
@@ -356,8 +358,9 @@ def FuseAddConvPattern: Pat<
356358
[(HasShapeAndRank:$res),
357359
(NotNoneType $b),
358360
(AttributeIsNotNull:$denseAttr),
359-
(AllDimsFromAxisToEndAre<1, 1>:$y),
360-
(RankXMinusRankYIs<1> $res, $y)]
361+
(RankXMinusRankYIs<1> $res, $y),
362+
(HasRankGT<0> $y),
363+
(AllDimsFromAxisToEndAre<1, 1>:$y)]
361364
>;
362365

363366
//===----------------------------------------------------------------------===//
@@ -403,10 +406,11 @@ def FuseMulConvNullBiasPattern: Pat<
403406
(HasRankGT<1> $w), // rank of $w must be at least 2.
404407
(RankXMinusRankYIs<1> $w, $y), // rank($y) must be equal to rank($w)-1.
405408
(HaveSameDim<0> $w, $y), // the first dimension of $w and $y must be equal.
409+
(HasRankGT<0> $y), // constant cannot be a scalar.
406410
(AllDimsFromAxisToEndAre<1, 1>:$y)] // all dimensions of $y must be 1 except for the first one.
407411
>;
408412

409-
// TODO add pattern for non-null bias with contraints:
413+
// TODO add pattern for non-null bias with constraints:
410414
// - bias must be have rank equal to 1 and
411415
// - bias element data type must be the same as mul constant
412416
// - bias dimension (0) must be equal to mul constant dim(0)
@@ -904,7 +908,7 @@ def RewriteBatchNormInferenceModeConvPattern1: Pat<
904908

905909
// Special case of BatchNorm whose input shape is [N]. In this case, 'scale',
906910
// 'bias', 'mean', and 'var' will have shape of [1], according to ONNXBatchNorm
907-
// decription: https://github.com/onnx/onnx/blob/main/docs/Operators.md#inputs-12.
911+
// description: https://github.com/onnx/onnx/blob/main/docs/Operators.md#inputs-12.
908912
// Thus, we need not unsqueeze intermediate results.
909913
def RewriteBatchNormInferenceModeConvPattern2: Pat<
910914
(ONNXBatchNormalizationInferenceModeOp:$res
@@ -1089,7 +1093,7 @@ def ShapeTransformComposePattern : Pat<
10891093

10901094
// In this pattern, the condition in onnx.Where is always false, so we can replace
10911095
// onnx.Where by its "false" value.
1092-
// Condition in this pattern is a comparision between dimension sizes and negative values.
1096+
// Condition in this pattern is a comparison between dimension sizes and negative values.
10931097
// Since dimension sizes are always positive, the condition is evaluated to false.
10941098

10951099
// This pattern was found in xlm-roberta-base-language-detection model in HuggingFace.

test/mlir/onnx/onnx_canonicalization.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -887,6 +887,30 @@ func.func @test_fuse_add_conv_bias_unranked(%arg0 : tensor<*xf32>, %arg1 : tenso
887887

888888
// -----
889889

890+
// A bug was discovered when the constant being added was a scalar. This test
891+
// ensures that the compiler does not crash is such cases. Note that the fusion
892+
// does not occur, as we would need to first expand the constant to the right shape.
893+
894+
func.func @test_fuse_add_conv_with_scalar_const(%arg0 : tensor<1x1x28x28xf32>, %arg1 : tensor<8x1x5x5xf32>) -> tensor<1x8x28x28xf32> {
895+
%cst = "onnx.NoValue"() {value} : () -> none
896+
%0 = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "SAME_UPPER", dilations = [1, 1], group = 1 : si64, kernel_shape = [5, 5], onnx_node_name = "Convolution28", strides = [1, 1]} : (tensor<1x1x28x28xf32>, tensor<8x1x5x5xf32>, none) -> tensor<1x8x28x28xf32>
897+
%1 = "onnx.Constant"() {value = dense<2.0> : tensor<f32>} : () -> tensor<f32>
898+
%2 = "onnx.Add"(%0, %1) : (tensor<1x8x28x28xf32>, tensor<f32>) -> tensor<1x8x28x28xf32>
899+
onnx.Return %2 : tensor<1x8x28x28xf32>
900+
901+
// mlir2FileCheck.py
902+
// CHECK-LABEL: func.func @test_fuse_add_conv_with_scalar_const
903+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x1x28x28xf32>, [[PARAM_1_:%.+]]: tensor<8x1x5x5xf32>) -> tensor<1x8x28x28xf32> {
904+
// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<2.000000e+00> : tensor<f32>
905+
// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.NoValue"() {value} : () -> none
906+
// CHECK: [[VAR_2_:%.+]] = "onnx.Conv"([[PARAM_0_]], [[PARAM_1_]], [[VAR_1_]]) {auto_pad = "SAME_UPPER", dilations = [1, 1], group = 1 : si64, kernel_shape = [5, 5], onnx_node_name = "Convolution28", strides = [1, 1]} : (tensor<1x1x28x28xf32>, tensor<8x1x5x5xf32>, none) -> tensor<1x8x28x28xf32>
907+
// CHECK: [[VAR_3_:%.+]] = "onnx.Add"([[VAR_2_]], [[VAR_0_]]) : (tensor<1x8x28x28xf32>, tensor<f32>) -> tensor<1x8x28x28xf32>
908+
// CHECK: onnx.Return [[VAR_3_]] : tensor<1x8x28x28xf32>
909+
// CHECK: }
910+
}
911+
912+
// -----
913+
890914
func.func @test_fuse_mul_conv(%arg0: tensor<1x1x28x28xf32>) -> tensor<*xf32> {
891915
%0 = onnx.Constant dense<[[[[0.0234164055, 0.0228030644], [2.442580e-02, 0.0237577036]]], [[[-0.0410864502, 0.0488203131], [0.164448678, -0.0200194642]]], [[[-4.34581793E-9, 0.025325032], [0.0373019315, 0.165243402]]], [[[-0.0198689923, 0.131284416], [0.0572107285, 2.33985098E-8]]], [[[0.0187684372, -0.148515195], [0.0154875498, 0.019133633]]], [[[0.0176953916, -0.0154658081], [0.0233727545, -0.274110436]]], [[[-0.021181887, 0.0936150252], [0.135688141, -0.0202601217]]], [[[-0.0201558527, 0.0192655921], [0.227748245, -0.196346223]]]]> : tensor<8x1x2x2xf32>
892916
%1 = "onnx.NoValue"() {value} : () -> none

0 commit comments

Comments
 (0)