Skip to content

Commit 4ba2e34

Browse files
authored
[LinalgExt] Fix reshape fusion crash (#21472)
The upstream reshape builder requires that no change in static/dynamic dims. This causes `getOrCreateExpanded` to fail but at that point the pattern has already modified the IR so it's too late for the pattern to return `failure()` (will cause infinite looping). Instead, exit early when computing expansion info for ops with mixed static & dynamic dims. Related issue for propagating static shape info #21471 Closes #21439 Signed-off-by: Ian Wood <[email protected]>
1 parent 0a3ffcf commit 4ba2e34

File tree

2 files changed

+42
-2
lines changed

2 files changed

+42
-2
lines changed

compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,18 @@ LogicalResult ExpansionInfo::compute(
200200
if (operandReassoc.empty())
201201
return failure();
202202

203+
// Check that the operand dim size matches the iteration space dim size. This
204+
// can fail when one is static and the other is dynamic.
205+
for (const ReshapeOperandInfo &info : infos) {
206+
for (auto [operandDim, iterDim] :
207+
llvm::enumerate(info.operandToIterationSpace)) {
208+
if (iterDim != ReshapeOperandInfo::kNoMapping &&
209+
loopRanges[iterDim] != info.originalShape[operandDim]) {
210+
return failure();
211+
}
212+
}
213+
}
214+
203215
int64_t operandNum = fusableOpOperand->getOperandNumber();
204216
ReshapeOperandInfo &fusionOperandInfo = infos[operandNum];
205217
this->loopShapeMap.clear();
@@ -413,8 +425,10 @@ fuseWithReshapeByExpansion(OpTy op, Operation *reshapeOp,
413425

414426
IRMapping mapping;
415427
for (OpOperand &operand : op->getOpOperands()) {
416-
mapping.map(operand.get(),
417-
info.getOrCreateExpanded(loc, &operand, rewriter).value());
428+
std::optional<Value> maybeNewOperand =
429+
info.getOrCreateExpanded(loc, &operand, rewriter);
430+
assert(maybeNewOperand.has_value());
431+
mapping.map(operand.get(), maybeNewOperand.value());
418432
}
419433

420434
assert(op.getNumDpsInits() == 1);

compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/reshape_fusion.mlir

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,32 @@ util.func public @sink_through_k2(%0 : tensor<128x16x2x2x128xf16>, %1 : tensor<1
482482

483483
// -----
484484

485+
#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
486+
#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>
487+
#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>
488+
#map3 = affine_map<(d0, d1, d2, d3, d4) -> ()>
489+
#map4 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>
490+
491+
util.func public @no_fuse_attention_mixed_static_dynamic(%arg0: tensor<?x4096x16xf16>, %arg1: tensor<20x1024x16xf16>, %arg2: tensor<20x1024x64xf16>, %arg3: f16) -> tensor<2x10x4096x64xf16> {
492+
%0 = tensor.empty() : tensor<20x4096x64xf16>
493+
%1 = iree_linalg_ext.attention {indexing_maps = [#map, #map1, #map2, #map3, #map4]} ins(%arg0, %arg1, %arg2, %arg3 : tensor<?x4096x16xf16>, tensor<20x1024x16xf16>, tensor<20x1024x64xf16>, f16) outs(%0 : tensor<20x4096x64xf16>) {
494+
^bb0(%score: f16):
495+
iree_linalg_ext.yield %score: f16
496+
} -> tensor<20x4096x64xf16>
497+
%expanded = tensor.expand_shape %1 [[0, 1], [2], [3]] output_shape [2, 10, 4096, 64] : tensor<20x4096x64xf16> into tensor<2x10x4096x64xf16>
498+
util.return %expanded : tensor<2x10x4096x64xf16>
499+
}
500+
501+
//CHECK-LABEL: func public @no_fuse_attention_mixed_static_dynamic(
502+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x4096x16xf16>
503+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<20x1024x16xf16>
504+
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<20x1024x64xf16>
505+
// CHECK-SAME: %[[ARG3:.+]]: f16)
506+
// CHECK: %[[ATTENTION:.+]] = iree_linalg_ext.attention
507+
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]] :
508+
// CHECK: tensor.expand_shape %[[ATTENTION]]
509+
510+
485511
util.func @scatter_collapse_updates(%arg0: tensor<4x?x2x16x4x128xf16>, %arg1: tensor<?x1xi32>, %arg2: tensor<?x2x16x4x128xf16>) -> tensor<?x2x16x4x128xf16> {
486512
%collapsed = tensor.collapse_shape %arg0[[0, 1], [2], [3], [4], [5]] : tensor<4x?x2x16x4x128xf16> into tensor<?x2x16x4x128xf16>
487513
%1 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) ins(%collapsed, %arg1 : tensor<?x2x16x4x128xf16>, tensor<?x1xi32>) outs(%arg2 : tensor<?x2x16x4x128xf16>) {

0 commit comments

Comments
 (0)