From b0cf72da113e6c7282733f8ba6bfcb7754a7495c Mon Sep 17 00:00:00 2001 From: Asra Ali Date: Wed, 11 Jun 2025 09:20:06 -0700 Subject: [PATCH] linalg: add pattern to rewrite matvec/vecmat with transposed matrices PiperOrigin-RevId: 770181162 --- .../LinalgCanonicalizations.cpp | 48 ++++++++++++++++++- .../transposed_matmuls.mlir | 35 ++++++++++++++ 2 files changed, 82 insertions(+), 1 deletion(-) create mode 100644 tests/Transforms/linalg_canonicalizations/transposed_matmuls.mlir diff --git a/lib/Transforms/LinalgCanonicalizations/LinalgCanonicalizations.cpp b/lib/Transforms/LinalgCanonicalizations/LinalgCanonicalizations.cpp index 87b2dce88..b1ee6884d 100644 --- a/lib/Transforms/LinalgCanonicalizations/LinalgCanonicalizations.cpp +++ b/lib/Transforms/LinalgCanonicalizations/LinalgCanonicalizations.cpp @@ -17,6 +17,7 @@ #include "mlir/include/mlir/IR/OpDefinition.h" // from @llvm-project #include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/include/mlir/IR/Value.h" // from @llvm-project +#include "mlir/include/mlir/IR/ValueRange.h" // from @llvm-project #include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project #include "mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project @@ -272,6 +273,50 @@ struct BroadcastToExpandShape } }; +struct RewriteTransposedVecmat + : public OpRewritePattern { + public: + RewriteTransposedVecmat(MLIRContext *context) + : OpRewritePattern(context) {} + + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mlir::linalg::VecmatOp vecmatOp, + PatternRewriter &rewriter) const override { + auto transposeOp = + vecmatOp.getInputs()[1].getDefiningOp(); + if (!transposeOp) return failure(); + + rewriter.replaceOpWithNewOp( + vecmatOp, vecmatOp.getResultTypes()[0], + ValueRange{transposeOp.getInput(), vecmatOp.getInputs()[0]}, + vecmatOp.getDpsInits()[0]); + return success(); + } +}; + +struct RewriteTransposedMatvec + : public OpRewritePattern { + public: + RewriteTransposedMatvec(MLIRContext *context) + : OpRewritePattern(context) {} + + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mlir::linalg::MatvecOp matvecOp, + PatternRewriter &rewriter) const override { + auto transposeOp = + matvecOp.getInputs()[0].getDefiningOp(); + if (!transposeOp) return failure(); + + rewriter.replaceOpWithNewOp( + matvecOp, matvecOp.getResultTypes()[0], + ValueRange{matvecOp.getInputs()[1], transposeOp.getInput()}, + matvecOp.getDpsInits()[0]); + return success(); + } +}; + struct LinalgCanonicalizations : public impl::LinalgCanonicalizationsBase { void runOnOperation() override { @@ -281,7 +326,8 @@ struct LinalgCanonicalizations RewritePatternSet patterns(context); patterns.add(context); + BroadcastToExpandShape, RewriteTransposedVecmat, + RewriteTransposedMatvec>(context); // Run pattern matching and conversion // TODO (#1221): Investigate whether folding (default: on) can be skipped diff --git a/tests/Transforms/linalg_canonicalizations/transposed_matmuls.mlir b/tests/Transforms/linalg_canonicalizations/transposed_matmuls.mlir new file mode 100644 index 000000000..48c10c2ba --- /dev/null +++ b/tests/Transforms/linalg_canonicalizations/transposed_matmuls.mlir @@ -0,0 +1,35 @@ +// RUN: heir-opt --linalg-canonicalizations --split-input-file %s | FileCheck %s + +module { + // CHECK: func @main + // CHECK-SAME: %[[arg0:.*]]: tensor<512x784xf32>, + // CHECK-SAME: %[[arg1:.*]]: tensor<784xf32>) + // CHECK: %[[cst:.*]] = arith.constant dense<0.{{0*}}e+00> : tensor<512xf32> + // CHECK: %[[v0:.*]] = linalg.matvec ins(%[[arg0]], %[[arg1]] : tensor<512x784xf32>, tensor<784xf32>) outs(%[[cst]] : tensor<512xf32>) + // CHECK: return %[[v0]] : tensor<512xf32> + func.func @main(%arg0: tensor<512x784xf32>, %arg2: tensor<784xf32>) -> tensor<512xf32> { + %cst_0 = arith.constant dense<0.000000e+00> : tensor<512xf32> + %0 = tensor.empty() : tensor<784x512xf32> + %transposed = linalg.transpose ins(%arg0 : tensor<512x784xf32>) outs(%0 : tensor<784x512xf32>) permutation = [1, 0] + %1 = linalg.vecmat ins(%arg2, %transposed : tensor<784xf32>, tensor<784x512xf32>) outs(%cst_0 : tensor<512xf32>) -> tensor<512xf32> + return %1 : tensor<512xf32> + } +} + +// ----- + +module { + // CHECK: func @main + // CHECK-SAME: %[[arg0:.*]]: tensor<784x512xf32>, + // CHECK-SAME: %[[arg1:.*]]: tensor<784xf32>) + // CHECK: %[[cst:.*]] = arith.constant dense<0.{{0*}}e+00> : tensor<512xf32> + // CHECK: %[[v0:.*]] = linalg.vecmat ins(%[[arg1]], %[[arg0]] : tensor<784xf32>, tensor<784x512xf32>) outs(%[[cst]] : tensor<512xf32>) + // CHECK: return %[[v0]] : tensor<512xf32> + func.func @main(%arg0: tensor<784x512xf32>, %arg2: tensor<784xf32>) -> tensor<512xf32> { + %cst_0 = arith.constant dense<0.000000e+00> : tensor<512xf32> + %0 = tensor.empty() : tensor<512x784xf32> + %transposed = linalg.transpose ins(%arg0 : tensor<784x512xf32>) outs(%0 : tensor<512x784xf32>) permutation = [1, 0] + %1 = linalg.matvec ins(%transposed, %arg2 : tensor<512x784xf32>, tensor<784xf32>) outs(%cst_0 : tensor<512xf32>) -> tensor<512xf32> + return %1 : tensor<512xf32> + } +}