Skip to content

Commit c675b2f

Browse files
authored
[Stablehlo] Support aten.prod.dim_int (#4198)
1 parent 4b206d7 commit c675b2f

File tree

3 files changed

+99
-3
lines changed

3 files changed

+99
-3
lines changed

lib/Conversion/TorchToStablehlo/Reduction.cpp

Lines changed: 66 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
9696
}
9797
}
9898

99-
if (isa<AtenProdOp>(op)) {
99+
if (isa<AtenProdOp, AtenProdDimIntOp>(op)) {
100100
if (isa<mlir::FloatType>(elementTy)) {
101101
APFloat one(cast<mlir::FloatType>(elementTy).getFloatSemantics(), 1);
102102
auto constAttr = DenseElementsAttr::get(constType, one);
@@ -172,7 +172,7 @@ static Value createReduceOpWithSingleRegionOp(Operation *op, Value input,
172172
} else if (isa<AtenAnyOp, AtenAnyDimOp>(op)) {
173173
result = rewriter.create<stablehlo::OrOp>(
174174
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
175-
} else if (isa<AtenProdOp>(op)) {
175+
} else if (isa<AtenProdOp, AtenProdDimIntOp>(op)) {
176176
result = rewriter.create<stablehlo::MulOp>(
177177
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
178178
} else {
@@ -689,6 +689,69 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
689689
}
690690
} // namespace
691691

692+
// AtenProdDimIntOp
693+
namespace {
694+
template <>
695+
LogicalResult ConvertAtenReductionOp<AtenProdDimIntOp>::matchAndRewrite(
696+
AtenProdDimIntOp op, OpAdaptor adaptor,
697+
ConversionPatternRewriter &rewriter) const {
698+
Value input = adaptor.getSelf();
699+
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
700+
auto outTy =
701+
dyn_cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
702+
if (!inputTy) {
703+
return rewriter.notifyMatchFailure(
704+
op, "only Tensor types supported in StableHLO");
705+
}
706+
if (inputTy.getElementType() != outTy.getElementType()) {
707+
// Use output element type as computation type.
708+
auto dstElemTy = outTy.getElementType();
709+
input =
710+
rewriter.create<stablehlo::ConvertOp>(op->getLoc(), input, dstElemTy);
711+
inputTy = dyn_cast<RankedTensorType>(input.getType());
712+
}
713+
auto inputElemTy = inputTy.getElementType();
714+
if (!inputElemTy.isIntOrFloat()) {
715+
return op.emitError(
716+
"Only floating-point or integer datatype legalization supported");
717+
}
718+
719+
int64_t dim;
720+
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) {
721+
return rewriter.notifyMatchFailure(
722+
op, "non-const integer `dim` is not supported");
723+
}
724+
dim = toPositiveDim(dim, inputTy.getRank());
725+
SmallVector<int64_t> reduceResultShape =
726+
getReduceOutputShape(inputTy.getShape(), {dim});
727+
728+
bool keepDim = false;
729+
if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) {
730+
return rewriter.notifyMatchFailure(op, "non-bool keepdim unsupported");
731+
}
732+
733+
Value reduceResult = createReduceOpWithSingleRegionOp(
734+
op, input,
735+
RankedTensorType::get(reduceResultShape, outTy.getElementType()), dim,
736+
rewriter);
737+
if (!reduceResult) {
738+
return op->emitError("createReduceOpWithSingleRegionOp return nullptr");
739+
}
740+
741+
if (keepDim) {
742+
auto outShapeInfo = hlo::getDimIndexOfTensor(rewriter, op, input);
743+
if (failed(outShapeInfo)) {
744+
return rewriter.notifyMatchFailure(
745+
op, "failed to get dimension sizes of the input");
746+
}
747+
reduceResult = reshapeReduceResultWhenKeepDim(
748+
rewriter, op->getLoc(), reduceResult, *outShapeInfo, outTy, dim);
749+
}
750+
rewriter.replaceOp(op, reduceResult);
751+
return success();
752+
}
753+
} // namespace
754+
692755
// AtenFrobeniusNormDimOp
693756
// aten.frobenius_norm.dim => stablehlo.reduce(calculate square sum along given
694757
// dims) + stablehlo.sqrt
@@ -868,6 +931,7 @@ void mlir::torch::torch_to_stablehlo::populateReductionOpPatternsAndLegality(
868931
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenSumDimIntListOp);
869932
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenFrobeniusNormDimOp);
870933
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenLinalgVectorNormOp);
934+
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenProdDimIntOp);
871935
#undef INSERT_ATEN_REDUCTION_OP_PATTERN
872936

873937
#define INSERT_ATEN_REDUCTION_ALL_DIMS_OP_PATTERN(AtenOp) \

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -823,7 +823,6 @@
823823
"RandnLikeDtypeModule_basic",
824824
"RandnLikeModule_basic",
825825
"RandnModule_basic",
826-
"ReduceProdDimIntFloatModule_basic",
827826
"ReflectionPad1dModule2dInput_Right",
828827
"ReflectionPad1dModule2dInput_basic",
829828
"ReflectionPad1dModule3dInput_Left",
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// RUN: torch-mlir-opt <%s -convert-torch-to-stablehlo -split-input-file -verify-diagnostics | FileCheck %s
2+
3+
// -----
4+
5+
func.func @torch.aten.prod.intdim(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
6+
// CHECK-LABEL: @torch.aten.prod.intdim(
7+
// CHECK: %[[VAL_0:.*]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[?,?,?,?],f32> -> tensor<?x?x?x?xf32>
8+
%int1 = torch.constant.int 1
9+
%false = torch.constant.bool false
10+
%none = torch.constant.none
11+
// CHECK: %[[VAL_1:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<f32>
12+
// CHECK: %[[VAL_2:.*]] = stablehlo.reduce(%[[VAL_0]] init: %[[VAL_1]]) applies stablehlo.multiply across dimensions = [1] : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?xf32>
13+
%0 = torch.aten.prod.dim_int %arg0, %int1, %false, %none : !torch.vtensor<[?,?,?,?],f32>, !torch.int, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?],f32>
14+
// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32>
15+
// CHECK: return %[[VAL_3]] : !torch.vtensor<[?,?,?],f32>
16+
return %0 : !torch.vtensor<[?,?,?],f32>
17+
}
18+
19+
// -----
20+
21+
func.func @torch.aten.prod.intdim_negative_dim(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
22+
// CHECK-LABEL: @torch.aten.prod.intdim_negative_dim(
23+
// CHECK: %[[VAL_0:.*]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[?,?,?,?],f32> -> tensor<?x?x?x?xf32>
24+
%int-1 = torch.constant.int -1
25+
%false = torch.constant.bool false
26+
%none = torch.constant.none
27+
// CHECK: %[[VAL_1:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<f32>
28+
// CHECK: %[[VAL_2:.*]] = stablehlo.reduce(%[[VAL_0]] init: %[[VAL_1]]) applies stablehlo.multiply across dimensions = [3] : (tensor<?x?x?x?xf32>, tensor<f32>) -> tensor<?x?x?xf32>
29+
%0 = torch.aten.prod.dim_int %arg0, %int-1, %false, %none : !torch.vtensor<[?,?,?,?],f32>, !torch.int, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?],f32>
30+
// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32>
31+
// CHECK: return %[[VAL_3]] : !torch.vtensor<[?,?,?],f32>
32+
return %0 : !torch.vtensor<[?,?,?],f32>
33+
}

0 commit comments

Comments
 (0)