Skip to content

Commit 1cb25e9

Browse files
authored
[TORCH] Add support for aten.fliplr & aten.flipud ops via torch decomposition (#4197)
This PR takes care of issue: #4192 **Feature**: Adds e2e support of two related ops `aten.fliplr` & `aten.flipud` via decomposition to `aten.flip` op which is already supported in torch-mlir. --------- Signed-off-by: Zahid Wakeel <[email protected]>
1 parent 7073781 commit 1cb25e9

File tree

8 files changed

+289
-0
lines changed

8 files changed

+289
-0
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7226,6 +7226,52 @@ def Torch_AtenFlipOp : Torch_Op<"aten.flip", [
72267226
}];
72277227
}
72287228

7229+
def Torch_AtenFliplrOp : Torch_Op<"aten.fliplr", [
7230+
AllowsTypeRefinement,
7231+
HasValueSemantics,
7232+
ReadOnly
7233+
]> {
7234+
let summary = "Generated op for `aten::fliplr : (Tensor) -> (Tensor)`";
7235+
let arguments = (ins
7236+
AnyTorchTensorType:$self
7237+
);
7238+
let results = (outs
7239+
AnyTorchOptionalTensorType:$result
7240+
);
7241+
let hasCustomAssemblyFormat = 1;
7242+
let extraClassDefinition = [{
7243+
ParseResult AtenFliplrOp::parse(OpAsmParser &parser, OperationState &result) {
7244+
return parseDefaultTorchOp(parser, result, 1, 1);
7245+
}
7246+
void AtenFliplrOp::print(OpAsmPrinter &printer) {
7247+
printDefaultTorchOp(printer, *this, 1, 1);
7248+
}
7249+
}];
7250+
}
7251+
7252+
def Torch_AtenFlipudOp : Torch_Op<"aten.flipud", [
7253+
AllowsTypeRefinement,
7254+
HasValueSemantics,
7255+
ReadOnly
7256+
]> {
7257+
let summary = "Generated op for `aten::flipud : (Tensor) -> (Tensor)`";
7258+
let arguments = (ins
7259+
AnyTorchTensorType:$self
7260+
);
7261+
let results = (outs
7262+
AnyTorchOptionalTensorType:$result
7263+
);
7264+
let hasCustomAssemblyFormat = 1;
7265+
let extraClassDefinition = [{
7266+
ParseResult AtenFlipudOp::parse(OpAsmParser &parser, OperationState &result) {
7267+
return parseDefaultTorchOp(parser, result, 1, 1);
7268+
}
7269+
void AtenFlipudOp::print(OpAsmPrinter &printer) {
7270+
printDefaultTorchOp(printer, *this, 1, 1);
7271+
}
7272+
}];
7273+
}
7274+
72297275
def Torch_AtenNativeBatchNormOp : Torch_Op<"aten.native_batch_norm", [
72307276
AllowsTypeRefinement,
72317277
HasValueSemantics,

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10281,6 +10281,34 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1028110281
" func.func @\"__torch_mlir_shape_fn.aten.flip\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
1028210282
" return %arg0 : !torch.list<int>\n"
1028310283
" }\n"
10284+
" func.func @\"__torch_mlir_shape_fn.aten.fliplr\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
10285+
" %none = torch.constant.none\n"
10286+
" %str = torch.constant.str \"AssertionError: \"\n"
10287+
" %int2 = torch.constant.int 2\n"
10288+
" %0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
10289+
" %1 = torch.aten.ge.int %0, %int2 : !torch.int, !torch.int -> !torch.bool\n"
10290+
" torch.prim.If %1 -> () {\n"
10291+
" torch.prim.If.yield\n"
10292+
" } else {\n"
10293+
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
10294+
" torch.prim.If.yield\n"
10295+
" }\n"
10296+
" return %arg0 : !torch.list<int>\n"
10297+
" }\n"
10298+
" func.func @\"__torch_mlir_shape_fn.aten.flipud\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
10299+
" %none = torch.constant.none\n"
10300+
" %str = torch.constant.str \"AssertionError: \"\n"
10301+
" %int1 = torch.constant.int 1\n"
10302+
" %0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
10303+
" %1 = torch.aten.ge.int %0, %int1 : !torch.int, !torch.int -> !torch.bool\n"
10304+
" torch.prim.If %1 -> () {\n"
10305+
" torch.prim.If.yield\n"
10306+
" } else {\n"
10307+
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
10308+
" torch.prim.If.yield\n"
10309+
" }\n"
10310+
" return %arg0 : !torch.list<int>\n"
10311+
" }\n"
1028410312
" func.func @\"__torch_mlir_shape_fn.aten.convolution_backward\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.optional<list<int>>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.list<int>, %arg7: !torch.bool, %arg8: !torch.list<int>, %arg9: !torch.int, %arg10: !torch.list<bool>) -> !torch.tuple<list<int>, list<int>, list<int>> {\n"
1028510313
" %0 = call @__torch__.torch.jit._shape_functions.conv_backwards(%arg0, %arg1, %arg2, %arg3) : (!torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.optional<list<int>>) -> !torch.tuple<list<int>, list<int>, list<int>>\n"
1028610314
" return %0 : !torch.tuple<list<int>, list<int>, list<int>>\n"
@@ -12590,6 +12618,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1259012618
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
1259112619
" return %0#1 : !torch.int\n"
1259212620
" }\n"
12621+
" func.func @\"__torch_mlir_dtype_fn.aten.fliplr\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
12622+
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
12623+
" return %0#1 : !torch.int\n"
12624+
" }\n"
12625+
" func.func @\"__torch_mlir_dtype_fn.aten.flipud\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
12626+
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
12627+
" return %0#1 : !torch.int\n"
12628+
" }\n"
1259312629
" func.func @\"__torch_mlir_dtype_fn.aten.sign\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
1259412630
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
1259512631
" return %0#1 : !torch.int\n"

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1372,6 +1372,84 @@ class DecomposeAtenDeg2radOp : public OpRewritePattern<AtenDeg2radOp> {
13721372
};
13731373
} // namespace
13741374

1375+
namespace {
1376+
class DecomposeAtenFliplrOp : public OpRewritePattern<AtenFliplrOp> {
1377+
public:
1378+
using OpRewritePattern::OpRewritePattern;
1379+
LogicalResult matchAndRewrite(AtenFliplrOp op,
1380+
PatternRewriter &rewriter) const override {
1381+
auto inputTy = dyn_cast<ValueTensorType>(op.getSelf().getType());
1382+
auto maybeSizes = inputTy.getOptionalSizes();
1383+
if (!maybeSizes) {
1384+
return rewriter.notifyMatchFailure(
1385+
op, "Expected input tensor to have known rank.");
1386+
}
1387+
auto inShape = maybeSizes.value();
1388+
auto inRank = inShape.size();
1389+
1390+
if (inRank < 2) {
1391+
return rewriter.notifyMatchFailure(op,
1392+
"Fliplr expects input rank >= 2D.");
1393+
}
1394+
1395+
Location loc = op.getLoc();
1396+
Value constI = rewriter.create<Torch::ConstantIntOp>(
1397+
loc, rewriter.getI64IntegerAttr(1));
1398+
1399+
SmallVector<Value> dims;
1400+
dims.push_back(constI);
1401+
1402+
Value flipDimList = rewriter.create<Torch::PrimListConstructOp>(
1403+
loc,
1404+
rewriter.getType<Torch::ListType>(rewriter.getType<Torch::IntType>()),
1405+
dims);
1406+
Value flip = rewriter.create<AtenFlipOp>(loc, op.getType(), op.getSelf(),
1407+
flipDimList);
1408+
rewriter.replaceOp(op, flip);
1409+
return success();
1410+
}
1411+
};
1412+
} // namespace
1413+
1414+
namespace {
1415+
class DecomposeAtenFlipudOp : public OpRewritePattern<AtenFlipudOp> {
1416+
public:
1417+
using OpRewritePattern::OpRewritePattern;
1418+
LogicalResult matchAndRewrite(AtenFlipudOp op,
1419+
PatternRewriter &rewriter) const override {
1420+
auto inputTy = dyn_cast<ValueTensorType>(op.getSelf().getType());
1421+
auto maybeSizes = inputTy.getOptionalSizes();
1422+
if (!maybeSizes) {
1423+
return rewriter.notifyMatchFailure(
1424+
op, "Expected input tensor to have known rank.");
1425+
}
1426+
auto inShape = maybeSizes.value();
1427+
auto inRank = inShape.size();
1428+
1429+
if (inRank < 1) {
1430+
return rewriter.notifyMatchFailure(op,
1431+
"Flipud expects input rank >= 1D.");
1432+
}
1433+
1434+
Location loc = op.getLoc();
1435+
Value constI = rewriter.create<Torch::ConstantIntOp>(
1436+
loc, rewriter.getI64IntegerAttr(0));
1437+
1438+
SmallVector<Value> dims;
1439+
dims.push_back(constI);
1440+
1441+
Value flipDimList = rewriter.create<Torch::PrimListConstructOp>(
1442+
loc,
1443+
rewriter.getType<Torch::ListType>(rewriter.getType<Torch::IntType>()),
1444+
dims);
1445+
Value flip = rewriter.create<AtenFlipOp>(loc, op.getType(), op.getSelf(),
1446+
flipDimList);
1447+
rewriter.replaceOp(op, flip);
1448+
return success();
1449+
}
1450+
};
1451+
} // namespace
1452+
13751453
namespace {
13761454
class DecomposeAtenSizeOp : public OpRewritePattern<AtenSizeOp> {
13771455
public:
@@ -12048,6 +12126,8 @@ class DecomposeComplexOpsPass
1204812126
addPatternIfTargetOpIsIllegal<DecomposeAtenRreluWithNoiseBackwardOp>(
1204912127
patterns);
1205012128
addPatternIfTargetOpIsIllegal<DecomposeAtenCeluOp>(patterns);
12129+
addPatternIfTargetOpIsIllegal<DecomposeAtenFliplrOp>(patterns);
12130+
addPatternIfTargetOpIsIllegal<DecomposeAtenFlipudOp>(patterns);
1205112131
addPatternIfTargetOpIsIllegal<DecomposeAtenAtleast1dOp>(patterns);
1205212132
addPatternIfTargetOpIsIllegal<DecomposeAtenAtleast2dOp>(patterns);
1205312133
addPatternIfTargetOpIsIllegal<DecomposeAtenEinsumOp>(patterns);

lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -579,6 +579,8 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
579579
target.addIllegalOp<AtenFminOp>();
580580
target.addIllegalOp<AtenFmaxOp>();
581581
target.addIllegalOp<AtenSpecialExpm1Op>();
582+
target.addIllegalOp<AtenFliplrOp>();
583+
target.addIllegalOp<AtenFlipudOp>();
582584

583585
for (auto &opName : backendLegalOpsSet) {
584586
target.addLegalOp(

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1910,6 +1910,10 @@
19101910
"FlipModuleStaticShape_basic",
19111911
"FlipModule_basic",
19121912
"FlipNegativeIndexModule_basic",
1913+
"FliplrOddRankModule_basic",
1914+
"FliplrEvenRankModule_basic",
1915+
"FlipudOddRankModule_basic",
1916+
"FlipudEvenRankModule_basic",
19131917
"Rot90BasicModule_basic",
19141918
"Rot90DynamicDimsModule_basic",
19151919
"Rot90MultipleRotationsModule_basic",
@@ -2697,6 +2701,10 @@
26972701
"ElementwiseFminModule_basic",
26982702
"ElementwiseFmaxModule_basic",
26992703
"Exp2StaticModule_basic",
2704+
"FliplrOddRankModule_basic",
2705+
"FliplrEvenRankModule_basic",
2706+
"FlipudOddRankModule_basic",
2707+
"FlipudEvenRankModule_basic",
27002708
"FloatPowerTensorTensorStaticModule_basic",
27012709
"MultinomialModule2D_basic",
27022710
"MultinomialModule2D_F32",
@@ -4400,6 +4408,10 @@
44004408
"FlipModuleStaticShape_basic",
44014409
"FlipModule_basic",
44024410
"FlipNegativeIndexModule_basic",
4411+
"FliplrOddRankModule_basic",
4412+
"FliplrEvenRankModule_basic",
4413+
"FlipudOddRankModule_basic",
4414+
"FlipudEvenRankModule_basic",
44034415
"FloatImplicitModule_basic",
44044416
"FullLikeModuleDefaultDtype_basic",
44054417
"FullLikeModuleFalsePinMemory_basic",

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1975,6 +1975,14 @@ def aten〇_convolution〇deprecated〡shape(input: List[int], weight: List[int]
19751975
def aten〇flip〡shape(self: List[int], dims: List[int]) -> List[int]:
19761976
return self
19771977

1978+
def aten〇fliplr〡shape(self: List[int]) -> List[int]:
1979+
assert len(self) >= 2
1980+
return self
1981+
1982+
def aten〇flipud〡shape(self: List[int]) -> List[int]:
1983+
assert len(self) >= 1
1984+
return self
1985+
19781986
def aten〇convolution_backward〡shape(grad_output: List[int], input: List[int], weight: List[int], bias_sizes: Optional[List[int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, output_mask: List[bool]) -> Tuple[List[int], List[int], List[int]]:
19791987
return upstream_shape_functions.conv_backwards(grad_output, input, weight, bias_sizes)
19801988

@@ -3293,6 +3301,16 @@ def aten〇flip〡dtype(self_rank_dtype: Tuple[int, int], dims: List[int]) -> in
32933301
self_rank, self_dtype = self_rank_dtype
32943302
return self_dtype
32953303

3304+
@check_dtype_function([Invocation(TensorOfShape(2, 3, dtype=torch.float32, device="cpu"))])
3305+
def aten〇fliplr〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
3306+
self_rank, self_dtype = self_rank_dtype
3307+
return self_dtype
3308+
3309+
@check_dtype_function([Invocation(TensorOfShape(2, 3, dtype=torch.float32, device="cpu"))])
3310+
def aten〇flipud〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
3311+
self_rank, self_dtype = self_rank_dtype
3312+
return self_dtype
3313+
32963314
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
32973315
def aten〇sign〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
32983316
self_rank, self_dtype = self_rank_dtype

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -622,6 +622,8 @@ def emit_with_mutating_variants(key, **kwargs):
622622
"aten::convolution_backward : (Tensor, Tensor, Tensor, int[]?, int[], int[], int[], bool, int[], int, bool[]) -> (Tensor, Tensor, Tensor)"
623623
)
624624
emit("aten::flip : (Tensor, int[]) -> (Tensor)")
625+
emit("aten::fliplr : (Tensor) -> (Tensor)")
626+
emit("aten::flipud : (Tensor) -> (Tensor)")
625627
emit(
626628
"aten::native_batch_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float) -> (Tensor, Tensor, Tensor)"
627629
)

projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from torch_mlir_e2e_test.registry import register_test_case
1111
from torch_mlir_e2e_test.annotations import annotate_args, export
1212

13+
1314
# ==============================================================================
1415

1516

@@ -4267,6 +4268,98 @@ def FlipNegativeIndexModule_basic(module, tu: TestUtils):
42674268
# ==============================================================================
42684269

42694270

4271+
class FliplrOddRankModule(torch.nn.Module):
4272+
def __init__(self):
4273+
super().__init__()
4274+
4275+
@export
4276+
@annotate_args(
4277+
[
4278+
None,
4279+
([-1, -1, -1], torch.float32, True),
4280+
]
4281+
)
4282+
def forward(self, a):
4283+
return torch.ops.aten.fliplr(a)
4284+
4285+
4286+
@register_test_case(module_factory=lambda: FliplrOddRankModule())
4287+
def FliplrOddRankModule_basic(module, tu: TestUtils):
4288+
module.forward(tu.rand(3, 5, 2))
4289+
4290+
4291+
# ==============================================================================
4292+
4293+
4294+
class FliplrEvenRankModule(torch.nn.Module):
4295+
def __init__(self):
4296+
super().__init__()
4297+
4298+
@export
4299+
@annotate_args(
4300+
[
4301+
None,
4302+
([-1, -1, -1, -1], torch.float32, True),
4303+
]
4304+
)
4305+
def forward(self, a):
4306+
return torch.ops.aten.fliplr(a)
4307+
4308+
4309+
@register_test_case(module_factory=lambda: FliplrEvenRankModule())
4310+
def FliplrEvenRankModule_basic(module, tu: TestUtils):
4311+
module.forward(tu.rand(3, 5, 2, 4))
4312+
4313+
4314+
# ==============================================================================
4315+
4316+
4317+
class FlipudOddRankModule(torch.nn.Module):
4318+
def __init__(self):
4319+
super().__init__()
4320+
4321+
@export
4322+
@annotate_args(
4323+
[
4324+
None,
4325+
([-1, -1, -1], torch.float32, True),
4326+
]
4327+
)
4328+
def forward(self, a):
4329+
return torch.ops.aten.flipud(a)
4330+
4331+
4332+
@register_test_case(module_factory=lambda: FlipudOddRankModule())
4333+
def FlipudOddRankModule_basic(module, tu: TestUtils):
4334+
module.forward(tu.rand(3, 5, 2))
4335+
4336+
4337+
# ==============================================================================
4338+
4339+
4340+
class FlipudEvenRankModule(torch.nn.Module):
4341+
def __init__(self):
4342+
super().__init__()
4343+
4344+
@export
4345+
@annotate_args(
4346+
[
4347+
None,
4348+
([-1, -1, -1, -1], torch.float32, True),
4349+
]
4350+
)
4351+
def forward(self, a):
4352+
return torch.ops.aten.flipud(a)
4353+
4354+
4355+
@register_test_case(module_factory=lambda: FlipudEvenRankModule())
4356+
def FlipudEvenRankModule_basic(module, tu: TestUtils):
4357+
module.forward(tu.rand(3, 5, 2, 4))
4358+
4359+
4360+
# ==============================================================================
4361+
4362+
42704363
class DetachModule(torch.nn.Module):
42714364
def __init__(self):
42724365
super().__init__()

0 commit comments

Comments
 (0)