Skip to content

[TORCH] Add support for PoissonNLLLoss Op #4232

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jun 19, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -9452,6 +9452,34 @@ def Torch_AtenNllLossBackwardOp : Torch_Op<"aten.nll_loss_backward", [
}];
}

def Torch_AtenPoissonNllLossOp : Torch_Op<"aten.poisson_nll_loss", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::poisson_nll_loss : (Tensor, Tensor, bool, bool, float, int) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$input,
AnyTorchTensorType:$target,
Torch_BoolType:$log_input,
Torch_BoolType:$full,
Torch_FloatType:$eps,
Torch_IntType:$reduction
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenPoissonNllLossOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 6, 1);
}
void AtenPoissonNllLossOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 6, 1);
}
}];
}

def Torch_AtenBincountOp : Torch_Op<"aten.bincount", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
15 changes: 15 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10736,6 +10736,17 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" return %2 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.poisson_nll_loss\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.bool, %arg3: !torch.bool, %arg4: !torch.float, %arg5: !torch.int) -> !torch.list<int> {\n"
" %int0 = torch.constant.int 0\n"
" %0 = torch.aten.eq.int %arg5, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" %1 = torch.prim.If %0 -> (!torch.list<int>) {\n"
" torch.prim.If.yield %arg0 : !torch.list<int>\n"
" } else {\n"
" %2 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
" torch.prim.If.yield %2 : !torch.list<int>\n"
" }\n"
" return %1 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.native_layer_norm\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.optional<list<int>>, %arg4: !torch.float) -> !torch.tuple<list<int>, list<int>, list<int>> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.native_layer_norm(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.tuple<list<int>, list<int>, list<int>>\n"
" return %0 : !torch.tuple<list<int>, list<int>, list<int>>\n"
Expand Down Expand Up @@ -15220,6 +15231,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %4 = torch.prim.TupleConstruct %0#1, %0#1 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
" return %4 : !torch.tuple<int, int>\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.poisson_nll_loss\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.bool, %arg3: !torch.bool, %arg4: !torch.float, %arg5: !torch.int) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.native_layer_norm\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>, %arg2: !torch.optional<tuple<int, int>>, %arg3: !torch.optional<tuple<int, int>>, %arg4: !torch.float) -> !torch.tuple<int, int, int> {\n"
" %int7 = torch.constant.int 7\n"
" %int10 = torch.constant.int 10\n"
Expand Down
83 changes: 83 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10471,6 +10471,88 @@ class DecomposeAtenNllLossForwardOp
};
} // namespace

namespace {
class DecomposeAtenPoissonNllLossOp
: public OpRewritePattern<AtenPoissonNllLossOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenPoissonNllLossOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value input = op.getInput();
Value target = op.getTarget();
Value logInValV = op.getLogInput();
Value fullValV = op.getFull();
Value redValV = op.getReduction();
Value epsA = op.getEps();

bool logInVal, fullVal;
if (!matchPattern(logInValV, m_TorchConstantBool(&logInVal)))
return rewriter.notifyMatchFailure(op, "expected constant log_input");
if (!matchPattern(fullValV, m_TorchConstantBool(&fullVal)))
return rewriter.notifyMatchFailure(op, "expected constant full");

int64_t r;
if (!matchPattern(redValV, m_TorchConstantInt(&r)))
return rewriter.notifyMatchFailure(op, "expected constant reduction");

double epsValue;
if (!matchPattern(epsA, m_TorchConstantFloat(&epsValue))) {
return rewriter.notifyMatchFailure(op, "expected constant eps");
}

Value one =
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0));
Value epsConst = rewriter.create<ConstantFloatOp>(
loc, rewriter.getF64FloatAttr(epsValue));

Value safeIn = rewriter.create<AtenAddScalarOp>(loc, input.getType(), input,
epsConst, one);

Value lossElem;
if (logInVal) {
Value expIn = rewriter.create<AtenExpOp>(loc, input.getType(), input);
Value tX =
rewriter.create<AtenMulTensorOp>(loc, input.getType(), target, input);
lossElem = rewriter.create<AtenSubTensorOp>(loc, input.getType(), expIn,
tX, one);
} else {
Value logOp = rewriter.create<AtenLogOp>(loc, input.getType(), safeIn);
Value tL =
rewriter.create<AtenMulTensorOp>(loc, input.getType(), target, logOp);
lossElem = rewriter.create<AtenSubTensorOp>(loc, input.getType(), input,
tL, one);
}

if (fullVal)
return rewriter.notifyMatchFailure(op, "full==true not supported yet");

Value result;
if (r == 0) { // None reduction
result = lossElem;
} else {
// Create sum of all elements
Value sum = rewriter.create<AtenSumOp>(
loc, op.getType(), lossElem, rewriter.create<ConstantNoneOp>(loc));

if (r == 2) { // Sum reduction
result = sum;
} else { // Mean reduction (r == 1)
Value nElem = rewriter.create<AtenNumelOp>(loc, lossElem);

Value nElemFloat = rewriter.create<AtenFloatScalarOp>(loc, nElem);

result = rewriter.create<AtenDivScalarOp>(loc, sum.getType(), sum,
nElemFloat);
}
}

rewriter.replaceOp(op, result);
return success();
}
};
} // namespace

namespace {
class DecomposeAtenBinaryCrossEntropyWithLogitsOp
: public OpRewritePattern<AtenBinaryCrossEntropyWithLogitsOp> {
Expand Down Expand Up @@ -12384,6 +12466,7 @@ class DecomposeComplexOpsPass
addPatternIfTargetOpIsIllegal<DecomposeAtenOneHotOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenCrossEntropyLossOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenNllLossForwardOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenPoissonNllLossOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenBinaryCrossEntropyWithLogitsOp>(
patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenVarMeanDimOp>(patterns);
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -537,6 +537,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenLerpTensorOp>();
target.addIllegalOp<AtenMseLossOp>();
target.addIllegalOp<AtenL1LossOp>();
target.addIllegalOp<AtenPoissonNllLossOp>();
target.addIllegalOp<AtenRandintLowOp>();
target.addIllegalOp<AtenRandintOp>();
target.addIllegalOp<AtenVarMeanCorrectionOp>();
Expand Down
1 change: 1 addition & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3068,6 +3068,7 @@
"NllLossStaticModule_mean_basic",
"NllLossModule_sum_basic",
"NllLossStaticModule_sum_basic",
"PoissonNLLLoss_basic",
"NormScalarComplexModule_basic",
"NormScalarModule_basic",
"NormScalarOptDimKeepDimComplexModule_basic",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2205,6 +2205,16 @@ def aten〇binary_cross_entropy_with_logits〡shape(self: List[int], target: Lis
result_shape = scalar_shape
return result_shape

@check_shape_function([
Invocation(TensorOfShape(2, 3), TensorOfShape(2, 3), True, False, 1e-8, 0), # No reduction
Invocation(TensorOfShape(2, 3), TensorOfShape(2, 3), True, False, 1e-8, 1), # Mean reduction
Invocation(TensorOfShape(2, 3), TensorOfShape(2, 3), True, False, 1e-8, 2), # Sum reduction
])
def aten〇poisson_nll_loss〡shape(input: List[int], target: List[int], log_input: bool, full: bool, eps: float, reduction: int) -> List[int]:
if reduction == 0:
return input
return []

@check_shape_function([
Invocation(TensorOfShape(2, 5, 2, 2, 3), [2, 2, 3], None, None, 1e-6), # Basic case.
])
Expand Down Expand Up @@ -5094,6 +5104,18 @@ def aten〇nll_loss_forward〡dtype(self_rank_dtype: Tuple[int, int], target_ran
assert target_dtype == torch.int64 or target_dtype == torch.int32
return self_dtype, self_dtype

@check_dtype_function([
Invocation(TensorOfShape(2, 3, dtype=torch.float32), TensorOfShape(2, 3, dtype=torch.float32), # No reduction
True, False, 1e-8, 0),
Invocation(TensorOfShape(4, 5, dtype=torch.float32), TensorOfShape(4, 5, dtype=torch.float32), # Mean reduction
True, False, 1e-8, 1),
Invocation(TensorOfShape(3, 3, dtype=torch.float64), TensorOfShape(3, 3, dtype=torch.float64), # Sum reduction
True, False, 1e-8, 2),
])
def aten〇poisson_nll_loss〡dtype(input_rank_dtype: Tuple[int, int], target_rank_dtype: Tuple[int, int], log_input: bool, full: bool, eps: float, reduction: int) -> int:
_, in_dtype = input_rank_dtype
return in_dtype

@check_dtype_function(
[Invocation(TensorOfShape(2, 3, dtype=torch.float32), [3], TensorOfShape(3, dtype=torch.float32),
TensorOfShape(3, dtype=torch.float32), eps=0.0),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -760,6 +760,9 @@ def emit_with_mutating_variants(key, **kwargs):
emit(
"aten::nll_loss_backward : (Tensor, Tensor, Tensor, Tensor?, int, int, Tensor) -> (Tensor)"
)
emit(
"aten::poisson_nll_loss : (Tensor, Tensor, bool, bool, float, int) -> (Tensor)"
)
emit("aten::bincount : (Tensor, Tensor?, int) -> (Tensor)")
emit("aten::linalg_vector_norm : (Tensor, Scalar, int[]?, bool, int?) -> (Tensor)")
emit("aten::linalg_norm : (Tensor, Scalar?, int[]?, bool, int?) -> (Tensor)")
Expand Down
28 changes: 28 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/nll_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,3 +670,31 @@ def NllLossModuleBackward1DSumWeight_basic(module, tu: TestUtils):
module.forward(
tu.rand(1), tu.rand(3), torch.tensor([2, 3, 0]), tu.rand(3), torch.tensor(3.0)
)


class PoissonNLLLossModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([-1, -1], torch.float32, True),
([-1, -1], torch.float32, True),
]
)
def forward(self, input, target):
return torch.ops.aten.poisson_nll_loss(
input=input,
target=target,
log_input=True,
full=False,
eps=1e-8,
reduction=1,
)


@register_test_case(module_factory=lambda: PoissonNLLLossModule())
def PoissonNLLLoss_basic(module: PoissonNLLLossModule, tu: TestUtils):
module.forward(tu.rand(5, 7).abs(), torch.poisson(tu.rand(5, 7).abs()))
Loading