Skip to content

Commit d37fb52

Browse files
asraacopybara-github
authored andcommitted
linalg: add pattern to rewrite matvec/vecmat with transposed matrices
PiperOrigin-RevId: 769621770
1 parent d2ded60 commit d37fb52

File tree

2 files changed

+82
-1
lines changed

2 files changed

+82
-1
lines changed

lib/Transforms/LinalgCanonicalizations/LinalgCanonicalizations.cpp

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "mlir/include/mlir/IR/OpDefinition.h" // from @llvm-project
1818
#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project
1919
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
20+
#include "mlir/include/mlir/IR/ValueRange.h" // from @llvm-project
2021
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
2122
#include "mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
2223

@@ -272,6 +273,50 @@ struct BroadcastToExpandShape
272273
}
273274
};
274275

276+
struct RewriteTransposedVecmat
277+
: public OpRewritePattern<mlir::linalg::VecmatOp> {
278+
public:
279+
RewriteTransposedVecmat(MLIRContext *context)
280+
: OpRewritePattern<mlir::linalg::VecmatOp>(context) {}
281+
282+
using OpRewritePattern::OpRewritePattern;
283+
284+
LogicalResult matchAndRewrite(mlir::linalg::VecmatOp vecmatOp,
285+
PatternRewriter &rewriter) const override {
286+
auto transposeOp =
287+
vecmatOp.getInputs()[1].getDefiningOp<linalg::TransposeOp>();
288+
if (!transposeOp) return failure();
289+
290+
rewriter.replaceOpWithNewOp<linalg::MatvecOp>(
291+
vecmatOp, vecmatOp.getResultTypes()[0],
292+
ValueRange{transposeOp.getInput(), vecmatOp.getInputs()[0]},
293+
vecmatOp.getDpsInits()[0]);
294+
return success();
295+
}
296+
};
297+
298+
struct RewriteTransposedMatvec
299+
: public OpRewritePattern<mlir::linalg::MatvecOp> {
300+
public:
301+
RewriteTransposedMatvec(MLIRContext *context)
302+
: OpRewritePattern<mlir::linalg::MatvecOp>(context) {}
303+
304+
using OpRewritePattern::OpRewritePattern;
305+
306+
LogicalResult matchAndRewrite(mlir::linalg::MatvecOp matvecOp,
307+
PatternRewriter &rewriter) const override {
308+
auto transposeOp =
309+
matvecOp.getInputs()[0].getDefiningOp<linalg::TransposeOp>();
310+
if (!transposeOp) return failure();
311+
312+
rewriter.replaceOpWithNewOp<linalg::VecmatOp>(
313+
matvecOp, matvecOp.getResultTypes()[0],
314+
ValueRange{matvecOp.getInputs()[1], transposeOp.getInput()},
315+
matvecOp.getDpsInits()[0]);
316+
return success();
317+
}
318+
};
319+
275320
struct LinalgCanonicalizations
276321
: public impl::LinalgCanonicalizationsBase<LinalgCanonicalizations> {
277322
void runOnOperation() override {
@@ -281,7 +326,8 @@ struct LinalgCanonicalizations
281326
RewritePatternSet patterns(context);
282327
patterns.add<FoldConstantLinalgTranspose, FoldConstantFill,
283328
FoldConstantBroadcast, LinalgMapToElementwise,
284-
BroadcastToExpandShape>(context);
329+
BroadcastToExpandShape, RewriteTransposedVecmat,
330+
RewriteTransposedMatvec>(context);
285331

286332
// Run pattern matching and conversion
287333
// TODO (#1221): Investigate whether folding (default: on) can be skipped
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
// RUN: heir-opt --linalg-canonicalizations --split-input-file %s | FileCheck %s
2+
3+
module {
4+
// CHECK: func @main
5+
// CHECK-SAME: %[[arg0:.*]]: tensor<512x784xf32>,
6+
// CHECK-SAME: %[[arg1:.*]]: tensor<784xf32>)
7+
// CHECK: %[[cst:.*]] = arith.constant dense<0.{{0*}}e+00> : tensor<512xf32>
8+
// CHECK: %[[v0:.*]] = linalg.matvec ins(%[[arg0]], %[[arg1]] : tensor<512x784xf32>, tensor<784xf32>) outs(%[[cst]] : tensor<512xf32>)
9+
// CHECK: return %[[v0]] : tensor<512xf32>
10+
func.func @main(%arg0: tensor<512x784xf32>, %arg2: tensor<784xf32>) -> tensor<512xf32> {
11+
%cst_0 = arith.constant dense<0.000000e+00> : tensor<512xf32>
12+
%0 = tensor.empty() : tensor<784x512xf32>
13+
%transposed = linalg.transpose ins(%arg0 : tensor<512x784xf32>) outs(%0 : tensor<784x512xf32>) permutation = [1, 0]
14+
%1 = linalg.vecmat ins(%arg2, %transposed : tensor<784xf32>, tensor<784x512xf32>) outs(%cst_0 : tensor<512xf32>) -> tensor<512xf32>
15+
return %1 : tensor<512xf32>
16+
}
17+
}
18+
19+
// -----
20+
21+
module {
22+
// CHECK: func @main
23+
// CHECK-SAME: %[[arg0:.*]]: tensor<784x512xf32>,
24+
// CHECK-SAME: %[[arg1:.*]]: tensor<784xf32>)
25+
// CHECK: %[[cst:.*]] = arith.constant dense<0.{{0*}}e+00> : tensor<512xf32>
26+
// CHECK: %[[v0:.*]] = linalg.vecmat ins(%[[arg1]], %[[arg0]] : tensor<784xf32>, tensor<784x512xf32>) outs(%[[cst]] : tensor<512xf32>)
27+
// CHECK: return %[[v0]] : tensor<512xf32>
28+
func.func @main(%arg0: tensor<784x512xf32>, %arg2: tensor<784xf32>) -> tensor<512xf32> {
29+
%cst_0 = arith.constant dense<0.000000e+00> : tensor<512xf32>
30+
%0 = tensor.empty() : tensor<512x784xf32>
31+
%transposed = linalg.transpose ins(%arg0 : tensor<784x512xf32>) outs(%0 : tensor<512x784xf32>) permutation = [1, 0]
32+
%1 = linalg.matvec ins(%transposed, %arg2 : tensor<512x784xf32>, tensor<784xf32>) outs(%cst_0 : tensor<512xf32>) -> tensor<512xf32>
33+
return %1 : tensor<512xf32>
34+
}
35+
}

0 commit comments

Comments
 (0)