Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree/compiler/DispatchCreation/Passes.h"
#include "llvm/Support/DebugLog.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
Expand Down Expand Up @@ -110,6 +111,12 @@ struct SetSplitReductionSizesPass final
IREE::LinalgExt::setSplitReductionAttribute(tilingOp, *tileSizes);
return;
}

// --- Case 4: arg_compare operations ---
if (auto tileSizes = getArgCompareReductionSizes(tilingOp)) {
IREE::LinalgExt::setSplitReductionAttribute(tilingOp, *tileSizes);
return;
}
});
}

Expand Down Expand Up @@ -389,6 +396,39 @@ struct SetSplitReductionSizesPass final

return tileSizes;
}

/// Determine split reduction sizes specifically for arg_compare operations.
std::optional<SmallVector<int64_t>>
getArgCompareReductionSizes(PartialReductionOpInterface op) const {
auto argCompareOp =
dyn_cast<IREE::LinalgExt::ArgCompareOp>(op.getOperation());
if (!argCompareOp) {
return std::nullopt;
}

ShapedType inputType = argCompareOp.getInputType();
ArrayRef<int64_t> inputShape = inputType.getShape();
int64_t reductionDim = argCompareOp.getDimension();
int64_t reductionSize = inputShape[reductionDim];

if (reductionSize == ShapedType::kDynamic) {
return std::nullopt;
}

const int64_t minSizeToSplit = 1024;
if (reductionSize < minSizeToSplit) {
return std::nullopt;
}

int64_t tileSize = findSmallestFactorWithLowerBound(
reductionSize, splitReductionTargetSize)
.value_or(reductionSize);

LDBG() << "arg_compare split: dim=" << reductionDim
<< " size=" << reductionSize << " tile=" << tileSize;

return SmallVector<int64_t>{tileSize};
}
};
} // namespace
} // namespace mlir::iree_compiler::DispatchCreation
Original file line number Diff line number Diff line change
Expand Up @@ -179,3 +179,24 @@ util.func public @arg_compare_negative_outer_dynamic_reduction(

util.return %res_val, %res_idx : tensor<64xf32>, tensor<64xindex>
}

// -----

// CHECK-LABEL: @arg_compare_large_inner_reduction
util.func public @arg_compare_large_inner_reduction(%arg0: tensor<4x1x128256xf16>)
-> tensor<4x1xi32> {
// CHECK: iree_linalg_ext.split_reduction = [1336 : index]
%init_val = tensor.empty() : tensor<4x1xf16>
%init_idx = tensor.empty() : tensor<4x1xi32>

%res:2 = iree_linalg_ext.arg_compare
dimension(2)
ins(%arg0 : tensor<4x1x128256xf16>)
outs(%init_val, %init_idx : tensor<4x1xf16>, tensor<4x1xi32>) {
^bb0(%arg1: f16, %arg2: f16):
%cmp = arith.cmpf ogt, %arg1, %arg2 : f16
iree_linalg_ext.yield %cmp : i1
} -> tensor<4x1xf16>, tensor<4x1xi32>

util.return %res#1 : tensor<4x1xi32>
}
Loading