@@ -579,6 +579,7 @@ def GetMatMulLayoutStringAttr : NativeCodeCall<
579
579
"$_builder.getStringAttr(($0 == 2) ? LAYOUT_2D : LAYOUT_3DS)"
580
580
>;
581
581
582
+ // Stacked or bcast1 => 2DS, otherwise 1D
582
583
def GetMatMulBiasLayoutStringAttr : NativeCodeCall<
583
584
"$_builder.getStringAttr(((($0 == 3) && ($1 == 3)) || (($0 == 2) && ($1 == 3))) ? LAYOUT_2DS : LAYOUT_1D)"
584
585
>;
@@ -591,13 +592,20 @@ def GetMatMulBiasLayoutStringAttr : NativeCodeCall<
591
592
// CreateNoneValue)
592
593
//===----------------------------------------------------------------------===//
593
594
595
+ def IsMatMulLegalForZDNN: Constraint<
596
+ CPred<"isSuitableForZDNN<ONNXMatMulOp>(" #
597
+ "dyn_cast_or_null<ONNXMatMulOp>($0.getDefiningOp()))">,
598
+ "MatMul is legal for zDNN"
599
+ >;
600
+
594
601
def replaceONNXMatMulPattern : Pat<
595
602
(ONNXMatMulOp:$res $x, $y),
596
603
(ZHighUnstickOp
597
604
(ZHighMatMulOp
598
605
(ZHighStickOp $x, (GetMatMulLayoutStringAttr (GetRank $x)), (GetDefaultSaturation)),
599
606
(ZHighStickOp $y, (GetMatMulLayoutStringAttr (GetRank $y)), (GetDefaultSaturation)),
600
- (CreateNoneValue), (GetZeroI64Attr), (GetZeroI64Attr)))
607
+ (CreateNoneValue), (GetZeroI64Attr), (GetZeroI64Attr))),
608
+ [(IsMatMulLegalForZDNN $res)]
601
609
>;
602
610
603
611
//===----------------------------------------------------------------------===//
@@ -609,13 +617,8 @@ def replaceONNXMatMulPattern : Pat<
609
617
// (ZHighStickOp %B))
610
618
//===----------------------------------------------------------------------===//
611
619
612
- def IsMatMulLegalForZDNN: Constraint<
613
- CPred<"isSuitableForZDNN<ONNXMatMulOp>(" #
614
- "dyn_cast_or_null<ONNXMatMulOp>($0.getDefiningOp()))">,
615
- "MatMul is legal for zDNN"
616
- >;
617
-
618
620
// Be careful, this check is very specific to '$0' of rank 2 and '$1' of rank 1.
621
+ // Limitation: requires dim to be static.
619
622
def HaveSameLastDimR2R1: Constraint<
620
623
CPred<"(!mlir::cast<ShapedType>($0.getType()).isDynamicDim(1))" #
621
624
" && (!mlir::cast<ShapedType>($1.getType()).isDynamicDim(0))" #
@@ -624,9 +627,8 @@ def HaveSameLastDimR2R1: Constraint<
624
627
"Have the same last dimension"
625
628
>;
626
629
627
- // Only 1D bias is suitable for this transformation since only then
628
- // the semantics of bias addition is the same for both ONNX and zDNN.
629
- def replaceONNXMatMulAddPattern1 : Pat<
630
+ // Unstacked or BCast23: only 1D bias is suitable for this transformation.
631
+ def replaceONNXMatMulAddUnstackedOrBCast23Pattern1 : Pat<
630
632
// From Add $b, (MatMul $x, $y)
631
633
(ONNXAddOp $b, (ONNXMatMulOp:$m $x, $y)),
632
634
// To ZHighMatMulOp
@@ -637,12 +639,13 @@ def replaceONNXMatMulAddPattern1 : Pat<
637
639
(ZHighStickOp $b, (GetMatMulBiasLayoutStringAttr (GetRank $x),
638
640
(GetRank $y)), (GetDefaultSaturation)),
639
641
(GetZeroI64Attr), (GetZeroI64Attr))),
640
- [(IsMatMulLegalForZDNN $m), ( HasRankOf<2> $y), (HasRankOf<1> $b),
641
- (HaveSameLastDimR2R1 $y, $b )], [],
642
+ [(HasRankOf<2> $y), (HasRankOf<1> $b), (HaveSameLastDimR2R1 $y, $b),
643
+ (IsMatMulLegalForZDNN $m )], [],
642
644
(addBenefit 0)
643
645
>;
644
646
645
- def replaceONNXMatMulAddPattern2 : Pat<
647
+ // Same as above, multiplication in first slot of add.
648
+ def replaceONNXMatMulAddUnstackedOrBCast23Pattern2 : Pat<
646
649
// From Add (MatMul $x, $y), $b
647
650
(ONNXAddOp (ONNXMatMulOp:$m $x, $y), $b),
648
651
// To ZHighMatMulOp
@@ -653,44 +656,114 @@ def replaceONNXMatMulAddPattern2 : Pat<
653
656
(ZHighStickOp $b, (GetMatMulBiasLayoutStringAttr (GetRank $x),
654
657
(GetRank $y)), (GetDefaultSaturation)),
655
658
(GetZeroI64Attr), (GetZeroI64Attr))),
656
- [(IsMatMulLegalForZDNN $m), (HasRankOf<2> $y), (HasRankOf<1> $b),
657
- (HaveSameLastDimR2R1 $y, $b)], [],
659
+ [(HasRankOf<2> $y), (HasRankOf<1> $b), (HaveSameLastDimR2R1 $y, $b),
660
+ (IsMatMulLegalForZDNN $m)], [],
661
+ (addBenefit 0)
662
+ >;
663
+
664
+ //===----------------------------------------------------------------------===//
665
+ // Replace onnx.add and onnx.matmul with stacked tensors with ZHighMatMul
666
+ //===----------------------------------------------------------------------===//
667
+
668
+ // Be careful, this check if the first and the last dim (of ranked 3 vectors)
669
+ // are the same: testing s & p in (s,n,p) with (s, 1, p)
670
+ def HaveSameFirstAndLastDimR3: Constraint<
671
+ CPred<"(!mlir::cast<ShapedType>($0.getType()).isDynamicDim(0))" #
672
+ " && (!mlir::cast<ShapedType>($0.getType()).isDynamicDim(2))" #
673
+ " && (!mlir::cast<ShapedType>($1.getType()).isDynamicDim(0))" #
674
+ " && (!mlir::cast<ShapedType>($1.getType()).isDynamicDim(2))" #
675
+ " && (mlir::cast<ShapedType>($0.getType()).getShape()[0]" #
676
+ " == mlir::cast<ShapedType>($1.getType()).getShape()[0])" #
677
+ " && (mlir::cast<ShapedType>($0.getType()).getShape()[2]" #
678
+ " == mlir::cast<ShapedType>($1.getType()).getShape()[2])">,
679
+ "Have the same R1 and R3 dimensions"
680
+ >;
681
+
682
+ // For bias with 3dims, need to ensure that the 2nd dim is 1 as was augmented
683
+ // just to please shape analysis: want [s, p] but had to use [s, 1, p]
684
+ def Has1InMiddleR3: Constraint<
685
+ CPred<"(!mlir::cast<ShapedType>($0.getType()).isDynamicDim(1))" #
686
+ " && (mlir::cast<ShapedType>($0.getType()).getShape()[1]==1)">,
687
+ "Has shape(1) == 1"
688
+ >;
689
+
690
+ def replaceONNXMatMulAddStackedPattern1 : Pat<
691
+ // From Add $b, (MatMul $x, $y)
692
+ (ONNXAddOp $b, (ONNXMatMulOp:$m $x, $y)),
693
+ // To ZHighMatMulOp
694
+ (ZHighUnstickOp
695
+ (ZHighMatMulOp
696
+ (ZHighStickOp $x, (GetMatMulLayoutStringAttr (GetRank $x)), (GetDefaultSaturation)),
697
+ (ZHighStickOp $y, (GetMatMulLayoutStringAttr (GetRank $y)), (GetDefaultSaturation)),
698
+ (ZHighStickOp
699
+ (ONNXSqueezeV11Op $b, (GetI64ArrayAttr<1>)),
700
+ (GetMatMulBiasLayoutStringAttr (GetRank $x), (GetRank $y)), (GetDefaultSaturation)),
701
+ (GetZeroI64Attr), (GetZeroI64Attr))),
702
+ [(HasRankOf<3> $x), (HasRankOf<3> $y), (HasRankOf<3> $b),
703
+ (HaveSameFirstAndLastDimR3 $y, $b), (Has1InMiddleR3 $b), (IsMatMulLegalForZDNN $m)],
704
+ [],
705
+ (addBenefit 0)
706
+ >;
707
+
708
+ def replaceONNXMatMulAddStackedPattern2 : Pat<
709
+ // From Add (MatMul $x, $y), $b
710
+ (ONNXAddOp (ONNXMatMulOp:$m $x, $y), $b),
711
+ // To ZHighMatMulOp
712
+ (ZHighUnstickOp
713
+ (ZHighMatMulOp
714
+ (ZHighStickOp $x, (GetMatMulLayoutStringAttr (GetRank $x)), (GetDefaultSaturation)),
715
+ (ZHighStickOp $y, (GetMatMulLayoutStringAttr (GetRank $y)), (GetDefaultSaturation)),
716
+ (ZHighStickOp
717
+ (ONNXSqueezeV11Op $b, (GetI64ArrayAttr<1>)),
718
+ //$b,
719
+ (GetMatMulBiasLayoutStringAttr (GetRank $x), (GetRank $y)), (GetDefaultSaturation)),
720
+ (GetZeroI64Attr), (GetZeroI64Attr))),
721
+ [(HasRankOf<3> $x), (HasRankOf<3> $y), (HasRankOf<3> $b),
722
+ (HaveSameFirstAndLastDimR3 $y, $b), (Has1InMiddleR3 $b), (IsMatMulLegalForZDNN $m)],
723
+ [],
658
724
(addBenefit 0)
659
725
>;
660
726
661
727
//===----------------------------------------------------------------------===//
662
728
// Replace onnx.add and onnx.matmul with bcast1 tensors with ZHighMatMul
663
729
//===----------------------------------------------------------------------===//
664
730
665
- def replaceONNXMatMulAddPatternBcast1A : Pat<
731
+ def replaceONNXMatMulAddBCast1Pattern1 : Pat<
666
732
// From Add $b, (MatMul $x, $y)
667
733
(ONNXAddOp $b, (ONNXMatMulOp:$m $x, $y)),
668
734
// To ZHighMatMulOp
669
735
(ZHighUnstickOp
670
736
(ZHighMatMulOp
671
737
(ZHighStickOp $x, (GetMatMulLayoutStringAttr (GetRank $x)), (GetDefaultSaturation)),
672
738
(ZHighStickOp $y, (GetMatMulLayoutStringAttr (GetRank $y)), (GetDefaultSaturation)),
673
- (ZHighStickOp $b, (GetMatMulBiasLayoutStringAttr (GetRank $x),
674
- (GetRank $y)), (GetDefaultSaturation)),
739
+ (ZHighStickOp
740
+ (ONNXSqueezeV11Op $b, (GetI64ArrayAttr<1>)),
741
+ (GetMatMulBiasLayoutStringAttr (GetRank $x), (GetRank $y)), (GetDefaultSaturation)),
675
742
(GetZeroI64Attr), (GetZeroI64Attr))),
676
- [(IsCompatibleWithNNPALevelArch15),(IsMatMulLegalForZDNN $m),
677
- (HasRankOf<2> $x), (HasRankOf<3> $y), (HasRankOf<2> $b)], [],
743
+ [(IsCompatibleWithNNPALevelArch15),
744
+ (HasRankOf<2> $x), (HasRankOf<3> $y), (HasRankOf<3> $b),
745
+ (HaveSameFirstAndLastDimR3 $y, $b), (Has1InMiddleR3 $b), (IsMatMulLegalForZDNN $m)],
746
+ [],
678
747
(addBenefit 0)
679
748
>;
680
749
681
- def replaceONNXMatMulAddPatternBcast1B : Pat<
750
+ def replaceONNXMatMulAddBCast1Pattern2 : Pat<
682
751
// From Add (MatMul $x, $y), $b
683
752
(ONNXAddOp (ONNXMatMulOp:$m $x, $y), $b),
684
753
// To ZHighMatMulOp
685
754
(ZHighUnstickOp
686
755
(ZHighMatMulOp
687
- (ZHighStickOp $x, (GetMatMulLayoutStringAttr (GetRank $x)), (GetDefaultSaturation)),
756
+ (ZHighStickOp $x, (GetMatMulLayoutStringAttr (GetRank $x)), (GetDefaultSaturation)),
688
757
(ZHighStickOp $y, (GetMatMulLayoutStringAttr (GetRank $y)), (GetDefaultSaturation)),
689
- (ZHighStickOp $b, (GetMatMulBiasLayoutStringAttr (GetRank $x),
690
- (GetRank $y)), (GetDefaultSaturation)),
691
- (GetZeroI64Attr), (GetZeroI64Attr))),
692
- [(IsCompatibleWithNNPALevelArch15),(IsMatMulLegalForZDNN $m),
693
- (HasRankOf<2> $x), (HasRankOf<3> $y), (HasRankOf<2> $b)], [],
758
+ (ZHighStickOp
759
+ (ONNXSqueezeV11Op $b, (GetI64ArrayAttr<1>)),
760
+ //$b,
761
+ (GetMatMulBiasLayoutStringAttr (GetRank $x), (GetRank $y)), (GetDefaultSaturation)),
762
+ (GetZeroI64Attr), (GetZeroI64Attr))),
763
+ [(IsCompatibleWithNNPALevelArch15),
764
+ (HasRankOf<2> $x), (HasRankOf<3> $y), (HasRankOf<3> $b),
765
+ (HaveSameFirstAndLastDimR3 $y, $b), (Has1InMiddleR3 $b), (IsMatMulLegalForZDNN $m)],
766
+ [],
694
767
(addBenefit 0)
695
768
>;
696
769
0 commit comments