Skip to content

Commit dcf9bbf

Browse files
authored
[Stablehlo] Add conversion for AtenAnyDimsOp (#4223)
1 parent f726c81 commit dcf9bbf

File tree

2 files changed

+79
-3
lines changed

2 files changed

+79
-3
lines changed

lib/Conversion/TorchToStablehlo/Reduction.cpp

Lines changed: 79 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
117117
constAttr);
118118
}
119119

120-
if (isa<AtenAnyOp, AtenAnyDimOp>(op)) {
120+
if (isa<AtenAnyOp, AtenAnyDimOp, AtenAnyDimsOp>(op)) {
121121
auto constAttr =
122122
DenseElementsAttr::get(constType, {APInt(/*numBits=*/1, 0)});
123123
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
@@ -169,7 +169,7 @@ static Value createReduceOpWithSingleRegionOp(Operation *op, Value input,
169169
} else if (isa<AtenAllOp, AtenAllDimOp>(op)) {
170170
result = rewriter.create<stablehlo::AndOp>(
171171
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
172-
} else if (isa<AtenAnyOp, AtenAnyDimOp>(op)) {
172+
} else if (isa<AtenAnyOp, AtenAnyDimOp, AtenAnyDimsOp>(op)) {
173173
result = rewriter.create<stablehlo::OrOp>(
174174
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
175175
} else if (isa<AtenProdOp, AtenProdDimIntOp>(op)) {
@@ -610,6 +610,82 @@ class ConvertAtenReduceWithIndicesOp : public ConvertAtenReductionOp<AtenOpT> {
610610
};
611611
} // namespace
612612

613+
// AtenAnyDimsOp
614+
namespace {
615+
template <>
616+
LogicalResult ConvertAtenReductionOp<AtenAnyDimsOp>::matchAndRewrite(
617+
AtenAnyDimsOp op, OpAdaptor adaptor,
618+
ConversionPatternRewriter &rewriter) const {
619+
Value input = adaptor.getSelf();
620+
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
621+
auto outTy =
622+
dyn_cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
623+
if (!inputTy) {
624+
return rewriter.notifyMatchFailure(
625+
op, "only Tensor types supported in StableHLO");
626+
}
627+
if (inputTy.getElementType() != outTy.getElementType()) {
628+
// Use output element type as computation type.
629+
auto dstElemTy = outTy.getElementType();
630+
input =
631+
rewriter.create<stablehlo::ConvertOp>(op->getLoc(), input, dstElemTy);
632+
inputTy = dyn_cast<RankedTensorType>(input.getType());
633+
}
634+
auto inputElemTy = inputTy.getElementType();
635+
if (!inputElemTy.isIntOrFloat()) {
636+
return op.emitError(
637+
"Only floating-point or integer datatype legalization supported");
638+
}
639+
640+
SmallVector<int64_t> inputDims;
641+
SmallVector<int64_t> dims;
642+
if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(inputDims))) {
643+
return rewriter.notifyMatchFailure(
644+
op, "non-const integer `dim` is not supported");
645+
}
646+
if (inputDims.size() == 0) {
647+
rewriter.replaceOp(op, input);
648+
return success();
649+
}
650+
for (auto d : inputDims) {
651+
d = toPositiveDim(d, inputTy.getRank());
652+
// Drop invalid dims
653+
if (isValidDim(d, inputTy.getRank())) {
654+
dims.push_back(d);
655+
}
656+
}
657+
llvm::sort(dims.begin(), dims.end());
658+
659+
SmallVector<int64_t> reduceResultShape =
660+
getReduceOutputShape(inputTy.getShape(), dims);
661+
662+
bool keepDim = false;
663+
if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) {
664+
return rewriter.notifyMatchFailure(op, "non-bool keepdim unsupported");
665+
}
666+
667+
Value reduceResult = createReduceOpWithSingleRegionOp(
668+
op, input,
669+
RankedTensorType::get(reduceResultShape, outTy.getElementType()), dims,
670+
rewriter);
671+
if (!reduceResult) {
672+
return op->emitError("createReduceOpWithSingleRegionOp return nullptr");
673+
}
674+
675+
if (keepDim) {
676+
auto outShapeInfo = hlo::getDimIndexOfTensor(rewriter, op, input);
677+
if (failed(outShapeInfo)) {
678+
return rewriter.notifyMatchFailure(
679+
op, "failed to get dimension sizes of the input");
680+
}
681+
reduceResult = reshapeReduceResultWhenKeepDim(
682+
rewriter, op->getLoc(), reduceResult, *outShapeInfo, outTy, dims);
683+
}
684+
rewriter.replaceOp(op, reduceResult);
685+
return success();
686+
}
687+
} // namespace
688+
613689
// AtenSumDimIntListOp
614690
namespace {
615691
template <>
@@ -928,6 +1004,7 @@ void mlir::torch::torch_to_stablehlo::populateReductionOpPatternsAndLegality(
9281004
#define INSERT_ATEN_REDUCTION_OP_PATTERN(AtenOp) \
9291005
target.addIllegalOp<AtenOp>(); \
9301006
patterns.add<ConvertAtenReductionOp<AtenOp>>(typeConverter, context, options)
1007+
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenAnyDimsOp);
9311008
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenSumDimIntListOp);
9321009
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenFrobeniusNormDimOp);
9331010
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenLinalgVectorNormOp);

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -824,7 +824,6 @@
824824
"RandnLikeDtypeModule_basic",
825825
"RandnLikeModule_basic",
826826
"RandnModule_basic",
827-
"ReduceAnyDimsFloatModule_basic",
828827
"ReflectionPad1dModule2dInput_Right",
829828
"ReflectionPad1dModule2dInput_basic",
830829
"ReflectionPad1dModule3dInput_Left",

0 commit comments

Comments
 (0)