@@ -862,7 +862,7 @@ namespace {
862862// used in the divisor of the average pooling operator.
863863template <int NumOfDims> class PoolSizeCalculator {
864864public:
865- PoolSizeCalculator (Value self, Value sumPool,
865+ PoolSizeCalculator (Value self, Value sumPool, bool countIncludePad,
866866 ConversionPatternRewriter &rewriter, Location loc);
867867
868868 // The algorithm for computing the divisor with
@@ -877,36 +877,37 @@ template <int NumOfDims> class PoolSizeCalculator {
877877 SmallVectorImpl<int64_t > &paddingInts);
878878
879879private:
880- int64_t DimSizeFromSumPoolType [NumOfDims];
881- Value InputSpatialDimValues [NumOfDims];
880+ int64_t SumPoolTypeDimIndex [NumOfDims];
881+ Value InputSpatialDimSizes [NumOfDims];
882882 Location location;
883+ bool isCountIncludePad;
883884};
884885
885886} // namespace
886887
887888template <int NumOfDims>
888889PoolSizeCalculator<NumOfDims>::PoolSizeCalculator(
889- Value self, Value sumPool, ConversionPatternRewriter &rewriter ,
890- Location loc)
891- : location(loc) {
890+ Value self, Value sumPool, bool countIncludePad ,
891+ ConversionPatternRewriter &rewriter, Location loc)
892+ : location(loc), isCountIncludePad(countIncludePad) {
892893 auto selfType = cast<RankedTensorType>(self.getType ());
893894 const int64_t selfRank = selfType.getRank ();
894895 RankedTensorType sumPoolType = cast<RankedTensorType>(sumPool.getType ());
895896 const int64_t rank = sumPoolType.getRank ();
896897
897898 // Store dimensions in this order:
898- // 0 => width , 1 => height, 2 => depth
899+ // 0 => depth , 1 => height, 2 => width
899900 for (int i = 0 ; i < NumOfDims; ++i) {
900- int64_t DimSizeFromSelfType = toPositiveDim (-(i + 1 ), selfRank);
901- InputSpatialDimValues[i ] =
902- getDimOp (rewriter, location, self, DimSizeFromSelfType );
903- DimSizeFromSumPoolType[i ] = toPositiveDim (-(i + 1 ), rank);
901+ int64_t inputSpatialDimIndex = toPositiveDim (-(i + 1 ), selfRank);
902+ InputSpatialDimSizes[NumOfDims - i - 1 ] =
903+ getDimOp (rewriter, location, self, inputSpatialDimIndex );
904+ SumPoolTypeDimIndex[NumOfDims - i - 1 ] = toPositiveDim (-(i + 1 ), rank);
904905 }
905906}
906907
907908template <int NumOfDims>
908909Value PoolSizeCalculator<NumOfDims>::getPoolSize(
909- OpBuilder &b, SmallVectorImpl<Value> &kernelSizeIntValues ,
910+ OpBuilder &b, SmallVectorImpl<Value> &kernelDimSizes ,
910911 SmallVectorImpl<int64_t > &strideInts,
911912 SmallVectorImpl<int64_t > &paddingInts) {
912913 Value poolSize;
@@ -921,19 +922,20 @@ Value PoolSizeCalculator<NumOfDims>::getPoolSize(
921922 // Dim below stands for spatial dimension. Prior to the February 2025
922923 // change, these variables used "height" and "width" (or "h" and "w")
923924 // in these intermediate variables instead of "Dim".
925+
924926 Value IndexODim =
925927 b.create <linalg::IndexOp>(location,
926- /* value=*/ DimSizeFromSumPoolType [i]);
928+ /* value=*/ SumPoolTypeDimIndex [i]);
927929 Value ODim = castIndexToInt64 (b, location, IndexODim);
928930 Value DDim = b.createOrFold <arith::ConstantOp>(
929931 location, b.getI64IntegerAttr (strideInts[i]));
930932 Value PadDim = b.createOrFold <arith::ConstantOp>(
931933 location, b.getI64IntegerAttr (paddingInts[i]));
932934 Value ODimDDim = b.createOrFold <arith::MulIOp>(location, ODim, DDim);
933935 Value IDim0 = b.createOrFold <arith::SubIOp>(location, ODimDDim, PadDim);
934- Value IDim = castIndexToInt64 (b, location, InputSpatialDimValues [i]);
936+ Value IDim = castIndexToInt64 (b, location, InputSpatialDimSizes [i]);
935937 Value IDim0KDim =
936- b.createOrFold <arith::AddIOp>(location, IDim0, kernelSizeIntValues [i]);
938+ b.createOrFold <arith::AddIOp>(location, IDim0, kernelDimSizes [i]);
937939 Value IDimPadDim = b.createOrFold <arith::AddIOp>(location, IDim, PadDim);
938940 Value IDim1 =
939941 b.createOrFold <arith::MinSIOp>(location, IDim0KDim, IDimPadDim);
@@ -943,11 +945,15 @@ Value PoolSizeCalculator<NumOfDims>::getPoolSize(
943945 Value IDim1Clamped = b.createOrFold <arith::MinSIOp>(location, IDim1, IDim);
944946 Value IDim1_IDim0_Clamped =
945947 b.createOrFold <arith::SubIOp>(location, IDim1Clamped, IDim0Clamped);
948+
949+ Value poolSizeDim =
950+ !isCountIncludePad
951+ ? IDim1_IDim0_Clamped
952+ : b.createOrFold <arith::SubIOp>(location, IDim1, IDim0);
946953 if (i == 0 ) {
947- poolSize = IDim1_IDim0_Clamped ;
954+ poolSize = poolSizeDim ;
948955 } else {
949- poolSize = b.createOrFold <arith::MulIOp>(location, poolSize,
950- IDim1_IDim0_Clamped);
956+ poolSize = b.createOrFold <arith::MulIOp>(location, poolSize, poolSizeDim);
951957 }
952958 }
953959 return poolSize;
@@ -963,26 +969,35 @@ class ConvertAtenAvgPoolOp : public OpConversionPattern<OpTy> {
963969 matchAndRewrite (OpTy op, typename OpTy::Adaptor adaptor,
964970 ConversionPatternRewriter &rewriter) const override ;
965971
966- // Creates the average pooling operation value when the
967- // count_include_pad parameter is equal to false.
968- static std::optional<LogicalResult>
969- createAvgPoolValueCountIncludePadFalseCase (
970- bool countIncludePad, OpTy op, typename OpTy::Adaptor adaptor,
971- ConversionPatternRewriter &rewriter, Value self, Value sumPool,
972- Value outputTensor, Type resultType,
973- SmallVectorImpl<Value> &kernelSizeIntValues,
972+ // If the condition below is true, the divisor total must subtract the
973+ // elements not counted (clamped divisor count). If false, the divisor
974+ // is just the product of kernel dimensions.
975+ static bool
976+ doesAvgPoolDivisorNeedsClamping (bool ceilMode, bool countIncludePad,
977+ SmallVectorImpl<int64_t > &strideInts,
978+ SmallVectorImpl<int64_t > &paddingInts);
979+
980+ // Creates the average pooling operation value with a clamped
981+ // divisor. The clamped divisor is the product of kernel
982+ // dimensions minus the elements not counted; e.g., padding
983+ // and ceiling mode implicit padding.
984+ static LogicalResult createAveragePoolValueWithClampedDivisor (
985+ bool ceilMode, bool countIncludePad, OpTy op,
986+ typename OpTy::Adaptor adaptor, ConversionPatternRewriter &rewriter,
987+ Value self, Value sumPool, Value outputTensor, Type resultType,
988+ SmallVectorImpl<Value> &kernelDimSizes,
974989 SmallVectorImpl<int64_t > &strideInts,
975990 SmallVectorImpl<int64_t > &paddingInts,
976991 SmallVector<AffineMap> &indexingMapsAvg,
977992 SmallVector<utils::IteratorType> &iteratorTypesAvg);
978993
979- // Creates the average pooling operation value when the
980- // count_include_pad parameter is equal to true .
981- static LogicalResult createAvgPoolValueCountIncludePadTrueCase (
994+ // Creates the average pooling operation value with a
995+ // regular divisor; i.e., the product of kernel dimensions .
996+ static LogicalResult createAveragePoolValueWithRegularDivisor (
982997 OpTy op, typename OpTy::Adaptor &adaptor,
983998 ConversionPatternRewriter &rewriter, Value self, Value sumPool,
984999 Value outputTensor, Type resultType,
985- SmallVectorImpl<Value> &kernelSizeIntValues ,
1000+ SmallVectorImpl<Value> &kernelDimSizes ,
9861001 SmallVector<AffineMap> &indexingMapsAvg,
9871002 SmallVector<utils::IteratorType> &iteratorTypesAvg);
9881003};
@@ -1046,27 +1061,64 @@ LogicalResult ConvertAtenAvgPoolOp<OpTy, PoolingOpTy, Dim>::matchAndRewrite(
10461061 SmallVector<utils::IteratorType> iteratorTypesAvg (
10471062 Dim + 2 , utils::IteratorType::parallel);
10481063
1049- auto divisorOpResult = createAvgPoolValueCountIncludePadFalseCase (
1050- countIncludePad, op, adaptor, rewriter, self, sumPool, outputTensor,
1051- resultType, kernelSizeIntValues, strideInts, paddingInts, indexingMapsAvg,
1052- iteratorTypesAvg);
1053- if (divisorOpResult)
1054- return *divisorOpResult;
1064+ if (doesAvgPoolDivisorNeedsClamping (ceilMode, countIncludePad, strideInts,
1065+ paddingInts)) {
1066+ return createAveragePoolValueWithClampedDivisor (
1067+ ceilMode, countIncludePad, op, adaptor, rewriter, self, sumPool,
1068+ outputTensor, resultType, kernelSizeIntValues, strideInts, paddingInts,
1069+ indexingMapsAvg, iteratorTypesAvg);
1070+ }
10551071
1056- return createAvgPoolValueCountIncludePadTrueCase (
1072+ return createAveragePoolValueWithRegularDivisor (
10571073 op, adaptor, rewriter, self, sumPool, outputTensor, resultType,
10581074 kernelSizeIntValues, indexingMapsAvg, iteratorTypesAvg);
1075+ }
10591076
1060- return success ();
1077+ template <typename OpTy, typename PoolingOpTy, int Dim>
1078+ bool ConvertAtenAvgPoolOp<OpTy, PoolingOpTy, Dim>::
1079+ doesAvgPoolDivisorNeedsClamping (bool ceilMode, bool countIncludePad,
1080+ SmallVectorImpl<int64_t > &strideInts,
1081+ SmallVectorImpl<int64_t > &paddingInts) {
1082+ // Determines whether the average pooling divisor needs to be clamped
1083+ // (i.e., adjusted to exclude padded or out-of-bounds elements).
1084+ //
1085+ // There are two primary cases where clamping is needed:
1086+ // 1. Padding with count_include_pad == false:
1087+ // - If padding is applied (padding != 0) and count_include_pad is false,
1088+ // then padding elements are *excluded* from the divisor, effectively
1089+ // clamping the divisor to the number of valid input elements.
1090+ //
1091+ // 2. Ceil mode with non-unit stride:
1092+ // - When ceil_mode is enabled, output dimensions are rounded up,
1093+ // potentially
1094+ // creating pooling windows that extend beyond the input tensor bounds.
1095+ // PyTorch handles this by implicitly adding zero-padding outside the
1096+ // tensor, but these extra (implicit) padded elements are *not* included
1097+ // in the divisor. This behavior is independent of the count_include_pad
1098+ // flag.
1099+ // - If all strides are 1, ceil_mode will not produce fractional divisions,
1100+ // so the windows will not extend beyond bounds, and no clamping occurs.
1101+ //
1102+ // Reference: PyTorch AvgPool2d documentation and formula for H_out/W_out:
1103+ // https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html
1104+ //
1105+ // See torch.nn.AvgPool2d E2E tests for comprehensive coverage.
1106+
1107+ bool hasPadding =
1108+ !llvm::all_of (paddingInts, [](int64_t p) { return p == 0 ; });
1109+ bool allStridesUnitary =
1110+ llvm::all_of (strideInts, [](int64_t s) { return s == 1 ; });
1111+
1112+ return (!countIncludePad && hasPadding) || (ceilMode && !allStridesUnitary);
10611113}
10621114
10631115template <typename OpTy, typename PoolingOpTy, int Dim>
1064- std::optional< LogicalResult> ConvertAtenAvgPoolOp<OpTy, PoolingOpTy, Dim>::
1065- createAvgPoolValueCountIncludePadFalseCase (
1066- bool countIncludePad, OpTy op, typename OpTy::Adaptor adaptor ,
1067- ConversionPatternRewriter &rewriter, Value self, Value sumPool ,
1068- Value outputTensor, Type resultType,
1069- SmallVectorImpl<Value> &kernelSizeIntValues ,
1116+ LogicalResult ConvertAtenAvgPoolOp<OpTy, PoolingOpTy, Dim>::
1117+ createAveragePoolValueWithClampedDivisor (
1118+ bool ceilMode, bool countIncludePad, OpTy op ,
1119+ typename OpTy::Adaptor adaptor, ConversionPatternRewriter &rewriter ,
1120+ Value self, Value sumPool, Value outputTensor, Type resultType,
1121+ SmallVectorImpl<Value> &kernelDimSizes ,
10701122 SmallVectorImpl<int64_t > &strideInts,
10711123 SmallVectorImpl<int64_t > &paddingInts,
10721124 SmallVector<AffineMap> &indexingMapsAvg,
@@ -1075,11 +1127,6 @@ std::optional<LogicalResult> ConvertAtenAvgPoolOp<OpTy, PoolingOpTy, Dim>::
10751127
10761128 constexpr int avgPoolDims = getAvgPoolNumOfDims<OpTy>();
10771129
1078- bool noPadding = llvm::all_of (paddingInts, [](int64_t p) { return p == 0 ; });
1079- if (countIncludePad || noPadding) {
1080- // These cases are not handled here.
1081- return std::nullopt ;
1082- }
10831130 if (avgPoolDims < 1 ) {
10841131 return rewriter.notifyMatchFailure (
10851132 op, " Unexpected type. Only expected AtenAvgPool1dOp, AtenAvgPool2dOp, "
@@ -1088,8 +1135,8 @@ std::optional<LogicalResult> ConvertAtenAvgPoolOp<OpTy, PoolingOpTy, Dim>::
10881135
10891136 Type resultElementType = cast<RankedTensorType>(resultType).getElementType ();
10901137
1091- PoolSizeCalculator<avgPoolDims> poolSizeCalculator (self, sumPool, rewriter,
1092- loc);
1138+ PoolSizeCalculator<avgPoolDims> poolSizeCalculator (
1139+ self, sumPool, countIncludePad, rewriter, loc);
10931140
10941141 // AtenAvgPool2/3dOp has an optional divisor_override
10951142 // attribute while AtenAvgPool1dOp does not.
@@ -1110,7 +1157,7 @@ std::optional<LogicalResult> ConvertAtenAvgPoolOp<OpTy, PoolingOpTy, Dim>::
11101157 [&](OpBuilder &b, Location loc, ValueRange args) {
11111158 if (!poolSize) {
11121159 poolSize = poolSizeCalculator.getPoolSize (
1113- b, kernelSizeIntValues , strideInts, paddingInts);
1160+ b, kernelDimSizes , strideInts, paddingInts);
11141161 }
11151162 Value divisor =
11161163 convertScalarToDtype (b, loc, poolSize, resultElementType);
@@ -1128,21 +1175,21 @@ std::optional<LogicalResult> ConvertAtenAvgPoolOp<OpTy, PoolingOpTy, Dim>::
11281175
11291176template <typename OpTy, typename PoolingOpTy, int Dim>
11301177LogicalResult ConvertAtenAvgPoolOp<OpTy, PoolingOpTy, Dim>::
1131- createAvgPoolValueCountIncludePadTrueCase (
1178+ createAveragePoolValueWithRegularDivisor (
11321179 OpTy op, typename OpTy::Adaptor &adaptor,
11331180 ConversionPatternRewriter &rewriter, Value self, Value sumPool,
11341181 Value outputTensor, Type resultType,
1135- SmallVectorImpl<Value> &kernelSizeIntValues ,
1182+ SmallVectorImpl<Value> &kernelDimSizes ,
11361183 SmallVector<AffineMap> &indexingMapsAvg,
11371184 SmallVector<utils::IteratorType> &iteratorTypesAvg) {
11381185 Location loc = op->getLoc ();
11391186
11401187 Type resultElementType = cast<RankedTensorType>(resultType).getElementType ();
11411188
1142- Value divisor = kernelSizeIntValues [0 ];
1143- for (uint32_t i = 1 ; i < kernelSizeIntValues .size (); ++i) {
1144- divisor = rewriter. createOrFold <arith::MulIOp>(loc, divisor,
1145- kernelSizeIntValues [i]);
1189+ Value divisor = kernelDimSizes [0 ];
1190+ for (uint32_t i = 1 ; i < kernelDimSizes .size (); ++i) {
1191+ divisor =
1192+ rewriter. createOrFold <arith::MulIOp>(loc, divisor, kernelDimSizes [i]);
11461193 }
11471194 // Only average pooling 2D/3D have optional divisor override.
11481195 if constexpr (!std::is_same<OpTy, AtenAvgPool1dOp>()) {
0 commit comments