Skip to content

Commit dcdb77e

Browse files
authored
[TORCH] Add aten.logaddexp & aten.logaddexp2 ops support (#4201)
This PR takes care of #4194 - e2e support for **aten.logaddexp** and **aten.logaddexp2** ops by decomposing to primitive ops having existing support in torch-mlir. - Added relevant expected fails in `xfail_sets.py` --------- Signed-off-by: Zahid Wakeel <[email protected]>
1 parent dcf9bbf commit dcdb77e

File tree

8 files changed

+220
-0
lines changed

8 files changed

+220
-0
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8813,6 +8813,54 @@ def Torch_AtenLogsumexpOp : Torch_Op<"aten.logsumexp", [
88138813
}];
88148814
}
88158815

8816+
def Torch_AtenLogaddexpOp : Torch_Op<"aten.logaddexp", [
8817+
AllowsTypeRefinement,
8818+
HasValueSemantics,
8819+
ReadOnly
8820+
]> {
8821+
let summary = "Generated op for `aten::logaddexp : (Tensor, Tensor) -> (Tensor)`";
8822+
let arguments = (ins
8823+
AnyTorchTensorType:$self,
8824+
AnyTorchTensorType:$other
8825+
);
8826+
let results = (outs
8827+
AnyTorchOptionalTensorType:$result
8828+
);
8829+
let hasCustomAssemblyFormat = 1;
8830+
let extraClassDefinition = [{
8831+
ParseResult AtenLogaddexpOp::parse(OpAsmParser &parser, OperationState &result) {
8832+
return parseDefaultTorchOp(parser, result, 2, 1);
8833+
}
8834+
void AtenLogaddexpOp::print(OpAsmPrinter &printer) {
8835+
printDefaultTorchOp(printer, *this, 2, 1);
8836+
}
8837+
}];
8838+
}
8839+
8840+
def Torch_AtenLogaddexp2Op : Torch_Op<"aten.logaddexp2", [
8841+
AllowsTypeRefinement,
8842+
HasValueSemantics,
8843+
ReadOnly
8844+
]> {
8845+
let summary = "Generated op for `aten::logaddexp2 : (Tensor, Tensor) -> (Tensor)`";
8846+
let arguments = (ins
8847+
AnyTorchTensorType:$self,
8848+
AnyTorchTensorType:$other
8849+
);
8850+
let results = (outs
8851+
AnyTorchOptionalTensorType:$result
8852+
);
8853+
let hasCustomAssemblyFormat = 1;
8854+
let extraClassDefinition = [{
8855+
ParseResult AtenLogaddexp2Op::parse(OpAsmParser &parser, OperationState &result) {
8856+
return parseDefaultTorchOp(parser, result, 2, 1);
8857+
}
8858+
void AtenLogaddexp2Op::print(OpAsmPrinter &printer) {
8859+
printDefaultTorchOp(printer, *this, 2, 1);
8860+
}
8861+
}];
8862+
}
8863+
88168864
def Torch_AtenMeanDimOp : Torch_Op<"aten.mean.dim", [
88178865
AllowsTypeRefinement,
88188866
HasValueSemantics,

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9286,6 +9286,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
92869286
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
92879287
" return %0 : !torch.list<int>\n"
92889288
" }\n"
9289+
" func.func @\"__torch_mlir_shape_fn.aten.logaddexp\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
9290+
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
9291+
" return %0 : !torch.list<int>\n"
9292+
" }\n"
9293+
" func.func @\"__torch_mlir_shape_fn.aten.logaddexp2\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
9294+
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
9295+
" return %0 : !torch.list<int>\n"
9296+
" }\n"
92899297
" func.func @\"__torch_mlir_shape_fn.aten.masked_fill.Scalar\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.float) -> !torch.list<int> {\n"
92909298
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
92919299
" return %0 : !torch.list<int>\n"
@@ -12796,6 +12804,47 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1279612804
" func.func @\"__torch_mlir_dtype_fn.aten._log_softmax_backward_data\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n"
1279712805
" return %arg3 : !torch.int\n"
1279812806
" }\n"
12807+
" func.func @\"__torch_mlir_dtype_fn.aten.logaddexp\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
12808+
" %none = torch.constant.none\n"
12809+
" %str = torch.constant.str \"AssertionError: \"\n"
12810+
" %false = torch.constant.bool false\n"
12811+
" %int11 = torch.constant.int 11\n"
12812+
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
12813+
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
12814+
" %2 = torch.aten.ne.int %0#1, %int11 : !torch.int, !torch.int -> !torch.bool\n"
12815+
" %3 = torch.prim.If %2 -> (!torch.bool) {\n"
12816+
" %4 = torch.aten.ne.int %1#1, %int11 : !torch.int, !torch.int -> !torch.bool\n"
12817+
" torch.prim.If.yield %4 : !torch.bool\n"
12818+
" } else {\n"
12819+
" torch.prim.If.yield %false : !torch.bool\n"
12820+
" }\n"
12821+
" torch.prim.If %3 -> () {\n"
12822+
" torch.prim.If.yield\n"
12823+
" } else {\n"
12824+
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
12825+
" torch.prim.If.yield\n"
12826+
" }\n"
12827+
" return %0#1 : !torch.int\n"
12828+
" }\n"
12829+
" func.func @\"__torch_mlir_dtype_fn.aten.logaddexp2\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
12830+
" %none = torch.constant.none\n"
12831+
" %str = torch.constant.str \"AssertionError: \"\n"
12832+
" %int10 = torch.constant.int 10\n"
12833+
" %int9 = torch.constant.int 9\n"
12834+
" %int8 = torch.constant.int 8\n"
12835+
" %int11 = torch.constant.int 11\n"
12836+
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
12837+
" %1 = torch.prim.ListConstruct %int11, %int8, %int9, %int10 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
12838+
" %2 = torch.aten.__contains__.int_list %1, %0#1 : !torch.list<int>, !torch.int -> !torch.bool\n"
12839+
" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n"
12840+
" torch.prim.If %3 -> () {\n"
12841+
" torch.prim.If.yield\n"
12842+
" } else {\n"
12843+
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
12844+
" torch.prim.If.yield\n"
12845+
" }\n"
12846+
" return %0#1 : !torch.int\n"
12847+
" }\n"
1279912848
" func.func @\"__torch_mlir_dtype_fn.aten.masked_fill.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.number) -> !torch.int {\n"
1280012849
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
1280112850
" return %0#1 : !torch.int\n"

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2976,6 +2976,52 @@ class DecomposeAtenLogSigmoidOp : public OpRewritePattern<AtenLogSigmoidOp> {
29762976
};
29772977
} // namespace
29782978

2979+
namespace {
2980+
class DecomposeAtenLogAddExpOp : public OpRewritePattern<AtenLogaddexpOp> {
2981+
public:
2982+
using OpRewritePattern<AtenLogaddexpOp>::OpRewritePattern;
2983+
LogicalResult matchAndRewrite(AtenLogaddexpOp op,
2984+
PatternRewriter &rewriter) const override {
2985+
Location loc = op.getLoc();
2986+
Value self = op.getSelf();
2987+
Value other = op.getOther();
2988+
auto outTy = op.getType();
2989+
2990+
Value constantOne =
2991+
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
2992+
Value expSelf = rewriter.create<AtenExpOp>(loc, outTy, self);
2993+
Value expOther = rewriter.create<AtenExpOp>(loc, outTy, other);
2994+
Value addValue = rewriter.create<AtenAddTensorOp>(loc, outTy, expSelf,
2995+
expOther, constantOne);
2996+
rewriter.replaceOpWithNewOp<AtenLogOp>(op, outTy, addValue);
2997+
return success();
2998+
}
2999+
};
3000+
} // namespace
3001+
3002+
namespace {
3003+
class DecomposeAtenLogAddExp2Op : public OpRewritePattern<AtenLogaddexp2Op> {
3004+
public:
3005+
using OpRewritePattern<AtenLogaddexp2Op>::OpRewritePattern;
3006+
LogicalResult matchAndRewrite(AtenLogaddexp2Op op,
3007+
PatternRewriter &rewriter) const override {
3008+
Location loc = op.getLoc();
3009+
Value self = op.getSelf();
3010+
Value other = op.getOther();
3011+
auto outTy = op.getType();
3012+
3013+
Value constantOne =
3014+
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
3015+
Value expSelf = rewriter.create<AtenExp2Op>(loc, outTy, self);
3016+
Value expOther = rewriter.create<AtenExp2Op>(loc, outTy, other);
3017+
Value addValue = rewriter.create<AtenAddTensorOp>(loc, outTy, expSelf,
3018+
expOther, constantOne);
3019+
rewriter.replaceOpWithNewOp<AtenLog2Op>(op, outTy, addValue);
3020+
return success();
3021+
}
3022+
};
3023+
} // namespace
3024+
29793025
// SoftShrink(x, lambda) function:
29803026
// Applies a shrinkage function where:
29813027
// - If x > lambda, returns x - lambda
@@ -12068,6 +12114,8 @@ class DecomposeComplexOpsPass
1206812114
addPatternIfTargetOpIsIllegal<DecomposeAten_LogSoftmaxOp>(patterns);
1206912115
addPatternIfTargetOpIsIllegal<DecomposeAtenLogSoftmaxIntOp>(patterns);
1207012116
addPatternIfTargetOpIsIllegal<DecomposeAtenLogSigmoidOp>(patterns);
12117+
addPatternIfTargetOpIsIllegal<DecomposeAtenLogAddExpOp>(patterns);
12118+
addPatternIfTargetOpIsIllegal<DecomposeAtenLogAddExp2Op>(patterns);
1207112119
addPatternIfTargetOpIsIllegal<DecomposeAtenHardshrinkOp>(patterns);
1207212120
addPatternIfTargetOpIsIllegal<DecomposeAtenSoftshrinkOp>(patterns);
1207312121
addPatternIfTargetOpIsIllegal<DecomposeAtenEmptyLikeOp>(patterns);

lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -582,6 +582,8 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
582582
target.addIllegalOp<AtenSpecialExpm1Op>();
583583
target.addIllegalOp<AtenFliplrOp>();
584584
target.addIllegalOp<AtenFlipudOp>();
585+
target.addIllegalOp<AtenLogaddexpOp>();
586+
target.addIllegalOp<AtenLogaddexp2Op>();
585587

586588
for (auto &opName : backendLegalOpsSet) {
587589
target.addLegalOp(

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2923,6 +2923,8 @@
29232923
"ElementwiseEluNonDefaultModule_basic",
29242924
"ElementwiseExpm1IntModule_basic",
29252925
"ElementwiseExpm1Module_basic",
2926+
"ElementwiseLogAddExpModule_basic",
2927+
"ElementwiseLogAddExp2Module_basic",
29262928
"ElementwiseSpecialExpm1IntModule_basic",
29272929
"ElementwiseSpecialExpm1Module_basic",
29282930
"ElementwiseFmodTensor_Int_basic",
@@ -3931,6 +3933,8 @@
39313933
"L1LossMeanReductionModule_basic",
39323934
"L1LossNoReductionModule_basic",
39333935
"L1LossSumReductionModule_basic",
3936+
"ElementwiseLogAddExpModule_basic",
3937+
"ElementwiseLogAddExp2Module_basic",
39343938
"FloatPowerTensorTensorStaticModule_basic",
39353939
"IsInfiniteModule_basic",
39363940
"ElementwiseCopysignModule_basic",

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1482,6 +1482,12 @@ def aten〇count_nonzero〇dim_IntList〡shape(self: List[int], dim: List[int])
14821482
def aten〇_to_copy〡shape(self: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, non_blocking: bool = False, memory_format: Optional[int] = None) -> List[int]:
14831483
return upstream_shape_functions.unary(self)
14841484

1485+
def aten〇logaddexp〡shape(self: List[int], other: List[int]) -> List[int]:
1486+
return upstream_shape_functions.unary(self)
1487+
1488+
def aten〇logaddexp2〡shape(self: List[int], other: List[int]) -> List[int]:
1489+
return upstream_shape_functions.unary(self)
1490+
14851491
def aten〇masked_fill〇Scalar〡shape(self: List[int], mask: List[int], value: float) -> List[int]:
14861492
return upstream_shape_functions.unary(self)
14871493

@@ -3475,6 +3481,19 @@ def aten〇kthvalue〡dtype(self_rank_dtype: Tuple[int, int], k: int, dim: int =
34753481
def aten〇_log_softmax_backward_data〡dtype(grad_output_rank_dtype: Tuple[int, int], output_rank_dtype: Tuple[int, int], dim: int, input_dtype: int) -> int:
34763482
return input_dtype
34773483

3484+
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=2, error_types={torch.bool}))
3485+
def aten〇logaddexp〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int:
3486+
self_rank, self_dtype = self_rank_dtype
3487+
other_rank, other_dtype = other_rank_dtype
3488+
assert self_dtype != torch.bool and other_dtype != torch.bool
3489+
return self_dtype
3490+
3491+
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=2, error_types={torch.bool, torch.complex32, torch.complex64, torch.complex128}))
3492+
def aten〇logaddexp2〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int:
3493+
self_rank, self_dtype = self_rank_dtype
3494+
assert self_dtype not in [torch.bool, torch.complex32, torch.complex64, torch.complex128]
3495+
return self_dtype
3496+
34783497
@check_dtype_function(_check_tensors_with_the_same_dtype(None, [(3,)], None, None, TensorOfShape(1, dtype=torch.bool), 0))
34793498
def aten〇masked_fill〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], mask_rank_dtype: Tuple[int, int], value: Union[int, float, complex]) -> int:
34803499
self_rank, self_dtype = self_rank_dtype

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -725,6 +725,8 @@ def emit_with_mutating_variants(key, **kwargs):
725725
emit("aten::cumprod : (Tensor, int, int?) -> (Tensor)")
726726
emit("aten::floor_divide.Scalar : (Tensor, Scalar) -> (Tensor)")
727727
emit("aten::logsumexp : (Tensor, int[], bool) -> (Tensor)")
728+
emit("aten::logaddexp : (Tensor, Tensor) -> (Tensor)")
729+
emit("aten::logaddexp2 : (Tensor, Tensor) -> (Tensor)")
728730
emit("aten::mean.dim : (Tensor, int[]?, bool, int?) -> (Tensor)")
729731
emit("aten::__and__.Tensor : (Tensor, Tensor) -> (Tensor)")
730732
emit("aten::__and__.Scalar : (Tensor, Scalar) -> (Tensor)", has_canonicalizer=True)

projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2580,6 +2580,54 @@ def ElementwiseLogitModule_basic(module, tu: TestUtils):
25802580
# ==============================================================================
25812581

25822582

2583+
class ElementwiseLogAddExpModule(torch.nn.Module):
2584+
def __init__(self):
2585+
super().__init__()
2586+
2587+
@export
2588+
@annotate_args(
2589+
[
2590+
None,
2591+
([-1, -1, -1], torch.float32, True),
2592+
([-1, -1, -1], torch.float32, True),
2593+
]
2594+
)
2595+
def forward(self, x, y):
2596+
return torch.ops.aten.logaddexp(x, y)
2597+
2598+
2599+
@register_test_case(module_factory=lambda: ElementwiseLogAddExpModule())
2600+
def ElementwiseLogAddExpModule_basic(module, tu: TestUtils):
2601+
module.forward(tu.rand(3, 2, 4), tu.rand(3, 2, 4))
2602+
2603+
2604+
# ==============================================================================
2605+
2606+
2607+
class ElementwiseLogAddExp2Module(torch.nn.Module):
2608+
def __init__(self):
2609+
super().__init__()
2610+
2611+
@export
2612+
@annotate_args(
2613+
[
2614+
None,
2615+
([-1, -1, -1], torch.float32, True),
2616+
([-1, -1, -1], torch.float32, True),
2617+
]
2618+
)
2619+
def forward(self, x, y):
2620+
return torch.ops.aten.logaddexp2(x, y)
2621+
2622+
2623+
@register_test_case(module_factory=lambda: ElementwiseLogAddExp2Module())
2624+
def ElementwiseLogAddExp2Module_basic(module, tu: TestUtils):
2625+
module.forward(tu.rand(3, 2, 4), tu.rand(3, 2, 4))
2626+
2627+
2628+
# ==============================================================================
2629+
2630+
25832631
class ElementwiseLogSigmoidModule(torch.nn.Module):
25842632
def __init__(self):
25852633
super().__init__()

0 commit comments

Comments
 (0)