@@ -96,7 +96,7 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
96
96
}
97
97
}
98
98
99
- if (isa<AtenProdOp>(op)) {
99
+ if (isa<AtenProdOp, AtenProdDimIntOp >(op)) {
100
100
if (isa<mlir::FloatType>(elementTy)) {
101
101
APFloat one (cast<mlir::FloatType>(elementTy).getFloatSemantics (), 1 );
102
102
auto constAttr = DenseElementsAttr::get (constType, one);
@@ -172,7 +172,7 @@ static Value createReduceOpWithSingleRegionOp(Operation *op, Value input,
172
172
} else if (isa<AtenAnyOp, AtenAnyDimOp>(op)) {
173
173
result = rewriter.create <stablehlo::OrOp>(
174
174
op->getLoc (), blockArgumentTy, *firstArgument, *secondArgument);
175
- } else if (isa<AtenProdOp>(op)) {
175
+ } else if (isa<AtenProdOp, AtenProdDimIntOp >(op)) {
176
176
result = rewriter.create <stablehlo::MulOp>(
177
177
op->getLoc (), blockArgumentTy, *firstArgument, *secondArgument);
178
178
} else {
@@ -689,6 +689,69 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
689
689
}
690
690
} // namespace
691
691
692
+ // AtenProdDimIntOp
693
+ namespace {
694
+ template <>
695
+ LogicalResult ConvertAtenReductionOp<AtenProdDimIntOp>::matchAndRewrite(
696
+ AtenProdDimIntOp op, OpAdaptor adaptor,
697
+ ConversionPatternRewriter &rewriter) const {
698
+ Value input = adaptor.getSelf ();
699
+ auto inputTy = dyn_cast<RankedTensorType>(input.getType ());
700
+ auto outTy =
701
+ dyn_cast<RankedTensorType>(getTypeConverter ()->convertType (op.getType ()));
702
+ if (!inputTy) {
703
+ return rewriter.notifyMatchFailure (
704
+ op, " only Tensor types supported in StableHLO" );
705
+ }
706
+ if (inputTy.getElementType () != outTy.getElementType ()) {
707
+ // Use output element type as computation type.
708
+ auto dstElemTy = outTy.getElementType ();
709
+ input =
710
+ rewriter.create <stablehlo::ConvertOp>(op->getLoc (), input, dstElemTy);
711
+ inputTy = dyn_cast<RankedTensorType>(input.getType ());
712
+ }
713
+ auto inputElemTy = inputTy.getElementType ();
714
+ if (!inputElemTy.isIntOrFloat ()) {
715
+ return op.emitError (
716
+ " Only floating-point or integer datatype legalization supported" );
717
+ }
718
+
719
+ int64_t dim;
720
+ if (!matchPattern (op.getDim (), m_TorchConstantInt (&dim))) {
721
+ return rewriter.notifyMatchFailure (
722
+ op, " non-const integer `dim` is not supported" );
723
+ }
724
+ dim = toPositiveDim (dim, inputTy.getRank ());
725
+ SmallVector<int64_t > reduceResultShape =
726
+ getReduceOutputShape (inputTy.getShape (), {dim});
727
+
728
+ bool keepDim = false ;
729
+ if (!matchPattern (op.getKeepdim (), m_TorchConstantBool (&keepDim))) {
730
+ return rewriter.notifyMatchFailure (op, " non-bool keepdim unsupported" );
731
+ }
732
+
733
+ Value reduceResult = createReduceOpWithSingleRegionOp (
734
+ op, input,
735
+ RankedTensorType::get (reduceResultShape, outTy.getElementType ()), dim,
736
+ rewriter);
737
+ if (!reduceResult) {
738
+ return op->emitError (" createReduceOpWithSingleRegionOp return nullptr" );
739
+ }
740
+
741
+ if (keepDim) {
742
+ auto outShapeInfo = hlo::getDimIndexOfTensor (rewriter, op, input);
743
+ if (failed (outShapeInfo)) {
744
+ return rewriter.notifyMatchFailure (
745
+ op, " failed to get dimension sizes of the input" );
746
+ }
747
+ reduceResult = reshapeReduceResultWhenKeepDim (
748
+ rewriter, op->getLoc (), reduceResult, *outShapeInfo, outTy, dim);
749
+ }
750
+ rewriter.replaceOp (op, reduceResult);
751
+ return success ();
752
+ }
753
+ } // namespace
754
+
692
755
// AtenFrobeniusNormDimOp
693
756
// aten.frobenius_norm.dim => stablehlo.reduce(calculate square sum along given
694
757
// dims) + stablehlo.sqrt
@@ -868,6 +931,7 @@ void mlir::torch::torch_to_stablehlo::populateReductionOpPatternsAndLegality(
868
931
INSERT_ATEN_REDUCTION_OP_PATTERN (AtenSumDimIntListOp);
869
932
INSERT_ATEN_REDUCTION_OP_PATTERN (AtenFrobeniusNormDimOp);
870
933
INSERT_ATEN_REDUCTION_OP_PATTERN (AtenLinalgVectorNormOp);
934
+ INSERT_ATEN_REDUCTION_OP_PATTERN (AtenProdDimIntOp);
871
935
#undef INSERT_ATEN_REDUCTION_OP_PATTERN
872
936
873
937
#define INSERT_ATEN_REDUCTION_ALL_DIMS_OP_PATTERN (AtenOp ) \
0 commit comments