Skip to content

Commit

Permalink
Address comments.
Browse files Browse the repository at this point in the history
Signed-off-by: MaheshRavishankar <[email protected]>
  • Loading branch information
MaheshRavishankar committed Nov 14, 2024
1 parent 4e326b5 commit 3b79bc7
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -188,17 +188,16 @@ struct ConvertTensorConcatPattern : public OpRewritePattern<tensor::ConcatOp> {
return rewriter.notifyMatchFailure(
concatOp, "only outer-dim concat lowering supported");
}
if (cast<RankedTensorType>(concatOp.getInputs().front().getType())
.getRank() == 0) {
// This should be handled here, but not sure what concat operation does
// when inptus are of rank 0.
return rewriter.notifyMatchFailure(
concatOp, "unhandled concat of zero-rank tensors");
}
assert(cast<RankedTensorType>(concatOp.getInputs().front().getType())
.getRank() != 0 &&
"concat cannot be of zero-rank tensors");

Location loc = concatOp.getLoc();
SmallVector<SmallVector<OpFoldResult>> inputShapes;
inputShapes.reserve(concatOp.getInputs().size());
// Note the output shape is computed directly without using
// `reifyResultShapes` since we need the `inputShapes` anyway and using the
// method would create duplicate `tensor.dim` operations.
SmallVector<OpFoldResult> outputShape;
AffineExpr addExpr =
rewriter.getAffineSymbolExpr(0) + rewriter.getAffineSymbolExpr(1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,13 @@ func.func @mixed_concat(%arg0: tensor<2x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 :
// CHECK-SAME: : tensor<?x?xf32>{%[[ARG1_D0]], %[[ARG1_D1]]} -> %[[UPDATE0]] as tensor<?x?xf32>{%[[RESULT_D0]], %[[ARG0_D1]]}
// CHECK: %[[UPDATE2:.+]] = flow.tensor.update %[[ARG2]], %[[UPDATE1]][%[[OFFSET0]], %[[C0]]]
// CHECK-SAME: : tensor<4x?xf32>{%[[ARG2_D1]]} -> %[[UPDATE1]] as tensor<?x?xf32>{%[[RESULT_D0]], %[[ARG0_D1]]}

// -----

func.func @dont_lower_non_outer_dim_concat(%arg0: tensor<4x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : tensor<4x?xf32>) -> tensor<?x?xf32> {
%0 = tensor.concat dim(1) %arg0, %arg1, %arg2 : (tensor<4x?xf32>, tensor<?x?xf32>, tensor<4x?xf32>) -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
// CHECK-LABEL: func @dont_lower_non_outer_dim_concat
// CHECK: %[[CONCAT:.+]] = tensor.concat
// CHECK: return %[[CONCAT]]

0 comments on commit 3b79bc7

Please sign in to comment.