-
Notifications
You must be signed in to change notification settings - Fork 612
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add pass to bubble-up extract_slice operations. (#18332)
This adds pass to replace a `tensor.extract_slice` operation with a slice of the producer. In general there might be more opportunities to use this pass more aggressively (like when an operation has a single use which is a slice), but for now this is being done only for bit-extend operations. Co-authored-by: Ian Wood <[email protected]>
- Loading branch information
1 parent
6e3be28
commit d6762d4
Showing
8 changed files
with
242 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
134 changes: 134 additions & 0 deletions
134
compiler/src/iree/compiler/DispatchCreation/BubbleUpExtractSlices.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
// Copyright 2024 The IREE Authors | ||
// | ||
// Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h" | ||
#include "llvm/ADT/STLExtras.h" | ||
#include "mlir/Dialect/Affine/IR/AffineOps.h" | ||
#include "mlir/Dialect/Linalg/IR/Linalg.h" | ||
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" | ||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h" | ||
#include "mlir/Dialect/Linalg/Utils/Utils.h" | ||
#include "mlir/Dialect/Tensor/IR/Tensor.h" | ||
#include "mlir/Dialect/Tensor/Transforms/Transforms.h" | ||
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h" | ||
#include "mlir/Pass/Pass.h" | ||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" | ||
|
||
namespace mlir::iree_compiler::DispatchCreation { | ||
|
||
#define GEN_PASS_DEF_BUBBLEUPEXTRACTSLICESPASS | ||
#include "iree/compiler/DispatchCreation/Passes.h.inc" | ||
|
||
namespace { | ||
|
||
// Convert extract_slice(dequant) to dequant(extract_slice) | ||
// | ||
// Because `extract_slice` ops and dequantize-like ops get cloned into regions | ||
// later, it's okay to bubble up through multi-use dequant ops. | ||
struct BubbleUpExtract : OpRewritePattern<tensor::ExtractSliceOp> { | ||
using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern; | ||
|
||
LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp, | ||
PatternRewriter &rewriter) const final { | ||
Value source = sliceOp.getSource(); | ||
auto genericOp = source.getDefiningOp<linalg::GenericOp>(); | ||
if (!genericOp || genericOp->getNumResults() != 1) { | ||
return rewriter.notifyMatchFailure( | ||
sliceOp, "expected source to implement `linalg::LinalgOp` and have a " | ||
"single result"); | ||
} | ||
|
||
if (!IREE::LinalgExt::isBitExtendOp(genericOp)) { | ||
return rewriter.notifyMatchFailure( | ||
sliceOp, "expected source to be dequantize-like"); | ||
} | ||
|
||
if (!sliceOp.hasUnitStride()) { | ||
return rewriter.notifyMatchFailure(sliceOp, "expected unit stride"); | ||
} | ||
|
||
if (!llvm::all_of(genericOp.getIndexingMapsArray(), [](AffineMap map) { | ||
return map.isProjectedPermutation(); | ||
})) { | ||
return rewriter.notifyMatchFailure( | ||
genericOp, | ||
"expected generic op to have all projected permutation maps"); | ||
} | ||
|
||
if (genericOp.hasIndexSemantics()) { | ||
return rewriter.notifyMatchFailure( | ||
genericOp, "pattern doesn't support index semantics"); | ||
} | ||
|
||
Value replacement; | ||
linalg::GenericOp swappedOp; | ||
{ | ||
FailureOr<TilingResult> tilingResult = | ||
tensor::replaceExtractSliceWithTiledProducer(rewriter, sliceOp, | ||
genericOp->getResult(0)); | ||
assert(succeeded(tilingResult) && "failed to swap extract_slice with op"); | ||
assert(tilingResult->tiledOps.size() == 1); | ||
replacement = tilingResult->tiledValues[0]; | ||
swappedOp = cast<linalg::GenericOp>(tilingResult->tiledOps[0]); | ||
} | ||
|
||
// Check if this is a rank-reducing slice, if so we need to fold the unit | ||
// dimensions of the op. | ||
// This is necessary because `replaceExtractSliceWithTiledProducer` does not | ||
// take into account the `extract_slice`'s implicit rank reduction. The | ||
// operations generated by that function will have any unit dims that were | ||
// removed by the original `extract_slice`. Folding them away ensures that | ||
// the types match. | ||
if (sliceOp.getSourceType().getRank() != | ||
sliceOp.getResultType().getRank()) { | ||
|
||
llvm::SmallBitVector droppedDims = sliceOp.getDroppedDims(); | ||
// Get the indexing map for the result. | ||
AffineMap resultMap = | ||
swappedOp.getIndexingMapMatchingResult(swappedOp->getResult(0)); | ||
linalg::ControlDropUnitDims options; | ||
options.rankReductionStrategy = linalg::ControlDropUnitDims:: | ||
RankReductionStrategy::ExtractInsertSlice; | ||
options.controlFn = [&](Operation *op) -> SmallVector<unsigned> { | ||
SmallVector<unsigned> droppedDimsVec; | ||
for (auto [index, expr] : llvm::enumerate(resultMap.getResults())) { | ||
if (!droppedDims.test(index)) { | ||
continue; | ||
} | ||
auto dimExpr = cast<AffineDimExpr>(expr); | ||
droppedDimsVec.push_back(dimExpr.getPosition()); | ||
} | ||
return droppedDimsVec; | ||
}; | ||
FailureOr<linalg::DropUnitDimsResult> dropUnitDims = | ||
linalg::dropUnitDims(rewriter, swappedOp, options); | ||
assert(succeeded(dropUnitDims) && | ||
"failed to drop unit dims of produced operation"); | ||
swappedOp = dropUnitDims->resultOp; | ||
replacement = swappedOp->getResult(0); | ||
} | ||
rewriter.replaceOp(sliceOp, replacement); | ||
return success(); | ||
} | ||
}; | ||
|
||
struct BubbleUpExtractSlicesPass | ||
: impl::BubbleUpExtractSlicesPassBase<BubbleUpExtractSlicesPass> { | ||
void runOnOperation() override { | ||
MLIRContext *context = &getContext(); | ||
{ | ||
RewritePatternSet patterns(context); | ||
patterns.insert<BubbleUpExtract>(context); | ||
if (failed(applyPatternsAndFoldGreedily(getOperation(), | ||
std::move(patterns)))) { | ||
return signalPassFailure(); | ||
} | ||
} | ||
} | ||
}; | ||
} // namespace | ||
|
||
} // namespace mlir::iree_compiler::DispatchCreation |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
96 changes: 96 additions & 0 deletions
96
compiler/src/iree/compiler/DispatchCreation/test/bubble_up_extract_slice.mlir
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
// RUN: iree-opt --split-input-file --iree-dispatch-creation-bubble-up-extract-slices --iree-flow-canonicalize %s | FileCheck %s | ||
|
||
util.func public @bubble_up_extract_rank_reduce(%arg0 : tensor<1024x7x7x2xi8>) -> tensor<1024x7x7xf32>{ | ||
%0 = tensor.empty() : tensor<1024x7x7x2xf32> | ||
%cst = arith.constant 5.000000e-01 : f32 | ||
%1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<1024x7x7x2xi8>) outs(%0 : tensor<1024x7x7x2xf32>) { | ||
^bb0(%in: i8, %out: f32): | ||
%4 = arith.extsi %in : i8 to i32 | ||
%5 = arith.sitofp %4 : i32 to f32 | ||
%6 = arith.mulf %5, %cst : f32 | ||
linalg.yield %6 : f32 | ||
} -> tensor<1024x7x7x2xf32> | ||
|
||
%extracted_slice = tensor.extract_slice %1[0, 0, 0, 1] [1024, 7, 7, 1] [1, 1, 1, 1] : tensor<1024x7x7x2xf32> to tensor<1024x7x7xf32> | ||
util.return %extracted_slice : tensor<1024x7x7xf32> | ||
} | ||
|
||
// CHECK-LABEL: @bubble_up_extract_rank_reduce | ||
// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice | ||
// CHECK: %[[GENERIC:.+]] = linalg.generic | ||
// CHECK: util.return %[[GENERIC]] | ||
|
||
// ----- | ||
|
||
util.func public @bubble_up_extract(%arg0 : tensor<1024x7x7x2xi8>) -> tensor<1024x7x7x1xf32>{ | ||
%0 = tensor.empty() : tensor<1024x7x7x2xf32> | ||
%cst = arith.constant 5.000000e-01 : f32 | ||
%1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<1024x7x7x2xi8>) outs(%0 : tensor<1024x7x7x2xf32>) { | ||
^bb0(%in: i8, %out: f32): | ||
%4 = arith.extsi %in : i8 to i32 | ||
%5 = arith.sitofp %4 : i32 to f32 | ||
%6 = arith.mulf %5, %cst : f32 | ||
linalg.yield %6 : f32 | ||
} -> tensor<1024x7x7x2xf32> | ||
|
||
%extracted_slice = tensor.extract_slice %1[0, 0, 0, 1] [1024, 7, 7, 1] [1, 1, 1, 1] : tensor<1024x7x7x2xf32> to tensor<1024x7x7x1xf32> | ||
util.return %extracted_slice : tensor<1024x7x7x1xf32> | ||
} | ||
|
||
// CHECK-LABEL: @bubble_up_extract | ||
// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice | ||
// CHECK: %[[GENERIC:.+]] = linalg.generic | ||
// CHECK: util.return %[[GENERIC]] | ||
|
||
// ----- | ||
|
||
util.func public @bubble_up_extract_multi_input(%arg0 : tensor<1024x7x7x2xi8>, %arg1 : tensor<1024x7x7x2xi8>) -> tensor<1024x7x7x1xf32>{ | ||
%0 = tensor.empty() : tensor<1024x7x7x2xf32> | ||
%cst = arith.constant 5.000000e-01 : f32 | ||
%1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0, %arg1 : tensor<1024x7x7x2xi8>, tensor<1024x7x7x2xi8>) outs(%0 : tensor<1024x7x7x2xf32>) { | ||
^bb0(%in: i8, %in_0 : i8, %out: f32): | ||
%4 = arith.extsi %in : i8 to i32 | ||
%5 = arith.sitofp %4 : i32 to f32 | ||
%6 = arith.mulf %5, %cst : f32 | ||
linalg.yield %6 : f32 | ||
} -> tensor<1024x7x7x2xf32> | ||
|
||
%extracted_slice = tensor.extract_slice %1[0, 0, 0, 1] [1024, 7, 7, 1] [1, 1, 1, 1] : tensor<1024x7x7x2xf32> to tensor<1024x7x7x1xf32> | ||
util.return %extracted_slice : tensor<1024x7x7x1xf32> | ||
} | ||
|
||
// CHECK-LABEL: @bubble_up_extract_multi_input | ||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] | ||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] | ||
// CHECK-DAG: %[[EXTRACT0:.+]] = tensor.extract_slice %[[ARG0]] | ||
// CHECK-DAG: %[[EXTRACT1:.+]] = tensor.extract_slice %[[ARG1]] | ||
// CHECK: %[[GENERIC:.+]] = linalg.generic | ||
// CHECK-SAME: ins(%[[EXTRACT0]], %[[EXTRACT1]] : tensor<1024x7x7x1xi8>, tensor<1024x7x7x1xi8>) | ||
// CHECK: util.return %[[GENERIC]] | ||
|
||
// ----- | ||
|
||
util.func public @bubble_up_extract_with_use(%arg0 : tensor<1024x7x7x2xi8>) -> (tensor<1024x7x7xf32>, tensor<1024x7x7x2xf32>) { | ||
%0 = tensor.empty() : tensor<1024x7x7x2xf32> | ||
%cst = arith.constant 5.000000e-01 : f32 | ||
%1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<1024x7x7x2xi8>) outs(%0 : tensor<1024x7x7x2xf32>) { | ||
^bb0(%in: i8, %out: f32): | ||
%4 = arith.extsi %in : i8 to i32 | ||
%5 = arith.sitofp %4 : i32 to f32 | ||
%6 = arith.mulf %5, %cst : f32 | ||
linalg.yield %6 : f32 | ||
} -> tensor<1024x7x7x2xf32> | ||
|
||
%extracted_slice = tensor.extract_slice %1[0, 0, 0, 1] [1024, 7, 7, 1] [1, 1, 1, 1] : tensor<1024x7x7x2xf32> to tensor<1024x7x7xf32> | ||
util.return %extracted_slice, %1 : tensor<1024x7x7xf32>, tensor<1024x7x7x2xf32> | ||
} | ||
|
||
// CHECK-LABEL: @bubble_up_extract_with_use | ||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] | ||
// CHECK-DAG: %[[GENERIC0:.+]] = linalg.generic | ||
// CHECK-SAME: ins(%[[ARG0]] : tensor<1024x7x7x2xi8>) | ||
// | ||
// CHECK-DAG: %[[EXTRACT0:.+]] = tensor.extract_slice %[[ARG0]] | ||
// CHECK-DAG: %[[GENERIC1:.+]] = linalg.generic | ||
// CHECK-SAME: ins(%[[EXTRACT0]] : tensor<1024x7x7xi8>) | ||
// CHECK: util.return %[[GENERIC1]], %[[GENERIC0]] |