Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
Expand All @@ -43,6 +44,7 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/RegionUtils.h"

#define DEBUG_TYPE "iree-dispatch-creation-form-dispatch-regions"

Expand Down Expand Up @@ -145,6 +147,49 @@ class FusionGroup {
// Insert `op` into the fusion group.
void insert(Operation *op);

/// Returns true if `consumerOp` has a transitive dependency on the fusion
/// group. This means that some transitive dependency of `consumerOp` (not in
/// the fusion group) itself uses an operation in the fusion group. This is
/// required for fusion because it must be legal to take a program slice that
/// contains only the ops in the fusion group.
bool hasTransitiveDependencyOnFusionGroup(Operation *consumerOp) const {
BackwardSliceOptions options;
DominanceInfo dominance(consumerOp);
options.inclusive = true;
options.omitUsesFromAbove = false;
options.omitBlockArguments = true;
options.filter = [&](Operation *sliceBoundaryOp) {
return !llvm::all_of(
loopMaps.getArrayRef(), [&](std::pair<Operation *, AffineMap> pair) {
return dominance.properlyDominates(sliceBoundaryOp, pair.first);
});
};

llvm::SetVector<Operation *> slice;
auto populateSlice = [&](OpOperand *operand) {
// It's okay if the consumer directly uses an operation in the fusion
// group.
if (loopMaps.contains(operand->get().getDefiningOp())) {
return;
}
LogicalResult result = getBackwardSlice(operand->get(), &slice, options);
assert(result.succeeded() && "expected a backward slice");
(void)result;
};

// Search all of the operands op `consumerOp` as well as all the values used
// in its regions.
mlir::visitUsedValuesDefinedAbove(consumerOp->getRegions(), populateSlice);
for (OpOperand &operand : consumerOp->getOpOperands()) {
populateSlice(&operand);
}

return llvm::any_of(loopMaps.getArrayRef(),
[&](std::pair<Operation *, AffineMap> pair) {
return slice.contains(pair.first);
});
}

private:
Operation *rootOp;
// All operations to be fused with the root op. This does not include
Expand Down Expand Up @@ -667,6 +712,12 @@ fuseRootsWithConsumers(MLIRContext *context, ArrayRef<Operation *> roots,
continue;
}

// Ensure that fusing the consumer would not cause use-def violations.
if (tracker.getFusionGroup(currRoot)
.hasTransitiveDependencyOnFusionGroup(fusableUse->getOwner())) {
continue;
}

if (isFusableWithConsumer(*fusableUse, tracker, options)) {
tracker.appendToFusionGroup(consumerOp, fusionGroup);
workList.push_back(consumerOp);
Expand Down Expand Up @@ -974,7 +1025,7 @@ createFusionGroups(TensorDimTrackingRewriter &rewriter,
auto newRegionOp = IREE::Flow::moveFollowingOpIntoDispatchRegion(
rewriter, consumer, regionOp);
if (failed(newRegionOp)) {
continue;
return consumer->emitOpError("failed to move consumer into region");
}
regionOp = *newRegionOp;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1391,15 +1391,16 @@ util.func public @avoid_illegal_consumer_fusion(%arg0: tensor<75600x5120xbf16>)
util.return %6 : tensor<75600x1x5120xbf16>
}
// CHECK-LABEL: @avoid_illegal_consumer_fusion(
// CHECK: %[[DISPATCH:.+]]:2 = flow.dispatch.region
// CHECK: %[[DISPATCH0:.+]]:2 = flow.dispatch.region
// CHECK: %[[GENERIC0:.+]] = linalg.generic
// CHECK: %[[GENERIC1:.+]] = linalg.generic
// CHECK-SAME: ins(%[[GENERIC0]] :
// CHECK: flow.return %[[GENERIC1]], %[[GENERIC0]]
// CHECK: %[[EXPAND_SHAPE:.+]] = tensor.expand_shape %[[DISPATCH]]#1
// CHECK: %[[EXPAND_SHAPE:.+]] = tensor.expand_shape %[[DISPATCH0]]#1
// CHECK: %[[DISPATCH1:.+]] = flow.dispatch.region
// CHECK: %[[GENERIC2:.+]] = linalg.generic
// CHECK-SAME: ins(%[[EXPAND_SHAPE]], %[[DISPATCH]]#0 :
// CHECK: util.return %[[GENERIC2]]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note for reviewers: this checks that hasTransitiveDependencyOnFusionGroup prevents the fusion since the old check will now error out.

// CHECK-SAME: ins(%[[EXPAND_SHAPE]], %[[DISPATCH0]]#0 :
// CHECK: util.return %[[DISPATCH1]]

// -----

Expand Down
Loading