Skip to content

Commit

Permalink
[DispatchCreation] Collapse iree_linalg_ext.attention (#19012)
Browse files Browse the repository at this point in the history
This change adds support for attention in `CollapseDimensionsPass` so
that the attention op will be collapsed as much as possible. This is
motivated by reducing the different variants of attention that the sdxl
attention spec has to handle.


Changes to LinalgExt/Transforms/ReshapeFusion.cpp are mostly taken
directly from
https://github.com/llvm/llvm-project/blob/002a0a27bc4702d6f34434c1838cb1698a0b0098/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
(attributed at the top of the file). I attempted to keep not modify the
original logic as much as possible to keep it general in case it needs
to be reused for other `LinalgExt` ops.

---------

Signed-off-by: Ian Wood <[email protected]>
  • Loading branch information
IanWood1 authored Nov 12, 2024
1 parent e8f755d commit 2bfc639
Show file tree
Hide file tree
Showing 5 changed files with 512 additions and 86 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
// modified to work with LinalgExt ops, specifically `LinalgExt::AttentionOp`.

#include <optional>
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
Expand Down Expand Up @@ -56,6 +57,56 @@ class ExpansionInfo {
SmallVector<int64_t> originalLoopExtent;
unsigned expandedOpNumDims;
};

class CollapsingInfo {
public:
LogicalResult initialize(unsigned origNumLoops,
ArrayRef<ReassociationIndices> foldedIterationDims);

/// Return mapping from collapsed loop domain to original loop domain.
ArrayRef<ReassociationIndices> getCollapsedOpToOrigOpMapping() const {
return collapsedOpToOrigOpIterationDim;
}

/// Return mapping from original loop domain to collapsed loop domain. The
/// mapping is a pair. First value is the dimension in the collapsed loop that
/// the original loop is mapped to. Second is the relative position in folded
/// list of this domain. For example if the original loop domain is 3D, and
/// the collapsed loop domain is folding all of it, i.e.
///
/// ```
/// collapsedOpToOrigOpMapping = [[0, 1, 2] [3, 4]]`
/// ```
///
/// then
///
/// ```
/// origOpToCollapsedOpMapping[0] = {0, 0};
/// origOpToCollapsedOpMapping[1] = {0, 1};
/// origOpToCollapsedOpMapping[2] = {0, 2};
/// origOpToCollapsedOpMapping[3] = {1, 0};
/// origOpToCollapsedOpMapping[4] = {1, 1};
/// ```
///
ArrayRef<std::pair<int64_t, unsigned>> getOrigOpToCollapsedOpMapping() const {
return origOpToCollapsedOpIterationDim;
}

/// Return the collapsed op iteration domain rank.
unsigned getCollapsedOpIterationRank() const {
return collapsedOpToOrigOpIterationDim.size();
}

private:
/// Map from the iteration domain index in collapsed op to the iteration
/// domain indices in the original op.
SmallVector<ReassociationIndices> collapsedOpToOrigOpIterationDim;

/// Map from iteration domain index in the original op to the iteration domain
/// index in the collapsed op.
SmallVector<std::pair<int64_t, unsigned>> origOpToCollapsedOpIterationDim;
};

} // namespace

template <typename OpTy>
Expand Down Expand Up @@ -105,6 +156,51 @@ LogicalResult ExpansionInfo::compute(OpTy op, OpOperand *fusableOpOperand,
return success();
}

LogicalResult
CollapsingInfo::initialize(unsigned origNumLoops,
ArrayRef<ReassociationIndices> foldedIterationDims) {
llvm::SmallDenseSet<int64_t, 4> processedDims;
// Find all the dims that are folded.
for (ReassociationIndicesRef foldedIterationDim : foldedIterationDims) {
if (foldedIterationDim.empty())
continue;
// If the folded dims contain dims already folded, that's illegal
// specification. Repetition within a list is also illegal.
for (auto dim : foldedIterationDim) {
if (dim >= origNumLoops)
return failure();
if (processedDims.count(dim))
return failure();
processedDims.insert(dim);
}
collapsedOpToOrigOpIterationDim.emplace_back(foldedIterationDim.begin(),
foldedIterationDim.end());
}
if (processedDims.size() > origNumLoops)
return failure();

// Add all the preserved dims of the original op as single
// elements to `collapsedOpToOrigOpIterationDim`.
for (auto dim : llvm::seq<int64_t>(0, origNumLoops)) {
if (processedDims.count(dim))
continue;
collapsedOpToOrigOpIterationDim.emplace_back(ReassociationIndices{dim});
}

llvm::sort(collapsedOpToOrigOpIterationDim,
[&](ReassociationIndicesRef lhs, ReassociationIndicesRef rhs) {
return lhs[0] < rhs[0];
});
origOpToCollapsedOpIterationDim.resize(origNumLoops);
for (const auto &foldedDims :
llvm::enumerate(collapsedOpToOrigOpIterationDim)) {
for (const auto &dim : enumerate(foldedDims.value()))
origOpToCollapsedOpIterationDim[dim.value()] =
std::make_pair<int64_t, unsigned>(foldedDims.index(), dim.index());
}
return success();
}

static AffineMap
getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap,
const ExpansionInfo &expansionInfo) {
Expand Down Expand Up @@ -392,6 +488,217 @@ struct FoldAttentionWithProducerReshapeByExpansion final

} // namespace

/// Return the `reassociation` indices to use to collapse the operand when the
/// iteration space of a generic op is collapsed.
static SmallVector<ReassociationIndices>
getOperandReassociation(AffineMap indexingMap,
const CollapsingInfo &collapsingInfo) {
unsigned counter = 0;
SmallVector<ReassociationIndices> operandReassociation;
auto origOpToCollapsedOpMapping =
collapsingInfo.getOrigOpToCollapsedOpMapping();
auto collapsedOpToOrigOpMapping =
collapsingInfo.getCollapsedOpToOrigOpMapping();
while (counter < indexingMap.getNumResults()) {
unsigned dim =
cast<AffineDimExpr>(indexingMap.getResult(counter)).getPosition();
// This is the start of a collapsed dimensions of the iteration that
// is gauranteed to be preserved in the indexing map. The number of folded
// dims is obtained from the collapsed op to original op mapping.
unsigned numFoldedDims =
collapsedOpToOrigOpMapping[origOpToCollapsedOpMapping[dim].first]
.size();
if (origOpToCollapsedOpMapping[dim].second == 0) {
auto range = llvm::seq<unsigned>(counter, counter + numFoldedDims);
operandReassociation.emplace_back(range.begin(), range.end());
}
counter += numFoldedDims;
}
return operandReassociation;
}

/// Get the new value to use for a given `OpOperand` in the collapsed operation.
static Value getCollapsedOpOperand(Location loc, AttentionOp op,
OpOperand *opOperand,
const CollapsingInfo &collapsingInfo,
OpBuilder &builder) {
AffineMap indexingMap = op.getMatchingIndexingMap(opOperand);
SmallVector<ReassociationIndices> operandReassociation =
getOperandReassociation(indexingMap, collapsingInfo);

// If the number of entries in the reassociation for the operand is same as
// the number of results of the indexing map, then nothing to do for this
// operand.
Value operand = opOperand->get();
if (operandReassociation.size() == indexingMap.getNumResults())
return operand;

// Insert a reshape to collapse the dimensions.
if (isa<MemRefType>(operand.getType())) {
return builder
.create<memref::CollapseShapeOp>(loc, operand, operandReassociation)
.getResult();
}
return builder
.create<tensor::CollapseShapeOp>(loc, operand, operandReassociation)
.getResult();
}

static void collapseOperandsAndResults(AttentionOp op,
const CollapsingInfo &collapsingInfo,
RewriterBase &rewriter,
SmallVectorImpl<Value> &inputOperands,
SmallVectorImpl<Value> &outputOperands,
SmallVectorImpl<Type> &resultTypes) {
Location loc = op->getLoc();
inputOperands =
llvm::map_to_vector(op.getDpsInputOperands(), [&](OpOperand *opOperand) {
return getCollapsedOpOperand(loc, op, opOperand, collapsingInfo,
rewriter);
});

// Get the output operands and result types.
resultTypes.reserve(op.getNumDpsInits());
outputOperands.reserve(op.getNumDpsInits());
for (OpOperand &output : op.getDpsInitsMutable()) {
Value newOutput =
getCollapsedOpOperand(loc, op, &output, collapsingInfo, rewriter);
outputOperands.push_back(newOutput);
// If the op has "buffer semantics", then the init operands are ranked
// memrefs and the op has no results.
if (!op.hasPureBufferSemantics())
resultTypes.push_back(newOutput.getType());
}
}

/// Compute the indexing map in the collapsed op that corresponds to the given
/// `indexingMap` of the original operation.
static AffineMap
getCollapsedOpIndexingMap(AffineMap indexingMap,
const CollapsingInfo &collapsingInfo) {
MLIRContext *context = indexingMap.getContext();
assert(indexingMap.isProjectedPermutation() &&
"expected indexing map to be projected permutation");
SmallVector<AffineExpr> resultExprs;
auto origOpToCollapsedOpMapping =
collapsingInfo.getOrigOpToCollapsedOpMapping();
for (auto expr : indexingMap.getResults()) {
unsigned dim = cast<AffineDimExpr>(expr).getPosition();
// If the dim is not the first of the collapsed dim, do nothing.
if (origOpToCollapsedOpMapping[dim].second != 0)
continue;
// The next n-dims are guaranteed to be collapsed. So just use the
// iteration dimension of the collapsed op.
resultExprs.push_back(
getAffineDimExpr(origOpToCollapsedOpMapping[dim].first, context));
}
return AffineMap::get(collapsingInfo.getCollapsedOpIterationRank(), 0,
resultExprs, context);
}

/// Get the iterator types for the collapsed operation given the original
/// iterator types and collapsed dimensions.
static SmallVector<utils::IteratorType>
getCollapsedOpIteratorTypes(ArrayRef<utils::IteratorType> iteratorTypes,
const CollapsingInfo &collapsingInfo) {
SmallVector<utils::IteratorType> collapsedIteratorTypes;
for (ReassociationIndicesRef foldedIterDims :
collapsingInfo.getCollapsedOpToOrigOpMapping()) {
assert(!foldedIterDims.empty() &&
"reassociation indices expected to have non-empty sets");
// Just pick the iterator type of the first folded dim. Pre-condition checks
// expected to have checked that iterator types of all folded dimensions are
// the same.
collapsedIteratorTypes.push_back(iteratorTypes[foldedIterDims[0]]);
}
return collapsedIteratorTypes;
}

/// Returns a copy of `attentionOp` with collapsed iteration dimensions.
static Operation *createCollapsedOp(AttentionOp origOp,
const CollapsingInfo &collapsingInfo,
RewriterBase &rewriter) {
SmallVector<Value> inputOperands, outputOperands;
SmallVector<Type> resultTypes;
collapseOperandsAndResults(origOp, collapsingInfo, rewriter, inputOperands,
outputOperands, resultTypes);
SmallVector<AffineMap> indexingMaps(
llvm::map_range(origOp.getIndexingMapsArray(), [&](AffineMap map) {
return getCollapsedOpIndexingMap(map, collapsingInfo);
}));

SmallVector<utils::IteratorType> iteratorTypes(getCollapsedOpIteratorTypes(
origOp.getLoopIteratorTypes(), collapsingInfo));

Value maskOperand;
if (inputOperands.size() > 4) {
maskOperand = inputOperands[4];
}

auto collapsedOp = rewriter.create<AttentionOp>(
origOp.getLoc(), resultTypes, inputOperands[0], inputOperands[1],
inputOperands[2], inputOperands[3], outputOperands[0],
rewriter.getAffineMapArrayAttr(indexingMaps), maskOperand);
rewriter.inlineRegionBefore(origOp.getRegion(), collapsedOp.getRegion(),
collapsedOp.getRegion().begin());
return collapsedOp;
}

FailureOr<CollapseResult>
collapseOpIterationDims(AttentionOp op,
ArrayRef<ReassociationIndices> foldedIterationDims,
RewriterBase &rewriter) {
if (op.getNumLoops() <= 1 || foldedIterationDims.empty() ||
llvm::all_of(foldedIterationDims, [](ReassociationIndicesRef foldedDims) {
return foldedDims.size() <= 1;
}))
return failure();

FailureOr<SmallVector<int64_t>> staticLoops = op.getStaticLoopRanges();
if (failed(staticLoops) ||
llvm::any_of(staticLoops.value(), ShapedType::isDynamic)) {
return failure();
}

CollapsingInfo collapsingInfo;
if (failed(
collapsingInfo.initialize(op.getNumLoops(), foldedIterationDims))) {
return rewriter.notifyMatchFailure(
op, "illegal to collapse specified dimensions");
}

Operation *collapsedOp = createCollapsedOp(op, collapsingInfo, rewriter);

auto loc = op.getLoc();
SmallVector<Value> results;
for (const auto &originalResult : llvm::enumerate(op->getResults())) {
Value collapsedOpResult = collapsedOp->getResult(originalResult.index());
auto originalResultType =
cast<ShapedType>(originalResult.value().getType());
auto collapsedOpResultType = cast<ShapedType>(collapsedOpResult.getType());
if (collapsedOpResultType.getRank() != originalResultType.getRank()) {
AffineMap indexingMap =
op.getIndexingMapMatchingResult(originalResult.value());
SmallVector<ReassociationIndices> reassociation =
getOperandReassociation(indexingMap, collapsingInfo);
Value result;
if (isa<MemRefType>(collapsedOpResult.getType())) {
MemRefType expandShapeResultType = MemRefType::get(
originalResultType.getShape(), originalResultType.getElementType());
result = rewriter.create<memref::ExpandShapeOp>(
loc, expandShapeResultType, collapsedOpResult, reassociation);
} else {
result = rewriter.create<tensor::ExpandShapeOp>(
loc, originalResultType, collapsedOpResult, reassociation);
}
results.push_back(result);
} else {
results.push_back(collapsedOpResult);
}
}
return CollapseResult{results, collapsedOp};
}

void populateFoldReshapeOpsByExpansionPatterns(
RewritePatternSet &patterns,
const linalg::ControlFusionFn &controlFoldingReshapes) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,33 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"

namespace mlir::iree_compiler::IREE::LinalgExt {

// Fold expand_shape ops with their producers (only `AttentionOp` supported)
/// Fold expand_shape ops with their producers (only `AttentionOp` supported)
void populateFoldReshapeOpsByExpansionPatterns(
RewritePatternSet &patterns,
const linalg::ControlFusionFn &controlFoldingReshapes);

/// Fuse transpose-like ops into LinalgExt ops (only `AttentionOp` supported).
void populateFuseLinalgExtOpsWithTransposes(
RewritePatternSet &patterns,
const linalg::ControlFusionFn &controlFusionFn);

/// Helper struct to hold the results of collapsing an operation.
struct CollapseResult {
SmallVector<Value> results;
Operation *collapsedOp;
};

/// Collapse the iteration dimension of `op` as described by
/// `foldedIterationDims`. Returns failure when the op cannot be collapsed or it
/// is a no-op.
FailureOr<CollapseResult>
collapseOpIterationDims(AttentionOp op,
ArrayRef<ReassociationIndices> foldedIterationDims,
RewriterBase &rewriter);

}; // namespace mlir::iree_compiler::IREE::LinalgExt
Loading

0 comments on commit 2bfc639

Please sign in to comment.