Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Codegen][CPU] Enable scalable transfer lowerings #18170

Merged
merged 2 commits into from
Aug 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ void LLVMCPULowerExecutableTargetPass::runOnOperation() {
pipelineOpts.enableVectorMasking =
isX86(target) || isRISCV(target) ||
(isAArch64(target) && hasAnySVEFeature(target));
pipelineOpts.enableAArch64SSVE =
pipelineOpts.enableAArch64SME =
isAArch64(target) && hasAnySVEFeature(target) && hasSMEFeature(target);
pipelineOpts.enableAArch64I8mm = isAArch64(target) && hasI8mmFeature(target);
pipelineOpts.enablePeeling = isLoopPeelingEnabled(funcOp);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ void LLVMCPUVectorTransferLoweringPass::runOnOperation() {
/*maxTransferRank=*/1);
auto vectorTransferToSCFOptions =
VectorTransferToSCFOptions().enableFullUnroll();
if (enableScalableLowerings) {
vectorTransferToSCFOptions.enableLowerScalable();
}

populateVectorToSCFConversionPatterns(patterns, vectorTransferToSCFOptions);
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}
Expand Down
18 changes: 16 additions & 2 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,15 @@ void buildLLVMCPUVectorLoweringPipeline(
// lower them and can't be optimized away anymore.
funcPassManager.addPass(createCanonicalizerPass());

funcPassManager.addPass(createLLVMCPUVectorTransferLoweringPass());
LLVMCPUVectorTransferLoweringPassOptions transferLoweringOptions{};
if (!options.enableArmSME) {
// The ArmSME dialect has its own (more specific) lowerings for scalable
// vectors that occur later in the pipeline, so only enable the general
// lowerings if SME is not available.
transferLoweringOptions.enableScalableLowerings = true;
}
funcPassManager.addPass(
createLLVMCPUVectorTransferLoweringPass(transferLoweringOptions));
funcPassManager.addPass(createLLVMCPUVectorTransposeLoweringPass(
LLVMCPUVectorTransposeLoweringPassOptions{
options.lowerVectorTransposeToAVX2}));
Expand Down Expand Up @@ -354,6 +362,7 @@ void addCPUBufferOpsTileAndVectorizePipeline(
options.lowerVectorTransposeToAVX2 = pipelineOpt.lowerToAVX2;
options.splitVectorTransfersTo = "linalg-copy";
options.enableArmI8mm = pipelineOpt.enableAArch64I8mm;
options.enableArmSME = pipelineOpt.enableAArch64SME;
buildLLVMCPUVectorLoweringPipeline(funcPassManager, options);
}
}
Expand Down Expand Up @@ -396,7 +405,7 @@ void addMultiTilingExpertPassPipeline(OpPassManager &funcPassManager,
funcPassManager.addPass(createLLVMCPUPeelPass());
}

if (pipelineOpt.enableAArch64SSVE) {
if (pipelineOpt.enableAArch64SME) {
funcPassManager.addPass(createLLVMCPU2DScalableTo1DScalablePass());
}

Expand Down Expand Up @@ -432,6 +441,7 @@ void addMultiTilingExpertPassPipeline(OpPassManager &funcPassManager,
options.lowerVectorTransposeToAVX2 = pipelineOpt.lowerToAVX2;
options.splitVectorTransfersTo = "linalg-copy";
options.enableArmI8mm = pipelineOpt.enableAArch64I8mm;
options.enableArmSME = pipelineOpt.enableAArch64SME;
buildLLVMCPUVectorLoweringPipeline(funcPassManager, options);
}
}
Expand Down Expand Up @@ -494,6 +504,7 @@ void addConvTileAndDecomposeExpertPassPipeline(
options.lowerVectorTransposeToAVX2 = pipelineOpt.lowerToAVX2;
options.splitVectorTransfersTo = "shuffle";
options.enableArmI8mm = pipelineOpt.enableAArch64I8mm;
options.enableArmSME = pipelineOpt.enableAArch64SME;
buildLLVMCPUVectorLoweringPipeline(funcPassManager, options);
}
}
Expand Down Expand Up @@ -542,6 +553,7 @@ void addMmt4dTilingExpertPassPipeline(OpPassManager &funcPassManager,
options.lowerVectorTransposeToAVX2 = pipelineOpt.lowerToAVX2;
options.splitVectorTransfersTo = "linalg-copy";
options.enableArmI8mm = pipelineOpt.enableAArch64I8mm;
options.enableArmSME = pipelineOpt.enableAArch64SME;
buildLLVMCPUVectorLoweringPipeline(funcPassManager, options);
}

Expand Down Expand Up @@ -583,6 +595,7 @@ void addCPUDataTilingPipeline(OpPassManager &funcPassManager,
options.lowerVectorTransposeToAVX2 = pipelineOpt.lowerToAVX2;
options.splitVectorTransfersTo = "linalg-copy";
options.enableArmI8mm = pipelineOpt.enableAArch64I8mm;
options.enableArmSME = pipelineOpt.enableAArch64SME;
buildLLVMCPUVectorLoweringPipeline(funcPassManager, options);
}
}
Expand Down Expand Up @@ -623,6 +636,7 @@ void addCPULinalgExtTileAndVectorizePipeline(
options.lowerVectorTransposeToAVX2 = pipelineOpt.lowerToAVX2;
options.splitVectorTransfersTo = "linalg-copy";
options.enableArmI8mm = pipelineOpt.enableAArch64I8mm;
options.enableArmSME = pipelineOpt.enableAArch64SME;
buildLLVMCPUVectorLoweringPipeline(funcPassManager, options);
}
}
Expand Down
3 changes: 2 additions & 1 deletion compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ struct LLVMCPUVectorLoweringPassOptions {
std::string splitVectorTransfersTo = "";
bool lowerVectorTransposeToAVX2 = false;
bool enableArmI8mm = false;
bool enableArmSME = false;
};

std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
Expand Down Expand Up @@ -72,7 +73,7 @@ struct LLVMCPUPipelineOptions {
bool useConfiguredVectorSizes = true;
bool enablePeeling = false;
bool enableVectorMasking = false;
bool enableAArch64SSVE = false;
bool enableAArch64SME = false;
bool enableAArch64I8mm = false;
bool lowerToAVX2 = false;
};
Expand Down
5 changes: 5 additions & 0 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,11 @@ def LLVMCPUVirtualVectorLoweringPass :
def LLVMCPUVectorTransferLoweringPass :
InterfacePass<"iree-llvmcpu-vector-transfer-lowering", "mlir::FunctionOpInterface"> {
let summary = "Pass to lower transfer ops to simpler ops like `vector.load`, `vector.store`, `vector.broadcast`, and a set of scf ops.";
let options = [
Option<"enableScalableLowerings", "enable-scalable-lowerings", "bool",
/*default=*/"false",
"Enables scalable vector specific transfer lowerings">,
];
}

def LLVMCPUVectorTransposeLoweringPass :
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,3 +212,19 @@ func.func @gather_strided_memref() {
// CHECK-LABEL: func.func @gather_strided_memref
// CHECK-NOT: memref.subview {{.*}} : memref<2592000xf32, strided<[3]>
// CHECK-NOT: vector.gather %subview[%c0] [%7], %cst_0, %cst : memref<2592000xf32, strided<[3]>

// -----

func.func @scalable_transpose_store(%vec: vector<4x[4]xf32>, %dest: memref<?x?xf32>, %i: index, %j: index) {
%transpose = vector.transpose %vec, [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
vector.transfer_write %transpose, %dest[%i, %j] {in_bounds = [true, true]} : vector<[4]x4xf32>, memref<?x?xf32>
return
}

/// Note: The lowering for this is implemented/tested upstream (this just checks
/// it is enabled in IREE).

// CHECK-LABEL: func.func @scalable_transpose_store
// CHECK-NOT: vector.transpose
// CHECK: vector.store {{.*}} : memref<?x?xf32>, vector<4xf32>
// CHECK-NOT: vector.transpose
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need the second CHECK-NOT?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CHECK-NOT only checks between two matches. The first checks between func.func @scalable_transpose_store and vector.store, the second checks from vector.store to the end of the function (IIRC).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I understand what it is happening. My point is that the check of vector.store already shows that the lowering happens. In this case, why do we need to check if there is a vector.transpose followed by it?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also want to check that the transpose (which is not directly supported) is eliminated.

Loading