Skip to content

Commit 953eba7

Browse files
[SYCL-MLIR] Fix merge
Signed-off-by: Tsang, Whitney <[email protected]>
1 parent da60cd4 commit 953eba7

File tree

16 files changed

+62
-49
lines changed

16 files changed

+62
-49
lines changed

mlir-sycl/include/mlir/Dialect/SYCL/IR/SYCLOps.td

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ include "mlir/IR/OpBase.td"
1717
include "mlir/IR/AttrTypeBase.td"
1818

1919
include "mlir/Dialect/SYCL/IR/SYCLOpInterfaces.td"
20-
include "mlir/Dialect/LLVMIR/LLVMOpsInterfaces.td"
20+
include "mlir/Dialect/LLVMIR/LLVMInterfaces.td"
2121
include "mlir/Dialect/MemRef/IR/MemRefBase.td"
2222
include "mlir/IR/BuiltinTypes.td"
2323
include "mlir/IR/BuiltinTypeInterfaces.td"
@@ -34,7 +34,6 @@ def SYCL_Dialect : Dialect {
3434
let name = "sycl";
3535
let cppNamespace = "::mlir::sycl";
3636
let useDefaultTypePrinterParser = 1;
37-
let useFoldAPI = kEmitFoldAdaptorFolder;
3837
let extraClassDeclaration = [{
3938
MethodRegistry methods;
4039

mlir/include/mlir/Dialect/GPU/IR/GPUOps.td

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,18 @@ def GPU_GPUFuncOp : GPU_Op<"func", [
382382
ArrayRef<Type> getCallableResults() { return getFunctionType()
383383
.getResults(); }
384384

385+
/// Returns the argument attributes for all callable region arguments or
386+
/// null if there are none.
387+
::mlir::ArrayAttr getCallableArgAttrs() {
388+
return getArgAttrs().value_or(nullptr);
389+
}
390+
391+
/// Returns the result attributes for all callable region results or
392+
/// null if there are none.
393+
::mlir::ArrayAttr getCallableResAttrs() {
394+
return getResAttrs().value_or(nullptr);
395+
}
396+
385397
}];
386398
let hasCustomAssemblyFormat = 1;
387399

polygeist/include/mlir/Dialect/Polygeist/IR/PolygeistBase.td

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ def Polygeist_Dialect : Dialect {
1515
let name = "polygeist";
1616
let cppNamespace = "::mlir::polygeist";
1717
let description = [{}];
18-
let useFoldAPI = kEmitFoldAdaptorFolder;
1918
}
2019

2120
#endif // POLYGEIST_BASE

polygeist/include/mlir/Dialect/Polygeist/IR/PolygeistOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
#define POLYGEIST_OPS
1111

1212
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
13-
include "mlir/Dialect/LLVMIR/LLVMOpsInterfaces.td"
13+
include "mlir/Dialect/LLVMIR/LLVMInterfaces.td"
1414
include "mlir/Dialect/Polygeist/IR/PolygeistBase.td"
1515
include "mlir/Interfaces/SideEffectInterfaces.td"
1616
include "mlir/Interfaces/ViewLikeInterface.td"

polygeist/lib/Dialect/Polygeist/IR/Ops.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1879,7 +1879,7 @@ OpFoldResult Pointer2MemrefOp::fold(FoldAdaptor operands) {
18791879
}
18801880
if (auto mc = getSource().getDefiningOp<LLVM::GEPOp>()) {
18811881
const LLVM::GEPIndicesAdaptor<ValueRange> &indices = mc.getIndices();
1882-
for (auto &iter : llvm::enumerate(indices)) {
1882+
for (const auto &iter : llvm::enumerate(indices)) {
18831883
if (indices.isDynamicIndex(iter.index()))
18841884
return nullptr;
18851885
if (!isa<IntegerAttr>(iter.value()))
@@ -2288,7 +2288,7 @@ struct AggressiveAllocaScopeInliner
22882288
Block *block = &op.getRegion().front();
22892289
Operation *terminator = block->getTerminator();
22902290
ValueRange results = terminator->getOperands();
2291-
rewriter.mergeBlockBefore(block, op);
2291+
rewriter.inlineBlockBefore(block, op);
22922292
rewriter.replaceOp(op, results);
22932293
rewriter.eraseOp(terminator);
22942294
return success();

polygeist/lib/Dialect/Polygeist/Transforms/LoopRestructure.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -669,7 +669,7 @@ void LoopRestructure::runOnRegion(DominanceInfo &domInfo, Region &region) {
669669
SmallVector<Value> results;
670670
llvm::append_range(results, terminator->getOperands());
671671
terminator->erase();
672-
B.mergeBlockBefore(block, exec);
672+
B.inlineBlockBefore(block, exec);
673673
exec.replaceAllUsesWith(results);
674674
exec.erase();
675675
}

polygeist/lib/Dialect/Polygeist/Transforms/OpenMPOpt.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -246,8 +246,8 @@ struct ParallelForInterchange : public OpRewritePattern<omp::ParallelOp> {
246246
PrevFor.getUpperBound(), PrevFor.getStep());
247247
auto *Yield = NextParallel.getRegion().front().getTerminator();
248248
NewFor.getRegion().takeBody(PrevFor.getRegion());
249-
Rewriter.mergeBlockBefore(&NextParallel.getRegion().front(),
250-
NewFor.getBody()->getTerminator());
249+
Rewriter.inlineBlockBefore(&NextParallel.getRegion().front(),
250+
NewFor.getBody()->getTerminator());
251251
Rewriter.setInsertionPoint(NewFor.getBody()->getTerminator());
252252
Rewriter.create<omp::BarrierOp>(NextParallel.getLoc());
253253

polygeist/lib/Dialect/Polygeist/Transforms/ParallelLoopDistribute.cpp

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -448,7 +448,7 @@ struct NormalizeLoop : public OpRewritePattern<scf::ForOp> {
448448
Value scaled = rewriter.create<MulIOp>(
449449
op.getLoc(), newForOp.getInductionVar(), op.getStep());
450450
Value iv = rewriter.create<AddIOp>(op.getLoc(), op.getLowerBound(), scaled);
451-
rewriter.mergeBlockBefore(op.getBody(), &newForOp.getBody()->back(), {iv});
451+
rewriter.inlineBlockBefore(op.getBody(), &newForOp.getBody()->back(), {iv});
452452
rewriter.eraseOp(&newForOp.getBody()->back());
453453
rewriter.eraseOp(op);
454454
return success();
@@ -460,11 +460,11 @@ struct NormalizeLoop : public OpRewritePattern<scf::ForOp> {
460460
static bool isNormalized(scf::ParallelOp op) {
461461
auto isZero = [](Value v) {
462462
APInt value;
463-
return matchPattern(v, m_ConstantInt(&value)) && value.isNullValue();
463+
return matchPattern(v, m_ConstantInt(&value)) && value.isZero();
464464
};
465465
auto isOne = [](Value v) {
466466
APInt value;
467-
return matchPattern(v, m_ConstantInt(&value)) && value.isOneValue();
467+
return matchPattern(v, m_ConstantInt(&value)) && value.isOne();
468468
};
469469
return llvm::all_of(op.getLowerBound(), isZero) &&
470470
llvm::all_of(op.getStep(), isOne);
@@ -519,8 +519,8 @@ struct NormalizeParallel : public OpRewritePattern<scf::ParallelOp> {
519519
inductionVars.push_back(shifted);
520520
}
521521

522-
rewriter.mergeBlockBefore(op.getBody(), &newOp.getBody()->back(),
523-
inductionVars);
522+
rewriter.inlineBlockBefore(op.getBody(), &newOp.getBody()->back(),
523+
inductionVars);
524524
rewriter.eraseOp(&newOp.getBody()->back());
525525
rewriter.eraseOp(op);
526526
return success();
@@ -1222,8 +1222,8 @@ static void moveBodiesIf(PatternRewriter &rewriter, T op, IfType ifOp,
12221222
}
12231223

12241224
rewriter.eraseOp(&getThenBlock(ifOp)->back());
1225-
rewriter.mergeBlockBefore(getThenBlock(ifOp),
1226-
&newParallel.getBody()->back());
1225+
rewriter.inlineBlockBefore(getThenBlock(ifOp),
1226+
&newParallel.getBody()->back());
12271227

12281228
insertRecomputables(rewriter, op, newParallel, ifOp);
12291229
}
@@ -1248,8 +1248,8 @@ static void moveBodiesIf(PatternRewriter &rewriter, T op, IfType ifOp,
12481248
});
12491249
}
12501250
rewriter.eraseOp(&getElseBlock(ifOp)->back());
1251-
rewriter.mergeBlockBefore(getElseBlock(ifOp),
1252-
&newParallel.getBody()->back());
1251+
rewriter.inlineBlockBefore(getElseBlock(ifOp),
1252+
&newParallel.getBody()->back());
12531253

12541254
insertRecomputables(rewriter, op, newParallel, ifOp);
12551255
}
@@ -1303,12 +1303,12 @@ static void moveBodiesFor(PatternRewriter &rewriter, T op, ForType forLoop,
13031303

13041304
// Merge in two stages so we can properly replace uses of two induction
13051305
// varibales defined in different blocks.
1306-
rewriter.mergeBlockBefore(op.getBody(), &newParallel.getBody()->back(),
1307-
newParallel.getBody()->getArguments());
1306+
rewriter.inlineBlockBefore(op.getBody(), &newParallel.getBody()->back(),
1307+
newParallel.getBody()->getArguments());
13081308
rewriter.eraseOp(&newParallel.getBody()->back());
13091309
rewriter.eraseOp(&forLoop.getBody()->back());
1310-
rewriter.mergeBlockBefore(forLoop.getBody(), &newParallel.getBody()->back(),
1311-
newForLoop.getBody()->getArguments());
1310+
rewriter.inlineBlockBefore(forLoop.getBody(), &newParallel.getBody()->back(),
1311+
newForLoop.getBody()->getArguments());
13121312
rewriter.eraseOp(op);
13131313
rewriter.eraseOp(forLoop);
13141314
}
@@ -1459,8 +1459,8 @@ template <typename T> struct InterchangeWhilePFor : public OpRewritePattern<T> {
14591459
auto beforeParallelOp = makeNewParallelOp();
14601460
auto afterParallelOp = makeNewParallelOp();
14611461

1462-
rewriter.mergeBlockBefore(&whileOp.getBefore().front(),
1463-
beforeParallelOp.getBody()->getTerminator());
1462+
rewriter.inlineBlockBefore(&whileOp.getBefore().front(),
1463+
beforeParallelOp.getBody()->getTerminator());
14641464
whileOp.getBefore().push_back(new Block());
14651465
conditionOp->moveBefore(&whileOp.getBefore().front(),
14661466
whileOp.getBefore().front().begin());
@@ -1469,8 +1469,8 @@ template <typename T> struct InterchangeWhilePFor : public OpRewritePattern<T> {
14691469

14701470
auto yieldOp = cast<scf::YieldOp>(whileOp.getAfter().front().back());
14711471

1472-
rewriter.mergeBlockBefore(&whileOp.getAfter().front(),
1473-
afterParallelOp.getBody()->getTerminator());
1472+
rewriter.inlineBlockBefore(&whileOp.getAfter().front(),
1473+
afterParallelOp.getBody()->getTerminator());
14741474
whileOp.getAfter().push_back(new Block());
14751475
yieldOp->moveBefore(&whileOp.getAfter().front(),
14761476
whileOp.getAfter().front().begin());
@@ -1578,8 +1578,8 @@ struct RotateWhile : public OpRewritePattern<scf::WhileOp> {
15781578
rewriter.setInsertionPoint(condition);
15791579
auto conditional =
15801580
rewriter.create<scf::IfOp>(op.getLoc(), condition.getCondition());
1581-
rewriter.mergeBlockBefore(&op.getAfter().front(),
1582-
&conditional.getBody()->back());
1581+
rewriter.inlineBlockBefore(&op.getAfter().front(),
1582+
&conditional.getBody()->back());
15831583
rewriter.eraseOp(&conditional.getBody()->back());
15841584

15851585
rewriter.createBlock(&op.getAfter());
@@ -1637,8 +1637,8 @@ struct Reg2MemFor : public OpRewritePattern<T> {
16371637
newRegionArguments);
16381638

16391639
auto oldTerminator = op.getBody()->getTerminator();
1640-
rewriter.mergeBlockBefore(op.getBody(), newOp.getBody()->getTerminator(),
1641-
newRegionArguments);
1640+
rewriter.inlineBlockBefore(op.getBody(), newOp.getBody()->getTerminator(),
1641+
newRegionArguments);
16421642
SmallVector<Value> oldOps;
16431643
llvm::append_range(oldOps, oldTerminator->getOperands());
16441644
rewriter.eraseOp(oldTerminator);
@@ -1652,7 +1652,8 @@ struct Reg2MemFor : public OpRewritePattern<T> {
16521652
}
16531653
rewriter.setInsertionPoint(IP);
16541654
for (auto en : llvm::enumerate(oldOps)) {
1655-
if (!en.value().getDefiningOp<LLVM::UndefOp>())
1655+
Value oldOp = en.value();
1656+
if (!oldOp.getDefiningOp<LLVM::UndefOp>())
16561657
rewriter.create<memref::StoreOp>(op.getLoc(), en.value(),
16571658
allocated[en.index()], ValueRange());
16581659
}

polygeist/lib/Dialect/Polygeist/Transforms/ParallelLower.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -326,9 +326,9 @@ void ParallelLower::runOnOperation() {
326326
launchArgs.push_back(launchOp.getBlockSizeX());
327327
launchArgs.push_back(launchOp.getBlockSizeY());
328328
launchArgs.push_back(launchOp.getBlockSizeZ());
329-
builder.mergeBlockBefore(&launchOp.getRegion().front(),
330-
threadr.getRegion().front().getTerminator(),
331-
launchArgs);
329+
builder.inlineBlockBefore(&launchOp.getRegion().front(),
330+
threadr.getRegion().front().getTerminator(),
331+
launchArgs);
332332

333333
auto container = threadr;
334334

polygeist/lib/Dialect/Polygeist/Transforms/RaiseToAffine.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ struct ForOpRaising : public OpRewritePattern<scf::ForOp> {
6767
return failure();
6868
}
6969

70-
auto getBounds = [](TypedValue<IndexType> bound, CmpIPredicate cmpPred,
70+
auto getBounds = [](Value bound, CmpIPredicate cmpPred,
7171
SmallVectorImpl<Value> &bounds) {
7272
SmallVector<Value> todo = {bound};
7373
while (todo.size()) {

0 commit comments

Comments
 (0)