Skip to content

Commit

Permalink
[mlir][DispatchCreation] Avoid cloning for ops that are index return …
Browse files Browse the repository at this point in the history
…type dispatch operands. (#18377)

Such operations are shape computations that need to happen on host.

Fixes: #18229
  • Loading branch information
nirvedhmeshram authored Aug 28, 2024
1 parent 9f7b25e commit 6e3be28
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -762,12 +762,12 @@ static bool hasUnfusableUseInDispatch(Value v, Operation *dispatchOp) {
for (OpOperand &use : v.getUses()) {
Operation *user = use.getOwner();

// Do not fuse `index_cast` operations if it is already an operand of the
// owner.
if (auto indexCastOp = v.getDefiningOp<arith::IndexCastOp>()) {
if (user == dispatchOp) {
return true;
}
// Do not fuse operations if they are already an operand of the
// owner and have an index return type as that means its a shape
// computation that needs to happen on the host.
if (user == dispatchOp && v.getType().isIndex() &&
isa<IREE::Flow::DispatchRegionOp>(dispatchOp)) {
return true;
}

Operation *ownerWorkgroupsOp =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -330,15 +330,34 @@ util.func public @clone_broadcast_dequant_op(
// -----

// Do no clone index cast operations when they are operands to the dispatch
util.func public @dont_clone_index_cast(%arg0 : i64) {
util.func public @dont_clone_index_type_op(%arg0 : i64) {
%0 = arith.index_cast %arg0 : i64 to index
%1 = flow.dispatch.region[] -> (tensor<?xf32>{%0}) {
%2 = tensor.empty(%0) : tensor<?xf32>
flow.return %2 : tensor<?xf32>
}
util.return
}
// CHECK-LABEL: func public @dont_clone_index_cast
// CHECK-LABEL: func public @dont_clone_index_type_op
// CHECK: arith.index_cast
// CHECK: flow.dispatch.region
// CHECK-NOT: arith.index_cast

// -----
// Do no clone index cast operations when they are in-direct operands to the dispatch
#map = affine_map<()[s0] -> (s0 * 12)>
util.func public @dont_clone_index_type_op_2(%arg0: i64) {
%0 = arith.index_cast %arg0 : i64 to index
%1 = affine.apply #map()[%0]
%2 = flow.dispatch.region -> (tensor<?xf32>{%1}) {
%3 = tensor.empty(%1) : tensor<?xf32>
flow.return %3 : tensor<?xf32>
}
util.return
}
// CHECK-LABEL: func public @dont_clone_index_type_op_2
// CHECK: arith.index_cast
// CHECK: affine.apply
// CHECK: flow.dispatch.region
// CHECK-NOT: arith.index_cast
// CHECK-NOT: affine.apply

0 comments on commit 6e3be28

Please sign in to comment.