Skip to content

[TORCH] Add support for PoissonNLLLoss Op #4232

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jun 19, 2025
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -9452,6 +9452,34 @@ def Torch_AtenNllLossBackwardOp : Torch_Op<"aten.nll_loss_backward", [
}];
}

def Torch_AtenPoissonNllLossOp : Torch_Op<"aten.poisson_nll_loss", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::poisson_nll_loss : (Tensor, Tensor, bool, bool, float, int) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$input,
AnyTorchTensorType:$target,
Torch_BoolType:$log_input,
Torch_BoolType:$full,
Torch_FloatType:$eps,
Torch_IntType:$reduction
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenPoissonNllLossOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 6, 1);
}
void AtenPoissonNllLossOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 6, 1);
}
}];
}

def Torch_AtenBincountOp : Torch_Op<"aten.bincount", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
25 changes: 25 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10736,6 +10736,17 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" return %2 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.poisson_nll_loss\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.bool, %arg3: !torch.bool, %arg4: !torch.float, %arg5: !torch.int) -> !torch.list<int> {\n"
" %int0 = torch.constant.int 0\n"
" %0 = torch.aten.eq.int %arg5, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" %1 = torch.prim.If %0 -> (!torch.list<int>) {\n"
" torch.prim.If.yield %arg0 : !torch.list<int>\n"
" } else {\n"
" %2 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
" torch.prim.If.yield %2 : !torch.list<int>\n"
" }\n"
" return %1 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.native_layer_norm\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.optional<list<int>>, %arg4: !torch.float) -> !torch.tuple<list<int>, list<int>, list<int>> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.native_layer_norm(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.tuple<list<int>, list<int>, list<int>>\n"
" return %0 : !torch.tuple<list<int>, list<int>, list<int>>\n"
Expand Down Expand Up @@ -15220,6 +15231,20 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %4 = torch.prim.TupleConstruct %0#1, %0#1 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
" return %4 : !torch.tuple<int, int>\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.poisson_nll_loss\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.bool, %arg3: !torch.bool, %arg4: !torch.float, %arg5: !torch.int) -> !torch.int {\n"
" %int6 = torch.constant.int 6\n"
" %int15 = torch.constant.int 15\n"
" %int5 = torch.constant.int 5\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = torch.prim.ListConstruct %int5, %int15 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" %2 = torch.aten.__contains__.int_list %1, %0#1 : !torch.list<int>, !torch.int -> !torch.bool\n"
" %3 = torch.prim.If %2 -> (!torch.int) {\n"
" torch.prim.If.yield %int6 : !torch.int\n"
" } else {\n"
" torch.prim.If.yield %0#1 : !torch.int\n"
" }\n"
" return %3 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.native_layer_norm\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>, %arg2: !torch.optional<tuple<int, int>>, %arg3: !torch.optional<tuple<int, int>>, %arg4: !torch.float) -> !torch.tuple<int, int, int> {\n"
" %int7 = torch.constant.int 7\n"
" %int10 = torch.constant.int 10\n"
Expand Down
79 changes: 79 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10471,6 +10471,84 @@ class DecomposeAtenNllLossForwardOp
};
} // namespace

namespace {
class DecomposeAtenPoissonNllLossOp
: public OpRewritePattern<AtenPoissonNllLossOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenPoissonNllLossOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value input = op.getInput();
Value target = op.getTarget();
Value logInput = op.getLogInput();
Value full = op.getFull();
Value reduction = op.getReduction();
Value eps = op.getEps();

bool logInVal, fullVal;
if (!matchPattern(logInput, m_TorchConstantBool(&logInVal)))
return rewriter.notifyMatchFailure(
op, "expected logInput argument to be constant bool");
if (!matchPattern(full, m_TorchConstantBool(&fullVal)))
return rewriter.notifyMatchFailure(
op, "expected full argument to be constant bool");

int64_t reductionInt;
if (!matchPattern(reduction, m_TorchConstantInt(&reductionInt)))
return rewriter.notifyMatchFailure(op, "expected constant reduction");

double epsFloat;
if (!matchPattern(eps, m_TorchConstantFloat(&epsFloat))) {
return rewriter.notifyMatchFailure(op, "expected constant eps");
}
// TODO: add support for full=true (Stirling approximation)
if (fullVal)
return rewriter.notifyMatchFailure(
op, "Unimplemented: full loss computation is not supported");

Value one =
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0));
Value epsConst = rewriter.create<ConstantFloatOp>(
loc, rewriter.getF64FloatAttr(epsFloat));

Value safeInput = rewriter.create<AtenAddScalarOp>(loc, input.getType(),
input, epsConst, one);

Value loss;
if (logInVal) {
Value expIn = rewriter.create<AtenExpOp>(loc, input.getType(), input);
Value targetMulInput =
rewriter.create<AtenMulTensorOp>(loc, input.getType(), target, input);
loss = rewriter.create<AtenSubTensorOp>(loc, input.getType(), expIn,
targetMulInput, one);
} else {
Value logSafeInput =
rewriter.create<AtenLogOp>(loc, input.getType(), safeInput);
Value targetMulLog = rewriter.create<AtenMulTensorOp>(
loc, input.getType(), target, logSafeInput);
loss = rewriter.create<AtenSubTensorOp>(loc, input.getType(), input,
targetMulLog, one);
}

Value result;
if (reductionInt == 0) {
result = loss;
} else if (reductionInt == 1) {
// Case 1: Mean Reduction
result = rewriter.create<AtenMeanOp>(
loc, op.getType(), loss, rewriter.create<ConstantNoneOp>(loc));
} else {
// Case 2: Sum Reduction
result = rewriter.create<AtenSumOp>(loc, op.getType(), loss,
rewriter.create<ConstantNoneOp>(loc));
}
rewriter.replaceOp(op, result);
return success();
}
};
} // namespace

namespace {
class DecomposeAtenBinaryCrossEntropyWithLogitsOp
: public OpRewritePattern<AtenBinaryCrossEntropyWithLogitsOp> {
Expand Down Expand Up @@ -12384,6 +12462,7 @@ class DecomposeComplexOpsPass
addPatternIfTargetOpIsIllegal<DecomposeAtenOneHotOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenCrossEntropyLossOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenNllLossForwardOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenPoissonNllLossOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenBinaryCrossEntropyWithLogitsOp>(
patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenVarMeanDimOp>(patterns);
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -537,6 +537,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenLerpTensorOp>();
target.addIllegalOp<AtenMseLossOp>();
target.addIllegalOp<AtenL1LossOp>();
target.addIllegalOp<AtenPoissonNllLossOp>();
target.addIllegalOp<AtenRandintLowOp>();
target.addIllegalOp<AtenRandintOp>();
target.addIllegalOp<AtenVarMeanCorrectionOp>();
Expand Down
4 changes: 4 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3068,6 +3068,10 @@
"NllLossStaticModule_mean_basic",
"NllLossModule_sum_basic",
"NllLossStaticModule_sum_basic",
"PoissonNLLLossNoReductionModule_basic",
"PoissonNLLLossMeanReductionModule_basic",
"PoissonNLLLossSumReductionModule_basic",
"PoissonNLLLossNonDefaultEpsModule_basic",
"NormScalarComplexModule_basic",
"NormScalarModule_basic",
"NormScalarOptDimKeepDimComplexModule_basic",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2205,6 +2205,16 @@ def aten〇binary_cross_entropy_with_logits〡shape(self: List[int], target: Lis
result_shape = scalar_shape
return result_shape

@check_shape_function([
Invocation(TensorOfShape(2, 3), TensorOfShape(2, 3), True, False, 1e-8, 0), # No reduction
Invocation(TensorOfShape(2, 3), TensorOfShape(2, 3), True, False, 1e-8, 1), # Mean reduction
Invocation(TensorOfShape(2, 3), TensorOfShape(2, 3), True, False, 1e-8, 2), # Sum reduction
])
def aten〇poisson_nll_loss〡shape(input: List[int], target: List[int], log_input: bool, full: bool, eps: float, reduction: int) -> List[int]:
if reduction == 0:
return input
return []

@check_shape_function([
Invocation(TensorOfShape(2, 5, 2, 2, 3), [2, 2, 3], None, None, 1e-6), # Basic case.
])
Expand Down Expand Up @@ -5094,6 +5104,20 @@ def aten〇nll_loss_forward〡dtype(self_rank_dtype: Tuple[int, int], target_ran
assert target_dtype == torch.int64 or target_dtype == torch.int32
return self_dtype, self_dtype

@check_dtype_function([
Invocation(TensorOfShape(2, 3, dtype=torch.float32), TensorOfShape(2, 3, dtype=torch.float32), # No reduction
True, False, 1e-8, 0),
Invocation(TensorOfShape(4, 5, dtype=torch.float32), TensorOfShape(4, 5, dtype=torch.float32), # Mean reduction
True, False, 1e-8, 1),
Invocation(TensorOfShape(3, 3, dtype=torch.float64), TensorOfShape(3, 3, dtype=torch.float64), # Sum reduction
True, False, 1e-8, 2),
])
def aten〇poisson_nll_loss〡dtype(input_rank_dtype: Tuple[int, int], target_rank_dtype: Tuple[int, int], log_input: bool, full: bool, eps: float, reduction: int) -> int:
_, input_dtype = input_rank_dtype
if input_dtype in (torch.float16, torch.bfloat16):
return torch.float32
return input_dtype

@check_dtype_function(
[Invocation(TensorOfShape(2, 3, dtype=torch.float32), [3], TensorOfShape(3, dtype=torch.float32),
TensorOfShape(3, dtype=torch.float32), eps=0.0),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -760,6 +760,9 @@ def emit_with_mutating_variants(key, **kwargs):
emit(
"aten::nll_loss_backward : (Tensor, Tensor, Tensor, Tensor?, int, int, Tensor) -> (Tensor)"
)
emit(
"aten::poisson_nll_loss : (Tensor, Tensor, bool, bool, float, int) -> (Tensor)"
)
emit("aten::bincount : (Tensor, Tensor?, int) -> (Tensor)")
emit("aten::linalg_vector_norm : (Tensor, Scalar, int[]?, bool, int?) -> (Tensor)")
emit("aten::linalg_norm : (Tensor, Scalar?, int[]?, bool, int?) -> (Tensor)")
Expand Down
116 changes: 116 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/nll_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,3 +670,119 @@ def NllLossModuleBackward1DSumWeight_basic(module, tu: TestUtils):
module.forward(
tu.rand(1), tu.rand(3), torch.tensor([2, 3, 0]), tu.rand(3), torch.tensor(3.0)
)


class PoissonNLLLossNoReductionModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[None, ([-1, -1], torch.float32, True), ([-1, -1], torch.float32, True)]
)
def forward(self, input, target):
return torch.ops.aten.poisson_nll_loss(
input=input,
target=target,
log_input=False,
full=False,
eps=1e-8,
reduction=0,
)


@register_test_case(module_factory=lambda: PoissonNLLLossNoReductionModule())
def PoissonNLLLossNoReductionModule_basic(
module: PoissonNLLLossNoReductionModule, tu: TestUtils
):
input = tu.rand(4, 3).abs()
target = torch.poisson(input)
module.forward(input, target)


class PoissonNLLLossMeanReductionModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([-1, -1], torch.float32, True),
([-1, -1], torch.float32, True),
]
)
def forward(self, input, target):
return torch.ops.aten.poisson_nll_loss(
input=input,
target=target,
log_input=True,
full=False,
eps=1e-8,
reduction=1,
)


@register_test_case(module_factory=lambda: PoissonNLLLossMeanReductionModule())
def PoissonNLLLossMeanReductionModule_basic(
module: PoissonNLLLossMeanReductionModule, tu: TestUtils
):
input = tu.rand(5, 7).abs()
target = torch.poisson(input)
module.forward(input, target)


class PoissonNLLLossSumReductionModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[None, ([-1, -1], torch.float32, True), ([-1, -1], torch.float32, True)]
)
def forward(self, input, target):
return torch.ops.aten.poisson_nll_loss(
input=input,
target=target,
log_input=True,
full=False,
eps=1e-8,
reduction=2,
)


@register_test_case(module_factory=lambda: PoissonNLLLossSumReductionModule())
def PoissonNLLLossSumReductionModule_basic(
module: PoissonNLLLossSumReductionModule, tu: TestUtils
):
input = tu.rand(3, 3)
target = torch.poisson(input.abs())
module.forward(input, target)


class PoissonNLLLossNonDefaultEpsModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[None, ([-1, -1], torch.float32, True), ([-1, -1], torch.float32, True)]
)
def forward(self, input, target):
return torch.ops.aten.poisson_nll_loss(
input=input,
target=target,
log_input=False,
full=False,
eps=0.5,
reduction=1,
)


@register_test_case(module_factory=lambda: PoissonNLLLossNonDefaultEpsModule())
def PoissonNLLLossNonDefaultEpsModule_basic(
module: PoissonNLLLossNonDefaultEpsModule, tu: TestUtils
):
input = tu.rand(5, 4)
target = torch.poisson(input.abs())
module.forward(input, target)
Loading