diff --git a/compiler/src/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/Patterns.cpp b/compiler/src/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/Patterns.cpp index 0675fa43cd2d..7ca65ab30e64 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/Patterns.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/Patterns.cpp @@ -188,17 +188,16 @@ struct ConvertTensorConcatPattern : public OpRewritePattern { return rewriter.notifyMatchFailure( concatOp, "only outer-dim concat lowering supported"); } - if (cast(concatOp.getInputs().front().getType()) - .getRank() == 0) { - // This should be handled here, but not sure what concat operation does - // when inptus are of rank 0. - return rewriter.notifyMatchFailure( - concatOp, "unhandled concat of zero-rank tensors"); - } + assert(cast(concatOp.getInputs().front().getType()) + .getRank() != 0 && + "concat cannot be of zero-rank tensors"); Location loc = concatOp.getLoc(); SmallVector> inputShapes; inputShapes.reserve(concatOp.getInputs().size()); + // Note the output shape is computed directly without using + // `reifyResultShapes` since we need the `inputShapes` anyway and using the + // method would create duplicate `tensor.dim` operations. SmallVector outputShape; AffineExpr addExpr = rewriter.getAffineSymbolExpr(0) + rewriter.getAffineSymbolExpr(1); diff --git a/compiler/src/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/test/concat.mlir b/compiler/src/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/test/concat.mlir index 9494acccadbf..9e11ca2bc2fd 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/test/concat.mlir +++ b/compiler/src/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/test/concat.mlir @@ -24,3 +24,13 @@ func.func @mixed_concat(%arg0: tensor<2x?xf32>, %arg1 : tensor, %arg2 : // CHECK-SAME: : tensor{%[[ARG1_D0]], %[[ARG1_D1]]} -> %[[UPDATE0]] as tensor{%[[RESULT_D0]], %[[ARG0_D1]]} // CHECK: %[[UPDATE2:.+]] = flow.tensor.update %[[ARG2]], %[[UPDATE1]][%[[OFFSET0]], %[[C0]]] // CHECK-SAME: : tensor<4x?xf32>{%[[ARG2_D1]]} -> %[[UPDATE1]] as tensor{%[[RESULT_D0]], %[[ARG0_D1]]} + +// ----- + +func.func @dont_lower_non_outer_dim_concat(%arg0: tensor<4x?xf32>, %arg1 : tensor, %arg2 : tensor<4x?xf32>) -> tensor { + %0 = tensor.concat dim(1) %arg0, %arg1, %arg2 : (tensor<4x?xf32>, tensor, tensor<4x?xf32>) -> tensor + return %0 : tensor +} +// CHECK-LABEL: func @dont_lower_non_outer_dim_concat +// CHECK: %[[CONCAT:.+]] = tensor.concat +// CHECK: return %[[CONCAT]]