|  | 
| 6 | 6 | 
 | 
| 7 | 7 | #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h" | 
| 8 | 8 | #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h" | 
|  | 9 | +#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h" | 
| 9 | 10 | #include "iree/compiler/DispatchCreation/Passes.h" | 
| 10 | 11 | #include "llvm/Support/DebugLog.h" | 
| 11 | 12 | #include "mlir/Dialect/Linalg/IR/Linalg.h" | 
| @@ -110,6 +111,12 @@ struct SetSplitReductionSizesPass final | 
| 110 | 111 |         IREE::LinalgExt::setSplitReductionAttribute(tilingOp, *tileSizes); | 
| 111 | 112 |         return; | 
| 112 | 113 |       } | 
|  | 114 | + | 
|  | 115 | +      // --- Case 4: arg_compare operations --- | 
|  | 116 | +      if (auto tileSizes = getArgCompareReductionSizes(tilingOp)) { | 
|  | 117 | +        IREE::LinalgExt::setSplitReductionAttribute(tilingOp, *tileSizes); | 
|  | 118 | +        return; | 
|  | 119 | +      } | 
| 113 | 120 |     }); | 
| 114 | 121 |   } | 
| 115 | 122 | 
 | 
| @@ -389,6 +396,39 @@ struct SetSplitReductionSizesPass final | 
| 389 | 396 | 
 | 
| 390 | 397 |     return tileSizes; | 
| 391 | 398 |   } | 
|  | 399 | + | 
|  | 400 | +  /// Determine split reduction sizes specifically for arg_compare operations. | 
|  | 401 | +  std::optional<SmallVector<int64_t>> | 
|  | 402 | +  getArgCompareReductionSizes(PartialReductionOpInterface op) const { | 
|  | 403 | +    auto argCompareOp = | 
|  | 404 | +        dyn_cast<IREE::LinalgExt::ArgCompareOp>(op.getOperation()); | 
|  | 405 | +    if (!argCompareOp) { | 
|  | 406 | +      return std::nullopt; | 
|  | 407 | +    } | 
|  | 408 | + | 
|  | 409 | +    ShapedType inputType = argCompareOp.getInputType(); | 
|  | 410 | +    ArrayRef<int64_t> inputShape = inputType.getShape(); | 
|  | 411 | +    int64_t reductionDim = argCompareOp.getDimension(); | 
|  | 412 | +    int64_t reductionSize = inputShape[reductionDim]; | 
|  | 413 | + | 
|  | 414 | +    if (reductionSize == ShapedType::kDynamic) { | 
|  | 415 | +      return std::nullopt; | 
|  | 416 | +    } | 
|  | 417 | + | 
|  | 418 | +    const int64_t minSizeToSplit = 1024; | 
|  | 419 | +    if (reductionSize < minSizeToSplit) { | 
|  | 420 | +      return std::nullopt; | 
|  | 421 | +    } | 
|  | 422 | + | 
|  | 423 | +    int64_t tileSize = findSmallestFactorWithLowerBound( | 
|  | 424 | +                           reductionSize, splitReductionTargetSize) | 
|  | 425 | +                           .value_or(reductionSize); | 
|  | 426 | + | 
|  | 427 | +    LDBG() << "arg_compare split: dim=" << reductionDim | 
|  | 428 | +           << " size=" << reductionSize << " tile=" << tileSize; | 
|  | 429 | + | 
|  | 430 | +    return SmallVector<int64_t>{tileSize}; | 
|  | 431 | +  } | 
| 392 | 432 | }; | 
| 393 | 433 | } // namespace | 
| 394 | 434 | } // namespace mlir::iree_compiler::DispatchCreation | 
0 commit comments