17
17
#include " mlir/include/mlir/IR/OpDefinition.h" // from @llvm-project
18
18
#include " mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project
19
19
#include " mlir/include/mlir/IR/Value.h" // from @llvm-project
20
+ #include " mlir/include/mlir/IR/ValueRange.h" // from @llvm-project
20
21
#include " mlir/include/mlir/Support/LLVM.h" // from @llvm-project
21
22
#include " mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
22
23
@@ -272,6 +273,50 @@ struct BroadcastToExpandShape
272
273
}
273
274
};
274
275
276
+ struct RewriteTransposedVecmat
277
+ : public OpRewritePattern<mlir::linalg::VecmatOp> {
278
+ public:
279
+ RewriteTransposedVecmat (MLIRContext *context)
280
+ : OpRewritePattern<mlir::linalg::VecmatOp>(context) {}
281
+
282
+ using OpRewritePattern::OpRewritePattern;
283
+
284
+ LogicalResult matchAndRewrite (mlir::linalg::VecmatOp vecmatOp,
285
+ PatternRewriter &rewriter) const override {
286
+ auto transposeOp =
287
+ vecmatOp.getInputs ()[1 ].getDefiningOp <linalg::TransposeOp>();
288
+ if (!transposeOp) return failure ();
289
+
290
+ rewriter.replaceOpWithNewOp <linalg::MatvecOp>(
291
+ vecmatOp, vecmatOp.getResultTypes ()[0 ],
292
+ ValueRange{transposeOp.getInput (), vecmatOp.getInputs ()[0 ]},
293
+ vecmatOp.getDpsInits ()[0 ]);
294
+ return success ();
295
+ }
296
+ };
297
+
298
+ struct RewriteTransposedMatvec
299
+ : public OpRewritePattern<mlir::linalg::MatvecOp> {
300
+ public:
301
+ RewriteTransposedMatvec (MLIRContext *context)
302
+ : OpRewritePattern<mlir::linalg::MatvecOp>(context) {}
303
+
304
+ using OpRewritePattern::OpRewritePattern;
305
+
306
+ LogicalResult matchAndRewrite (mlir::linalg::MatvecOp matvecOp,
307
+ PatternRewriter &rewriter) const override {
308
+ auto transposeOp =
309
+ matvecOp.getInputs ()[0 ].getDefiningOp <linalg::TransposeOp>();
310
+ if (!transposeOp) return failure ();
311
+
312
+ rewriter.replaceOpWithNewOp <linalg::VecmatOp>(
313
+ matvecOp, matvecOp.getResultTypes ()[0 ],
314
+ ValueRange{matvecOp.getInputs ()[1 ], transposeOp.getInput ()},
315
+ matvecOp.getDpsInits ()[0 ]);
316
+ return success ();
317
+ }
318
+ };
319
+
275
320
struct LinalgCanonicalizations
276
321
: public impl::LinalgCanonicalizationsBase<LinalgCanonicalizations> {
277
322
void runOnOperation () override {
@@ -281,7 +326,8 @@ struct LinalgCanonicalizations
281
326
RewritePatternSet patterns (context);
282
327
patterns.add <FoldConstantLinalgTranspose, FoldConstantFill,
283
328
FoldConstantBroadcast, LinalgMapToElementwise,
284
- BroadcastToExpandShape>(context);
329
+ BroadcastToExpandShape, RewriteTransposedVecmat,
330
+ RewriteTransposedMatvec>(context);
285
331
286
332
// Run pattern matching and conversion
287
333
// TODO (#1221): Investigate whether folding (default: on) can be skipped
0 commit comments