7
7
// ===----------------------------------------------------------------------===//
8
8
9
9
#include " air/Conversion/AIRLoweringPass.h"
10
- #include " air/Conversion/AIRPipeline.h"
11
10
#include " air/Dialect/AIR/AIRDialect.h"
12
11
#include " air/Dialect/AIRRt/AIRRtDialect.h"
13
12
#include " air/Dialect/AIRRt/AIRRtOps.h"
@@ -279,59 +278,6 @@ class AIRHerdConversion : public ConversionPattern {
279
278
}
280
279
};
281
280
282
- class AIRPipelineConversion : public ConversionPattern {
283
- public:
284
- explicit AIRPipelineConversion (MLIRContext *context)
285
- : ConversionPattern(air::HerdPipelineOp::getOperationName(), 1, context) {
286
- }
287
-
288
- LogicalResult
289
- matchAndRewrite (Operation *op, ArrayRef<Value> operands,
290
- ConversionPatternRewriter &rewriter) const override {
291
- auto pipeOp = cast<air::HerdPipelineOp>(op);
292
- Block &bb = pipeOp.getBody ().front ();
293
- rewriter.eraseOp (pipeOp.getBody ().back ().getTerminator ());
294
- bb.getOperations ().splice (Block::iterator (op), bb.getOperations ());
295
- rewriter.eraseOp (op);
296
- return success ();
297
- }
298
- };
299
-
300
- class AIRPipelinePutConversion : public ConversionPattern {
301
- public:
302
- explicit AIRPipelinePutConversion (MLIRContext *context)
303
- : ConversionPattern(air::PipelinePutOp::getOperationName(), 1, context) {}
304
-
305
- LogicalResult
306
- matchAndRewrite (Operation *op, ArrayRef<Value> operands,
307
- ConversionPatternRewriter &rewriter) const override {
308
- rewriter.eraseOp (op);
309
- return success ();
310
- }
311
- };
312
-
313
- class AIRPipelineGetConversion : public ConversionPattern {
314
- public:
315
- explicit AIRPipelineGetConversion (MLIRContext *context)
316
- : ConversionPattern(air::PipelineGetOp::getOperationName(), 1, context) {}
317
-
318
- LogicalResult
319
- matchAndRewrite (Operation *op, ArrayRef<Value> operands,
320
- ConversionPatternRewriter &rewriter) const override {
321
- auto getOp = cast<air::PipelineGetOp>(op);
322
- SmallVector<Value, 2 > gets;
323
- for (auto r : getOp.getResults ()) {
324
- if (auto ty = llvm::dyn_cast<RankedTensorType>(r.getType ()))
325
- gets.push_back (rewriter.create <bufferization::AllocTensorOp>(
326
- op->getLoc (), ty, ValueRange{}));
327
- else
328
- return failure ();
329
- }
330
- rewriter.replaceOp (op, gets);
331
- return success ();
332
- }
333
- };
334
-
335
281
class AIRWaitAllToAIRRtConversion : public OpConversionPattern <air::WaitAllOp> {
336
282
public:
337
283
using OpConversionPattern<air::WaitAllOp>::OpConversionPattern;
@@ -1136,32 +1082,6 @@ class AIRLoweringPass : public air::impl::AIRLoweringBase<AIRLoweringPass> {
1136
1082
signalPassFailure ();
1137
1083
}
1138
1084
1139
- // Replace the PipelineStageOps first, followed by the
1140
- // HerdPipelineOps, then run the rest of the patterns.
1141
- // This avoids creating invalid intermediate code with respect
1142
- // to the herd->pipeline->stages nesting requirements.
1143
-
1144
- // PipelineStageOp conversion
1145
- RewritePatternSet air_pipe_stage_patterns (context);
1146
- air_pipe_stage_patterns.insert <air::AIRPipeStageConversion>(
1147
- context, air::AIRPipeStageConversion::LoweringType::AllocBuffer);
1148
- if (failed (applyPartialConversion (module , target,
1149
- std::move (air_pipe_stage_patterns)))) {
1150
- emitError (UnknownLoc::get (context),
1151
- " error lowering air.pipeline.stage\n " );
1152
- signalPassFailure ();
1153
- }
1154
-
1155
- // HerdPipelineOp conversion
1156
- RewritePatternSet air_pipe_patterns (context);
1157
- air_pipe_patterns.insert <AIRPipelineConversion, AIRPipelineGetConversion,
1158
- AIRPipelinePutConversion>(context);
1159
- if (failed (applyPartialConversion (module , target,
1160
- std::move (air_pipe_patterns)))) {
1161
- emitError (UnknownLoc::get (context), " error lowering air.pipeline\n " );
1162
- signalPassFailure ();
1163
- }
1164
-
1165
1085
// DMA and HerdOp conversion
1166
1086
RewritePatternSet air_patterns (context);
1167
1087
@@ -1528,62 +1448,6 @@ class AIRLoweringPass : public air::impl::AIRLoweringBase<AIRLoweringPass> {
1528
1448
}
1529
1449
};
1530
1450
1531
- class AIRPipelineToAffinePass
1532
- : public air::impl::AIRPipelineToAffineBase<AIRPipelineToAffinePass> {
1533
-
1534
- public:
1535
- AIRPipelineToAffinePass () = default ;
1536
- AIRPipelineToAffinePass (const AIRPipelineToAffinePass &pass) {}
1537
-
1538
- void getDependentDialects (::mlir::DialectRegistry ®istry) const override {
1539
- registry.insert <affine::AffineDialect>();
1540
- }
1541
-
1542
- void runOnOperation () override {
1543
- auto module = getOperation ();
1544
- auto context = module .getContext ();
1545
-
1546
- ConversionTarget target (*context);
1547
-
1548
- target.addLegalDialect <
1549
- LLVM::LLVMDialect, func::FuncDialect, arith::ArithDialect,
1550
- affine::AffineDialect, scf::SCFDialect, linalg::LinalgDialect,
1551
- memref::MemRefDialect, bufferization::BufferizationDialect,
1552
- airrt::AIRRtDialect, air::airDialect>();
1553
-
1554
- target.addIllegalOp <air::PipelineStageOp, air::PipelineYieldOp>();
1555
-
1556
- // PipelineStageOp conversion
1557
- RewritePatternSet air_pipe_stage_patterns (context);
1558
- auto loweringType =
1559
- air::AIRPipeStageConversion::LoweringType::PipelineGetPut;
1560
- if (clLoweringType == " buffer" )
1561
- loweringType = air::AIRPipeStageConversion::LoweringType::AllocBuffer;
1562
- air_pipe_stage_patterns.insert <air::AIRPipeStageConversion>(context,
1563
- loweringType);
1564
- if (failed (applyPartialConversion (module , target,
1565
- std::move (air_pipe_stage_patterns)))) {
1566
- emitError (UnknownLoc::get (context),
1567
- " error lowering air.pipeline.stage\n " );
1568
- signalPassFailure ();
1569
- }
1570
-
1571
- SmallVector<Operation *, 8 > pipelines;
1572
- module .walk ([&](air::HerdPipelineOp p) { pipelines.push_back (p); });
1573
-
1574
- for (auto p : pipelines) {
1575
- auto pipeOp = cast<air::HerdPipelineOp>(p);
1576
- OpBuilder b (p);
1577
- Block &bb = pipeOp.getBody ().front ();
1578
- IRMapping remap;
1579
- bb.getTerminator ()->erase ();
1580
- for (auto &o : bb)
1581
- b.clone (o, remap);
1582
- p->erase ();
1583
- }
1584
- }
1585
- };
1586
-
1587
1451
} // namespace
1588
1452
1589
1453
namespace xilinx {
@@ -1593,9 +1457,5 @@ std::unique_ptr<mlir::Pass> createAIRLoweringPass() {
1593
1457
return std::make_unique<AIRLoweringPass>();
1594
1458
}
1595
1459
1596
- std::unique_ptr<mlir::Pass> createAIRPipelineToAffinePass () {
1597
- return std::make_unique<AIRPipelineToAffinePass>();
1598
- }
1599
-
1600
1460
} // namespace air
1601
1461
} // namespace xilinx
0 commit comments