@@ -62,19 +62,19 @@ def subtractOrNeg: NativeCodeCall<
6262def getRankOf :
6363 NativeCodeCall<"mlir::cast<ShapedType>($0.getType()).getRank()">;
6464
65- // Create an ArrayAttr of IntergerAttr (s) of [$0].
65+ // Create an ArrayAttr of IntegerAttr (s) of [$0].
6666def createDenseElementsAttrOf : NativeCodeCall<
6767 "onnx_mlir::createDenseElementsAttrOfNToM($_builder, $0, $0)">;
6868
69- // Create an ArrayAttr of IntergerAttr (s) of values in [1, N-1].
69+ // Create an ArrayAttr of IntegerAttr (s) of values in [1, N-1].
7070def createDenseElementsAttrOfOneToRankOf : NativeCodeCall<
7171 "onnx_mlir::createDenseElementsAttrOfNToM($_builder, 1, mlir::cast<ShapedType>($0.getType()).getRank() - 1)">;
7272
73- // Create an ArrayAttr of IntergerAttr (s) of values in [1, N-2].
73+ // Create an ArrayAttr of IntegerAttr (s) of values in [1, N-2].
7474def createDenseElementsAttrOfOneToRankOfExclusive : NativeCodeCall<
7575 "onnx_mlir::createDenseElementsAttrOfNToM($_builder, 1, mlir::cast<ShapedType>($0.getType()).getRank() - 2)">;
7676
77- // Create an ArrayAttr of IntergerAttr (s) of values in [2, rank - 1].
77+ // Create an ArrayAttr of IntegerAttr (s) of values in [2, rank - 1].
7878def createArrayAttrOfTwoToRankOf : NativeCodeCall<
7979 "onnx_mlir::createArrayAttrOfNToM($_builder, 2, mlir::cast<ShapedType>($0.getType()).getRank() - 1)">;
8080
@@ -167,7 +167,7 @@ def HaveSameElementType : Constraint<
167167def HaveSameElementTypeBitWidth: Constraint<
168168 CPred<"(mlir::dyn_cast<ShapedType>($0.getType()).getElementTypeBitWidth() == "
169169 "mlir::dyn_cast<ShapedType>($1.getType()).getElementTypeBitWidth())">,
170- "has same element type bitwidth ">;
170+ "has same element type bit-width ">;
171171
172172def ElementTypeIsNotUnsigned: Constraint<
173173 CPred<"!mlir::dyn_cast<ShapedType>($_self.getType()).getElementType().isUnsignedInteger()">,
@@ -334,8 +334,10 @@ def FuseAddConvNullBiasPattern: Pat<
334334 [(HasShapeAndRank:$res),
335335 (HasNoneType $b),
336336 (AttributeIsNotNull:$denseAttr),
337+ (RankXMinusRankYIs<1> $res, $y),
338+ (HasRankGT<0> $y),
337339 (AllDimsFromAxisToEndAre<1, 1>:$y),
338- (RankXMinusRankYIs<1> $res, $y) ]
340+ ]
339341>;
340342
341343def FuseAddConvPattern: Pat<
@@ -356,8 +358,9 @@ def FuseAddConvPattern: Pat<
356358 [(HasShapeAndRank:$res),
357359 (NotNoneType $b),
358360 (AttributeIsNotNull:$denseAttr),
359- (AllDimsFromAxisToEndAre<1, 1>:$y),
360- (RankXMinusRankYIs<1> $res, $y)]
361+ (RankXMinusRankYIs<1> $res, $y),
362+ (HasRankGT<0> $y),
363+ (AllDimsFromAxisToEndAre<1, 1>:$y)]
361364>;
362365
363366//===----------------------------------------------------------------------===//
@@ -403,10 +406,11 @@ def FuseMulConvNullBiasPattern: Pat<
403406 (HasRankGT<1> $w), // rank of $w must be at least 2.
404407 (RankXMinusRankYIs<1> $w, $y), // rank($y) must be equal to rank($w)-1.
405408 (HaveSameDim<0> $w, $y), // the first dimension of $w and $y must be equal.
409+ (HasRankGT<0> $y), // constant cannot be a scalar.
406410 (AllDimsFromAxisToEndAre<1, 1>:$y)] // all dimensions of $y must be 1 except for the first one.
407411>;
408412
409- // TODO add pattern for non-null bias with contraints :
413+ // TODO add pattern for non-null bias with constraints :
410414// - bias must be have rank equal to 1 and
411415// - bias element data type must be the same as mul constant
412416// - bias dimension (0) must be equal to mul constant dim(0)
@@ -904,7 +908,7 @@ def RewriteBatchNormInferenceModeConvPattern1: Pat<
904908
905909// Special case of BatchNorm whose input shape is [N]. In this case, 'scale',
906910// 'bias', 'mean', and 'var' will have shape of [1], according to ONNXBatchNorm
907- // decription : https://github.com/onnx/onnx/blob/main/docs/Operators.md#inputs-12.
911+ // description : https://github.com/onnx/onnx/blob/main/docs/Operators.md#inputs-12.
908912// Thus, we need not unsqueeze intermediate results.
909913def RewriteBatchNormInferenceModeConvPattern2: Pat<
910914 (ONNXBatchNormalizationInferenceModeOp:$res
@@ -1089,7 +1093,7 @@ def ShapeTransformComposePattern : Pat<
10891093
10901094// In this pattern, the condition in onnx.Where is always false, so we can replace
10911095// onnx.Where by its "false" value.
1092- // Condition in this pattern is a comparision between dimension sizes and negative values.
1096+ // Condition in this pattern is a comparison between dimension sizes and negative values.
10931097// Since dimension sizes are always positive, the condition is evaluated to false.
10941098
10951099// This pattern was found in xlm-roberta-base-language-detection model in HuggingFace.
0 commit comments