diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp index 8ca036d33a90..f8aa0e8917b2 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp @@ -762,12 +762,12 @@ static bool hasUnfusableUseInDispatch(Value v, Operation *dispatchOp) { for (OpOperand &use : v.getUses()) { Operation *user = use.getOwner(); - // Do not fuse `index_cast` operations if it is already an operand of the - // owner. - if (auto indexCastOp = v.getDefiningOp()) { - if (user == dispatchOp) { - return true; - } + // Do not fuse operations if they are already an operand of the + // owner and have an index return type as that means its a shape + // computation that needs to happen on the host. + if (user == dispatchOp && v.getType().isIndex() && + isa(dispatchOp)) { + return true; } Operation *ownerWorkgroupsOp = diff --git a/compiler/src/iree/compiler/DispatchCreation/test/clone_producers_into_dispatch_regions.mlir b/compiler/src/iree/compiler/DispatchCreation/test/clone_producers_into_dispatch_regions.mlir index 713ae74a2fb9..9b7a146ebfac 100644 --- a/compiler/src/iree/compiler/DispatchCreation/test/clone_producers_into_dispatch_regions.mlir +++ b/compiler/src/iree/compiler/DispatchCreation/test/clone_producers_into_dispatch_regions.mlir @@ -330,7 +330,7 @@ util.func public @clone_broadcast_dequant_op( // ----- // Do no clone index cast operations when they are operands to the dispatch -util.func public @dont_clone_index_cast(%arg0 : i64) { +util.func public @dont_clone_index_type_op(%arg0 : i64) { %0 = arith.index_cast %arg0 : i64 to index %1 = flow.dispatch.region[] -> (tensor{%0}) { %2 = tensor.empty(%0) : tensor @@ -338,7 +338,26 @@ util.func public @dont_clone_index_cast(%arg0 : i64) { } util.return } -// CHECK-LABEL: func public @dont_clone_index_cast +// CHECK-LABEL: func public @dont_clone_index_type_op // CHECK: arith.index_cast // CHECK: flow.dispatch.region // CHECK-NOT: arith.index_cast + +// ----- +// Do no clone index cast operations when they are in-direct operands to the dispatch +#map = affine_map<()[s0] -> (s0 * 12)> +util.func public @dont_clone_index_type_op_2(%arg0: i64) { + %0 = arith.index_cast %arg0 : i64 to index + %1 = affine.apply #map()[%0] + %2 = flow.dispatch.region -> (tensor{%1}) { + %3 = tensor.empty(%1) : tensor + flow.return %3 : tensor + } + util.return +} +// CHECK-LABEL: func public @dont_clone_index_type_op_2 +// CHECK: arith.index_cast +// CHECK: affine.apply +// CHECK: flow.dispatch.region +// CHECK-NOT: arith.index_cast +// CHECK-NOT: affine.apply