Skip to content

Commit

Permalink
[GlobalOpt] Improve unary elementwise propagation to consider broadca…
Browse files Browse the repository at this point in the history
…sted operands (#17903)

For binary (or more operands) elementwise operations, if one of the
operands is broadcasted or otherwise unaffected by a transposition, then
it can effectively be treated like a unary elementwise operation for the
purpose of propagation because propagating the transpose would introduce
only one additional transpose on the input operand. This improves the
unary elementwise propagation patterns to handle such cases.
  • Loading branch information
qedawkins authored Aug 12, 2024
1 parent 8dc6820 commit 6ac6be6
Show file tree
Hide file tree
Showing 3 changed files with 352 additions and 39 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/pkgci_regression_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ jobs:
--goldentime-rocm-clip-ms 18.5 \
--goldentime-rocm-vae-ms 315.0 \
--goldendispatch-rocm-unet 1714 \
--goldendispatch-rocm-clip 1569 \
--goldendispatch-rocm-clip 1311 \
--goldendispatch-rocm-vae 248 \
--goldensize-rocm-unet-bytes 2280000 \
--goldensize-rocm-clip-bytes 860000 \
Expand All @@ -364,7 +364,7 @@ jobs:
--goldentime-rocm-clip-ms 15.5 \
--goldentime-rocm-vae-ms 74.0 \
--goldendispatch-rocm-unet 1714 \
--goldendispatch-rocm-clip 1569 \
--goldendispatch-rocm-clip 1311 \
--goldendispatch-rocm-vae 248 \
--goldensize-rocm-unet-bytes 2270000 \
--goldensize-rocm-clip-bytes 860000 \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -644,23 +644,71 @@ class FuseTransposeWithLinalgOpConsumer
bool allowGeneralizing = false;
};

bool isUnaryElementwiseGeneric(linalg::GenericOp genericOp) {
if (genericOp.getNumDpsInputs() != 1 || genericOp.getNumDpsInputs() != 1 ||
!linalg::isElementwise(genericOp)) {
return false;
static bool isIndexingMapAffectedByTransposeMap(
AffineMap indexingMap, ArrayRef<int64_t> iterationSpacePermutation) {
int64_t prevIdx = -1;
for (auto result : indexingMap.getResults()) {
int64_t idx =
iterationSpacePermutation[cast<AffineDimExpr>(result).getPosition()];
// Verify that the relative ordering of indices in the map remain the same.
// If not, then the transposition affects the access order for the given
// map (and associated operand).
if (idx <= prevIdx) {
return true;
}
prevIdx = idx;
}
return false;
}

// Skip transposes and broadcasts. Transposes make more sense to fuse
// rather than propagate through, and broadcasts are cheaper to transpose
// before broadcasting.
if (genericOp.getMatchingIndexingMap(genericOp.getDpsInputOperand(0)) !=
genericOp.getMatchingIndexingMap(genericOp.getDpsInitOperand(0))) {
return false;
// Finds a single DPS input operand of the given |genericOp| that is affected by
// the |iterationSpacePermutation|. In other words, the permutation changes the
// relative ordering of any of the dimensions of that input operand.
//
// For example, with permutation [1, 0, 2], affine map (d0, d1, d2) -> (d0, d1)
// is affected by the permutation because the first two dimensions are iterated
// in a different order while (d0, d1, d2) -> (d0, d2) is unaffected.
//
// If no such operand is found or there is more than one such operation, nullptr
// is returned.
static OpOperand *
getSingleTransposedInputOperand(linalg::GenericOp genericOp,
ArrayRef<int64_t> iterationSpacePermutation) {
OpOperand *operand = nullptr;
for (auto input : genericOp.getDpsInputOperands()) {
if (!isIndexingMapAffectedByTransposeMap(
genericOp.getMatchingIndexingMap(input),
iterationSpacePermutation)) {
continue;
}
if (operand) {
return nullptr;
}
operand = input;
}
return true;
return operand;
}

// Returns a new list of indexing maps that composes the iteration space
// permutation map |transposeMap| with all indexing maps of |genericOp| except
// for the |transposedInputIdx|'th operand. The unchanged operand is expected
// to have an explicit `linalg.transpose` op constructed for it so its map does
// not need to be updated.
static SmallVector<AffineMap>
getTransposedIndexingMaps(linalg::GenericOp genericOp,
int64_t transposedInputIdx, AffineMap transposeMap) {
SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
for (unsigned i = 0, e = genericOp.getNumDpsInputs(); i < e; ++i) {
if (i == transposedInputIdx) {
continue;
}
indexingMaps[i] = indexingMaps[i].compose(transposeMap);
}
return indexingMaps;
}

// Sinks a transpose through the input of a unary elementwise operation.
// Sinks a transpose through the input of a elementwise operation where the
// transposition of the iteration space only affects a single input operand.
class SinkTransposeThroughUnaryElementwiseInput
: public OpRewritePattern<linalg::GenericOp> {
public:
Expand All @@ -669,22 +717,57 @@ class SinkTransposeThroughUnaryElementwiseInput
LogicalResult matchAndRewrite(linalg::GenericOp genericOp,
PatternRewriter &rewriter) const override {
if (!IREE::Flow::isNonNullAndOutsideDispatch(genericOp)) {
return failure();
return rewriter.notifyMatchFailure(genericOp, "pre-formed dispatch");
}

if (!isUnaryElementwiseGeneric(genericOp)) {
return rewriter.notifyMatchFailure(genericOp, "not unary elementwise");
if (!linalg::isElementwise(genericOp)) {
return rewriter.notifyMatchFailure(genericOp, "non-elementwise generic");
}

auto transposeOp =
genericOp.getDpsInputs()[0].getDefiningOp<linalg::TransposeOp>();
if (!transposeOp) {
return rewriter.notifyMatchFailure(genericOp, "no transpose operand");
if (genericOp.getNumDpsInits() != 1) {
return rewriter.notifyMatchFailure(genericOp,
"unimplemented: multiple results");
}

if (!transposeOp->hasOneUse()) {
AffineMap resultMap =
genericOp.getMatchingIndexingMap(genericOp.getDpsInitOperand(0));
if (!resultMap.isIdentity()) {
return rewriter.notifyMatchFailure(
genericOp, "do not propagate multi-use transpose");
genericOp, "unimplemented: non-identity result map");
}

linalg::TransposeOp transposeOp;
OpOperand *inputOperand;
for (auto input : genericOp.getDpsInputOperands()) {
// Skip broadcasted operands and transposed operands. If the input is
// broadcasted then we would not want to propagate because that would
// do the transpose on larger data, and if transposed we would rather
// simply compose the transposes (handled in a separate pattern).
if (genericOp.getMatchingIndexingMap(input) != resultMap) {
continue;
}

auto maybeTransposeOp = input->get().getDefiningOp<linalg::TransposeOp>();
// Skip multi-use transposes.
if (!maybeTransposeOp || !maybeTransposeOp->hasOneUse()) {
continue;
}

auto transposableInputOperand = getSingleTransposedInputOperand(
genericOp, maybeTransposeOp.getPermutation());
// Skip if more than one operand is affected by the transpose.
if (transposableInputOperand != input) {
continue;
}

transposeOp = maybeTransposeOp;
inputOperand = transposableInputOperand;
break;
}

if (!transposeOp) {
return rewriter.notifyMatchFailure(genericOp,
"no single use transpose operand");
}

ArrayRef<int64_t> perm = transposeOp.getPermutation();
Expand All @@ -694,18 +777,30 @@ class SinkTransposeThroughUnaryElementwiseInput
Value newInit =
createTransposeInit(rewriter, genericOp.getDpsInits()[0], invPerm);

// We do not need to update indexing maps because this is a unary
// elementwise op where the input and output maps are the same. Just
// replace the operands with transposed variants.
auto newGenericOp = mlir::clone(rewriter, genericOp, newInit.getType(),
{transposeOp.getInput(), newInit});
// We do not need to update iterator types because this is an elementwise
// op. We just need to update the indexing maps of all other input operands
// by composing the transpose map.
AffineMap transposeMap =
AffineMap::getPermutationMap(perm, rewriter.getContext());
SmallVector<AffineMap> indexingMaps = getTransposedIndexingMaps(
genericOp, inputOperand->getOperandNumber(), transposeMap);

SmallVector<Value> newOperands = genericOp->getOperands();
newOperands[inputOperand->getOperandNumber()] = transposeOp.getInput();
newOperands[genericOp.getDpsInitOperand(0)->getOperandNumber()] = newInit;

auto newGenericOp =
mlir::clone(rewriter, genericOp, newInit.getType(), newOperands);
newGenericOp.setIndexingMapsAttr(
rewriter.getAffineMapArrayAttr(indexingMaps));
rewriter.replaceOp(
genericOp, createTranspose(rewriter, newGenericOp->getResult(0), perm));
return success();
}
};

// Bubbles a transpose through the init of a unary elementwise operation.
// Bubbles a transpose through the init of a elementwise operation where the
// transposition of the iteration space only affects a single input operand.
class BubbleTransposeThroughUnaryElementwiseDpsInit
: public OpRewritePattern<linalg::TransposeOp> {
public:
Expand All @@ -715,33 +810,64 @@ class BubbleTransposeThroughUnaryElementwiseDpsInit
PatternRewriter &rewriter) const override {
auto genericOp = transposeOp.getInput().getDefiningOp<linalg::GenericOp>();
if (!genericOp) {
return failure();
return rewriter.notifyMatchFailure(transposeOp, "non-generic producer");
}

if (genericOp.getNumDpsInits() != 1) {
return rewriter.notifyMatchFailure(transposeOp,
"unimplemented: multiple results");
}

if (!IREE::Flow::isNonNullAndOutsideDispatch({genericOp, transposeOp})) {
return failure();
}

if (!isUnaryElementwiseGeneric(genericOp)) {
return rewriter.notifyMatchFailure(genericOp, "not unary elementwise");
if (!linalg::isElementwise(genericOp) ||
!genericOp.getMatchingIndexingMap(genericOp.getDpsInitOperand(0))
.isIdentity()) {
return rewriter.notifyMatchFailure(transposeOp, "not elementwise");
}

if (!genericOp->hasOneUse()) {
return rewriter.notifyMatchFailure(genericOp, "not single user");
return rewriter.notifyMatchFailure(transposeOp, "not single user");
}

ArrayRef<int64_t> perm = transposeOp.getPermutation();
Value newTranspose =
createTranspose(rewriter, genericOp.getOperand(0), perm);
auto invPerm = invertPermutationVector(perm);

auto inputOperand = getSingleTransposedInputOperand(genericOp, invPerm);
if (!inputOperand ||
!genericOp.getMatchingIndexingMap(inputOperand).isIdentity()) {
return rewriter.notifyMatchFailure(
genericOp, "no single transposable input operand");
}

Value newTranspose = createTranspose(rewriter, inputOperand->get(), perm);

// Create a new empty init for the transposed generic.
Value newInit =
createTransposeInit(rewriter, genericOp.getDpsInits()[0], perm);

SmallVector<Value> newOperands = genericOp->getOperands();
newOperands[inputOperand->getOperandNumber()] = newTranspose;
newOperands[genericOp.getDpsInitOperand(0)->getOperandNumber()] = newInit;

AffineMap transposeMap =
AffineMap::getPermutationMap(invPerm, rewriter.getContext());

// We do not need to update iterator types because this is an elementwise
// op. We just need to update the indexing maps of all other input operands
// by composing the transpose map.
SmallVector<AffineMap> indexingMaps = getTransposedIndexingMaps(
genericOp, inputOperand->getOperandNumber(), transposeMap);

// We do not need to update indexing maps because this is a unary
// elementwise op where the input and output maps are the same. Just
// replace the operands with transposed variants.
auto newGenericOp = mlir::clone(rewriter, genericOp, newInit.getType(),
{newTranspose, newInit});
auto newGenericOp =
mlir::clone(rewriter, genericOp, newInit.getType(), newOperands);
newGenericOp.setIndexingMapsAttr(
rewriter.getAffineMapArrayAttr(indexingMaps));
rewriter.replaceOp(transposeOp, newGenericOp);
return success();
}
Expand Down Expand Up @@ -912,6 +1038,7 @@ void PropagateLinalgTransposePass::runOnOperation() {
context, /*benefit=*/2);
if (failed(
applyPatternsAndFoldGreedily(funcOp, std::move(sinkingPatterns)))) {
funcOp.emitError("Transpose initial sinking patterns failed");
return signalPassFailure();
}
}
Expand Down Expand Up @@ -968,6 +1095,7 @@ void PropagateLinalgTransposePass::runOnOperation() {
populateCommonCanonicalizationPatterns(context, bubblingPatterns);
if (failed(applyPatternsAndFoldGreedily(funcOp,
std::move(bubblingPatterns)))) {
funcOp.emitError("Transpose bubbling patterns failed");
return signalPassFailure();
}
}
Expand Down Expand Up @@ -1020,8 +1148,13 @@ void PropagateLinalgTransposePass::runOnOperation() {
populateCommonCanonicalizationPatterns(context, sinkingPatterns);
sinkingPatterns.add<SinkTransposeThroughUnaryElementwiseInput>(
context, /*benefit=*/2);
if (failed(
applyPatternsAndFoldGreedily(funcOp, std::move(sinkingPatterns)))) {
GreedyRewriteConfig config;
// TODO: This is inefficient. Consider rewriting this pass to use a
// worklist of just the transpose operations.
config.maxIterations = GreedyRewriteConfig::kNoLimit;
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(sinkingPatterns),
config))) {
funcOp.emitError("Transpose sinking patterns failed");
return signalPassFailure();
}
}
Expand Down
Loading

0 comments on commit 6ac6be6

Please sign in to comment.