@@ -701,6 +701,151 @@ struct RecomposeGeluFromMulPattern : public OpRewritePattern<ONNXMulOp> {
701
701
}
702
702
};
703
703
704
+ struct RecomposeDepthToSpaceCRD : public OpRewritePattern <ONNXReshapeOp> {
705
+ using OpRewritePattern<ONNXReshapeOp>::OpRewritePattern;
706
+
707
+ LogicalResult matchAndRewrite (
708
+ ONNXReshapeOp reshapeOp, PatternRewriter &rewriter) const final {
709
+ using namespace onnx_mlir ;
710
+ Location loc = ONNXLoc<ONNXReshapeOp>(reshapeOp);
711
+
712
+ std::optional<DepthToSpaceRecompositionResult> result =
713
+ matchDepthToSpaceCRDPattern (reshapeOp);
714
+ if (!result) {
715
+ return failure ();
716
+ }
717
+
718
+ MultiDialectBuilder<OnnxBuilder> create (rewriter, result->fusedLocation );
719
+ rewriter.replaceOp (
720
+ reshapeOp, create.onnx .createOpAndInferShapes <ONNXDepthToSpaceOp>(
721
+ reshapeOp.getType (), result->input , result->blockSize ,
722
+ result->mode ));
723
+ return success ();
724
+ }
725
+
726
+ // Result of attempting recomposing DepthToSpace. Contains useful information
727
+ // for the matching
728
+ struct DepthToSpaceRecompositionResult {
729
+ Value input;
730
+ int64_t blockSize;
731
+ std::string mode;
732
+ Location fusedLocation;
733
+ };
734
+
735
+ static std::optional<DepthToSpaceRecompositionResult>
736
+ matchDepthToSpaceCRDPattern (ONNXReshapeOp reshapeOp) {
737
+ using namespace onnx_mlir ;
738
+ // DepthToSpace mode CRD match:
739
+ // DepthToSpace(x) =
740
+ // %r0 = reshape %x NxCxHxW -> NxC//(B*B)xBxBxHxW
741
+ // %t = transpose %r0 perm=[0, 1, 4, 2, 5, 3]
742
+ // %r1 = reshape NxC//(B*B)xHxBxWxB -> NxC//(B*B)x(HxB)x(WxB)
743
+
744
+ ONNXReshapeOp r0;
745
+ ONNXTransposeOp t;
746
+ ONNXReshapeOp r1 = reshapeOp;
747
+
748
+ t = r1->getOperand (0 ).getDefiningOp <ONNXTransposeOp>();
749
+ if (!t) {
750
+ return reportFailureForCRDMode (" missing transpose" );
751
+ }
752
+ r0 = t->getOperand (0 ).getDefiningOp <ONNXReshapeOp>();
753
+ if (!r0) {
754
+ return reportFailureForCRDMode (" missing first reshape" );
755
+ }
756
+
757
+ auto hasShapedStaticType = [](Type ty) {
758
+ auto shapedType = dyn_cast<ShapedType>(ty);
759
+ return shapedType && shapedType.hasStaticShape ();
760
+ };
761
+
762
+ const bool haveOperationsValidTy =
763
+ llvm::all_of (TypeRange{r0.getOperand (0 ).getType (), r0.getType (),
764
+ t.getType (), r1.getType ()},
765
+ hasShapedStaticType);
766
+ if (!haveOperationsValidTy) {
767
+ return reportFailureForCRDMode (
768
+ " pattern operations have no shaped static tensor types" );
769
+ }
770
+
771
+ auto fstReshapeInTy = cast<ShapedType>(r0->getOperand (0 ).getType ());
772
+ ArrayRef<int64_t > fstReshapeInShape = fstReshapeInTy.getShape ();
773
+ const size_t fstReshapeInRank = fstReshapeInTy.getRank ();
774
+ if (fstReshapeInRank != 4 ) {
775
+ return reportFailureForCRDMode (" input rank is not 4D " );
776
+ }
777
+
778
+ auto fstReshapeOutTy = cast<ShapedType>(r0.getType ());
779
+ ArrayRef<int64_t > fstReshapeOutShape = fstReshapeOutTy.getShape ();
780
+ const size_t fstReshapeOutRank = fstReshapeOutTy.getRank ();
781
+ if (fstReshapeOutRank != 6 ) {
782
+ return reportFailureForCRDMode (" output rank of first reshape is not 6D" );
783
+ }
784
+
785
+ // Check for concrete reshape pattern:
786
+ // reshape %x NxCxHxW -> NxC//(B*B)xBxBxHxW
787
+ const int64_t blocksize = fstReshapeOutShape[2 ];
788
+ if (blocksize != fstReshapeOutShape[3 ]) {
789
+ return reportFailureForCRDMode (" blocksize do not match in dim 2 and 3" );
790
+ }
791
+
792
+ if (fstReshapeInShape[0 ] != fstReshapeOutShape[0 ] ||
793
+ fstReshapeInShape[1 ] != fstReshapeOutShape[1 ] * blocksize * blocksize ||
794
+ fstReshapeInShape[2 ] != fstReshapeOutShape[4 ] ||
795
+ fstReshapeInShape[3 ] != fstReshapeOutShape[5 ]) {
796
+ return reportFailureForCRDMode (" unexpected first reshape result shape" );
797
+ }
798
+
799
+ // Check for concrete permutation pattern:
800
+ // transpose %r0 perm=[0, 1, 4, 2, 5, 3]
801
+ std::optional<ArrayAttr> permOpt = t.getPerm ();
802
+ if (!permOpt) {
803
+ return reportFailureForCRDMode (" missing permutation on transpose" );
804
+ }
805
+
806
+ // Get transpose permutation
807
+ SmallVector<int64_t , 6 > perms;
808
+ ArrayAttrIntVals (*permOpt, perms);
809
+
810
+ // Check for transpose permutation
811
+ constexpr std::array<int64_t , 6 > expectedPerms = {0 , 1 , 4 , 2 , 5 , 3 };
812
+ if (perms != ArrayRef (expectedPerms)) {
813
+ return reportFailureForCRDMode (" unexpected permutations" );
814
+ }
815
+
816
+ // Check for concrete reshape pattern:
817
+ // reshape NxC//(B*B)xHxBxWxB -> NxC//(B*B)x(HxB)x(WxB)
818
+ auto sndReshapeInTy = cast<ShapedType>(t.getType ());
819
+ ArrayRef<int64_t > sndReshapeInShape = sndReshapeInTy.getShape ();
820
+
821
+ auto sndReshapeOutTy = cast<ShapedType>(r1.getType ());
822
+ ArrayRef<int64_t > sndReshapeOutShape = sndReshapeOutTy.getShape ();
823
+ const size_t sndReshapeOutRank = sndReshapeOutTy.getRank ();
824
+ if (sndReshapeOutRank != 4 ) {
825
+ return reportFailureForCRDMode (" out rank of second reshape is not 4D" );
826
+ }
827
+
828
+ if (sndReshapeInShape[0 ] != sndReshapeOutShape[0 ] ||
829
+ sndReshapeInShape[1 ] != sndReshapeOutShape[1 ] ||
830
+ sndReshapeInShape[2 ] * sndReshapeInShape[3 ] != sndReshapeOutShape[2 ] ||
831
+ sndReshapeInShape[4 ] * sndReshapeInShape[5 ] != sndReshapeOutShape[3 ]) {
832
+ return reportFailureForCRDMode (" unexpected second reshape result shape" );
833
+ }
834
+
835
+ Location fusedLocation = FusedLoc::get (
836
+ reshapeOp->getContext (), {r0->getLoc (), t->getLoc (), r1->getLoc ()});
837
+
838
+ return DepthToSpaceRecompositionResult{
839
+ /* input=*/ r0.getOperand (0 ), blocksize, /* mode=*/ " CRD" , fusedLocation};
840
+ }
841
+
842
+ static std::nullopt_t reportFailureForCRDMode (std::string msg) {
843
+ // Can disable line below if not needed.
844
+ LLVM_DEBUG (llvm::dbgs () << " DepthToSpace [CRD] failure: " << msg << " \n " );
845
+ return std::nullopt;
846
+ }
847
+ };
848
+
704
849
struct RecomposeQLinearMatMulFromQuantizeLinearPattern
705
850
: public OpRewritePattern<ONNXQuantizeLinearOp> {
706
851
using OpRewritePattern<ONNXQuantizeLinearOp>::OpRewritePattern;
@@ -815,6 +960,11 @@ void RecomposeONNXToONNXPass::runOnOperation() {
815
960
return true ;
816
961
});
817
962
963
+ // Recompose DepthToSpace starting from reshape op
964
+ target.addDynamicallyLegalOp <ONNXReshapeOp>([](ONNXReshapeOp op) {
965
+ return !RecomposeDepthToSpaceCRD::matchDepthToSpaceCRDPattern (op);
966
+ });
967
+
818
968
// AMD Disabled
819
969
// // Recompose QLinearMatMul, starting from QuantizeLinear.
820
970
// // Pattern: DequanizeLinear + MatMul + QuantizeLinear.
@@ -841,6 +991,7 @@ void onnx_mlir::getRecomposeONNXToONNXPatterns(
841
991
MLIRContext *context = patterns.getContext ();
842
992
patterns.insert <RecomposeGeluFromMulPattern>(context);
843
993
patterns.insert <RecomposeLayerNormFromMulPattern>(context);
994
+ patterns.insert <RecomposeDepthToSpaceCRD>(context);
844
995
// AMD Disabled as downstream has no special support for it
845
996
// patterns.insert<RecomposeQLinearMatMulFromQuantizeLinearPattern>(context);
846
997
}
0 commit comments