Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 6f40941

Browse files
committedJun 18, 2025·
Implement DepthToSpace CDR mode recomposition
1 parent b742d71 commit 6f40941

File tree

2 files changed

+227
-0
lines changed

2 files changed

+227
-0
lines changed
 

‎src/Dialect/ONNX/Transforms/Recompose.cpp

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -701,6 +701,151 @@ struct RecomposeGeluFromMulPattern : public OpRewritePattern<ONNXMulOp> {
701701
}
702702
};
703703

704+
struct RecomposeDepthToSpaceCRD : public OpRewritePattern<ONNXReshapeOp> {
705+
using OpRewritePattern<ONNXReshapeOp>::OpRewritePattern;
706+
707+
LogicalResult matchAndRewrite(
708+
ONNXReshapeOp reshapeOp, PatternRewriter &rewriter) const final {
709+
using namespace onnx_mlir;
710+
Location loc = ONNXLoc<ONNXReshapeOp>(reshapeOp);
711+
712+
std::optional<DepthToSpaceRecompositionResult> result =
713+
matchDepthToSpaceCRDPattern(reshapeOp);
714+
if (!result) {
715+
return failure();
716+
}
717+
718+
MultiDialectBuilder<OnnxBuilder> create(rewriter, result->fusedLocation);
719+
rewriter.replaceOp(
720+
reshapeOp, create.onnx.createOpAndInferShapes<ONNXDepthToSpaceOp>(
721+
reshapeOp.getType(), result->input, result->blockSize,
722+
result->mode));
723+
return success();
724+
}
725+
726+
// Result of attempting recomposing DepthToSpace. Contains useful information
727+
// for the matching
728+
struct DepthToSpaceRecompositionResult {
729+
Value input;
730+
int64_t blockSize;
731+
std::string mode;
732+
Location fusedLocation;
733+
};
734+
735+
static std::optional<DepthToSpaceRecompositionResult>
736+
matchDepthToSpaceCRDPattern(ONNXReshapeOp reshapeOp) {
737+
using namespace onnx_mlir;
738+
// DepthToSpace mode CRD match:
739+
// DepthToSpace(x) =
740+
// %r0 = reshape %x NxCxHxW -> NxC//(B*B)xBxBxHxW
741+
// %t = transpose %r0 perm=[0, 1, 4, 2, 5, 3]
742+
// %r1 = reshape NxC//(B*B)xHxBxWxB -> NxC//(B*B)x(HxB)x(WxB)
743+
744+
ONNXReshapeOp r0;
745+
ONNXTransposeOp t;
746+
ONNXReshapeOp r1 = reshapeOp;
747+
748+
t = r1->getOperand(0).getDefiningOp<ONNXTransposeOp>();
749+
if (!t) {
750+
return reportFailureForCRDMode("missing transpose");
751+
}
752+
r0 = t->getOperand(0).getDefiningOp<ONNXReshapeOp>();
753+
if (!r0) {
754+
return reportFailureForCRDMode("missing first reshape");
755+
}
756+
757+
auto hasShapedStaticType = [](Type ty) {
758+
auto shapedType = dyn_cast<ShapedType>(ty);
759+
return shapedType && shapedType.hasStaticShape();
760+
};
761+
762+
const bool haveOperationsValidTy =
763+
llvm::all_of(TypeRange{r0.getOperand(0).getType(), r0.getType(),
764+
t.getType(), r1.getType()},
765+
hasShapedStaticType);
766+
if (!haveOperationsValidTy) {
767+
return reportFailureForCRDMode(
768+
"pattern operations have no shaped static tensor types");
769+
}
770+
771+
auto fstReshapeInTy = cast<ShapedType>(r0->getOperand(0).getType());
772+
ArrayRef<int64_t> fstReshapeInShape = fstReshapeInTy.getShape();
773+
const size_t fstReshapeInRank = fstReshapeInTy.getRank();
774+
if (fstReshapeInRank != 4) {
775+
return reportFailureForCRDMode("input rank is not 4D ");
776+
}
777+
778+
auto fstReshapeOutTy = cast<ShapedType>(r0.getType());
779+
ArrayRef<int64_t> fstReshapeOutShape = fstReshapeOutTy.getShape();
780+
const size_t fstReshapeOutRank = fstReshapeOutTy.getRank();
781+
if (fstReshapeOutRank != 6) {
782+
return reportFailureForCRDMode("output rank of first reshape is not 6D");
783+
}
784+
785+
// Check for concrete reshape pattern:
786+
// reshape %x NxCxHxW -> NxC//(B*B)xBxBxHxW
787+
const int64_t blocksize = fstReshapeOutShape[2];
788+
if (blocksize != fstReshapeOutShape[3]) {
789+
return reportFailureForCRDMode("blocksize do not match in dim 2 and 3");
790+
}
791+
792+
if (fstReshapeInShape[0] != fstReshapeOutShape[0] ||
793+
fstReshapeInShape[1] != fstReshapeOutShape[1] * blocksize * blocksize ||
794+
fstReshapeInShape[2] != fstReshapeOutShape[4] ||
795+
fstReshapeInShape[3] != fstReshapeOutShape[5]) {
796+
return reportFailureForCRDMode("unexpected first reshape result shape");
797+
}
798+
799+
// Check for concrete permutation pattern:
800+
// transpose %r0 perm=[0, 1, 4, 2, 5, 3]
801+
std::optional<ArrayAttr> permOpt = t.getPerm();
802+
if (!permOpt) {
803+
return reportFailureForCRDMode("missing permutation on transpose");
804+
}
805+
806+
// Get transpose permutation
807+
SmallVector<int64_t, 6> perms;
808+
ArrayAttrIntVals(*permOpt, perms);
809+
810+
// Check for transpose permutation
811+
constexpr std::array<int64_t, 6> expectedPerms = {0, 1, 4, 2, 5, 3};
812+
if (perms != ArrayRef(expectedPerms)) {
813+
return reportFailureForCRDMode("unexpected permutations");
814+
}
815+
816+
// Check for concrete reshape pattern:
817+
// reshape NxC//(B*B)xHxBxWxB -> NxC//(B*B)x(HxB)x(WxB)
818+
auto sndReshapeInTy = cast<ShapedType>(t.getType());
819+
ArrayRef<int64_t> sndReshapeInShape = sndReshapeInTy.getShape();
820+
821+
auto sndReshapeOutTy = cast<ShapedType>(r1.getType());
822+
ArrayRef<int64_t> sndReshapeOutShape = sndReshapeOutTy.getShape();
823+
const size_t sndReshapeOutRank = sndReshapeOutTy.getRank();
824+
if (sndReshapeOutRank != 4) {
825+
return reportFailureForCRDMode("out rank of second reshape is not 4D");
826+
}
827+
828+
if (sndReshapeInShape[0] != sndReshapeOutShape[0] ||
829+
sndReshapeInShape[1] != sndReshapeOutShape[1] ||
830+
sndReshapeInShape[2] * sndReshapeInShape[3] != sndReshapeOutShape[2] ||
831+
sndReshapeInShape[4] * sndReshapeInShape[5] != sndReshapeOutShape[3]) {
832+
return reportFailureForCRDMode("unexpected second reshape result shape");
833+
}
834+
835+
Location fusedLocation = FusedLoc::get(
836+
reshapeOp->getContext(), {r0->getLoc(), t->getLoc(), r1->getLoc()});
837+
838+
return DepthToSpaceRecompositionResult{
839+
/*input=*/r0.getOperand(0), blocksize, /*mode=*/"CRD", fusedLocation};
840+
}
841+
842+
static std::nullopt_t reportFailureForCRDMode(std::string msg) {
843+
// Can disable line below if not needed.
844+
LLVM_DEBUG(llvm::dbgs() << "DepthToSpace [CRD] failure: " << msg << "\n");
845+
return std::nullopt;
846+
}
847+
};
848+
704849
struct RecomposeQLinearMatMulFromQuantizeLinearPattern
705850
: public OpRewritePattern<ONNXQuantizeLinearOp> {
706851
using OpRewritePattern<ONNXQuantizeLinearOp>::OpRewritePattern;
@@ -815,6 +960,11 @@ void RecomposeONNXToONNXPass::runOnOperation() {
815960
return true;
816961
});
817962

963+
// Recompose DepthToSpace starting from reshape op
964+
target.addDynamicallyLegalOp<ONNXReshapeOp>([](ONNXReshapeOp op) {
965+
return !RecomposeDepthToSpaceCRD::matchDepthToSpaceCRDPattern(op);
966+
});
967+
818968
// AMD Disabled
819969
// // Recompose QLinearMatMul, starting from QuantizeLinear.
820970
// // Pattern: DequanizeLinear + MatMul + QuantizeLinear.
@@ -841,6 +991,7 @@ void onnx_mlir::getRecomposeONNXToONNXPatterns(
841991
MLIRContext *context = patterns.getContext();
842992
patterns.insert<RecomposeGeluFromMulPattern>(context);
843993
patterns.insert<RecomposeLayerNormFromMulPattern>(context);
994+
patterns.insert<RecomposeDepthToSpaceCRD>(context);
844995
// AMD Disabled as downstream has no special support for it
845996
// patterns.insert<RecomposeQLinearMatMulFromQuantizeLinearPattern>(context);
846997
}

‎test/mlir/onnx/onnx_recompose.mlir

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -699,3 +699,79 @@ func.func @test_gelu_erf_two_adds(%arg0: tensor<?x?x3072xf32>, %arg1: tensor<307
699699
// CHECK: [[VAR_3_:%.+]] = "onnx.MatMul"([[VAR_2_]], [[PARAM_1_]]) : (tensor<?x?x3072xf32>, tensor<3072x768xf32>) -> tensor<?x?x768xf32>
700700
// CHECK: return [[VAR_3_]] : tensor<?x?x768xf32>
701701
// CHECK: }
702+
703+
// -----
704+
705+
func.func @test_depth_to_space_cdr(%arg0: tensor<1x128x540x960xf32>) -> tensor<1x32x1080x1920xf32> {
706+
%0 = onnx.Constant dense<[-1, 32, 2, 2, 540, 960]> : tensor<6xi64>
707+
%1 = onnx.Constant dense<[-1, 32, 1080, 1920]> : tensor<4xi64>
708+
%2 = "onnx.Reshape"(%arg0, %0) {allowzero = 0 : si64} : (tensor<1x128x540x960xf32>, tensor<6xi64>) -> tensor<1x32x2x2x540x960xf32>
709+
%3 = "onnx.Transpose"(%2) {perm = [0, 1, 4, 2, 5, 3]} : (tensor<1x32x2x2x540x960xf32>) -> tensor<1x32x540x2x960x2xf32>
710+
%4 = "onnx.Reshape"(%3, %1) {allowzero = 0 : si64} : (tensor<1x32x540x2x960x2xf32>, tensor<4xi64>) -> tensor<1x32x1080x1920xf32>
711+
return %4 : tensor<1x32x1080x1920xf32>
712+
}
713+
// CHECK-LABEL:func.func @test_depth_to_space_cdr
714+
// CHECK-SAME: (%[[PARAM_1:.+]]: tensor<1x128x540x960xf32>) -> tensor<1x32x1080x1920xf32>
715+
// CHECK: %[[DTS:.+]] = "onnx.DepthToSpace"(%[[PARAM_1]]) {blocksize = 2 : si64, mode = "CRD"} : (tensor<1x128x540x960xf32>) -> tensor<1x32x1080x1920xf32>
716+
// CHECK: return %[[DTS]] : tensor<1x32x1080x1920xf32>
717+
// CHECK:}
718+
719+
// -----
720+
721+
func.func @test_depth_to_space_cdr_missing_transpose_perm(%arg0: tensor<1x128x540x960xf32>) -> tensor<1x32x1080x1920xf32> {
722+
%0 = onnx.Constant dense<[-1, 32, 2, 2, 540, 960]> : tensor<6xi64>
723+
%1 = onnx.Constant dense<[-1, 32, 1080, 1920]> : tensor<4xi64>
724+
%2 = "onnx.Reshape"(%arg0, %0) {allowzero = 0 : si64} : (tensor<1x128x540x960xf32>, tensor<6xi64>) -> tensor<1x32x2x2x540x960xf32>
725+
%3 = "onnx.Transpose"(%2) : (tensor<1x32x2x2x540x960xf32>) -> tensor<1x32x540x2x960x2xf32>
726+
%4 = "onnx.Reshape"(%3, %1) {allowzero = 0 : si64} : (tensor<1x32x540x2x960x2xf32>, tensor<4xi64>) -> tensor<1x32x1080x1920xf32>
727+
return %4 : tensor<1x32x1080x1920xf32>
728+
}
729+
// CHECK-NOT: onnx.DepthToSpace
730+
731+
// -----
732+
733+
func.func @test_depth_to_space_cdr_unexpected_first_reshape_result(%arg0: tensor<1x128x540x960xf32>) -> tensor<1x32x540x3840xf32> {
734+
%0 = onnx.Constant dense<[-1, 32, 1, 4, 540, 960]> : tensor<6xi64>
735+
%1 = onnx.Constant dense<[-1, 32, 524, 3840]> : tensor<4xi64>
736+
%2 = "onnx.Reshape"(%arg0, %0) {allowzero = 0 : si64} : (tensor<1x128x540x960xf32>, tensor<6xi64>) -> tensor<1x32x1x4x540x960xf32>
737+
%3 = "onnx.Transpose"(%2) {perm = [0, 1, 4, 2, 5, 3]} : (tensor<1x32x1x4x540x960xf32>) -> tensor<1x32x540x1x960x4xf32>
738+
%4 = "onnx.Reshape"(%3, %1) {allowzero = 0 : si64} : (tensor<1x32x540x1x960x4xf32>, tensor<4xi64>) -> tensor<1x32x540x3840xf32>
739+
return %4 : tensor<1x32x540x3840xf32>
740+
}
741+
// CHECK-NOT: onnx.DepthToSpace
742+
743+
// -----
744+
745+
func.func @test_depth_to_space_cdr_unexpected_perm(%arg0: tensor<1x128x540x960xf32>) -> tensor<1x32x1080x1920xf32> {
746+
%0 = onnx.Constant dense<[-1, 32, 2, 2, 540, 960]> : tensor<6xi64>
747+
%1 = onnx.Constant dense<[-1, 32, 1080, 1920]> : tensor<4xi64>
748+
%2 = "onnx.Reshape"(%arg0, %0) {allowzero = 0 : si64} : (tensor<1x128x540x960xf32>, tensor<6xi64>) -> tensor<1x32x2x2x540x960xf32>
749+
%3 = "onnx.Transpose"(%2) {perm = [0, 1, 4, 3, 5, 2]} : (tensor<1x32x2x2x540x960xf32>) -> tensor<1x32x540x2x960x2xf32>
750+
%4 = "onnx.Reshape"(%3, %1) {allowzero = 0 : si64} : (tensor<1x32x540x2x960x2xf32>, tensor<4xi64>) -> tensor<1x32x1080x1920xf32>
751+
return %4 : tensor<1x32x1080x1920xf32>
752+
}
753+
// CHECK-NOT: onnx.DepthToSpace
754+
755+
// -----
756+
757+
func.func @test_depth_to_space_cdr_unexpected_second_reshape_result(%arg0: tensor<1x128x540x960xf32>) -> tensor<1x1x32x1080x1920xf32> {
758+
%0 = onnx.Constant dense<[-1, 32, 2, 2, 540, 960]> : tensor<6xi64>
759+
%1 = onnx.Constant dense<[-1, 1, 32, 1080, 1920]> : tensor<5xi64>
760+
%2 = "onnx.Reshape"(%arg0, %0) {allowzero = 0 : si64} : (tensor<1x128x540x960xf32>, tensor<6xi64>) -> tensor<1x32x2x2x540x960xf32>
761+
%3 = "onnx.Transpose"(%2) {perm = [0, 1, 4, 2, 5, 3]} : (tensor<1x32x2x2x540x960xf32>) -> tensor<1x32x540x2x960x2xf32>
762+
%4 = "onnx.Reshape"(%3, %1) {allowzero = 0 : si64} : (tensor<1x32x540x2x960x2xf32>, tensor<5xi64>) -> tensor<1x1x32x1080x1920xf32>
763+
return %4 : tensor<1x1x32x1080x1920xf32>
764+
}
765+
// CHECK-NOT: onnx.DepthToSpace
766+
767+
// -----
768+
769+
func.func @test_depth_to_space_cdr_not_static_shapes(%arg0: tensor<*xf32>) -> tensor<*xf32> {
770+
%0 = onnx.Constant dense<[-1, 32, 2, 2, 540, 960]> : tensor<6xi64>
771+
%1 = onnx.Constant dense<[-1, 32, 1080, 1920]> : tensor<4xi64>
772+
%2 = "onnx.Reshape"(%arg0, %0) {allowzero = 0 : si64} : (tensor<*xf32>, tensor<6xi64>) -> tensor<*xf32>
773+
%3 = "onnx.Transpose"(%2) {perm = [0, 1, 4, 2, 5, 3]} : (tensor<*xf32>) -> tensor<*xf32>
774+
%4 = "onnx.Reshape"(%3, %1) {allowzero = 0 : si64} : (tensor<*xf32>, tensor<4xi64>) -> tensor<*xf32>
775+
return %4 : tensor<*xf32>
776+
}
777+
// CHECK-NOT: onnx.DepthToSpace

0 commit comments

Comments
 (0)
Please sign in to comment.