Skip to content

Commit

Permalink
[Codegen] Add canonicalization pass to track lowering configs (#19138)
Browse files Browse the repository at this point in the history
This allows us to retain lowering configs (or other discardable
attributes we need) through canonicalization patterns. This patch only
replaces canonicalizer uses before bufferization/vectorization as
currently those are the only places where we rely on lowering configs.
  • Loading branch information
qedawkins authored Nov 13, 2024
1 parent 1c43bcd commit 11fe5cd
Show file tree
Hide file tree
Showing 10 changed files with 167 additions and 35 deletions.
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/Common/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ iree_compiler_cc_library(
"BufferizeCopyOnlyDispatchesPass.cpp",
"CleanupBufferAllocViewPass.cpp",
"ConcretizePadResultShape.cpp",
"ConfigTrackingCanonicalizer.cpp",
"ConvertBf16ArithToF32.cpp",
"ConvertBf16ToUInt16Buffers.cpp",
"ConvertToDestinationPassingStylePass.cpp",
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ iree_cc_library(
"BufferizeCopyOnlyDispatchesPass.cpp"
"CleanupBufferAllocViewPass.cpp"
"ConcretizePadResultShape.cpp"
"ConfigTrackingCanonicalizer.cpp"
"ConvertBf16ArithToF32.cpp"
"ConvertBf16ToUInt16Buffers.cpp"
"ConvertToDestinationPassingStylePass.cpp"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Codegen/Common/Passes.h"
#include "iree/compiler/Codegen/Common/Transforms.h"
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
Expand Down Expand Up @@ -138,7 +139,11 @@ class ConcretizePadResultShapePass final
{
RewritePatternSet patterns(context);
populateConcretizePadResultShapePatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
GreedyRewriteConfig config;
auto listener = ConfigTrackingListener();
config.listener = &listener;
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns),
config))) {
return signalPassFailure();
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Codegen/Common/Passes.h"
#include "iree/compiler/Codegen/Common/Transforms.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#define DEBUG_TYPE "iree-codegen-config-tracking-canonicalizer"

namespace mlir::iree_compiler {

#define GEN_PASS_DEF_CONFIGTRACKINGCANONICALIZERPASS
#include "iree/compiler/Codegen/Common/Passes.h.inc"

static Operation *skipCastsDefiningOp(Value v) {
auto producer = v.getDefiningOp();
while (auto castProducer = dyn_cast<tensor::CastOp>(producer)) {
producer = castProducer.getSource().getDefiningOp();
}
return producer;
}

void ConfigTrackingListener::notifyOperationReplaced(Operation *op,
ValueRange replacement) {
// We have no way to track replacements without a producer.
if (replacement.empty()) {
return;
}

IREE::Codegen::LoweringConfigAttrInterface loweringConfig =
getLoweringConfig(op);
if (!loweringConfig) {
return;
}

// Must have a producer of the same type to track the lowering config.
auto producer = skipCastsDefiningOp(replacement.front());
if (!producer || producer->getName() != op->getName()) {
return;
}

for (auto v : replacement.drop_front()) {
// Conservatively require that all replacements are produced by the same
// operation.
if (skipCastsDefiningOp(v) != producer) {
return;
}
}

// No need to add the lowering config if it's already present.
if (getLoweringConfig(producer)) {
return;
}

setLoweringConfig(producer, loweringConfig);
}

namespace {

/// Add the corresponding fast-math flags to operations given a floating-point
/// optimization mode.
// TODO: For now we only allow default flags, such as arithmetic reassociation.
struct ConfigTrackingCanonicalizerPass final
: impl::ConfigTrackingCanonicalizerPassBase<
ConfigTrackingCanonicalizerPass> {
public:
using impl::ConfigTrackingCanonicalizerPassBase<
ConfigTrackingCanonicalizerPass>::ConfigTrackingCanonicalizerPassBase;
/// Initialize the canonicalizer by building the set of patterns used during
/// execution.
LogicalResult initialize(MLIRContext *context) override {
// Inherit the same config defaults from the upstream canonicalizer pass.
config.useTopDownTraversal = true;
config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Normal;

RewritePatternSet owningPatterns(context);
for (auto *dialect : context->getLoadedDialects())
dialect->getCanonicalizationPatterns(owningPatterns);
for (RegisteredOperationName op : context->getRegisteredOperations())
op.getCanonicalizationPatterns(owningPatterns, context);

patterns =
std::make_shared<FrozenRewritePatternSet>(std::move(owningPatterns));
return success();
}

void runOnOperation() override {
// Canonicalization is best-effort. Non-convergence is not a pass failure.
auto listener = ConfigTrackingListener();
config.listener = &listener;
LogicalResult didConverge =
applyPatternsAndFoldGreedily(getOperation(), *patterns, config);
if (this->testConvergence && failed(didConverge)) {
getOperation()->emitError("Canonicalizer failed to converge");
return signalPassFailure();
}
}
GreedyRewriteConfig config;
std::shared_ptr<const FrozenRewritePatternSet> patterns;
};

} // namespace
} // namespace mlir::iree_compiler
9 changes: 9 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,15 @@ def BufferizeCopyOnlyDispatchesPass :
}];
}

def ConfigTrackingCanonicalizerPass :
Pass<"iree-codegen-config-tracking-canonicalize", ""> {
let summary = "Codegen specific canonicalization pass that tracks lowering configs";
let options = [
Option<"testConvergence", "test-convergence", "bool",
/*default=*/"false", "Fails if the patterns fail to converge">
];
}

def CleanupBufferAllocViewPass :
InterfacePass<"iree-codegen-cleanup-buffer-alloc-view", "mlir::FunctionOpInterface"> {
let summary =
Expand Down
8 changes: 8 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,14 @@ struct OneShotBufferizationOptions;

namespace mlir::iree_compiler {

/// Common helper class for tracking lowering configs through pattern
/// applications.
class ConfigTrackingListener : public RewriterBase::Listener {
public:
ConfigTrackingListener() = default;
void notifyOperationReplaced(Operation *op, ValueRange replacement) override;
};

using IGEMMConfigFn =
std::function<LogicalResult(linalg::GenericOp, IREE::LinalgExt::Im2colOp)>;
using IGEMMControlFn = std::function<bool(Operation *)>;
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ static void addTileAndDistributePasses(OpPassManager &funcPassManager) {
funcPassManager.addPass(createConvertToDestinationPassingStylePass());
funcPassManager.addPass(createFoldAffineMinInDistributedLoopsPass());
}
funcPassManager.addPass(createCanonicalizerPass());
funcPassManager.addPass(createConfigTrackingCanonicalizerPass());
funcPassManager.addPass(createCSEPass());
funcPassManager.addPass(createFuseTensorPadWithConsumerPass());
funcPassManager.addPass(createConcretizePadResultShapePass());
Expand Down Expand Up @@ -425,7 +425,7 @@ void addMultiTilingExpertPassPipeline(OpPassManager &funcPassManager,
funcPassManager.addPass(createTensorToVectorVectorizePadPass());
if (pipelineOpt.decomposePackUnPackOps) {
funcPassManager.addPass(createDecomposePackUnPackOpsPass());
funcPassManager.addPass(createCanonicalizerPass());
funcPassManager.addPass(createConfigTrackingCanonicalizerPass());
funcPassManager.addPass(createCSEPass());
}

Expand Down
48 changes: 24 additions & 24 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ static void tileAndDistributeToWorkgroup(
// TODO(#16421): Disable decomposition due to failure in bufferization.
// funcPassManager.addPass(
// IREE::LinalgExt::createTileAndDecomposeAttentionPass());
funcPassManager.addPass(createCanonicalizerPass());
funcPassManager.addPass(createConfigTrackingCanonicalizerPass());
funcPassManager.addPass(createCSEPass());
}

Expand Down Expand Up @@ -238,13 +238,13 @@ static void addGPUVectorizationPasses(OpPassManager &funcPassManager,
void addGPUVectorizationPassPipeline(OpPassManager &funcPassManager) {
tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/false);

funcPassManager.addPass(createCanonicalizerPass());
funcPassManager.addPass(createCanonicalizerPass());
funcPassManager.addPass(createConfigTrackingCanonicalizerPass());
funcPassManager.addPass(createConfigTrackingCanonicalizerPass());
funcPassManager.addPass(createCSEPass());

// Distribute linalg onto threads within the workgroup.
funcPassManager.addPass(createGPUTensorTilePass());
funcPassManager.addPass(createCanonicalizerPass());
funcPassManager.addPass(createConfigTrackingCanonicalizerPass());
funcPassManager.addPass(createCSEPass());

// Linalg -> vector
Expand Down Expand Up @@ -365,7 +365,7 @@ void addGPUTileAndFusePassPipeline(OpPassManager &funcPassManager,
GPUApplyTilingLevelPassOptions options;
options.tilingLevel = IREE::GPU::TilingLevel::Reduction;
funcPassManager.addPass(createGPUApplyTilingLevelPass(options));
funcPassManager.addPass(createCanonicalizerPass());
funcPassManager.addPass(createConfigTrackingCanonicalizerPass());
funcPassManager.addPass(createCSEPass());
}

Expand All @@ -384,15 +384,15 @@ void addGPUTileAndFusePassPipeline(OpPassManager &funcPassManager,
}

funcPassManager.addPass(createPropagateReshapesByExpansionPass());
funcPassManager.addPass(createCanonicalizerPass());
funcPassManager.addPass(createConfigTrackingCanonicalizerPass());
funcPassManager.addPass(createCSEPass());

// Step 4. Tile and fuse tileable ops to subgroups/threads.
{
GPUApplyTilingLevelPassOptions options;
options.tilingLevel = IREE::GPU::TilingLevel::Thread;
funcPassManager.addPass(createGPUApplyTilingLevelPass(options));
funcPassManager.addPass(createCanonicalizerPass());
funcPassManager.addPass(createConfigTrackingCanonicalizerPass());
funcPassManager.addPass(createCSEPass());
}
{
Expand All @@ -406,7 +406,7 @@ void addGPUTileAndFusePassPipeline(OpPassManager &funcPassManager,
funcPassManager.addPass(iree_compiler::createNormalizeLoopBoundsPass(
NormalizeLoopBoundsPassOptions{/*normalizeFor=*/false,
/*normalizeForall=*/true}));
funcPassManager.addPass(createCanonicalizerPass());
funcPassManager.addPass(createConfigTrackingCanonicalizerPass());
funcPassManager.addPass(createCSEPass());

// TODO: This LICM instance is load bearing due to brittleness of the
Expand Down Expand Up @@ -489,13 +489,13 @@ void addGPUTileAndFusePassPipeline(OpPassManager &funcPassManager,
void addGPUWinogradVectorizePassPipeline(OpPassManager &funcPassManager) {
tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/false);

funcPassManager.addPass(createCanonicalizerPass());
funcPassManager.addPass(createCanonicalizerPass());
funcPassManager.addPass(createConfigTrackingCanonicalizerPass());
funcPassManager.addPass(createConfigTrackingCanonicalizerPass());
funcPassManager.addPass(createCSEPass());

// Distribute linalg onto threads within the workgroup.
funcPassManager.addPass(createGPUTilePass());
funcPassManager.addPass(createCanonicalizerPass());
funcPassManager.addPass(createConfigTrackingCanonicalizerPass());
funcPassManager.addPass(createCSEPass());
funcPassManager.addPass(
IREE::LinalgExt::createDecomposeWinogradTransformPass());
Expand All @@ -512,7 +512,7 @@ void addGPUWinogradVectorizePassPipeline(OpPassManager &funcPassManager) {
// Post bufferization optimizations.
funcPassManager.addPass(createIREELoopInvariantCodeMotionPass());
funcPassManager.addPass(memref::createFoldMemRefAliasOpsPass());
funcPassManager.addPass(createCanonicalizerPass());
funcPassManager.addPass(createConfigTrackingCanonicalizerPass());
funcPassManager.addPass(createCSEPass());
funcPassManager.addPass(createOptimizeVectorTransferPass());
funcPassManager.addPass(createOptimizeTensorInsertExtractSlicesPass());
Expand All @@ -526,8 +526,8 @@ void addGPUMatmulSimtPassPipeline(OpPassManager &funcPassManager,
const GPUPipelineOptions &options) {
tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/false);

funcPassManager.addPass(createCanonicalizerPass());
funcPassManager.addPass(createCanonicalizerPass());
funcPassManager.addPass(createConfigTrackingCanonicalizerPass());
funcPassManager.addPass(createConfigTrackingCanonicalizerPass());
funcPassManager.addPass(createCSEPass());

funcPassManager.addPass(createGPUTensorTileToSerialLoopsPass());
Expand Down Expand Up @@ -727,8 +727,8 @@ void addGPUTransposePassPipeline(OpPassManager &funcPassManager,
const GPUPipelineOptions &options) {
tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/false);

funcPassManager.addPass(createCanonicalizerPass());
funcPassManager.addPass(createCanonicalizerPass());
funcPassManager.addPass(createConfigTrackingCanonicalizerPass());
funcPassManager.addPass(createConfigTrackingCanonicalizerPass());
funcPassManager.addPass(createCSEPass());

funcPassManager.addPass(
Expand Down Expand Up @@ -844,7 +844,7 @@ void addGPUVectorDistributePassPipeline(OpPassManager &funcPassManager,
funcPassManager.addPass(
IREE::LinalgExt::createConvertAttentionToOnlineAttentionPass());

funcPassManager.addPass(createCanonicalizerPass());
funcPassManager.addPass(createConfigTrackingCanonicalizerPass());
funcPassManager.addPass(createCSEPass());
funcPassManager.addPass(createGPUPromoteMatmulOperandsPass());

Expand All @@ -855,12 +855,12 @@ void addGPUVectorDistributePassPipeline(OpPassManager &funcPassManager,
options.allowZeroSlices = true;
funcPassManager.addPass(createGPUApplyTilingLevelPass(options));
funcPassManager.addPass(affine::createLoopCoalescingPass());
funcPassManager.addPass(createCanonicalizerPass());
funcPassManager.addPass(createConfigTrackingCanonicalizerPass());
funcPassManager.addPass(createCSEPass());
}

funcPassManager.addPass(IREE::LinalgExt::createDecomposeAttentionPass());
funcPassManager.addPass(createCanonicalizerPass());
funcPassManager.addPass(createConfigTrackingCanonicalizerPass());
funcPassManager.addPass(createCSEPass());

// Set anchors at tensor level for vector distribution later and hoist out
Expand Down Expand Up @@ -927,9 +927,9 @@ void addGPUVectorDistributePassPipeline(OpPassManager &funcPassManager,
void addGPUWarpReductionPassPipeline(OpPassManager &funcPassManager) {
tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/false);
funcPassManager.addPass(createRematerializeParallelOpsPass());
funcPassManager.addPass(createCanonicalizerPass());
funcPassManager.addPass(createConfigTrackingCanonicalizerPass());
funcPassManager.addPass(createGPUTileReductionPass());
funcPassManager.addPass(createCanonicalizerPass());
funcPassManager.addPass(createConfigTrackingCanonicalizerPass());
funcPassManager.addPass(createCSEPass());

// Linalg -> vector
Expand Down Expand Up @@ -970,11 +970,11 @@ void addGPUWarpReductionPassPipeline(OpPassManager &funcPassManager) {

void addGPUPackUnPackPasses(OpPassManager &funcPassManager) {
tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/false);
funcPassManager.addPass(createCanonicalizerPass());
funcPassManager.addPass(createConfigTrackingCanonicalizerPass());
funcPassManager.addPass(createCSEPass());

funcPassManager.addPass(createGPUTensorTilePass());
funcPassManager.addPass(createCanonicalizerPass());
funcPassManager.addPass(createConfigTrackingCanonicalizerPass());
funcPassManager.addPass(createCSEPass());

funcPassManager.addPass(createDecomposePackUnPackOpsPass(
Expand Down Expand Up @@ -1165,7 +1165,7 @@ static void buildLLVMGPUCodegenConfigurationPassPipelineImpl(
addCommonTargetExecutablePreprocessingPasses(funcPassManager);
addEncodingToNopPasses(funcPassManager);
funcPassManager.addPass(createBlockDynamicDimensionsPass);
funcPassManager.addPass(createCanonicalizerPass);
funcPassManager.addPass(createConfigTrackingCanonicalizerPass);
funcPassManager.addPass(createCSEPass);
}
modulePassManager.addPass(createMaterializeUserConfigsPass());
Expand Down
Loading

0 comments on commit 11fe5cd

Please sign in to comment.