Skip to content

Commit d81d7fa

Browse files
committed
[DispatchCreation] Set split reduction size for ArgCompare
Signed-off-by: Bangtian Liu <[email protected]>
1 parent a775bc9 commit d81d7fa

File tree

2 files changed

+62
-0
lines changed

2 files changed

+62
-0
lines changed

compiler/src/iree/compiler/DispatchCreation/SetSplitReductionSizes.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
88
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
9+
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
910
#include "iree/compiler/DispatchCreation/Passes.h"
1011
#include "llvm/Support/DebugLog.h"
1112
#include "mlir/Dialect/Linalg/IR/Linalg.h"
@@ -110,6 +111,12 @@ struct SetSplitReductionSizesPass final
110111
IREE::LinalgExt::setSplitReductionAttribute(tilingOp, *tileSizes);
111112
return;
112113
}
114+
115+
// --- Case 4: arg_compare operations ---
116+
if (auto tileSizes = getArgCompareReductionSizes(tilingOp)) {
117+
IREE::LinalgExt::setSplitReductionAttribute(tilingOp, *tileSizes);
118+
return;
119+
}
113120
});
114121
}
115122

@@ -389,6 +396,40 @@ struct SetSplitReductionSizesPass final
389396

390397
return tileSizes;
391398
}
399+
400+
/// Determine split reduction sizes specifically for arg_compare operations.
401+
std::optional<SmallVector<int64_t>>
402+
getArgCompareReductionSizes(PartialReductionOpInterface op) const {
403+
auto argCompareOp =
404+
dyn_cast<IREE::LinalgExt::ArgCompareOp>(op.getOperation());
405+
if (!argCompareOp) {
406+
return std::nullopt;
407+
}
408+
409+
ShapedType inputType = argCompareOp.getInputType();
410+
ArrayRef<int64_t> inputShape = inputType.getShape();
411+
int64_t reductionDim = argCompareOp.getDimension();
412+
int64_t reductionSize = inputShape[reductionDim];
413+
414+
if (reductionSize == ShapedType::kDynamic) {
415+
return std::nullopt;
416+
}
417+
418+
// Only split if reduction is large enough (empirical threshold)
419+
const int64_t minSizeToSplit = 1024;
420+
if (reductionSize < minSizeToSplit) {
421+
return std::nullopt;
422+
}
423+
424+
int64_t tileSize = findSmallestFactorWithLowerBound(
425+
reductionSize, splitReductionTargetSize)
426+
.value_or(reductionSize);
427+
428+
LDBG() << "arg_compare split: dim=" << reductionDim
429+
<< " size=" << reductionSize << " tile=" << tileSize;
430+
431+
return SmallVector<int64_t>{tileSize};
432+
}
392433
};
393434
} // namespace
394435
} // namespace mlir::iree_compiler::DispatchCreation

compiler/src/iree/compiler/DispatchCreation/test/set_split_reduction_sizes_outer_reduction.mlir

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,3 +179,24 @@ util.func public @arg_compare_negative_outer_dynamic_reduction(
179179

180180
util.return %res_val, %res_idx : tensor<64xf32>, tensor<64xindex>
181181
}
182+
183+
// -----
184+
185+
// CHECK-LABEL: @arg_compare_large_inner_reduction
186+
util.func public @arg_compare_large_inner_reduction(%arg0: tensor<4x1x128256xf16>)
187+
-> tensor<4x1xi32> {
188+
// CHECK: iree_linalg_ext.split_reduction = [1336 : index]
189+
%init_val = tensor.empty() : tensor<4x1xf16>
190+
%init_idx = tensor.empty() : tensor<4x1xi32>
191+
192+
%res:2 = iree_linalg_ext.arg_compare
193+
dimension(2)
194+
ins(%arg0 : tensor<4x1x128256xf16>)
195+
outs(%init_val, %init_idx : tensor<4x1xf16>, tensor<4x1xi32>) {
196+
^bb0(%arg1: f16, %arg2: f16):
197+
%cmp = arith.cmpf ogt, %arg1, %arg2 : f16
198+
iree_linalg_ext.yield %cmp : i1
199+
} -> tensor<4x1xf16>, tensor<4x1xi32>
200+
201+
util.return %res#1 : tensor<4x1xi32>
202+
}

0 commit comments

Comments
 (0)