Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ struct FuseEncodingOpsIntoDispatchRegionsPass final
for (IREE::Encoding::SetEncodingOp encodingOp : encodingOps) {
OpOperand &operand = encodingOp.getSourceMutable();
std::optional<std::pair<OpResult, SmallVector<Operation *>>>
producerChain = getProducerDispatchValueAndOpChain(operand.get());
producerChain = getProducerDispatchValueAndOpChain(
operand.get(), enableAggressiveFusion);
// Nothing to fuse with, so wrap the `encodingOp` in its own dispatch.
if (!producerChain) {
continue;
Expand Down
5 changes: 3 additions & 2 deletions compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ bool areFusableAsElementwiseOps(MLIRContext *context, OpOperand *fusedOperand,
}

std::optional<std::pair<OpResult, SmallVector<Operation *>>>
getProducerDispatchValueAndOpChain(Value operand) {
getProducerDispatchValueAndOpChain(Value operand, bool enableAggressiveFusion) {
auto operandType = dyn_cast<RankedTensorType>(operand.getType());
if (!operandType || operandType.getRank() == 0) {
return std::nullopt;
Expand Down Expand Up @@ -163,7 +163,8 @@ getProducerDispatchValueAndOpChain(Value operand) {
!llvm::hasSingleElement(producerDispatch.getBody())) {
return std::nullopt;
}
if (!llvm::hasSingleElement(producerValue.getUses())) {
if (!enableAggressiveFusion &&
!llvm::hasSingleElement(producerValue.getUses())) {
return std::nullopt;
}
return std::make_pair(producerValue, opChain);
Expand Down
3 changes: 2 additions & 1 deletion compiler/src/iree/compiler/DispatchCreation/FusionUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ bool areFusableAsElementwiseOps(MLIRContext *context, OpOperand *operand,
/// dispatch. Returns std::nullopt if the dispatch can not be found in the
/// chain or any op in the chain is not a reshape-like op.
std::optional<std::pair<OpResult, SmallVector<Operation *>>>
getProducerDispatchValueAndOpChain(Value operand);
getProducerDispatchValueAndOpChain(Value operand,
bool enableAggressiveFusion = false);

} // namespace mlir::iree_compiler::DispatchCreation
8 changes: 7 additions & 1 deletion compiler/src/iree/compiler/DispatchCreation/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,13 @@ static void addDispatchRegionCreationPasses(OpPassManager &passManager,
passManager.addPass(DispatchCreation::createHoistEncodingOpsPass());
}
FunctionLikeNest(passManager)
.addPass(DispatchCreation::createFuseEncodingOpsIntoDispatchRegionsPass)
.addPass([&]() {
FuseEncodingOpsIntoDispatchRegionsPassOptions passOptions;
passOptions.enableAggressiveFusion =
options.enableMultiUseEncodingFusion;
return DispatchCreation::createFuseEncodingOpsIntoDispatchRegionsPass(
passOptions);
})
.addPass(DispatchCreation::createConvertEncodingToFlowPass);
// Hoist encoding operations into initializers when possible.
if (options.constExprHoisting) {
Expand Down
7 changes: 7 additions & 0 deletions compiler/src/iree/compiler/DispatchCreation/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,13 @@ struct TransformOptions : public PassPipelineOptions<TransformOptions> {
llvm::cl::desc("Enable aggressive fusion for dispatch creation pipeline"),
llvm::cl::init(false),
};
Option<bool> enableMultiUseEncodingFusion{
*this,
"multi-use-encoding-fusion",
llvm::cl::desc(
"Enable encoding ops' fusion if the producer has more than one uses"),
llvm::cl::init(false),
};
Option<bool> enableFuseMultiUse{
*this,
"fuse-multi-use",
Expand Down
5 changes: 5 additions & 0 deletions compiler/src/iree/compiler/DispatchCreation/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,11 @@ def FuseEncodingOpsIntoDispatchRegionsPass :
"IREE::Flow::FlowDialect",
"IREE::Encoding::IREEEncodingDialect",
];
let options = [
Option<"enableAggressiveFusion", "enable-aggressive-fusion", "bool",
/*default=*/"false",
"Enable encoding ops' fusion if the producer has more than one uses">,
];
}

def HoistEncodingOpsPass : Pass<"iree-dispatch-creation-hoist-encoding-ops", "mlir::ModuleOp"> {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
// RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-dispatch-creation-fuse-encoding-ops-into-dispatch-regions-pass))" --split-input-file %s | FileCheck %s
// RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-dispatch-creation-fuse-encoding-ops-into-dispatch-regions-pass{enable-aggressive-fusion}))" --split-input-file %s | FileCheck %s --check-prefix=MULTI-USE

#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#encoding = #iree_encoding.testing<>
Expand Down Expand Up @@ -347,3 +348,45 @@ util.func public @multi_result_fusion(%arg0: tensor<123x456xf32>) -> (tensor<123
// CHECK: %[[SET_ENCODING:.+]] = iree_encoding.set_encoding %[[ELEM]]#1
// CHECK: flow.return %[[ELEM]]#0, %[[SET_ENCODING]]
// CHECK: util.return %[[DISPATCH0]]#0, %[[DISPATCH0]]#1

// -----

#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#encoding = #iree_encoding.testing<>
util.func public @multi_use_producer_fusion(%arg0: tensor<2x11008x128xf32>) -> (tensor<2x11008x128xf32>, tensor<2x11008x128xf32, #encoding>) {
%cst = arith.constant 0.000000e+00 : f32
%0 = tensor.empty() : tensor<2x11008x128xf32>
%1 = flow.dispatch.region -> (tensor<2x11008x128xf32>) {
%3 = linalg.generic {
indexing_maps = [#map, #map, #map],
iterator_types = ["parallel", "parallel", "parallel"]}
ins(%arg0, %arg0 : tensor<2x11008x128xf32>, tensor<2x11008x128xf32>)
outs(%0 : tensor<2x11008x128xf32>) {
^bb0(%in: f32, %in_0: f32, %out: f32):
%4 = arith.addf %in, %in_0 : f32
linalg.yield %4 : f32
} -> tensor<2x11008x128xf32>
flow.return %3 : tensor<2x11008x128xf32>
}
%2 = iree_encoding.set_encoding %1 : tensor<2x11008x128xf32> -> tensor<2x11008x128xf32, #encoding>
util.return %1, %2 : tensor<2x11008x128xf32>, tensor<2x11008x128xf32, #encoding>
}
// CHECK-DAG: #[[$ENCODING:.+]] = #iree_encoding.testing<>
// CHECK-LABEL: @multi_use_producer_fusion
// CHECK: %[[DISPATCH0:.+]] = flow.dispatch.region -> (tensor<2x11008x128xf32>)
// CHECK: %[[ADD:.+]] = linalg.generic
// CHECK: flow.return %[[ADD]] :
// CHECK: }
// CHECK: %[[ENCODING:.+]] = iree_encoding.set_encoding %[[DISPATCH0]]
// CHECK: util.return %[[DISPATCH0]], %[[ENCODING:.+]] : tensor<2x11008x128xf32>, tensor<2x11008x128xf32, #[[$ENCODING]]>


// MULTI-USE-DAG: #[[$ENCODING:.+]] = #iree_encoding.testing<>
// MULTI-USE-LABEL: @multi_use_producer_fusion
// MULTI-USE: %[[DISPATCH:.+]]:2 = flow.dispatch.region -> (tensor<2x11008x128xf32>, tensor<2x11008x128xf32, #[[$ENCODING]]>)
// MULTI-USE: %[[ADD:.+]] = linalg.generic
// MULTI-USE: %[[ENCODING:.+]] = iree_encoding.set_encoding %[[ADD]]
// MULTI-USE: flow.return %[[ADD]], %[[ENCODING]] :
// MULTI-USE: }
// MULTI-USE-NOT: iree_encoding.set_encoding
// MULTI-USE: util.return %[[DISPATCH]]#0, %[[DISPATCH]]#1 : tensor<2x11008x128xf32>, tensor<2x11008x128xf32, #[[$ENCODING]]>
9 changes: 9 additions & 0 deletions compiler/src/iree/compiler/Pipelines/Options.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,15 @@ void DispatchCreationOptions::bindOptions(OptionsBinder &binder) {
llvm::cl::desc("Aggressive fusion opportunities that are behind a flag "
"since all backends dont support it yet"),
llvm::cl::cat(category));
binder.opt<bool>(
"iree-dispatch-creation-enable-multi-use-encoding-fusion",
enableMultiUseEncodingFusion,
{init_at_opt(llvm::OptimizationLevel::O0, false),
init_at_opt(llvm::OptimizationLevel::O2, false),
init_at_opt(llvm::OptimizationLevel::O3, true)},
llvm::cl::desc(
"Enable encoding ops' fusion if the producer has more than one uses"),
llvm::cl::cat(category));
binder.opt<bool>("iree-dispatch-creation-fuse-multi-use", enableFuseMultiUse,
llvm::cl::desc("Fuse operations with multiple uses."),
llvm::cl::cat(category));
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Pipelines/Options.h
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ struct DispatchCreationOptions {
llvm::OptimizationLevel optLevel;

bool enableAggressiveFusion = false;
bool enableMultiUseEncodingFusion = false;
bool enableFuseMultiUse = true;
bool enableSplitReduction = false;

Expand Down
Loading