Skip to content

Commit

Permalink
[LinalgExt] Switch to new pass generation tablegen definitions. (#18216)
Browse files Browse the repository at this point in the history
The revision applies few cleanups:

- Remove outdated pass constructors in DecomposeAttention pass. The
`tileOnly` option is not used at all.
- Add dummy summery to ConvertAttentionToOnlineAttentionPass
- Switch namespaces to the single-line syntax for few passes.

Signed-off-by: hanhanW <[email protected]>
  • Loading branch information
hanhanW authored Aug 14, 2024
1 parent fe638b0 commit 7d60397
Show file tree
Hide file tree
Showing 14 changed files with 83 additions and 175 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ iree_compiler_cc_library(
"DecomposeIm2col.cpp",
"DecomposeWinogradPass.cpp",
"PadContractionToBlockSize.cpp",
"PassDetail.h",
"Passes.cpp",
"SplitReduction.cpp",
"TileAttention.cpp",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ iree_cc_library(
"DecomposeIm2col.cpp"
"DecomposeWinogradPass.cpp"
"PadContractionToBlockSize.cpp"
"PassDetail.h"
"Passes.cpp"
"SplitReduction.cpp"
"TileAttention.cpp"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree/compiler/Dialect/LinalgExt/Transforms/PassDetail.h"
#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
Expand All @@ -16,6 +15,9 @@

namespace mlir::iree_compiler::IREE::LinalgExt {

#define GEN_PASS_DEF_CONVERTCONV2DTOIM2COLOPPASS
#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h.inc"

static bool hasAllOneValues(DenseIntElementsAttr attr) {
return llvm::all_of(
attr, [](APInt element) { return element.getSExtValue() == 1; });
Expand Down Expand Up @@ -322,8 +324,8 @@ class ConvertConv2DNchwFchw final
ControlFnTy controlFn;
};

struct ConvertConv2DToIm2ColOpPass
: ConvertConv2DToIm2ColOpBase<ConvertConv2DToIm2ColOpPass> {
struct ConvertConv2DToIm2ColOpPass final
: impl::ConvertConv2DToIm2ColOpPassBase<ConvertConv2DToIm2ColOpPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<tensor::TensorDialect, IREELinalgExtDialect>();
}
Expand All @@ -345,9 +347,4 @@ void populateConv2DToIm2colOpPatterns(RewritePatternSet &patterns,
patterns.getContext(), std::move(controlFn));
}

std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createConvertConv2DToIm2ColOpPass() {
return std::make_unique<ConvertConv2DToIm2ColOpPass>();
}

} // namespace mlir::iree_compiler::IREE::LinalgExt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree/compiler/Dialect/LinalgExt/Transforms/PassDetail.h"
#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h"
#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
#include "iree/compiler/Dialect/LinalgExt/Utils/WinogradConstants.h"
Expand All @@ -27,6 +26,9 @@

namespace mlir::iree_compiler::IREE::LinalgExt {

#define GEN_PASS_DEF_CONVERTCONV2DTOWINOGRADPASS
#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h.inc"

static const char kWinogradAttr[] = "__winograd_conv";

static bool hasAllOneValues(DenseIntElementsAttr attr) {
Expand Down Expand Up @@ -403,8 +405,11 @@ class ConvertConvToWinograd final : public OpRewritePattern<ConvOp> {
/// }
/// }
/// ```
struct ConvertConv2DToWinogradPass
: ConvertConv2DToWinogradBase<ConvertConv2DToWinogradPass> {
struct ConvertConv2DToWinogradPass final
: impl::ConvertConv2DToWinogradPassBase<ConvertConv2DToWinogradPass> {
using impl::ConvertConv2DToWinogradPassBase<
ConvertConv2DToWinogradPass>::ConvertConv2DToWinogradPassBase;

void getDependentDialects(DialectRegistry &registry) const override {
registry
.insert<linalg::LinalgDialect, IREE::LinalgExt::IREELinalgExtDialect>();
Expand All @@ -423,10 +428,4 @@ struct ConvertConv2DToWinogradPass
};

} // namespace

std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createConvertConv2DToWinogradPass() {
return std::make_unique<ConvertConv2DToWinogradPass>();
}

} // namespace mlir::iree_compiler::IREE::LinalgExt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree/compiler/Dialect/LinalgExt/Transforms/PassDetail.h"
#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
Expand All @@ -23,9 +22,10 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

using namespace mlir;
namespace IREE = mlir::iree_compiler::IREE;
using namespace IREE::LinalgExt;
namespace mlir::iree_compiler::IREE::LinalgExt {

#define GEN_PASS_DEF_LINALGEXTTOLOOPSPASS
#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h.inc"

/// Recursive method that lowers one dimension of the `TiledOpInterface` to
/// scalar loops at a time.
Expand Down Expand Up @@ -100,8 +100,8 @@ struct TilingInterfaceLowerToLoopsPattern : public RewritePattern {
//===----------------------------------------------------------------------===//

namespace {
struct LinalgExtToLoopsPass
: public LinalgExtToLoopsBase<LinalgExtToLoopsPass> {
struct LinalgExtToLoopsPass final
: impl::LinalgExtToLoopsPassBase<LinalgExtToLoopsPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry
.insert<linalg::LinalgDialect, mlir::arith::ArithDialect,
Expand All @@ -120,8 +120,4 @@ struct LinalgExtToLoopsPass
}
};
} // namespace

std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
IREE::LinalgExt::createLinalgExtToLoopsPass() {
return std::make_unique<LinalgExtToLoopsPass>();
}
} // namespace mlir::iree_compiler::IREE::LinalgExt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree/compiler/Dialect/LinalgExt/Transforms/PassDetail.h"
#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
Expand All @@ -16,6 +15,9 @@

namespace mlir::iree_compiler::IREE::LinalgExt {

#define GEN_PASS_DEF_DECOMPOSEATTENTIONPASS
#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h.inc"

namespace {

// Computes a reduction along the rows of a 2d tensor of shape MxN
Expand Down Expand Up @@ -337,20 +339,16 @@ void decomposeTiledAttention(IREE::LinalgExt::AttentionOp tiledAttnOp,
}

namespace {
struct DecomposeAttentionPass
: public DecomposeAttentionBase<DecomposeAttentionPass> {
struct DecomposeAttentionPass final
: impl::DecomposeAttentionPassBase<DecomposeAttentionPass> {
using impl::DecomposeAttentionPassBase<
DecomposeAttentionPass>::DecomposeAttentionPassBase;

void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<
affine::AffineDialect, IREE::LinalgExt::IREELinalgExtDialect,
linalg::LinalgDialect, scf::SCFDialect, tensor::TensorDialect>();
}
DecomposeAttentionPass() = default;
DecomposeAttentionPass(bool onlyTile, uint64_t tileSize) {
this->tileSize = tileSize;
}
DecomposeAttentionPass(const DecomposeAttentionPass &pass) {
tileSize = pass.tileSize;
}
void runOnOperation() override;
};
} // namespace
Expand All @@ -377,9 +375,4 @@ void DecomposeAttentionPass::runOnOperation() {
rewriter.replaceOp(onlineAtt, results.value());
});
}

std::unique_ptr<Pass> createDecomposeAttentionPass() {
return std::make_unique<DecomposeAttentionPass>();
}

} // namespace mlir::iree_compiler::IREE::LinalgExt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree/compiler/Dialect/LinalgExt/Transforms/PassDetail.h"
#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
Expand All @@ -16,6 +15,10 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

namespace mlir::iree_compiler::IREE::LinalgExt {

#define GEN_PASS_DEF_DECOMPOSEIM2COLPASS
#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h.inc"

namespace {

/// Pattern to decompose the tiled im2col op.
Expand All @@ -37,7 +40,8 @@ struct DecomposeIm2col : public OpRewritePattern<Im2colOp> {
} // namespace

namespace {
struct DecomposeIm2colPass : public DecomposeIm2colBase<DecomposeIm2colPass> {
struct DecomposeIm2colPass final
: impl::DecomposeIm2colPassBase<DecomposeIm2colPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<
affine::AffineDialect, IREE::LinalgExt::IREELinalgExtDialect,
Expand All @@ -58,10 +62,4 @@ void DecomposeIm2colPass::runOnOperation() {
return signalPassFailure();
}
}

std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createDecomposeIm2colPass() {
return std::make_unique<DecomposeIm2colPass>();
}

} // namespace mlir::iree_compiler::IREE::LinalgExt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree/compiler/Dialect/LinalgExt/Transforms/PassDetail.h"
#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h"
#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
#include "iree/compiler/Dialect/LinalgExt/Utils/WinogradConstants.h"
Expand All @@ -24,6 +23,10 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

namespace mlir::iree_compiler::IREE::LinalgExt {

#define GEN_PASS_DEF_DECOMPOSEWINOGRADTRANSFORMPASS
#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h.inc"

namespace {

/// Pattern to remove unit dims from winograd ops after tililng. Tiling is
Expand Down Expand Up @@ -333,8 +336,8 @@ struct DecomposeWinogradOutputTransform
} // namespace

namespace {
struct DecomposeWinogradTransformPass
: public DecomposeWinogradTransformBase<DecomposeWinogradTransformPass> {
struct DecomposeWinogradTransformPass final
: impl::DecomposeWinogradTransformPassBase<DecomposeWinogradTransformPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<
affine::AffineDialect, IREE::LinalgExt::IREELinalgExtDialect,
Expand Down Expand Up @@ -363,9 +366,4 @@ void DecomposeWinogradTransformPass::runOnOperation() {
}
}

std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createDecomposeWinogradTransformPass() {
return std::make_unique<DecomposeWinogradTransformPass>();
}

} // namespace mlir::iree_compiler::IREE::LinalgExt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

#include "iree-dialects/Dialect/Input/InputDialect.h"
#include "iree-dialects/Dialect/Input/InputOps.h"
#include "iree/compiler/Dialect/LinalgExt/Transforms/PassDetail.h"
#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
Expand All @@ -16,9 +15,10 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

using namespace mlir;
namespace IREE = mlir::iree_compiler::IREE;
using namespace IREE::LinalgExt;
namespace mlir::iree_compiler::IREE::LinalgExt {

#define GEN_PASS_DEF_PADCONTRACTIONTOBLOCKSIZEPASS
#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h.inc"

static Operation *sliceTensor(Location loc, Value expanded, Value original,
OpBuilder &builder) {
Expand Down Expand Up @@ -87,8 +87,11 @@ static bool padTensor(Location loc, OpOperand *operand,

namespace {

struct PadContractionToBlockSizePass
: public PadContractionToBlockSizeBase<PadContractionToBlockSizePass> {
struct PadContractionToBlockSizePass final
: impl::PadContractionToBlockSizePassBase<PadContractionToBlockSizePass> {
using impl::PadContractionToBlockSizePassBase<
PadContractionToBlockSizePass>::PadContractionToBlockSizePassBase;

void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<IREE::Input::IREEInputDialect>();
}
Expand Down Expand Up @@ -126,8 +129,4 @@ struct PadContractionToBlockSizePass
}
};
} // namespace

std::unique_ptr<OperationPass<>>
IREE::LinalgExt::createPadContractionToBlockSizePass() {
return std::make_unique<PadContractionToBlockSizePass>();
}
} // namespace mlir::iree_compiler::IREE::LinalgExt

This file was deleted.

Loading

0 comments on commit 7d60397

Please sign in to comment.