Skip to content

Commit

Permalink
Reapply "[Flow] Convert from tensor.cast to flow.tensor.reshape early …(
Browse files Browse the repository at this point in the history
#18256)" (#18331)

This reverts commit 8da4564.
  • Loading branch information
nirvedhmeshram committed Aug 26, 2024
1 parent 7a7bfe1 commit 3818827
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -326,4 +326,9 @@ void populateTensorToFlowConversionPatterns(MLIRContext *context,
ConvertTensorReshapePattern<tensor::ExpandShapeOp>>(context);
}

void populateTensorDialectCastOpPattern(MLIRContext *context,
RewritePatternSet &patterns) {
patterns.insert<ConvertTensorCastPattern>(context);
}

} // namespace mlir::iree_compiler::IREE::Flow
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ namespace mlir::iree_compiler::IREE::Flow {
void populateTensorToFlowConversionPatterns(MLIRContext *context,
RewritePatternSet &patterns);

// Add pattern to convert tensor.cast -> flow.tensor.reshape.
void populateTensorDialectCastOpPattern(MLIRContext *context,
RewritePatternSet &patterns);

} // namespace mlir::iree_compiler::IREE::Flow

#endif // IREE_COMPILER_DIALECT_FLOW_CONVERSION_TENSORTOFLOW_PATTERNS_H_
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Dialect/Flow/Conversion/TensorToFlow/Patterns.h"
#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"

#include "mlir/Dialect/Affine/IR/AffineOps.h"
Expand Down Expand Up @@ -103,6 +104,11 @@ struct CanonicalizerPass
CanonicalizerPass>::CanonicalizerPassBase;
/// Initialize the canonicalizer by building the set of patterns used during
/// execution.

void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<IREE::Flow::FlowDialect>();
}

LogicalResult initialize(MLIRContext *context) override {
// Inherit the same config defaults from the upstream canonicalizer pass.
config.useTopDownTraversal = true;
Expand All @@ -117,6 +123,7 @@ struct CanonicalizerPass
// Pull in some borderline/downstream canonicalizations for the Flow
// compilation phase.
tensor::populateMergeConsecutiveInsertExtractSlicePatterns(owningPatterns);
IREE::Flow::populateTensorDialectCastOpPattern(context, owningPatterns);
owningPatterns.add<FoldConsecutiveConstantPadding>(context);

patterns =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,27 @@ util.func public @dont_merge_constant_padding_different_vals(
// CHECK-LABEL: util.func public @dont_merge_constant_padding_different_vals
// CHECK: tensor.pad
// CHECK: tensor.pad

// -----

util.func public @tensor_cast_to_reshape(%reshape_17 : tensor<?x?x?x?xf32>, %65 : tensor<?x12x?x64xf32>, %0 : index, %1 : index) -> tensor<?x?x?x?xf32> {
%cast = tensor.cast %reshape_17 : tensor<?x?x?x?xf32> to tensor<?x?x12x64xf32>
%66 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>],
iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
ins(%cast : tensor<?x?x12x64xf32>) outs(%65 : tensor<?x12x?x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<?x12x?x64xf32>
%cast_18 = tensor.cast %66 : tensor<?x12x?x64xf32> to tensor<?x?x?x?xf32>
util.return %cast_18 : tensor<?x?x?x?xf32>
}

// CHECK-LABEL: util.func public @tensor_cast_to_reshape
// CHECK: flow.tensor.reshape
// CHECK-SAME: tensor<?x?x?x?xf32>
// CHECK-SAME: -> tensor<?x?x12x64xf32>
// CHECK: linalg.generic
// CHECK: flow.tensor.reshape
// CHECK-SAME: tensor<?x12x?x64xf32>
// CHECK-SAME: -> tensor<?x?x?x?xf32>
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
// RUN: iree-opt --split-input-file --verify-diagnostics --pass-pipeline="builtin.module(util.func(iree-dispatch-creation-form-dispatch-regions{aggressive-fusion=true}, iree-dispatch-creation-clone-producers-into-dispatch-regions, iree-dispatch-creation-convert-dispatch-regions-to-workgroups, iree-dispatch-creation-convert-tensor-to-flow, canonicalize, iree-dispatch-creation-materialize-default-workgroup-count-region), cse, iree-flow-canonicalize, cse)" %s | FileCheck %s

// RUN: iree-opt --split-input-file --verify-diagnostics --pass-pipeline="builtin.module(util.func(iree-flow-canonicalize,iree-dispatch-creation-form-dispatch-regions{aggressive-fusion=true}, iree-dispatch-creation-clone-producers-into-dispatch-regions, iree-dispatch-creation-convert-dispatch-regions-to-workgroups, iree-dispatch-creation-convert-tensor-to-flow, canonicalize, iree-dispatch-creation-materialize-default-workgroup-count-region), cse, iree-flow-canonicalize, cse)" %s | FileCheck %s
util.func public @tile_matmul_alone(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> {
%1 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
Expand Down Expand Up @@ -232,17 +231,17 @@ util.func public @always_fuse_cast
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<4x?xf32>
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index
// CHECK-DAG: %[[M:.+]] = tensor.dim %[[ARG0]], %[[C0]]
// CHECK: %[[RESHAPE:.*]] = flow.tensor.reshape %[[ARG0]]
// CHECK-SAME: tensor<?x?xf32>{%[[M]], %[[C4]]} -> tensor<?x4xf32>{%[[M]]}
// CHECK-DAG: %[[N1:.+]] = tensor.dim %[[ARG1]], %[[C1]]
// CHECK-DAG: %[[K:.+]] = tensor.dim %[[ARG0]], %[[C1]]
// CHECK: %[[RESULT1:.+]] = flow.dispatch.workgroups[%[[M]], %[[K]], %[[N1]]]
// CHECK-SAME: (%[[ARG0]], %[[ARG1]], %[[M]], %[[K]], %[[N1]])
// CHECK: tensor.cast
// CHECK: %[[RESULT1:.+]] = flow.dispatch.workgroups[%[[N1]], %[[M]]]
// CHECK-SAME: (%[[RESHAPE]], %[[ARG1]], %[[N1]], %[[M]])
// CHECK: flow.return
// CHECK-DAG: %[[N2:.+]] = tensor.dim %[[ARG2]], %[[C1]]
// CHECK: %[[RESULT2:.+]] = flow.dispatch.workgroups[%[[M]], %[[K]], %[[N2]]]
// CHECK-SAME: (%[[ARG0]], %[[ARG2]], %[[M]], %[[K]], %[[N2]])
// CHECK: tensor.cast
// CHECK: %[[RESULT2:.+]] = flow.dispatch.workgroups[%[[N2]], %[[M]]]
// CHECK-SAME: (%[[RESHAPE]], %[[ARG2]], %[[N2]], %[[M]])
// CHECK: flow.return
// CHECK: util.return %[[RESULT1]], %[[RESULT2]]

Expand Down Expand Up @@ -513,26 +512,21 @@ util.func public @inline_dag_1(
// CHECK-NOT: linalg.
// CHECK-NOT: tensor.extract_slice
// CHECK: flow.dispatch.workgroups
// CHECK-NEXT: %[[ARG4:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor<readonly:tensor<?x?xf32>>
// CHECK-SAME: %[[ARG5:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor<readonly:tensor<?x?xf32>>
// CHECK-SAME: %[[ARG6:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor<readonly:tensor<i32>>
// CHECK-NEXT: %[[ARG4:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor<readonly:tensor<1x?xf32>>
// CHECK-SAME: %[[ARG5:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor<readonly:tensor<i32>>
// CHECK-SAME: %[[ARG6:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor<readonly:tensor<1x?xf32>>
// CHECK-SAME: %[[ARG7:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[ARG8:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[ARG9:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[ARG10:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[ARG11:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[ARG12:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor<writeonly:tensor<1x?xf32>>
// CHECK: %[[LEAF1:.+]] = flow.dispatch.tensor.load %[[ARG4]]
// CHECK: %[[LEAF2:.+]] = flow.dispatch.tensor.load %[[ARG5]]
// CHECK: %[[LEAF3:.+]] = flow.dispatch.tensor.load %[[ARG6]]
// CHECK: %[[INIT:.+]] = tensor.empty
// CHECK-DAG: %[[OP1:.+]] = tensor.cast %[[LEAF1]]
// CHECK-DAG: %[[OP2:.+]] = tensor.cast %[[LEAF2]]
// CHECK-DAG: %[[OP3:.+]] = tensor.extract_slice %[[OP1]][0, 0]
// CHECK-DAG: %[[OP4:.+]] = tensor.extract_slice %[[OP1]][0, 10]
// CHECK-DAG: %[[OP5:.+]] = tensor.extract_slice %[[OP1]][0, 20]
// CHECK-SAME: %[[ARG10:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor<writeonly:tensor<1x?xf32>>
// CHECK-DAG: %[[LEAF1:.+]] = flow.dispatch.tensor.load %[[ARG4]], offsets = [0, 0]
// CHECK-DAG: %[[LEAF2:.+]] = flow.dispatch.tensor.load %[[ARG4]], offsets = [0, 10]
// CHECK-DAG: %[[LEAF3:.+]] = flow.dispatch.tensor.load %[[ARG4]], offsets = [0, 20]
// CHECK-DAG: %[[LEAF4:.+]] = flow.dispatch.tensor.load %[[ARG5]]
// CHECK-DAG: %[[LEAF5:.+]] = flow.dispatch.tensor.load %[[ARG6]]
// CHECK-DAG: %[[INIT:.+]] = tensor.empty
// CHECK: linalg.generic
// CHECK-SAME: ins(%[[LEAF3]], %[[OP5]], %[[OP2]], %[[OP4]], %[[OP3]] :
// CHECK-SAME: ins(%[[LEAF4]], %[[LEAF3]], %[[LEAF5]], %[[LEAF2]], %[[LEAF1]] :
// CHECK-SAME: outs(%[[INIT]] :

// -----
Expand Down Expand Up @@ -573,24 +567,21 @@ util.func public @inline_dag_2(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<1x?xf32>
// CHECK: flow.dispatch.workgroups
// CHECK-NEXT: %[[ARG4:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor<readonly:tensor<?x?xf32>>
// CHECK-SAME: %[[ARG5:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor<readonly:tensor<1x?xf32>>
// CHECK-SAME: %[[ARG6:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor<readonly:tensor<i32>>
// CHECK-NEXT: %[[ARG4:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor<readonly:tensor<1x?xf32>>
// CHECK-SAME: %[[ARG5:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor<readonly:tensor<i32>>
// CHECK-SAME: %[[ARG6:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor<readonly:tensor<1x?xf32>>
// CHECK-SAME: %[[ARG7:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[ARG8:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[ARG9:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[ARG10:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[ARG11:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor<writeonly:tensor<1x?xf32>>
// CHECK: %[[LEAF1:.+]] = flow.dispatch.tensor.load %[[ARG4]], {{.*}}
// CHECK: %[[LEAF2:.+]] = flow.dispatch.tensor.load %[[ARG5]], {{.*}}
// CHECK: %[[LEAF3:.+]] = flow.dispatch.tensor.load %[[ARG6]], {{.*}}
// CHECK: %[[INIT:.+]] = tensor.empty
// CHECK-DAG: %[[OP1:.+]] = tensor.cast %[[LEAF1]]
// CHECK-DAG: %[[OP3:.+]] = tensor.extract_slice %[[OP1]][0, 0]
// CHECK-DAG: %[[OP4:.+]] = tensor.extract_slice %[[OP1]][0, 10]
// CHECK-DAG: %[[OP5:.+]] = tensor.extract_slice %[[OP1]][0, 20]
// CHECK-SAME: %[[ARG10:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor<writeonly:tensor<1x?xf32>>
// CHECK-DAG: %[[LEAF1:.+]] = flow.dispatch.tensor.load %[[ARG4]], offsets = [0, 0]
// CHECK-DAG: %[[LEAF2:.+]] = flow.dispatch.tensor.load %[[ARG4]], offsets = [0, 10]
// CHECK-DAG: %[[LEAF3:.+]] = flow.dispatch.tensor.load %[[ARG4]], offsets = [0, 20]
// CHECK-DAG: %[[LEAF4:.+]] = flow.dispatch.tensor.load %[[ARG5]], {{.*}}
// CHECK-DAG: %[[LEAF5:.+]] = flow.dispatch.tensor.load %[[ARG6]], {{.*}}
// CHECK-DAG: %[[INIT:.+]] = tensor.empty
// CHECK: linalg.generic
// CHECK-SAME: ins(%[[LEAF3]], %[[OP5]], %[[LEAF2]], %[[OP4]], %[[OP3]] :
// CHECK-SAME: ins(%[[LEAF4]], %[[LEAF3]], %[[LEAF5]], %[[LEAF2]], %[[LEAF1]] :
// CHECK-SAME: outs(%[[INIT]] :

// -----
Expand Down

0 comments on commit 3818827

Please sign in to comment.