From f3c1467833242e7f32d183f13c47335520fc8b25 Mon Sep 17 00:00:00 2001 From: Nirvedh Meshram <96096277+nirvedhmeshram@users.noreply.github.com> Date: Tue, 12 Nov 2024 20:58:02 -0600 Subject: [PATCH] Hoist collapse shape out of scf.forall when possible and expand its destination (#19044) This pattern allows to hoist out collapse shape producer of tensor.parallel_insert_slice and expand the enclosing scf.forall destination. This is only safe because we are doing this on workgroup mapped sc.forall's and hence the slices are disjoint. The reason to have this pattern is that it allows us to eliminate empty tensors during bufferization which would otherwise be blocked due to the collapse. --------- Signed-off-by: Nirvedh --- .../Common/PropagateReshapesByExpansion.cpp | 280 ++++++++++++++++++ .../test/propagate_reshapes_by_expansion.mlir | 253 +++++++++++++++- .../Codegen/Transforms/Transforms.cpp | 26 +- .../src/iree/compiler/Codegen/Utils/Utils.cpp | 22 ++ .../src/iree/compiler/Codegen/Utils/Utils.h | 7 + 5 files changed, 562 insertions(+), 26 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/Common/PropagateReshapesByExpansion.cpp b/compiler/src/iree/compiler/Codegen/Common/PropagateReshapesByExpansion.cpp index 67d2358d7f89..0f16f67130be 100644 --- a/compiler/src/iree/compiler/Codegen/Common/PropagateReshapesByExpansion.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/PropagateReshapesByExpansion.cpp @@ -7,6 +7,9 @@ #include "iree/compiler/Codegen/Common/Passes.h" #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h" #include "iree/compiler/Codegen/Transforms/Transforms.h" +#include "iree/compiler/Codegen/Utils/GPUUtils.h" +#include "iree/compiler/Codegen/Utils/Utils.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -17,6 +20,282 @@ namespace mlir::iree_compiler { namespace { +/// Calculate the expanded shape of `dest` if it can be expanded with the inner +/// expanded sizes of `sliceStaticSizes`. Returns failure if such expansion is +/// not possible. +static LogicalResult +getExpandedShape(SmallVector reIndices, + ArrayRef sliceStaticSizes, Value dest, + SmallVectorImpl &expandedShape, + SmallVectorImpl &totalInnerSizes) { + auto destType = dyn_cast(dest.getType()); + if (!destType) + return failure(); + // TODO (nirvedhmeshram): Support rank reducing parallel_insert_slice. + if (reIndices.size() != destType.getShape().size()) + return failure(); + // Iterator to insert outer sizes. + auto outerShapeIter = expandedShape.begin(); + for (auto [reassociations, destSize] : + llvm::zip_equal(reIndices, destType.getShape())) { + // Dynamic destination dims that are not getting expanded are allowed. + if (ShapedType::isDynamic(destSize) && reassociations.size() == 1) { + expandedShape.insert(outerShapeIter++, destSize); + totalInnerSizes.push_back(1); + continue; + } + // Dynamic destination dims that are expanded are currently unsupported but + // this support can be added if needed. + if (ShapedType::isDynamic(destSize)) { + return failure(); + } + int64_t totalInnerSize = 1; + for (int64_t reasociation : llvm::drop_begin(reassociations)) { + int64_t expandedInnerSize = sliceStaticSizes[reasociation]; + // It is not safe to do this pattern if inner dimensions are dynamic. + if (ShapedType::isDynamic(expandedInnerSize)) + return failure(); + expandedShape.push_back(expandedInnerSize); + totalInnerSize *= expandedInnerSize; + } + if (destSize % totalInnerSize != 0) + return failure(); + totalInnerSizes.push_back(totalInnerSize); + // insert the outer size in front of any inner sizes. + expandedShape.insert(outerShapeIter, destSize / totalInnerSize); + // set up the iterator for the next uncollapsed dimension. + outerShapeIter = expandedShape.end(); + } + return success(); +} + +/// Check if the users of the expanded scf.forall destination can be updated to +/// account for the expand. If not we bail out. There are two supported users +/// which are extract_slice -> expand_shape with the same exact reassociation +/// map as the collapse op to be hoisted out or the root parallel_insert_slice. +static LogicalResult +verifyAndCollectExpandableUsers(Value insertDest, + SmallVector reIndices, + tensor::ParallelInsertSliceOp parallelInsertOp, + SmallVector &expandableUsers) { + for (Operation *user : insertDest.getUsers()) { + if (user == parallelInsertOp) { + expandableUsers.push_back(user); + continue; + } + auto extractSliceOp = dyn_cast(user); + if (!extractSliceOp) + return failure(); + if (extractSliceOp.getMixedSizes() != parallelInsertOp.getMixedSizes()) + return failure(); + if (extractSliceOp.getMixedOffsets() != parallelInsertOp.getMixedOffsets()) + return failure(); + auto expandShapeOp = + dyn_cast(*extractSliceOp->getUsers().begin()); + if (!expandShapeOp) + return failure(); + SmallVector expandReIndices = + expandShapeOp.getReassociationIndices(); + if (reIndices != expandReIndices) + return failure(); + expandableUsers.push_back(user); + } + return success(); +} + +/// Utility to expand the pre-verified expandable users of the scf.forall +/// output. +static void expandVerifiedUsers(PatternRewriter &rewriter, Location loc, + MLIRContext *ctx, + SmallVector expandableUsers, + SmallVector totalInnerSizes, + SmallVector reIndices, + scf::ForallOp forallOp) { + // compute the offsets,sizes,strides in the expanded dimensions. + auto computeExpandedAccess = [&](ArrayRef mixedOffsets, + ShapedType resultType) + -> std::tuple, SmallVector, + SmallVector> { + SmallVector expandedOffsets; + auto expandedOffsetsIter = expandedOffsets.begin(); + + for (auto [index, offset] : llvm::enumerate(mixedOffsets)) { + // Add zero offsets for the extra dimensions from reIndices. + for (size_t i = 1, e = reIndices[index].size(); i < e; ++i) { + expandedOffsets.push_back(getAsIndexOpFoldResult(ctx, 0)); + } + // Compute the outer dimension expression. + AffineExpr s0, s1; + bindSymbols(rewriter.getContext(), s0, s1); + AffineExpr outerDimExpr = (s0).floorDiv(s1); + // Insert computed offset using affine expression. + expandedOffsets.insert( + expandedOffsetsIter, + affine::makeComposedFoldedAffineApply( + rewriter, loc, outerDimExpr, + {getValueOrCreateConstantIndexOp(rewriter, loc, offset), + rewriter.getIndexAttr(totalInnerSizes[index])})); + + expandedOffsetsIter = expandedOffsets.end(); + } + SmallVector expandedSizes = + getAsIndexOpFoldResult(ctx, resultType.getShape()); + SmallVector expandedStrides(resultType.getRank(), + rewriter.getIndexAttr(1)); + return {expandedOffsets, expandedSizes, expandedStrides}; + }; + for (Operation *user : expandableUsers) { + rewriter.setInsertionPointToStart(forallOp.getBody()); + if (auto extractSliceOp = dyn_cast(user)) { + auto expandShapeOp = + dyn_cast(*extractSliceOp->getUsers().begin()); + RankedTensorType resultType = expandShapeOp.getResultType(); + auto [expandedOffsets, expandedSizes, expandedStrides] = + computeExpandedAccess(extractSliceOp.getMixedOffsets(), resultType); + rewriter.setInsertionPoint(extractSliceOp); + rewriter.replaceOpWithNewOp( + extractSliceOp, resultType, extractSliceOp.getSource(), + expandedOffsets, expandedSizes, expandedStrides); + } else if (auto parallelInsertOp = + dyn_cast(user)) { + auto collapseShapeOp = + parallelInsertOp.getSource().getDefiningOp(); + RankedTensorType resultType = collapseShapeOp.getSrcType(); + auto [expandedOffsets, expandedSizes, expandedStrides] = + computeExpandedAccess(parallelInsertOp.getMixedOffsets(), resultType); + rewriter.setInsertionPoint(parallelInsertOp); + rewriter.replaceOpWithNewOp( + parallelInsertOp, collapseShapeOp.getSrc(), + parallelInsertOp.getDest(), expandedOffsets, expandedSizes, + expandedStrides); + } + } + return; +} + +/// This pattern expands destination of workgroup mapped scf.foralls by +/// hoisting out collapse_shape op consumed by its parallel.insert_slice op. +struct ExpandDestinationForallOp final + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(tensor::ParallelInsertSliceOp parallelInsertOp, + PatternRewriter &rewriter) const override { + Location loc = parallelInsertOp.getLoc(); + MLIRContext *ctx = getContext(); + auto collapseOp = + parallelInsertOp.getSource().getDefiningOp(); + // No collapse op to hoist out. + if (!collapseOp) + return failure(); + + // Ignore trivially foldable collapse ops. + if (collapseOp.getSrcType().getRank() == + collapseOp.getResultType().getRank()) { + return failure(); + } + + // Get the destination to expand. + Value insertDest = parallelInsertOp.getDest(); + + // Get the enclosing scf.forall op. + OpResult tiedResult = parallelInsertOp.getTiedOpResult(); + int64_t tiedResultIdx = tiedResult.getResultNumber(); + + auto forallOp = dyn_cast(tiedResult.getOwner()); + if (!forallOp) + return failure(); + + // We only want this pattern if the forall op result is being written to a + // full slice. Otherwise the hoisted collapse op is not foldable. + for (Operation *foralluser : tiedResult.getUsers()) { + auto storeOp = dyn_cast(foralluser); + if (!storeOp) + return failure(); + if (!isFullSlice(storeOp, storeOp.getTargetType(), + storeOp.getTargetDims())) { + return failure(); + } + } + + // This allows us to assume that the extract/inserts in the loop are + // disjoint and makes the application of this pattern safe. + if (!forallOpHasMappingType( + forallOp)) { + return failure(); + } + // This pattern only supports forall ops with single + // output. + SmallVector forallOutputs(forallOp.getOutputs()); + + SmallVector reIndices = + collapseOp.getReassociationIndices(); + SmallVector expandedDestShape; + SmallVector totalInnerSizes; + // Get the shape of the outer expand which will be the new destination + // of the scf.forall and the total size of inner dimensions per uncollapsed + // dimension. + if (failed(getExpandedShape(reIndices, collapseOp.getSrcType().getShape(), + insertDest, expandedDestShape, + totalInnerSizes))) { + return failure(); + } + + // Verify that the users of destination are valid to expand and collect all + // such users. + SmallVector expandableUsers; + if (failed(verifyAndCollectExpandableUsers( + insertDest, collapseOp.getReassociationIndices(), parallelInsertOp, + expandableUsers))) { + return failure(); + } + + // Expand the users of the destination. + rewriter.setInsertionPointToStart(forallOp.getBody()); + expandVerifiedUsers(rewriter, loc, ctx, expandableUsers, totalInnerSizes, + reIndices, forallOp); + rewriter.setInsertionPoint(forallOp); + + // Create the expand -> new scf.forall -> collapse chain. + auto expandedDestType = + cast(forallOutputs[tiedResultIdx].getType()) + .clone(expandedDestShape); + auto expandedDest = rewriter.create( + loc, expandedDestType, forallOutputs[tiedResultIdx], reIndices); + + forallOutputs[tiedResultIdx] = expandedDest; + + scf::ForallOp newForallOp = rewriter.create( + loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(), + forallOp.getMixedStep(), forallOutputs, forallOp.getMappingAttr()); + + auto collapsedResultOp = rewriter.create( + loc, cast(forallOp->getResult(tiedResultIdx).getType()), + newForallOp->getResult(tiedResultIdx), reIndices); + + // Merge the old scf.forall block which has the expanded users into the new + // scf.forall which has the expanded destination. + SmallVector argReplacements(newForallOp.getInductionVars()); + argReplacements.append(newForallOp.getRegionIterArgs().begin(), + newForallOp.getRegionIterArgs().end()); + scf::InParallelOp parallelTerminator = newForallOp.getTerminator(); + parallelTerminator->erase(); + rewriter.mergeBlocks(forallOp.getBody(), newForallOp.getBody(), + argReplacements); + + // Replaces the uses of the old scf.forall with the new scf.forall + for (int idx = 0; idx < forallOp->getNumResults(); ++idx) { + if (idx == tiedResultIdx) { + forallOp->getResult(idx).replaceAllUsesWith( + collapsedResultOp->getResult(0)); + } else { + forallOp->getResult(idx).replaceAllUsesWith( + newForallOp->getResult(idx)); + } + } + return success(); + } +}; + struct PropagateReshapesByExpansionPass final : impl::PropagateReshapesByExpansionPassBase< PropagateReshapesByExpansionPass> { @@ -65,6 +344,7 @@ void PropagateReshapesByExpansionPass::runOnOperation() { tensor::ExpandShapeOp::getCanonicalizationPatterns(bubbleExpandShapePatterns, context); populateReshapeToInterfaceTensorPatterns(bubbleExpandShapePatterns); + bubbleExpandShapePatterns.add(context); if (failed(applyPatternsAndFoldGreedily( getOperation(), std::move(bubbleExpandShapePatterns)))) { diff --git a/compiler/src/iree/compiler/Codegen/Common/test/propagate_reshapes_by_expansion.mlir b/compiler/src/iree/compiler/Codegen/Common/test/propagate_reshapes_by_expansion.mlir index fc9e85e3a764..faeb828097b4 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/propagate_reshapes_by_expansion.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/test/propagate_reshapes_by_expansion.mlir @@ -1,4 +1,5 @@ -// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-propagate-reshapes-by-expansion))" --split-input-file %s --mlir-print-local-scope | FileCheck %s +// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-propagate-reshapes-by-expansion), cse)" \ +// RUN: --split-input-file %s --mlir-print-local-scope | FileCheck %s func.func @reshape_and_lowering_config(%src: tensor<3x4xf16>, %dest: tensor<12xf16>, %dest2: tensor<12xf16>) -> tensor<12xf16> { %collapse = tensor.collapse_shape %src [[0, 1]] : tensor<3x4xf16> into tensor<12xf16> @@ -86,3 +87,253 @@ func.func @fold_collapse_into_stores_dynamic(%arg0 : tensor<2x?x32xf32>) { // CHECK: flow.dispatch.tensor.store %{{.+}}, %[[SUBSPAN]] // CHECK-SAME: offsets = [0, 0, 0], sizes = [2, %[[SHAPE]], 32], strides = [1, 1, 1] // CHECK-SAME: !flow.dispatch.tensor>{%[[SHAPE]]} + +// ----- + +#pipeline_layout = #hal.pipeline.layout], flags = Indirect> +func.func @expand_dest_forall() { + %cst = arith.constant 0.000000e+00 : f16 + %c0 = arith.constant 0 : index + %index = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : index + %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) + flags(Indirect) : !flow.dispatch.tensor>{%index} + %1 = tensor.empty(%index) : tensor + %extra = tensor.empty() : tensor<32x32xf32> + %2 = scf.forall (%arg0, %arg1) = (0, 0) to (64, 32) step (16, 16) + shared_outs(%arg2 = %1) -> (tensor) { + %extracted_slice = tensor.extract_slice %arg2[%c0, %arg0, %arg1] [1, 16, 16] [1, 1, 1] + : tensor to tensor<1x16x16xf32> + %expanded = tensor.expand_shape %extracted_slice [[0], [1], [2, 3, 4]] + output_shape [1, 16, 2, 4, 2] : tensor<1x16x16xf32> into tensor<1x16x2x4x2xf32> + %expanded_barrier = util.optimization_barrier %expanded : tensor<1x16x2x4x2xf32> + %collapsed = tensor.collapse_shape %expanded_barrier [[0], [1], [2, 3, 4]] : tensor<1x16x2x4x2xf32> into tensor<1x16x16xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %collapsed into %arg2[%c0, %arg0, %arg1] [1, 16, 16] [1, 1, 1] + : tensor<1x16x16xf32> into tensor + } + } {mapping = [#iree_codegen.workgroup_mapping, #iree_codegen.workgroup_mapping]} + flow.dispatch.tensor.store %2, %0, offsets = [0, 0, 0], sizes = [%index, 64, 32], strides = [1, 1, 1] + : tensor -> !flow.dispatch.tensor>{%index} + return +} + +// CHECK-LABEL: func @expand_dest_forall( +// CHECK: %[[LOAD_CONST:.+]] = hal.interface.constant.load +// CHECK: %[[SUBSPAN:.+]] = hal.interface.binding.subspan +// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[LOAD_CONST]]) : tensor +// CHECK: %[[SCFFORALL:.+]] = scf.forall (%[[ARG0:.+]], %[[ARG1:.+]]) = (0, 0) +// CHECK-SAME: shared_outs(%[[ARG2:.+]] = %[[EMPTY]]) -> (tensor) { +// CHECK-DAG: %[[OFFSET:.+]] = affine.apply affine_map<()[s0] -> (s0 floordiv 8)>()[%[[ARG1]]] +// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[ARG2]] +// CHECK-SAME: [0, %[[ARG0]], %[[OFFSET]], 0, 0] [1, 16, 2, 4, 2] [1, 1, 1, 1, 1] +// CHECK-SAME: tensor to tensor<1x16x2x4x2xf32> +// CHECK: %[[BARRIER:.+]] = util.optimization_barrier %[[EXTRACT]] : tensor<1x16x2x4x2xf32> +// CHECK: tensor.parallel_insert_slice %[[BARRIER]] into %[[ARG2]] +// CHECK-SAME: [0, %[[ARG0]], %[[OFFSET]], 0, 0] [1, 16, 2, 4, 2] [1, 1, 1, 1, 1] +// CHECK-SAME: tensor<1x16x2x4x2xf32> into tensor +// CHECK: flow.dispatch.tensor.store %[[SCFFORALL]], %[[SUBSPAN]] +// CHECK-SAME: offsets = [0, 0, 0, 0, 0], sizes = [%[[LOAD_CONST]], 64, 4, 4, 2], strides = [1, 1, 1, 1, 1] +// CHECK-SAME: !flow.dispatch.tensor>{%[[LOAD_CONST]]} + +// ----- +#pipeline_layout = #hal.pipeline.layout], flags = Indirect> +func.func @expand_dest_forall_multiresult() { + %cst = arith.constant 0.000000e+00 : f16 + %c0 = arith.constant 0 : index + %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) + flags(Indirect) : !flow.dispatch.tensor> + %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) + offset(%c0) flags(Indirect) : !flow.dispatch.tensor> + %2 = tensor.empty() : tensor<32xf32> + %3 = tensor.empty() : tensor<32x32xf32> + %4:2 = scf.forall (%arg0) = (0) to (32) step (16) + shared_outs(%arg1 = %3, %arg2 = %2) -> (tensor<32x32xf32>, tensor<32xf32>) { + %extracted_slice = tensor.extract_slice %arg2[%arg0] [16] [1] : tensor<32xf32> to tensor<16xf32> + %expanded = tensor.expand_shape %extracted_slice [[0, 1]] output_shape [2, 8] + : tensor<16xf32> into tensor<2x8xf32> + %5 = util.optimization_barrier %expanded : tensor<2x8xf32> + %collapsed = tensor.collapse_shape %5 [[0, 1]] : tensor<2x8xf32> into tensor<16xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %arg1 into %arg1[%c0, %c0] [32, 32] [1, 1] + : tensor<32x32xf32> into tensor<32x32xf32> + tensor.parallel_insert_slice %collapsed into %arg2[%arg0] [16] [1] + : tensor<16xf32> into tensor<32xf32> + } + } {mapping = [#iree_codegen.workgroup_mapping]} + flow.dispatch.tensor.store %4#1, %0, offsets = [0], sizes = [32], strides = [1] + : tensor<32xf32> -> !flow.dispatch.tensor> + flow.dispatch.tensor.store %4#0, %1, offsets = [0, 0], sizes = [32, 32], strides = [1, 1] + : tensor<32x32xf32> -> !flow.dispatch.tensor> + return +} + +// CHECK-LABEL: func @expand_dest_forall_multiresult( +// CHECK: %[[SUBSPAN0:.+]] = hal.interface.binding.subspan +// CHECK: %[[SUBSPAN1:.+]] = hal.interface.binding.subspan +// CHECK: %[[EMPTY0:.+]] = tensor.empty() : tensor<32x32xf32> +// CHECK: %[[EMPTY1:.+]] = tensor.empty() : tensor<4x8xf32> +// CHECK: %[[SCFFORALL:.+]]:2 = scf.forall (%[[ARG0:.+]]) = (0) to (32) step (16) +// CHECK-SAME: shared_outs(%[[ARG1:.+]] = %[[EMPTY0]], %[[ARG2:.+]] = %[[EMPTY1]]) +// CHECK-SAME: -> (tensor<32x32xf32>, tensor<4x8xf32>) { +// CHECK-DAG: %[[OFFSET:.+]] = affine.apply affine_map<()[s0] -> (s0 floordiv 8)>()[%[[ARG0]]] +// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[ARG2]] +// CHECK-SAME: [%[[OFFSET]], 0] [2, 8] [1, 1] +// CHECK-SAME: tensor<4x8xf32> to tensor<2x8xf32> +// CHECK: %[[BARRIER:.+]] = util.optimization_barrier %[[EXTRACT]] : tensor<2x8xf32> +// CHECK: tensor.parallel_insert_slice %[[ARG1]] into %[[ARG1]] +// CHECK-SAME: tensor<32x32xf32> into tensor<32x32xf32> +// CHECK: tensor.parallel_insert_slice %[[BARRIER]] into %[[ARG2]] +// CHECK-SAME: [%[[OFFSET]], 0] [2, 8] [1, 1] +// CHECK-SAME: tensor<2x8xf32> into tensor<4x8xf32> +// CHECK: flow.dispatch.tensor.store %[[SCFFORALL]]#1, %[[SUBSPAN0]] +// CHECK-SAME: offsets = [0, 0], sizes = [4, 8], strides = [1, 1] +// CHECK-SAME: !flow.dispatch.tensor> +// CHECK: flow.dispatch.tensor.store %[[SCFFORALL]]#0, %[[SUBSPAN1]] +// CHECK-SAME: offsets = [0, 0], sizes = [32, 32], strides = [1, 1] +// CHECK-SAME: !flow.dispatch.tensor> + + +// ----- + +#pipeline_layout = #hal.pipeline.layout], flags = Indirect> +func.func @noexpand_dest_forall_dynamicpacked() { + %index1 = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : index + %index2 = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : index + %index3 = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : index + %cst = arith.constant 0.000000e+00 : f16 + %c0 = arith.constant 0 : index + %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) + flags(Indirect) : !flow.dispatch.tensor> + %2 = tensor.empty() : tensor<32xf32> + %4 = scf.forall (%arg0) = (0) to (32) step (16) + shared_outs(%arg2 = %2) -> (tensor<32xf32>) { + %extracted_slice = tensor.extract_slice %arg2[%arg0] [%index1] [1] : tensor<32xf32> to tensor + %expanded = tensor.expand_shape %extracted_slice [[0, 1]] output_shape [%index2, %index3] + : tensor into tensor + %5 = util.optimization_barrier %expanded : tensor + %collapsed = tensor.collapse_shape %5 [[0, 1]] : tensor into tensor + scf.forall.in_parallel { + tensor.parallel_insert_slice %collapsed into %arg2[%arg0] [%index1] [1] + : tensor into tensor<32xf32> + } + } {mapping = [#iree_codegen.workgroup_mapping]} + flow.dispatch.tensor.store %4, %0, offsets = [0], sizes = [32], strides = [1] : tensor<32xf32> + -> !flow.dispatch.tensor> + return +} + +// CHECK-LABEL: func @noexpand_dest_forall_dynamicpacked( +// CHECK: %[[SUBSPAN:.+]] = hal.interface.binding.subspan +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<32xf32> +// CHECK: %[[SCFFORALL:.+]] = scf.forall (%[[ARG0:.+]]) = (0) to (32) step (16) +// CHECK-SAME: shared_outs(%[[ARG2:.+]] = %[[EMPTY]]) -> (tensor<32xf32>) { +// CHECK: flow.dispatch.tensor.store %[[SCFFORALL]], %[[SUBSPAN]] +// CHECK-SAME: offsets = [0], sizes = [32], strides = [1] : tensor<32xf32> +// CHECK-SAME: !flow.dispatch.tensor> + +// ----- +#pipeline_layout = #hal.pipeline.layout], flags = Indirect> +func.func @expand_dest_forall_unsupporteduse() { + %c0 = arith.constant 0 : index + %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) + flags(Indirect) : !flow.dispatch.tensor> + %2 = tensor.empty() : tensor<32xf32> + %4 = scf.forall (%arg0) = (0) to (32) step (16) + shared_outs(%arg2 = %2) -> (tensor<32xf32>) { + %extracted_slice = tensor.extract_slice %arg2[%arg0] [16] [1] : tensor<32xf32> to tensor<16xf32> + %arith_op = arith.negf %extracted_slice : tensor<16xf32> + %expanded = tensor.expand_shape %arith_op [[0, 1]] output_shape [2, 8] + : tensor<16xf32> into tensor<2x8xf32> + %5 = util.optimization_barrier %expanded : tensor<2x8xf32> + %collapsed = tensor.collapse_shape %5 [[0, 1]] : tensor<2x8xf32> into tensor<16xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %collapsed into %arg2[%arg0] [16] [1] + : tensor<16xf32> into tensor<32xf32> + } + } {mapping = [#iree_codegen.workgroup_mapping]} + flow.dispatch.tensor.store %4, %0, offsets = [0], sizes = [32], strides = [1] : tensor<32xf32> -> !flow.dispatch.tensor> + return +} + +// CHECK-LABEL: func @expand_dest_forall_unsupporteduse( +// CHECK: %[[SUBSPAN:.+]] = hal.interface.binding.subspan +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<32xf32> +// CHECK: %[[SCFFORALL:.+]] = scf.forall (%[[ARG0:.+]]) = (0) to (32) step (16) +// CHECK-SAME: shared_outs(%[[ARG2:.+]] = %[[EMPTY]]) -> (tensor<32xf32>) { +// CHECK: flow.dispatch.tensor.store %[[SCFFORALL]], %[[SUBSPAN]] +// CHECK-SAME: offsets = [0], sizes = [32], strides = [1] : tensor<32xf32> +// CHECK-SAME: !flow.dispatch.tensor> + + +// ----- + +#pipeline_layout = #hal.pipeline.layout], flags = Indirect> +func.func @noexpand_dest_forall_nomapping() { + %c0 = arith.constant 0 : index + %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) + flags(Indirect) : !flow.dispatch.tensor> + %2 = tensor.empty() : tensor<32xf32> + %4 = scf.forall (%arg0) = (0) to (32) step (16) + shared_outs(%arg2 = %2) -> (tensor<32xf32>) { + %extracted_slice = tensor.extract_slice %arg2[%arg0] [16] [1] : tensor<32xf32> to tensor<16xf32> + %expanded = tensor.expand_shape %extracted_slice [[0, 1]] output_shape [2, 8] + : tensor<16xf32> into tensor<2x8xf32> + %5 = util.optimization_barrier %expanded : tensor<2x8xf32> + %collapsed = tensor.collapse_shape %5 [[0, 1]] : tensor<2x8xf32> into tensor<16xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %collapsed into %arg2[%arg0] [16] [1] + : tensor<16xf32> into tensor<32xf32> + } + } + flow.dispatch.tensor.store %4, %0, offsets = [0], sizes = [32], strides = [1] : tensor<32xf32> -> !flow.dispatch.tensor> + return +} + +// CHECK-LABEL: func @noexpand_dest_forall_nomapping( +// CHECK: %[[SUBSPAN:.+]] = hal.interface.binding.subspan +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<32xf32> +// CHECK: %[[SCFFORALL:.+]] = scf.forall (%[[ARG0:.+]]) = (0) to (32) step (16) +// CHECK-SAME: shared_outs(%[[ARG2:.+]] = %[[EMPTY]]) -> (tensor<32xf32>) { +// CHECK: flow.dispatch.tensor.store %[[SCFFORALL]], %[[SUBSPAN]] +// CHECK-SAME: offsets = [0], sizes = [32], strides = [1] : tensor<32xf32> +// CHECK-SAME: !flow.dispatch.tensor> + + +// ----- + +#pipeline_layout = #hal.pipeline.layout], flags = Indirect> +func.func @noexpand_dest_forall_notfullslicestore() { + %c0 = arith.constant 0 : index + %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) + flags(Indirect) : !flow.dispatch.tensor> + %2 = tensor.empty() : tensor<32xf32> + %4 = scf.forall (%arg0) = (0) to (32) step (16) + shared_outs(%arg2 = %2) -> (tensor<32xf32>) { + %extracted_slice = tensor.extract_slice %arg2[%arg0] [16] [1] : tensor<32xf32> to tensor<16xf32> + %expanded = tensor.expand_shape %extracted_slice [[0, 1]] output_shape [2, 8] + : tensor<16xf32> into tensor<2x8xf32> + %5 = util.optimization_barrier %expanded : tensor<2x8xf32> + %collapsed = tensor.collapse_shape %5 [[0, 1]] : tensor<2x8xf32> into tensor<16xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %collapsed into %arg2[%arg0] [16] [1] + : tensor<16xf32> into tensor<32xf32> + } + } {mapping = [#iree_codegen.workgroup_mapping]} + flow.dispatch.tensor.store %4, %0, offsets = [1], sizes = [32], strides = [1] : tensor<32xf32> -> !flow.dispatch.tensor> + return +} + +// CHECK-LABEL: func @noexpand_dest_forall_notfullslicestore( +// CHECK: %[[SUBSPAN:.+]] = hal.interface.binding.subspan +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<32xf32> +// CHECK: %[[SCFFORALL:.+]] = scf.forall (%[[ARG0:.+]]) = (0) to (32) step (16) +// CHECK-SAME: shared_outs(%[[ARG2:.+]] = %[[EMPTY]]) -> (tensor<32xf32>) { +// CHECK: flow.dispatch.tensor.store %[[SCFFORALL]], %[[SUBSPAN]] +// CHECK-SAME: offsets = [1], sizes = [32], strides = [1] : tensor<32xf32> +// CHECK-SAME: !flow.dispatch.tensor> diff --git a/compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp b/compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp index 5fe45e9cd862..980925ccb11a 100644 --- a/compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp +++ b/compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp @@ -13,6 +13,7 @@ #include "iree/compiler/Codegen/Transforms/Transforms.h" #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h" +#include "iree/compiler/Codegen/Utils/Utils.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Debug.h" #include "mlir/Analysis/Liveness.h" @@ -36,31 +37,6 @@ namespace mlir::iree_compiler { -static bool isAllConstantValue(ArrayRef ofrs, int64_t v) { - return llvm::all_of( - ofrs, [&](OpFoldResult ofr) { return isConstantIntValue(ofr, v); }); -} - -static bool isFullSlice(ArrayRef mixedOffsets, - ArrayRef mixedSizes, - ArrayRef mixedStrides, - IREE::Flow::DispatchTensorType tensorType, - ValueRange dynamicDims) { - OpBuilder builder(tensorType.getContext()); - SmallVector tensorShape = llvm::to_vector(tensorType.getShape()); - SmallVector mixedTensorShape = - mlir::getMixedValues(tensorShape, dynamicDims, builder); - return isAllConstantValue(mixedOffsets, 0) && - isAllConstantValue(mixedStrides, 1) && mixedTensorShape == mixedSizes; -} -static bool isFullSlice(OffsetSizeAndStrideOpInterface sliceLoadStoreOp, - IREE::Flow::DispatchTensorType tensorType, - ValueRange dynamicDims) { - return isFullSlice( - sliceLoadStoreOp.getMixedOffsets(), sliceLoadStoreOp.getMixedSizes(), - sliceLoadStoreOp.getMixedStrides(), tensorType, dynamicDims); -} - static bool sliceFilter(Operation *op, ValueRange nonIndexComputationOperands, Operation *baseOp) { for (auto val : nonIndexComputationOperands) { diff --git a/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp b/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp index 544e0558ead2..f86f447c49dc 100644 --- a/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp +++ b/compiler/src/iree/compiler/Codegen/Utils/Utils.cpp @@ -1511,4 +1511,26 @@ computeDimUpperBound(Value shapedValue, unsigned dimNum, return roundedDimBound.getSize(); } +static bool isFullSlice(ArrayRef mixedOffsets, + ArrayRef mixedSizes, + ArrayRef mixedStrides, + IREE::Flow::DispatchTensorType tensorType, + ValueRange dynamicDims) { + OpBuilder builder(tensorType.getContext()); + SmallVector tensorShape = llvm::to_vector(tensorType.getShape()); + SmallVector mixedTensorShape = + mlir::getMixedValues(tensorShape, dynamicDims, builder); + return areAllConstantIntValue(mixedOffsets, 0) && + areAllConstantIntValue(mixedStrides, 1) && + mixedTensorShape == mixedSizes; +} + +bool isFullSlice(OffsetSizeAndStrideOpInterface sliceLoadStoreOp, + IREE::Flow::DispatchTensorType tensorType, + ValueRange dynamicDims) { + return isFullSlice( + sliceLoadStoreOp.getMixedOffsets(), sliceLoadStoreOp.getMixedSizes(), + sliceLoadStoreOp.getMixedStrides(), tensorType, dynamicDims); +} + } // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/Utils/Utils.h b/compiler/src/iree/compiler/Codegen/Utils/Utils.h index 7337549d5ec2..4603429afd96 100644 --- a/compiler/src/iree/compiler/Codegen/Utils/Utils.h +++ b/compiler/src/iree/compiler/Codegen/Utils/Utils.h @@ -7,6 +7,7 @@ #ifndef IREE_COMPILER_CODEGEN_UTILS_UTILS_H_ #define IREE_COMPILER_CODEGEN_UTILS_UTILS_H_ +#include "iree/compiler/Dialect/Flow/IR/FlowOps.h" #include "iree/compiler/Dialect/HAL/IR/HALOps.h" #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h" #include "llvm/TargetParser/Triple.h" @@ -251,6 +252,12 @@ computeDimUpperBound(Value shapedValue, unsigned dimNum, std::optional vscaleRange, RoundUpVscaleMultiple = RoundUpVscaleMultiple::No); +// Utility to make sure we are storing the full incoming subspan. Otherwise we +// cannot simply adjust the subspan's resultant type later. +bool isFullSlice(OffsetSizeAndStrideOpInterface sliceLoadStoreOp, + IREE::Flow::DispatchTensorType tensorType, + ValueRange dynamicDims); + } // namespace mlir::iree_compiler #endif // IREE_COMPILER_CODEGEN_UTILS_UTILS_H_