@@ -117,7 +117,7 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
117
117
constAttr);
118
118
}
119
119
120
- if (isa<AtenAnyOp, AtenAnyDimOp>(op)) {
120
+ if (isa<AtenAnyOp, AtenAnyDimOp, AtenAnyDimsOp >(op)) {
121
121
auto constAttr =
122
122
DenseElementsAttr::get (constType, {APInt (/* numBits=*/ 1 , 0 )});
123
123
return rewriter.create <stablehlo::ConstantOp>(op->getLoc (), constType,
@@ -169,7 +169,7 @@ static Value createReduceOpWithSingleRegionOp(Operation *op, Value input,
169
169
} else if (isa<AtenAllOp, AtenAllDimOp>(op)) {
170
170
result = rewriter.create <stablehlo::AndOp>(
171
171
op->getLoc (), blockArgumentTy, *firstArgument, *secondArgument);
172
- } else if (isa<AtenAnyOp, AtenAnyDimOp>(op)) {
172
+ } else if (isa<AtenAnyOp, AtenAnyDimOp, AtenAnyDimsOp >(op)) {
173
173
result = rewriter.create <stablehlo::OrOp>(
174
174
op->getLoc (), blockArgumentTy, *firstArgument, *secondArgument);
175
175
} else if (isa<AtenProdOp, AtenProdDimIntOp>(op)) {
@@ -610,6 +610,82 @@ class ConvertAtenReduceWithIndicesOp : public ConvertAtenReductionOp<AtenOpT> {
610
610
};
611
611
} // namespace
612
612
613
+ // AtenAnyDimsOp
614
+ namespace {
615
+ template <>
616
+ LogicalResult ConvertAtenReductionOp<AtenAnyDimsOp>::matchAndRewrite(
617
+ AtenAnyDimsOp op, OpAdaptor adaptor,
618
+ ConversionPatternRewriter &rewriter) const {
619
+ Value input = adaptor.getSelf ();
620
+ auto inputTy = dyn_cast<RankedTensorType>(input.getType ());
621
+ auto outTy =
622
+ dyn_cast<RankedTensorType>(getTypeConverter ()->convertType (op.getType ()));
623
+ if (!inputTy) {
624
+ return rewriter.notifyMatchFailure (
625
+ op, " only Tensor types supported in StableHLO" );
626
+ }
627
+ if (inputTy.getElementType () != outTy.getElementType ()) {
628
+ // Use output element type as computation type.
629
+ auto dstElemTy = outTy.getElementType ();
630
+ input =
631
+ rewriter.create <stablehlo::ConvertOp>(op->getLoc (), input, dstElemTy);
632
+ inputTy = dyn_cast<RankedTensorType>(input.getType ());
633
+ }
634
+ auto inputElemTy = inputTy.getElementType ();
635
+ if (!inputElemTy.isIntOrFloat ()) {
636
+ return op.emitError (
637
+ " Only floating-point or integer datatype legalization supported" );
638
+ }
639
+
640
+ SmallVector<int64_t > inputDims;
641
+ SmallVector<int64_t > dims;
642
+ if (!matchPattern (op.getDim (), m_TorchListOfConstantInts (inputDims))) {
643
+ return rewriter.notifyMatchFailure (
644
+ op, " non-const integer `dim` is not supported" );
645
+ }
646
+ if (inputDims.size () == 0 ) {
647
+ rewriter.replaceOp (op, input);
648
+ return success ();
649
+ }
650
+ for (auto d : inputDims) {
651
+ d = toPositiveDim (d, inputTy.getRank ());
652
+ // Drop invalid dims
653
+ if (isValidDim (d, inputTy.getRank ())) {
654
+ dims.push_back (d);
655
+ }
656
+ }
657
+ llvm::sort (dims.begin (), dims.end ());
658
+
659
+ SmallVector<int64_t > reduceResultShape =
660
+ getReduceOutputShape (inputTy.getShape (), dims);
661
+
662
+ bool keepDim = false ;
663
+ if (!matchPattern (op.getKeepdim (), m_TorchConstantBool (&keepDim))) {
664
+ return rewriter.notifyMatchFailure (op, " non-bool keepdim unsupported" );
665
+ }
666
+
667
+ Value reduceResult = createReduceOpWithSingleRegionOp (
668
+ op, input,
669
+ RankedTensorType::get (reduceResultShape, outTy.getElementType ()), dims,
670
+ rewriter);
671
+ if (!reduceResult) {
672
+ return op->emitError (" createReduceOpWithSingleRegionOp return nullptr" );
673
+ }
674
+
675
+ if (keepDim) {
676
+ auto outShapeInfo = hlo::getDimIndexOfTensor (rewriter, op, input);
677
+ if (failed (outShapeInfo)) {
678
+ return rewriter.notifyMatchFailure (
679
+ op, " failed to get dimension sizes of the input" );
680
+ }
681
+ reduceResult = reshapeReduceResultWhenKeepDim (
682
+ rewriter, op->getLoc (), reduceResult, *outShapeInfo, outTy, dims);
683
+ }
684
+ rewriter.replaceOp (op, reduceResult);
685
+ return success ();
686
+ }
687
+ } // namespace
688
+
613
689
// AtenSumDimIntListOp
614
690
namespace {
615
691
template <>
@@ -928,6 +1004,7 @@ void mlir::torch::torch_to_stablehlo::populateReductionOpPatternsAndLegality(
928
1004
#define INSERT_ATEN_REDUCTION_OP_PATTERN (AtenOp ) \
929
1005
target.addIllegalOp <AtenOp>(); \
930
1006
patterns.add <ConvertAtenReductionOp<AtenOp>>(typeConverter, context, options)
1007
+ INSERT_ATEN_REDUCTION_OP_PATTERN (AtenAnyDimsOp);
931
1008
INSERT_ATEN_REDUCTION_OP_PATTERN (AtenSumDimIntListOp);
932
1009
INSERT_ATEN_REDUCTION_OP_PATTERN (AtenFrobeniusNormDimOp);
933
1010
INSERT_ATEN_REDUCTION_OP_PATTERN (AtenLinalgVectorNormOp);
0 commit comments