Skip to content

Commit a43efb0

Browse files
authored
Merge pull request #357 from Xilinx/bump_to_6d728e82
[AutoBump] Merge with 6d728e8 (May 01) (3)
2 parents d776119 + e0d8651 commit a43efb0

35 files changed

+1846
-820
lines changed

docs/SupportedONNXOps-NNPA-supplement.md

Lines changed: 0 additions & 20 deletions
This file was deleted.

docs/SupportedONNXOps-cpu.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ Onnx-mlir currently supports ONNX operations targeting up to opset 22. Limitatio
4141
| **CastMap** |none | | | |
4242
| **CategoryMapper** |none | | | |
4343
| **Ceil** |6 - * | | |
44-
| **Celu** |none | | | |
44+
| **Celu** |12 - * | | | |
4545
| **CenterCropPad** |none | | | |
4646
| **Clip** |6 - * |No support for short integers. | |
4747
| **Col2Im** |none | | | |

src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXLegalityCheck.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,7 @@ bool isSuitableForZDNN<ONNXAddOp>(
369369
return false;
370370
if (!isValidElementTypeAndRank(op.getOperation(), op.getB()))
371371
return false;
372+
// Rule below is true for adds that are not fused into matmul.
372373
if (!dimAnalysis->sameShape(op.getA(), op.getB()))
373374
return onnxToZHighUnsupportedReport(op.getOperation(),
374375
"The dynamic dimension analysis couldn't identify "
@@ -677,7 +678,7 @@ bool isSuitableForZDNN<ONNXMatMulOp>(
677678
// (https://github.com/onnx/onnx/blob/main/docs/Operators.md#MatMul) on zDNN
678679
// by using broadcasting etc.
679680
if ((shapeA.size() == 2) && (shapeB.size() == 2)) {
680-
// unstacked case
681+
// Unstacked case.
681682
if (aType.hasStaticShape() && bType.hasStaticShape()) {
682683
if (shapeA[1] != shapeB[0]) {
683684
std::string message = "Unstacked case: the 2nd dim of A (" +
@@ -689,7 +690,7 @@ bool isSuitableForZDNN<ONNXMatMulOp>(
689690
}
690691
return true;
691692
} else if ((shapeA.size() == 3) && (shapeB.size() == 3)) {
692-
// stacked w/o bcast case
693+
// Stacked w/o bcast case.
693694
if (aType.hasStaticShape() && bType.hasStaticShape()) {
694695
if ((shapeA[0] != shapeB[0]) || (shapeA[2] != shapeB[1])) {
695696
std::string message =
@@ -704,7 +705,7 @@ bool isSuitableForZDNN<ONNXMatMulOp>(
704705
}
705706
return true;
706707
} else if ((shapeA.size() == 3) && (shapeB.size() == 2)) {
707-
// stacked w/ bcast23 case
708+
// Bcast23 case.
708709
if (aType.hasStaticShape() && bType.hasStaticShape()) {
709710
if (shapeA[2] != shapeB[0]) {
710711
std::string message = "Stacked w/ bcast23 case: the 3rd dim of A (" +
@@ -716,7 +717,7 @@ bool isSuitableForZDNN<ONNXMatMulOp>(
716717
}
717718
return true;
718719
} else if ((shapeA.size() == 2) && (shapeB.size() == 3)) {
719-
// stacked w/ bcast1 case
720+
// Bcast1 case.
720721
if (!isCompatibleWithNNPALevel(NNPALevel::M15))
721722
return onnxToZHighInCompatibilityReport(
722723
op.getOperation(), NNPALevel::M15);

src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1612,8 +1612,14 @@ void getONNXToZHighOneOpDynamicallyLegal(
16121612

16131613
void getONNXToZHighMultipleOpPatterns(RewritePatternSet &patterns) {
16141614
MLIRContext *context = patterns.getContext();
1615-
patterns.insert<replaceONNXMatMulAddPattern1>(context);
1616-
patterns.insert<replaceONNXMatMulAddPattern2>(context);
1615+
// Matmul add patterns.
1616+
patterns.insert<replaceONNXMatMulAddUnstackedOrBCast23Pattern1>(context);
1617+
patterns.insert<replaceONNXMatMulAddUnstackedOrBCast23Pattern2>(context);
1618+
patterns.insert<replaceONNXMatMulAddStackedPattern1>(context);
1619+
patterns.insert<replaceONNXMatMulAddStackedPattern2>(context);
1620+
patterns.insert<replaceONNXMatMulAddBCast1Pattern1>(context);
1621+
patterns.insert<replaceONNXMatMulAddBCast1Pattern2>(context);
1622+
// Other patterns.
16171623
patterns.insert<replaceONNXReluConvPattern>(context);
16181624
patterns.insert<replaceONNXLogSoftmaxPattern>(context);
16191625
patterns.insert<replaceONNXTransAMatMulPattern>(context);

src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.td

Lines changed: 100 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -579,6 +579,7 @@ def GetMatMulLayoutStringAttr : NativeCodeCall<
579579
"$_builder.getStringAttr(($0 == 2) ? LAYOUT_2D : LAYOUT_3DS)"
580580
>;
581581

582+
// Stacked or bcast1 => 2DS, otherwise 1D
582583
def GetMatMulBiasLayoutStringAttr : NativeCodeCall<
583584
"$_builder.getStringAttr(((($0 == 3) && ($1 == 3)) || (($0 == 2) && ($1 == 3))) ? LAYOUT_2DS : LAYOUT_1D)"
584585
>;
@@ -591,13 +592,20 @@ def GetMatMulBiasLayoutStringAttr : NativeCodeCall<
591592
// CreateNoneValue)
592593
//===----------------------------------------------------------------------===//
593594

595+
def IsMatMulLegalForZDNN: Constraint<
596+
CPred<"isSuitableForZDNN<ONNXMatMulOp>(" #
597+
"dyn_cast_or_null<ONNXMatMulOp>($0.getDefiningOp()))">,
598+
"MatMul is legal for zDNN"
599+
>;
600+
594601
def replaceONNXMatMulPattern : Pat<
595602
(ONNXMatMulOp:$res $x, $y),
596603
(ZHighUnstickOp
597604
(ZHighMatMulOp
598605
(ZHighStickOp $x, (GetMatMulLayoutStringAttr (GetRank $x)), (GetDefaultSaturation)),
599606
(ZHighStickOp $y, (GetMatMulLayoutStringAttr (GetRank $y)), (GetDefaultSaturation)),
600-
(CreateNoneValue), (GetZeroI64Attr), (GetZeroI64Attr)))
607+
(CreateNoneValue), (GetZeroI64Attr), (GetZeroI64Attr))),
608+
[(IsMatMulLegalForZDNN $res)]
601609
>;
602610

603611
//===----------------------------------------------------------------------===//
@@ -609,13 +617,8 @@ def replaceONNXMatMulPattern : Pat<
609617
// (ZHighStickOp %B))
610618
//===----------------------------------------------------------------------===//
611619

612-
def IsMatMulLegalForZDNN: Constraint<
613-
CPred<"isSuitableForZDNN<ONNXMatMulOp>(" #
614-
"dyn_cast_or_null<ONNXMatMulOp>($0.getDefiningOp()))">,
615-
"MatMul is legal for zDNN"
616-
>;
617-
618620
// Be careful, this check is very specific to '$0' of rank 2 and '$1' of rank 1.
621+
// Limitation: requires dim to be static.
619622
def HaveSameLastDimR2R1: Constraint<
620623
CPred<"(!mlir::cast<ShapedType>($0.getType()).isDynamicDim(1))" #
621624
" && (!mlir::cast<ShapedType>($1.getType()).isDynamicDim(0))" #
@@ -624,9 +627,8 @@ def HaveSameLastDimR2R1: Constraint<
624627
"Have the same last dimension"
625628
>;
626629

627-
// Only 1D bias is suitable for this transformation since only then
628-
// the semantics of bias addition is the same for both ONNX and zDNN.
629-
def replaceONNXMatMulAddPattern1 : Pat<
630+
// Unstacked or BCast23: only 1D bias is suitable for this transformation.
631+
def replaceONNXMatMulAddUnstackedOrBCast23Pattern1 : Pat<
630632
// From Add $b, (MatMul $x, $y)
631633
(ONNXAddOp $b, (ONNXMatMulOp:$m $x, $y)),
632634
// To ZHighMatMulOp
@@ -637,12 +639,13 @@ def replaceONNXMatMulAddPattern1 : Pat<
637639
(ZHighStickOp $b, (GetMatMulBiasLayoutStringAttr (GetRank $x),
638640
(GetRank $y)), (GetDefaultSaturation)),
639641
(GetZeroI64Attr), (GetZeroI64Attr))),
640-
[(IsMatMulLegalForZDNN $m), (HasRankOf<2> $y), (HasRankOf<1> $b),
641-
(HaveSameLastDimR2R1 $y, $b)], [],
642+
[(HasRankOf<2> $y), (HasRankOf<1> $b), (HaveSameLastDimR2R1 $y, $b),
643+
(IsMatMulLegalForZDNN $m)], [],
642644
(addBenefit 0)
643645
>;
644646

645-
def replaceONNXMatMulAddPattern2 : Pat<
647+
// Same as above, multiplication in first slot of add.
648+
def replaceONNXMatMulAddUnstackedOrBCast23Pattern2 : Pat<
646649
// From Add (MatMul $x, $y), $b
647650
(ONNXAddOp (ONNXMatMulOp:$m $x, $y), $b),
648651
// To ZHighMatMulOp
@@ -653,44 +656,114 @@ def replaceONNXMatMulAddPattern2 : Pat<
653656
(ZHighStickOp $b, (GetMatMulBiasLayoutStringAttr (GetRank $x),
654657
(GetRank $y)), (GetDefaultSaturation)),
655658
(GetZeroI64Attr), (GetZeroI64Attr))),
656-
[(IsMatMulLegalForZDNN $m), (HasRankOf<2> $y), (HasRankOf<1> $b),
657-
(HaveSameLastDimR2R1 $y, $b)], [],
659+
[(HasRankOf<2> $y), (HasRankOf<1> $b), (HaveSameLastDimR2R1 $y, $b),
660+
(IsMatMulLegalForZDNN $m)], [],
661+
(addBenefit 0)
662+
>;
663+
664+
//===----------------------------------------------------------------------===//
665+
// Replace onnx.add and onnx.matmul with stacked tensors with ZHighMatMul
666+
//===----------------------------------------------------------------------===//
667+
668+
// Be careful, this check if the first and the last dim (of ranked 3 vectors)
669+
// are the same: testing s & p in (s,n,p) with (s, 1, p)
670+
def HaveSameFirstAndLastDimR3: Constraint<
671+
CPred<"(!mlir::cast<ShapedType>($0.getType()).isDynamicDim(0))" #
672+
" && (!mlir::cast<ShapedType>($0.getType()).isDynamicDim(2))" #
673+
" && (!mlir::cast<ShapedType>($1.getType()).isDynamicDim(0))" #
674+
" && (!mlir::cast<ShapedType>($1.getType()).isDynamicDim(2))" #
675+
" && (mlir::cast<ShapedType>($0.getType()).getShape()[0]" #
676+
" == mlir::cast<ShapedType>($1.getType()).getShape()[0])" #
677+
" && (mlir::cast<ShapedType>($0.getType()).getShape()[2]" #
678+
" == mlir::cast<ShapedType>($1.getType()).getShape()[2])">,
679+
"Have the same R1 and R3 dimensions"
680+
>;
681+
682+
// For bias with 3dims, need to ensure that the 2nd dim is 1 as was augmented
683+
// just to please shape analysis: want [s, p] but had to use [s, 1, p]
684+
def Has1InMiddleR3: Constraint<
685+
CPred<"(!mlir::cast<ShapedType>($0.getType()).isDynamicDim(1))" #
686+
" && (mlir::cast<ShapedType>($0.getType()).getShape()[1]==1)">,
687+
"Has shape(1) == 1"
688+
>;
689+
690+
def replaceONNXMatMulAddStackedPattern1 : Pat<
691+
// From Add $b, (MatMul $x, $y)
692+
(ONNXAddOp $b, (ONNXMatMulOp:$m $x, $y)),
693+
// To ZHighMatMulOp
694+
(ZHighUnstickOp
695+
(ZHighMatMulOp
696+
(ZHighStickOp $x, (GetMatMulLayoutStringAttr (GetRank $x)), (GetDefaultSaturation)),
697+
(ZHighStickOp $y, (GetMatMulLayoutStringAttr (GetRank $y)), (GetDefaultSaturation)),
698+
(ZHighStickOp
699+
(ONNXSqueezeV11Op $b, (GetI64ArrayAttr<1>)),
700+
(GetMatMulBiasLayoutStringAttr (GetRank $x), (GetRank $y)), (GetDefaultSaturation)),
701+
(GetZeroI64Attr), (GetZeroI64Attr))),
702+
[(HasRankOf<3> $x), (HasRankOf<3> $y), (HasRankOf<3> $b),
703+
(HaveSameFirstAndLastDimR3 $y, $b), (Has1InMiddleR3 $b), (IsMatMulLegalForZDNN $m)],
704+
[],
705+
(addBenefit 0)
706+
>;
707+
708+
def replaceONNXMatMulAddStackedPattern2 : Pat<
709+
// From Add (MatMul $x, $y), $b
710+
(ONNXAddOp (ONNXMatMulOp:$m $x, $y), $b),
711+
// To ZHighMatMulOp
712+
(ZHighUnstickOp
713+
(ZHighMatMulOp
714+
(ZHighStickOp $x, (GetMatMulLayoutStringAttr (GetRank $x)), (GetDefaultSaturation)),
715+
(ZHighStickOp $y, (GetMatMulLayoutStringAttr (GetRank $y)), (GetDefaultSaturation)),
716+
(ZHighStickOp
717+
(ONNXSqueezeV11Op $b, (GetI64ArrayAttr<1>)),
718+
//$b,
719+
(GetMatMulBiasLayoutStringAttr (GetRank $x), (GetRank $y)), (GetDefaultSaturation)),
720+
(GetZeroI64Attr), (GetZeroI64Attr))),
721+
[(HasRankOf<3> $x), (HasRankOf<3> $y), (HasRankOf<3> $b),
722+
(HaveSameFirstAndLastDimR3 $y, $b), (Has1InMiddleR3 $b), (IsMatMulLegalForZDNN $m)],
723+
[],
658724
(addBenefit 0)
659725
>;
660726

661727
//===----------------------------------------------------------------------===//
662728
// Replace onnx.add and onnx.matmul with bcast1 tensors with ZHighMatMul
663729
//===----------------------------------------------------------------------===//
664730

665-
def replaceONNXMatMulAddPatternBcast1A : Pat<
731+
def replaceONNXMatMulAddBCast1Pattern1 : Pat<
666732
// From Add $b, (MatMul $x, $y)
667733
(ONNXAddOp $b, (ONNXMatMulOp:$m $x, $y)),
668734
// To ZHighMatMulOp
669735
(ZHighUnstickOp
670736
(ZHighMatMulOp
671737
(ZHighStickOp $x, (GetMatMulLayoutStringAttr (GetRank $x)), (GetDefaultSaturation)),
672738
(ZHighStickOp $y, (GetMatMulLayoutStringAttr (GetRank $y)), (GetDefaultSaturation)),
673-
(ZHighStickOp $b, (GetMatMulBiasLayoutStringAttr (GetRank $x),
674-
(GetRank $y)), (GetDefaultSaturation)),
739+
(ZHighStickOp
740+
(ONNXSqueezeV11Op $b, (GetI64ArrayAttr<1>)),
741+
(GetMatMulBiasLayoutStringAttr (GetRank $x), (GetRank $y)), (GetDefaultSaturation)),
675742
(GetZeroI64Attr), (GetZeroI64Attr))),
676-
[(IsCompatibleWithNNPALevelArch15),(IsMatMulLegalForZDNN $m),
677-
(HasRankOf<2> $x), (HasRankOf<3> $y), (HasRankOf<2> $b)], [],
743+
[(IsCompatibleWithNNPALevelArch15),
744+
(HasRankOf<2> $x), (HasRankOf<3> $y), (HasRankOf<3> $b),
745+
(HaveSameFirstAndLastDimR3 $y, $b), (Has1InMiddleR3 $b), (IsMatMulLegalForZDNN $m)],
746+
[],
678747
(addBenefit 0)
679748
>;
680749

681-
def replaceONNXMatMulAddPatternBcast1B : Pat<
750+
def replaceONNXMatMulAddBCast1Pattern2 : Pat<
682751
// From Add (MatMul $x, $y), $b
683752
(ONNXAddOp (ONNXMatMulOp:$m $x, $y), $b),
684753
// To ZHighMatMulOp
685754
(ZHighUnstickOp
686755
(ZHighMatMulOp
687-
(ZHighStickOp $x, (GetMatMulLayoutStringAttr (GetRank $x)), (GetDefaultSaturation)),
756+
(ZHighStickOp $x, (GetMatMulLayoutStringAttr (GetRank $x)), (GetDefaultSaturation)),
688757
(ZHighStickOp $y, (GetMatMulLayoutStringAttr (GetRank $y)), (GetDefaultSaturation)),
689-
(ZHighStickOp $b, (GetMatMulBiasLayoutStringAttr (GetRank $x),
690-
(GetRank $y)), (GetDefaultSaturation)),
691-
(GetZeroI64Attr), (GetZeroI64Attr))),
692-
[(IsCompatibleWithNNPALevelArch15),(IsMatMulLegalForZDNN $m),
693-
(HasRankOf<2> $x), (HasRankOf<3> $y), (HasRankOf<2> $b)], [],
758+
(ZHighStickOp
759+
(ONNXSqueezeV11Op $b, (GetI64ArrayAttr<1>)),
760+
//$b,
761+
(GetMatMulBiasLayoutStringAttr (GetRank $x), (GetRank $y)), (GetDefaultSaturation)),
762+
(GetZeroI64Attr), (GetZeroI64Attr))),
763+
[(IsCompatibleWithNNPALevelArch15),
764+
(HasRankOf<2> $x), (HasRankOf<3> $y), (HasRankOf<3> $b),
765+
(HaveSameFirstAndLastDimR3 $y, $b), (Has1InMiddleR3 $b), (IsMatMulLegalForZDNN $m)],
766+
[],
694767
(addBenefit 0)
695768
>;
696769

0 commit comments

Comments
 (0)