Skip to content

Commit f2299cc

Browse files
committed
Annotate liveness making some assumptions (write + read pattern for LDS). Also update ReuseLDS pass accordingly.
1 parent 0677f3d commit f2299cc

30 files changed

+1027
-281
lines changed

mlir/include/mlir/Dialect/Rock/IR/RockOps.td

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -569,17 +569,23 @@ def Rock_GpuAllocOp:
569569
let hasVerifier = 1;
570570
}
571571

572-
// Annotate lifetime of memory allocation on GPU memory hierachy.
573-
def Rock_GpuDeallocOp:
574-
Rock_Op<"dealloc", [MemoryEffects<[MemFree<DefaultResource>]>]>,
575-
Arguments<(ins AnyMemRef:$memref)> {
576-
let summary = "Annotate lifetime of memory allocation on GPU";
572+
// Annotate the start of lifetime of a memory allocation on GPU.
573+
def Rock_LiveInOp : Rock_Op<"live_in">, Arguments<(ins AnyMemRef:$memref)> {
574+
let summary = "Annotate the start of lifetime of a memory allocation on GPU";
577575
let description = [{
578-
The `rock.dealloc` op annotates lifetime of memory allocation memory on GPU.
579-
- Address space 0 : global.
580-
- Address space 3 : LDS.
581-
- Address space 5 : private (VGPR).
582-
All other values would be considered as allocation on global.
576+
The `rock.live_in` op annotates the start of lifetime of a memory allocation on GPU.
577+
}];
578+
let assemblyFormat = [{
579+
$memref attr-dict `:` type($memref)
580+
}];
581+
let hasVerifier = 1;
582+
}
583+
584+
// Annotate the end of lifetime of a memory allocation on GPU.
585+
def Rock_LiveOutOp : Rock_Op<"live_out">, Arguments<(ins AnyMemRef:$memref)> {
586+
let summary = "Annotate the end of lifetime of a memory allocation on GPU";
587+
let description = [{
588+
The `rock.live_out` op annotates the end of lifetime of a memory allocation on GPU.
583589
}];
584590
let assemblyFormat = [{
585591
$memref attr-dict `:` type($memref)

mlir/include/mlir/Dialect/Rock/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ namespace rock {
4949
#define GEN_PASS_DECL_ROCKSORTDIMENSIONSMEMORYLAYOUTPASS
5050
#define GEN_PASS_DECL_ROCKFINDFIRSTGEMMINDEXPASS
5151
#define GEN_PASS_DECL_ROCKREMOVEOUTPUTALLOCPASS
52+
#define GEN_PASS_DECL_ROCKANNOTATELIVENESSPASS
5253

5354
#define GEN_PASS_REGISTRATION
5455
#include "mlir/Dialect/Rock/Passes.h.inc"

mlir/include/mlir/Dialect/Rock/Passes.td

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,8 @@ def RockOutputSwizzlePass : Pass<"rock-output-swizzle", "::mlir::func::FuncOp">
159159
}
160160

161161
def RockReuseLDSPass : Pass<"rock-reuse-lds", "::mlir::func::FuncOp"> {
162-
let summary = "This pass re-uses LDS memory by using the lifetime annotations (rock.dealloc)";
162+
let summary = "This pass re-uses LDS memory by using the lifetime "
163+
"annotations (rock.live_in, rock.live_out)";
163164
let dependentDialects = ["rock::RockDialect", "memref::MemRefDialect"];
164165
}
165166

@@ -194,4 +195,11 @@ def RockRemoveOutputAllocPass
194195
let dependentDialects = ["rock::RockDialect", "func::FuncDialect"];
195196
}
196197

198+
def RockAnnotateLivenessPass
199+
: Pass<"rock-annotate-liveness", "::mlir::func::FuncOp"> {
200+
let summary = "This pass annotates LDS memory with liveness ops "
201+
"(rock.live_in, rock.live_out)";
202+
let dependentDialects = ["rock::RockDialect", "memref::MemRefDialect"];
203+
}
204+
197205
#endif // MLIR_DIALECT_ROCK_PASSES

mlir/include/mlir/Dialect/Rock/utility/loweringUtils.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,9 @@ FailureOr<SmallVector<OpOperand *>>
214214
traceGemmOutputToGenericOps(Value matC, func::FuncOp func,
215215
const BufferDependencyAnalysis &deps);
216216

217+
// Get the LDS size of the memref
218+
std::optional<int64_t> getWorkgroupMemorySize(MemRefType type);
219+
217220
} // end namespace rock
218221
} // end namespace mlir
219222
#endif

mlir/lib/Conversion/RockToGPU/RockToGPU.cpp

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -102,18 +102,6 @@ struct MIGPUAllocRewritePattern : public OpRewritePattern<rock::GpuAllocOp> {
102102
}
103103
};
104104

105-
struct MIGPUDeallocRewritePattern
106-
: public OpRewritePattern<rock::GpuDeallocOp> {
107-
using OpRewritePattern<rock::GpuDeallocOp>::OpRewritePattern;
108-
109-
LogicalResult matchAndRewrite(rock::GpuDeallocOp op,
110-
PatternRewriter &b) const override {
111-
112-
b.eraseOp(op);
113-
return mlir::success();
114-
}
115-
};
116-
117105
template <typename Tmi, typename Tgpu>
118106
struct MIOpRewritePattern : public OpRewritePattern<Tmi> {
119107
using OpRewritePattern<Tmi>::OpRewritePattern;
@@ -345,7 +333,7 @@ void LowerRockOpsToGPUPass::runOnOperation() {
345333
RewritePatternSet patterns(ctx);
346334

347335
// rock-lowering
348-
patterns.add<MIGPUAllocRewritePattern, MIGPUDeallocRewritePattern,
336+
patterns.add<MIGPUAllocRewritePattern,
349337
MIOpRewritePattern<rock::WorkgroupBarrierOp, gpu::BarrierOp>,
350338
MIOpRewritePattern<rock::LDSBarrierOp, amdgpu::LDSBarrierOp>,
351339
WorkgroupIdRewritePattern,

mlir/lib/Dialect/Rock/IR/RockDialect.cpp

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1132,20 +1132,29 @@ LogicalResult GpuAllocOp::verify() {
11321132
}
11331133

11341134
//===-----------------------------------------------------===//
1135-
// GpuDeallocOp
1135+
// LiveInOp
11361136
//===-----------------------------------------------------===//
11371137

1138-
LogicalResult GpuDeallocOp::verify() {
1138+
LogicalResult LiveInOp::verify() {
11391139
// Make sure the input memref defining operation is a GpuAllocOp
1140-
if (auto gpuAlloc = dyn_cast<GpuAllocOp>(getMemref().getDefiningOp())) {
1141-
// Make sure the size is bigger than 0
1142-
if (getByteSize(getMemref().getType()) > 0) {
1143-
return success();
1144-
}
1145-
return emitError("The size of rock.dealloc should be greather than zero.");
1140+
if (!isa<GpuAllocOp>(getMemref().getDefiningOp())) {
1141+
return emitError("The operand of rock.live_in must be the result of a "
1142+
"rock.alloc operation.");
1143+
}
1144+
return success();
1145+
}
1146+
1147+
//===-----------------------------------------------------===//
1148+
// LiveOutOp
1149+
//===-----------------------------------------------------===//
1150+
1151+
LogicalResult LiveOutOp::verify() {
1152+
// Make sure the input memref defining operation is a GpuAllocOp
1153+
if (!isa<GpuAllocOp>(getMemref().getDefiningOp())) {
1154+
return emitError("The operand of rock.live_out must be the result of a "
1155+
"rock.alloc operation.");
11461156
}
1147-
return emitError("The operand of rock.dealloc must be the result of a "
1148-
"rock.alloc operation.");
1157+
return success();
11491158
}
11501159

11511160
//===-----------------------------------------------------===//

mlir/lib/Dialect/Rock/Pipelines/Pipelines.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,8 +180,14 @@ void rock::buildKernelPipeline(OpPassManager &pm,
180180
funcPm.addPass(createConvertLinalgToAffineLoopsPass());
181181
funcPm.addPass(rock::createRockVectorizeFusionsPass());
182182
}
183+
// We run reuse LDS before the output swizzle pass because it uses a heuristic
184+
// to determine whether to swizzle or not, and that heuristic needs the actual
185+
// LDS usage. After running output swizzle, we'll create a new LDS buffer and
186+
// we need to run reuse LDS again to be able to reuse LDS memory.
187+
funcPm.addPass(rock::createRockAnnotateLivenessPass());
183188
funcPm.addPass(rock::createRockReuseLDSPass());
184189
funcPm.addPass(rock::createRockOutputSwizzlePass());
190+
funcPm.addPass(rock::createRockAnnotateLivenessPass());
185191
funcPm.addPass(rock::createRockReuseLDSPass());
186192

187193
if (!options.enableApplicability) {

mlir/lib/Dialect/Rock/Transforms/AlignTiling.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1519,9 +1519,6 @@ static LogicalResult insertBlockwiseReduction(
15191519
/*extraViews=*/nullptr,
15201520
getBlockSize(reduceOp->getParentOfType<func::FuncOp>()).value());
15211521

1522-
ViewLikeOpInterface viewOp =
1523-
ldsWorkspace.getDefiningOp<ViewLikeOpInterface>();
1524-
GpuDeallocOp::create(rewriter, loc, viewOp.getViewSource());
15251522
// Create partial reduction views
15261523
ArrayAttr paddedReducedTrStack;
15271524
{

0 commit comments

Comments
 (0)