Skip to content

Commit 7073781

Browse files
authored
[TORCH] Add support for fix Op (#4195)
- Added the support for fix Op - Since this is an alias for torch.trunc - Added the e2e test This implementation addresses and closes #4193 --------- Signed-off-by: sharavana20 <[email protected]>
1 parent 997251d commit 7073781

File tree

8 files changed

+128
-0
lines changed

8 files changed

+128
-0
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4608,6 +4608,51 @@ def Torch_AtenTrunc_Op : Torch_Op<"aten.trunc_", [
46084608
}];
46094609
}
46104610

4611+
def Torch_AtenFixOp : Torch_Op<"aten.fix", [
4612+
AllowsTypeRefinement,
4613+
HasValueSemantics,
4614+
ReadOnly
4615+
]> {
4616+
let summary = "Generated op for `aten::fix : (Tensor) -> (Tensor)`";
4617+
let arguments = (ins
4618+
AnyTorchTensorType:$self
4619+
);
4620+
let results = (outs
4621+
AnyTorchOptionalTensorType:$result
4622+
);
4623+
let hasCustomAssemblyFormat = 1;
4624+
let extraClassDefinition = [{
4625+
ParseResult AtenFixOp::parse(OpAsmParser &parser, OperationState &result) {
4626+
return parseDefaultTorchOp(parser, result, 1, 1);
4627+
}
4628+
void AtenFixOp::print(OpAsmPrinter &printer) {
4629+
printDefaultTorchOp(printer, *this, 1, 1);
4630+
}
4631+
}];
4632+
}
4633+
4634+
def Torch_AtenFix_Op : Torch_Op<"aten.fix_", [
4635+
IsTrailingUnderscoreInplaceVariant,
4636+
AllowsTypeRefinement
4637+
]> {
4638+
let summary = "Generated op for `aten::fix_ : (Tensor) -> (Tensor)`";
4639+
let arguments = (ins
4640+
Torch_NonValueTensorType:$self
4641+
);
4642+
let results = (outs
4643+
AnyTorchOptionalNonValueTensorType:$result
4644+
);
4645+
let hasCustomAssemblyFormat = 1;
4646+
let extraClassDefinition = [{
4647+
ParseResult AtenFix_Op::parse(OpAsmParser &parser, OperationState &result) {
4648+
return parseDefaultTorchOp(parser, result, 1, 1);
4649+
}
4650+
void AtenFix_Op::print(OpAsmPrinter &printer) {
4651+
printDefaultTorchOp(printer, *this, 1, 1);
4652+
}
4653+
}];
4654+
}
4655+
46114656
def Torch_AtenSpecialExpm1Op : Torch_Op<"aten.special_expm1", [
46124657
AllowsTypeRefinement,
46134658
HasValueSemantics,

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6710,6 +6710,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
67106710
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
67116711
" return %0 : !torch.list<int>\n"
67126712
" }\n"
6713+
" func.func @\"__torch_mlir_shape_fn.aten.fix\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
6714+
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
6715+
" return %0 : !torch.list<int>\n"
6716+
" }\n"
67136717
" func.func @\"__torch_mlir_shape_fn.aten.log\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
67146718
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
67156719
" return %0 : !torch.list<int>\n"
@@ -12357,6 +12361,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1235712361
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
1235812362
" return %0#1 : !torch.int\n"
1235912363
" }\n"
12364+
" func.func @\"__torch_mlir_dtype_fn.aten.fix\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
12365+
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
12366+
" return %0#1 : !torch.int\n"
12367+
" }\n"
1236012368
" func.func @\"__torch_mlir_dtype_fn.aten.clamp_max\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number) -> !torch.int {\n"
1236112369
" %int4 = torch.constant.int 4\n"
1236212370
" %int11 = torch.constant.int 11\n"

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8780,6 +8780,19 @@ class DecomposeAtenTruncOp : public OpRewritePattern<AtenTruncOp> {
87808780
};
87818781
} // namespace
87828782

8783+
namespace {
8784+
// fix Op is an alias for trunc Op
8785+
class DecomposeAtenFixOp : public OpRewritePattern<AtenFixOp> {
8786+
using OpRewritePattern::OpRewritePattern;
8787+
LogicalResult matchAndRewrite(AtenFixOp op,
8788+
PatternRewriter &rewriter) const override {
8789+
Value self = op.getSelf();
8790+
rewriter.replaceOpWithNewOp<AtenTruncOp>(op, op.getType(), self);
8791+
return success();
8792+
}
8793+
};
8794+
} // namespace
8795+
87838796
namespace {
87848797
// decompose `signbit(x)` to `view.dtype(x, si32/si64) < 0 `
87858798
class DecomposeAtenSignbitOp : public OpRewritePattern<AtenSignbitOp> {
@@ -12085,6 +12098,7 @@ class DecomposeComplexOpsPass
1208512098
addPatternIfTargetOpIsIllegal<DecomposeAtenRad2degOp>(patterns);
1208612099
addPatternIfTargetOpIsIllegal<DecomposeAtenCosineSimilarityOp>(patterns);
1208712100
addPatternIfTargetOpIsIllegal<DecomposeAtenTruncOp>(patterns);
12101+
addPatternIfTargetOpIsIllegal<DecomposeAtenFixOp>(patterns);
1208812102
addPatternIfTargetOpIsIllegal<DecomposeAtenSignbitOp>(patterns);
1208912103
addPatternIfTargetOpIsIllegal<DecomposeAtenFracOp>(patterns);
1209012104
addPatternIfTargetOpIsIllegal<DecomposeAtenCopysignTensorOp>(patterns);

lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -550,6 +550,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
550550
target.addIllegalOp<AtenRad2degOp>();
551551
target.addIllegalOp<AtenCosineSimilarityOp>();
552552
target.addIllegalOp<AtenTruncOp>();
553+
target.addIllegalOp<AtenFixOp>();
553554
target.addIllegalOp<AtenSignbitOp>();
554555
target.addIllegalOp<AtenFracOp>();
555556
target.addIllegalOp<AtenCopysignTensorOp>();

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1694,6 +1694,8 @@
16941694
"ElementwiseSinhModule_basic",
16951695
"ElementwiseTruncIntModule_basic",
16961696
"ElementwiseTruncModule_basic",
1697+
"ElementwiseFixModule_basic",
1698+
"ElementwiseFixIntModule_basic",
16971699
"ElementwiseLogSigmoidModule_basic",
16981700
"ElementwiseHardshrinkStaticModule_basic",
16991701
"ElementwiseSoftshrinkStaticModule_basic",
@@ -2049,6 +2051,8 @@
20492051
"ElementwiseTernaryStaticShapeModule_basic",
20502052
"ElementwiseTruncModule_basic",
20512053
"ElementwiseTruncIntModule_basic",
2054+
"ElementwiseFixModule_basic",
2055+
"ElementwiseFixIntModule_basic",
20522056
"ElementwiseSgnModule_basic",
20532057
"ElementwiseSignIntModule_basic",
20542058
"AddCDivModule_basic",
@@ -2900,6 +2904,8 @@
29002904
"ElementwiseCoshModule_basic",
29012905
"ElementwiseTruncIntModule_basic",
29022906
"ElementwiseTruncModule_basic",
2907+
"ElementwiseFixModule_basic",
2908+
"ElementwiseFixIntModule_basic",
29032909
"ElementwiseDequantizePerChannelModule_basic",
29042910
"ElementwiseDequantizePerTensorModule_basic",
29052911
"ElementwiseDivTensorRoundingModeFloorIntStaticModule_basic",
@@ -4356,6 +4362,8 @@
43564362
"ElementwiseToDtypeI64ToUI8Module_basic",
43574363
"ElementwiseTruncIntModule_basic",
43584364
"ElementwiseTruncModule_basic",
4365+
"ElementwiseFixModule_basic",
4366+
"ElementwiseFixIntModule_basic",
43594367
"ElementwiseUnaryIntModule_basic",
43604368
"ElementwiseUnsqueezeNegDimsModule_basic",
43614369
"ElementwiseWhereScalarOtherModule_basic",

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,9 @@ def aten〇ceil〡shape(self: List[int]) -> List[int]:
319319
def aten〇trunc〡shape(self: List[int]) -> List[int]:
320320
return upstream_shape_functions.unary(self)
321321

322+
def aten〇fix〡shape(self: List[int]) -> List[int]:
323+
return upstream_shape_functions.unary(self)
324+
322325
def aten〇log〡shape(self: List[int]) -> List[int]:
323326
return upstream_shape_functions.unary(self)
324327

@@ -3108,6 +3111,11 @@ def aten〇trunc〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
31083111
self_rank, self_dtype = self_rank_dtype
31093112
return self_dtype
31103113

3114+
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
3115+
def aten〇fix〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
3116+
self_rank, self_dtype = self_rank_dtype
3117+
return self_dtype
3118+
31113119
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, max=0))
31123120
def aten〇clamp_max〡dtype(self_rank_dtype: Tuple[int, int], max: Union[int, float, complex]) -> int:
31133121
self_rank, self_dtype = self_rank_dtype

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -452,6 +452,7 @@ def emit_with_mutating_variants(key, **kwargs):
452452
emit_with_mutating_variants("aten::ceil : (Tensor) -> (Tensor)", has_folder=True)
453453
emit_with_mutating_variants("aten::round : (Tensor) -> (Tensor)", has_folder=True)
454454
emit_with_mutating_variants("aten::trunc : (Tensor) -> (Tensor)", has_folder=True)
455+
emit_with_mutating_variants("aten::fix : (Tensor) -> (Tensor)")
455456
emit("aten::special_expm1 : (Tensor) -> (Tensor)")
456457
emit_with_mutating_variants(
457458
"aten::sign : (Tensor) -> (Tensor)", has_canonicalizer=True

projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2900,6 +2900,49 @@ def ElementwiseTruncIntModule_basic(module, tu: TestUtils):
29002900
# ==============================================================================
29012901

29022902

2903+
class ElementwiseFixModule(torch.nn.Module):
2904+
def __init__(self):
2905+
super().__init__()
2906+
2907+
@export
2908+
@annotate_args(
2909+
[
2910+
None,
2911+
([5], torch.float32, True),
2912+
]
2913+
)
2914+
def forward(self, a):
2915+
return torch.fix(a)
2916+
2917+
2918+
@register_test_case(module_factory=lambda: ElementwiseFixModule())
2919+
def ElementwiseFixModule_basic(module, tu: TestUtils):
2920+
module.forward(torch.tensor([torch.nan, torch.inf, -torch.inf, 2, 0.5]))
2921+
2922+
2923+
class ElementwiseFixIntModule(torch.nn.Module):
2924+
def __init__(self):
2925+
super().__init__()
2926+
2927+
@export
2928+
@annotate_args(
2929+
[
2930+
None,
2931+
([2, 3], torch.int64, True),
2932+
]
2933+
)
2934+
def forward(self, a):
2935+
return torch.fix(a)
2936+
2937+
2938+
@register_test_case(module_factory=lambda: ElementwiseFixIntModule())
2939+
def ElementwiseFixIntModule_basic(module, tu: TestUtils):
2940+
module.forward(tu.randint(2, 3, low=0, high=500))
2941+
2942+
2943+
# ==============================================================================
2944+
2945+
29032946
class ElementwiseSignbitModule(torch.nn.Module):
29042947
def __init__(self):
29052948
super().__init__()

0 commit comments

Comments
 (0)