Skip to content

Commit

Permalink
Hoist collapse shape out of scf.forall when possible and expand its d…
Browse files Browse the repository at this point in the history
…estination (#19044)

This pattern allows to hoist out collapse shape producer of
tensor.parallel_insert_slice and expand the enclosing scf.forall
destination. This is only safe because we are doing this on workgroup
mapped sc.forall's and hence the slices are disjoint. The reason to have
this pattern is that it allows us to eliminate empty tensors during
bufferization which would otherwise be blocked due to the collapse.

---------

Signed-off-by: Nirvedh <[email protected]>
  • Loading branch information
nirvedhmeshram authored Nov 13, 2024
1 parent 11f0099 commit f3c1467
Show file tree
Hide file tree
Showing 5 changed files with 562 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
#include "iree/compiler/Codegen/Common/Passes.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
#include "iree/compiler/Codegen/Transforms/Transforms.h"
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
#include "iree/compiler/Codegen/Utils/Utils.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

Expand All @@ -17,6 +20,282 @@ namespace mlir::iree_compiler {

namespace {

/// Calculate the expanded shape of `dest` if it can be expanded with the inner
/// expanded sizes of `sliceStaticSizes`. Returns failure if such expansion is
/// not possible.
static LogicalResult
getExpandedShape(SmallVector<ReassociationIndices> reIndices,
ArrayRef<int64_t> sliceStaticSizes, Value dest,
SmallVectorImpl<int64_t> &expandedShape,
SmallVectorImpl<int64_t> &totalInnerSizes) {
auto destType = dyn_cast<ShapedType>(dest.getType());
if (!destType)
return failure();
// TODO (nirvedhmeshram): Support rank reducing parallel_insert_slice.
if (reIndices.size() != destType.getShape().size())
return failure();
// Iterator to insert outer sizes.
auto outerShapeIter = expandedShape.begin();
for (auto [reassociations, destSize] :
llvm::zip_equal(reIndices, destType.getShape())) {
// Dynamic destination dims that are not getting expanded are allowed.
if (ShapedType::isDynamic(destSize) && reassociations.size() == 1) {
expandedShape.insert(outerShapeIter++, destSize);
totalInnerSizes.push_back(1);
continue;
}
// Dynamic destination dims that are expanded are currently unsupported but
// this support can be added if needed.
if (ShapedType::isDynamic(destSize)) {
return failure();
}
int64_t totalInnerSize = 1;
for (int64_t reasociation : llvm::drop_begin(reassociations)) {
int64_t expandedInnerSize = sliceStaticSizes[reasociation];
// It is not safe to do this pattern if inner dimensions are dynamic.
if (ShapedType::isDynamic(expandedInnerSize))
return failure();
expandedShape.push_back(expandedInnerSize);
totalInnerSize *= expandedInnerSize;
}
if (destSize % totalInnerSize != 0)
return failure();
totalInnerSizes.push_back(totalInnerSize);
// insert the outer size in front of any inner sizes.
expandedShape.insert(outerShapeIter, destSize / totalInnerSize);
// set up the iterator for the next uncollapsed dimension.
outerShapeIter = expandedShape.end();
}
return success();
}

/// Check if the users of the expanded scf.forall destination can be updated to
/// account for the expand. If not we bail out. There are two supported users
/// which are extract_slice -> expand_shape with the same exact reassociation
/// map as the collapse op to be hoisted out or the root parallel_insert_slice.
static LogicalResult
verifyAndCollectExpandableUsers(Value insertDest,
SmallVector<ReassociationIndices> reIndices,
tensor::ParallelInsertSliceOp parallelInsertOp,
SmallVector<Operation *> &expandableUsers) {
for (Operation *user : insertDest.getUsers()) {
if (user == parallelInsertOp) {
expandableUsers.push_back(user);
continue;
}
auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
if (!extractSliceOp)
return failure();
if (extractSliceOp.getMixedSizes() != parallelInsertOp.getMixedSizes())
return failure();
if (extractSliceOp.getMixedOffsets() != parallelInsertOp.getMixedOffsets())
return failure();
auto expandShapeOp =
dyn_cast<tensor::ExpandShapeOp>(*extractSliceOp->getUsers().begin());
if (!expandShapeOp)
return failure();
SmallVector<ReassociationIndices> expandReIndices =
expandShapeOp.getReassociationIndices();
if (reIndices != expandReIndices)
return failure();
expandableUsers.push_back(user);
}
return success();
}

/// Utility to expand the pre-verified expandable users of the scf.forall
/// output.
static void expandVerifiedUsers(PatternRewriter &rewriter, Location loc,
MLIRContext *ctx,
SmallVector<Operation *> expandableUsers,
SmallVector<int64_t> totalInnerSizes,
SmallVector<ReassociationIndices> reIndices,
scf::ForallOp forallOp) {
// compute the offsets,sizes,strides in the expanded dimensions.
auto computeExpandedAccess = [&](ArrayRef<OpFoldResult> mixedOffsets,
ShapedType resultType)
-> std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>,
SmallVector<OpFoldResult>> {
SmallVector<OpFoldResult> expandedOffsets;
auto expandedOffsetsIter = expandedOffsets.begin();

for (auto [index, offset] : llvm::enumerate(mixedOffsets)) {
// Add zero offsets for the extra dimensions from reIndices.
for (size_t i = 1, e = reIndices[index].size(); i < e; ++i) {
expandedOffsets.push_back(getAsIndexOpFoldResult(ctx, 0));
}
// Compute the outer dimension expression.
AffineExpr s0, s1;
bindSymbols(rewriter.getContext(), s0, s1);
AffineExpr outerDimExpr = (s0).floorDiv(s1);
// Insert computed offset using affine expression.
expandedOffsets.insert(
expandedOffsetsIter,
affine::makeComposedFoldedAffineApply(
rewriter, loc, outerDimExpr,
{getValueOrCreateConstantIndexOp(rewriter, loc, offset),
rewriter.getIndexAttr(totalInnerSizes[index])}));

expandedOffsetsIter = expandedOffsets.end();
}
SmallVector<OpFoldResult> expandedSizes =
getAsIndexOpFoldResult(ctx, resultType.getShape());
SmallVector<OpFoldResult> expandedStrides(resultType.getRank(),
rewriter.getIndexAttr(1));
return {expandedOffsets, expandedSizes, expandedStrides};
};
for (Operation *user : expandableUsers) {
rewriter.setInsertionPointToStart(forallOp.getBody());
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(user)) {
auto expandShapeOp =
dyn_cast<tensor::ExpandShapeOp>(*extractSliceOp->getUsers().begin());
RankedTensorType resultType = expandShapeOp.getResultType();
auto [expandedOffsets, expandedSizes, expandedStrides] =
computeExpandedAccess(extractSliceOp.getMixedOffsets(), resultType);
rewriter.setInsertionPoint(extractSliceOp);
rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
extractSliceOp, resultType, extractSliceOp.getSource(),
expandedOffsets, expandedSizes, expandedStrides);
} else if (auto parallelInsertOp =
dyn_cast<tensor::ParallelInsertSliceOp>(user)) {
auto collapseShapeOp =
parallelInsertOp.getSource().getDefiningOp<tensor::CollapseShapeOp>();
RankedTensorType resultType = collapseShapeOp.getSrcType();
auto [expandedOffsets, expandedSizes, expandedStrides] =
computeExpandedAccess(parallelInsertOp.getMixedOffsets(), resultType);
rewriter.setInsertionPoint(parallelInsertOp);
rewriter.replaceOpWithNewOp<tensor::ParallelInsertSliceOp>(
parallelInsertOp, collapseShapeOp.getSrc(),
parallelInsertOp.getDest(), expandedOffsets, expandedSizes,
expandedStrides);
}
}
return;
}

/// This pattern expands destination of workgroup mapped scf.foralls by
/// hoisting out collapse_shape op consumed by its parallel.insert_slice op.
struct ExpandDestinationForallOp final
: OpRewritePattern<tensor::ParallelInsertSliceOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(tensor::ParallelInsertSliceOp parallelInsertOp,
PatternRewriter &rewriter) const override {
Location loc = parallelInsertOp.getLoc();
MLIRContext *ctx = getContext();
auto collapseOp =
parallelInsertOp.getSource().getDefiningOp<tensor::CollapseShapeOp>();
// No collapse op to hoist out.
if (!collapseOp)
return failure();

// Ignore trivially foldable collapse ops.
if (collapseOp.getSrcType().getRank() ==
collapseOp.getResultType().getRank()) {
return failure();
}

// Get the destination to expand.
Value insertDest = parallelInsertOp.getDest();

// Get the enclosing scf.forall op.
OpResult tiedResult = parallelInsertOp.getTiedOpResult();
int64_t tiedResultIdx = tiedResult.getResultNumber();

auto forallOp = dyn_cast<scf::ForallOp>(tiedResult.getOwner());
if (!forallOp)
return failure();

// We only want this pattern if the forall op result is being written to a
// full slice. Otherwise the hoisted collapse op is not foldable.
for (Operation *foralluser : tiedResult.getUsers()) {
auto storeOp = dyn_cast<IREE::Flow::DispatchTensorStoreOp>(foralluser);
if (!storeOp)
return failure();
if (!isFullSlice(storeOp, storeOp.getTargetType(),
storeOp.getTargetDims())) {
return failure();
}
}

// This allows us to assume that the extract/inserts in the loop are
// disjoint and makes the application of this pattern safe.
if (!forallOpHasMappingType<IREE::Codegen::WorkgroupMappingAttr>(
forallOp)) {
return failure();
}
// This pattern only supports forall ops with single
// output.
SmallVector<Value> forallOutputs(forallOp.getOutputs());

SmallVector<ReassociationIndices> reIndices =
collapseOp.getReassociationIndices();
SmallVector<int64_t> expandedDestShape;
SmallVector<int64_t> totalInnerSizes;
// Get the shape of the outer expand which will be the new destination
// of the scf.forall and the total size of inner dimensions per uncollapsed
// dimension.
if (failed(getExpandedShape(reIndices, collapseOp.getSrcType().getShape(),
insertDest, expandedDestShape,
totalInnerSizes))) {
return failure();
}

// Verify that the users of destination are valid to expand and collect all
// such users.
SmallVector<Operation *> expandableUsers;
if (failed(verifyAndCollectExpandableUsers(
insertDest, collapseOp.getReassociationIndices(), parallelInsertOp,
expandableUsers))) {
return failure();
}

// Expand the users of the destination.
rewriter.setInsertionPointToStart(forallOp.getBody());
expandVerifiedUsers(rewriter, loc, ctx, expandableUsers, totalInnerSizes,
reIndices, forallOp);
rewriter.setInsertionPoint(forallOp);

// Create the expand -> new scf.forall -> collapse chain.
auto expandedDestType =
cast<RankedTensorType>(forallOutputs[tiedResultIdx].getType())
.clone(expandedDestShape);
auto expandedDest = rewriter.create<tensor::ExpandShapeOp>(
loc, expandedDestType, forallOutputs[tiedResultIdx], reIndices);

forallOutputs[tiedResultIdx] = expandedDest;

scf::ForallOp newForallOp = rewriter.create<scf::ForallOp>(
loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
forallOp.getMixedStep(), forallOutputs, forallOp.getMappingAttr());

auto collapsedResultOp = rewriter.create<tensor::CollapseShapeOp>(
loc, cast<ShapedType>(forallOp->getResult(tiedResultIdx).getType()),
newForallOp->getResult(tiedResultIdx), reIndices);

// Merge the old scf.forall block which has the expanded users into the new
// scf.forall which has the expanded destination.
SmallVector<Value> argReplacements(newForallOp.getInductionVars());
argReplacements.append(newForallOp.getRegionIterArgs().begin(),
newForallOp.getRegionIterArgs().end());
scf::InParallelOp parallelTerminator = newForallOp.getTerminator();
parallelTerminator->erase();
rewriter.mergeBlocks(forallOp.getBody(), newForallOp.getBody(),
argReplacements);

// Replaces the uses of the old scf.forall with the new scf.forall
for (int idx = 0; idx < forallOp->getNumResults(); ++idx) {
if (idx == tiedResultIdx) {
forallOp->getResult(idx).replaceAllUsesWith(
collapsedResultOp->getResult(0));
} else {
forallOp->getResult(idx).replaceAllUsesWith(
newForallOp->getResult(idx));
}
}
return success();
}
};

struct PropagateReshapesByExpansionPass final
: impl::PropagateReshapesByExpansionPassBase<
PropagateReshapesByExpansionPass> {
Expand Down Expand Up @@ -65,6 +344,7 @@ void PropagateReshapesByExpansionPass::runOnOperation() {
tensor::ExpandShapeOp::getCanonicalizationPatterns(bubbleExpandShapePatterns,
context);
populateReshapeToInterfaceTensorPatterns(bubbleExpandShapePatterns);
bubbleExpandShapePatterns.add<ExpandDestinationForallOp>(context);

if (failed(applyPatternsAndFoldGreedily(
getOperation(), std::move(bubbleExpandShapePatterns)))) {
Expand Down
Loading

0 comments on commit f3c1467

Please sign in to comment.