Skip to content

Commit 1ad9702

Browse files
authored
[Torch] Add support for aten.any.dims (#4200)
* Added lowering to Linalg-on-Tensors * Added test to projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py
1 parent 0f95851 commit 1ad9702

File tree

8 files changed

+100
-2
lines changed

8 files changed

+100
-2
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10940,6 +10940,32 @@ def Torch_AtenAnyDimOp : Torch_Op<"aten.any.dim", [
1094010940
}];
1094110941
}
1094210942

10943+
def Torch_AtenAnyDimsOp : Torch_Op<"aten.any.dims", [
10944+
AllowsTypeRefinement,
10945+
HasValueSemantics,
10946+
ReadOnly
10947+
]> {
10948+
let summary = "Generated op for `aten::any.dims : (Tensor, int[]?, bool) -> (Tensor)`";
10949+
let arguments = (ins
10950+
AnyTorchTensorType:$self,
10951+
AnyTorchOptionalListOfTorchIntType:$dim,
10952+
Torch_BoolType:$keepdim
10953+
);
10954+
let results = (outs
10955+
AnyTorchOptionalTensorType:$result
10956+
);
10957+
let hasCustomAssemblyFormat = 1;
10958+
let extraClassDefinition = [{
10959+
ParseResult AtenAnyDimsOp::parse(OpAsmParser &parser, OperationState &result) {
10960+
return parseDefaultTorchOp(parser, result, 3, 1);
10961+
}
10962+
void AtenAnyDimsOp::print(OpAsmPrinter &printer) {
10963+
printDefaultTorchOp(printer, *this, 3, 1);
10964+
}
10965+
}];
10966+
let hasFolder = 1;
10967+
}
10968+
1094310969
def Torch_AtenArangeOp : Torch_Op<"aten.arange", [
1094410970
AllowsTypeRefinement,
1094510971
HasValueSemantics,

lib/Conversion/TorchToLinalg/Reduction.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,7 @@ static Value createInitElementForReduceOp(OpBuilder &b, Location loc,
337337
return b.create<arith::ConstantOp>(loc, b.getBoolAttr(true));
338338
}
339339

340-
if (isa<AtenAnyOp>(op)) {
340+
if (isa<AtenAnyOp, AtenAnyDimsOp>(op)) {
341341
return b.create<arith::ConstantOp>(loc, b.getBoolAttr(false));
342342
}
343343

@@ -434,7 +434,7 @@ static Value createLinalgPayloadForReduceOp(OpBuilder &b, Location loc,
434434
Value result = payloadArgs[1];
435435
Value self = convertScalarToDtype(b, loc, elem, resultElementType);
436436
return b.create<arith::AndIOp>(loc, self, result);
437-
} else if (isa<AtenAnyOp>(op)) {
437+
} else if (isa<AtenAnyOp, AtenAnyDimsOp>(op)) {
438438
Value elem = payloadArgs[0];
439439
Value result = payloadArgs[1];
440440
Value self = convertScalarToDtype(b, loc, elem, resultElementType);
@@ -532,6 +532,9 @@ class ConvertReductionOp : public ConversionPattern {
532532
if (auto allOp = dyn_cast<AtenAllDimOp>(op))
533533
return computeReductionOpInfoForDimVariantOp(allOp, operands, rewriter);
534534

535+
if (auto anyOp = dyn_cast<AtenAnyDimsOp>(op))
536+
return computeReductionOpInfoForDimVariantOp(anyOp, operands, rewriter);
537+
535538
return rewriter.notifyMatchFailure(op, "not a supported reduce op");
536539
}
537540

@@ -709,6 +712,7 @@ void mlir::torch::torch_to_linalg::populateReductionPatternsAndLegality(
709712
patterns.add<ConvertAtenMinMaxDimOp<AtenMinDimOp>>(typeConverter, context);
710713
target.addIllegalOp<AtenSumOp>();
711714
target.addIllegalOp<AtenAnyOp>();
715+
target.addIllegalOp<AtenAnyDimsOp>();
712716
target.addIllegalOp<AtenAllOp>();
713717
target.addIllegalOp<AtenSumDimIntListOp>();
714718
target.addIllegalOp<AtenProdOp>();

lib/Dialect/Torch/IR/TorchOps.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1120,6 +1120,22 @@ OpFoldResult AtenDimOp::fold(FoldAdaptor adaptor) {
11201120
return nullptr;
11211121
}
11221122

1123+
//===----------------------------------------------------------------------===//
1124+
// AtenAnyDimsOp
1125+
//===----------------------------------------------------------------------===//
1126+
1127+
OpFoldResult AtenAnyDimsOp::fold(FoldAdaptor adaptor) {
1128+
auto resultType = dyn_cast<ValueTensorType>(getResult().getType());
1129+
auto resultShape = resultType.toBuiltinTensor().getShape();
1130+
auto inputType = dyn_cast<ValueTensorType>(getOperand(0).getType());
1131+
auto inputShape = inputType.toBuiltinTensor().getShape();
1132+
if ((inputType.getDtype() == resultType.getDtype()) &&
1133+
(inputShape == resultShape)) {
1134+
return getSelf();
1135+
}
1136+
return {};
1137+
}
1138+
11231139
//===----------------------------------------------------------------------===//
11241140
// AtenLenTOp
11251141
//===----------------------------------------------------------------------===//

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7504,6 +7504,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
75047504
" %1 = call @__torch__.torch.jit._shape_functions.argmax(%arg0, %0, %arg2) : (!torch.list<int>, !torch.optional<int>, !torch.bool) -> !torch.list<int>\n"
75057505
" return %1 : !torch.list<int>\n"
75067506
" }\n"
7507+
" func.func @\"__torch_mlir_shape_fn.aten.any.dims\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.bool) -> !torch.list<int> {\n"
7508+
" %none = torch.constant.none\n"
7509+
" %0 = torch.derefine %none : !torch.none to !torch.any\n"
7510+
" %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %arg2, %0) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>\n"
7511+
" return %1 : !torch.list<int>\n"
7512+
" }\n"
75077513
" func.func @\"__torch_mlir_shape_fn.aten.all.dim\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.list<int> {\n"
75087514
" %0 = torch.derefine %arg1 : !torch.int to !torch.optional<int>\n"
75097515
" %1 = call @__torch__.torch.jit._shape_functions.argmax(%arg0, %0, %arg2) : (!torch.list<int>, !torch.optional<int>, !torch.bool) -> !torch.list<int>\n"
@@ -15420,6 +15426,18 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1542015426
" }\n"
1542115427
" return %2 : !torch.int\n"
1542215428
" }\n"
15429+
" func.func @\"__torch_mlir_dtype_fn.aten.any.dims\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.bool) -> !torch.int {\n"
15430+
" %int11 = torch.constant.int 11\n"
15431+
" %int0 = torch.constant.int 0\n"
15432+
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
15433+
" %1 = torch.aten.eq.int %0#1, %int0 : !torch.int, !torch.int -> !torch.bool\n"
15434+
" %2 = torch.prim.If %1 -> (!torch.int) {\n"
15435+
" torch.prim.If.yield %0#1 : !torch.int\n"
15436+
" } else {\n"
15437+
" torch.prim.If.yield %int11 : !torch.int\n"
15438+
" }\n"
15439+
" return %2 : !torch.int\n"
15440+
" }\n"
1542315441
" func.func @\"__torch_mlir_dtype_fn.aten.all.dim\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.int {\n"
1542415442
" %int11 = torch.constant.int 11\n"
1542515443
" %int0 = torch.constant.int 0\n"

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -824,6 +824,7 @@
824824
"RandnLikeDtypeModule_basic",
825825
"RandnLikeModule_basic",
826826
"RandnModule_basic",
827+
"ReduceAnyDimsFloatModule_basic",
827828
"ReflectionPad1dModule2dInput_Right",
828829
"ReflectionPad1dModule2dInput_basic",
829830
"ReflectionPad1dModule3dInput_Left",
@@ -2691,6 +2692,7 @@
26912692
"PermuteNegativeIndexModule_basic",
26922693
# Failure - incorrect numerics
26932694
"ReduceAnyDimFloatModule_basic",
2695+
"ReduceAnyDimsFloatModule_basic",
26942696
"AvgPool2dDivisorOverrideModule_basic",
26952697
"BroadcastDynamicDimModule_basic",
26962698
"ElementwiseAtan2TensorIntModule_basic",
@@ -3793,6 +3795,7 @@
37933795
"RandnLikeModule_basic",
37943796
"RandnModule_basic",
37953797
"ReduceAllDimEmpty_basic",
3798+
"ReduceAnyDimsFloatModule_basic",
37963799
"ReduceFrobeniusNormComplexModule_basic",
37973800
"ReduceL1NormComplexModule_basic",
37983801
"ReduceL2NormComplexModule_basic",

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -778,6 +778,9 @@ def aten〇one_hot〡shape(self: List[int], num_classes: int = -1) -> List[int]:
778778
def aten〇any〇dim〡shape(self: List[int], dim: int, keepdim: bool = False) -> List[int]:
779779
return upstream_shape_functions.argmax(self, dim, keepdim)
780780

781+
def aten〇any〇dims〡shape(self: List[int], dim: Optional[List[int]] = None, keepdim: bool = False) -> List[int]:
782+
return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, None)
783+
781784
def aten〇all〇dim〡shape(self: List[int], dim: int, keepdim: bool = False) -> List[int]:
782785
return upstream_shape_functions.argmax(self, dim, keepdim)
783786

@@ -5249,6 +5252,13 @@ def aten〇any〇dim〡dtype(self_rank_dtype: Tuple[int, int], dim: int, keepdim
52495252
return self_dtype
52505253
return torch.bool
52515254

5255+
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
5256+
def aten〇any〇dims〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[List[int]] = None, keepdim: bool = False) -> int:
5257+
self_rank, self_dtype = self_rank_dtype
5258+
if self_dtype == torch.uint8:
5259+
return self_dtype
5260+
return torch.bool
5261+
52525262
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0))
52535263
def aten〇all〇dim〡dtype(self_rank_dtype: Tuple[int, int], dim: int, keepdim: bool = False) -> int:
52545264
self_rank, self_dtype = self_rank_dtype

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -842,6 +842,7 @@ def emit_with_mutating_variants(key, **kwargs):
842842
emit("aten::all.dim : (Tensor, int, bool) -> (Tensor)")
843843
emit("aten::any : (Tensor) -> (Tensor)")
844844
emit("aten::any.dim : (Tensor, int, bool) -> (Tensor)")
845+
emit("aten::any.dims : (Tensor, int[]?, bool) -> (Tensor)", has_folder=True)
845846
emit("aten::arange : (Scalar, int?, int?, Device?, bool?) -> (Tensor)")
846847
emit(
847848
"aten::arange.start : (Scalar, Scalar, int?, int?, Device?, bool?) -> (Tensor)"

projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,26 @@ def ReduceAnyDimFloatModule_basic(module, tu: TestUtils):
302302
module.forward(tu.rand(3, 4, 5))
303303

304304

305+
class ReduceAnyDimsFloatModule(torch.nn.Module):
306+
def __init__(self):
307+
super().__init__()
308+
309+
@export
310+
@annotate_args(
311+
[
312+
None,
313+
([-1, -1, -1], torch.float32, True),
314+
]
315+
)
316+
def forward(self, a):
317+
return torch.ops.aten.any(a, dim=[0, 1])
318+
319+
320+
@register_test_case(module_factory=lambda: ReduceAnyDimsFloatModule())
321+
def ReduceAnyDimsFloatModule_basic(module, tu: TestUtils):
322+
module.forward(tu.rand(3, 4, 5))
323+
324+
305325
# ==============================================================================
306326

307327

0 commit comments

Comments
 (0)