Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pytorch_toolbelt/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@
from .wing_loss import *
from .logcosh import *
from .quality_focal_loss import *
from .mcc import *
59 changes: 59 additions & 0 deletions pytorch_toolbelt/losses/mcc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from typing import Optional

import torch
from torch import Tensor
from torch.nn.modules.loss import _Loss

__all__ = ["MCCLoss"]


class MCCLoss(_Loss):
"""
Implementation of Matthews Correlation Coefficient (MCC) loss for image segmentation task.
It supports binary cases.
Reference: https://github.com/kakumarabhishek/MCC-Loss
Paper: https://doi.org/10.1109/ISBI48211.2021.9433782
"""

def __init__(self, eps: Optional[float] = 1e-7):
"""
Initializes the MCCLoss class.

:param eps: Small epsilon for numerical stability
"""
super().__init__()
self.eps = eps

def forward(self, y_pred: Tensor, y_true: Tensor) -> Tensor:
"""
Computes the Matthews Correlation Coefficient (MCC) loss.
MCC = (TP.TN - FP.FN) / sqrt((TP+FP) . (TP+FN) . (TN+FP) . (TN+FN))
where TP, TN, FP, and FN are elements in the confusion matrix.

:param y_pred: Predicted probabilities (logits) of shape (N, 1, H, W)
:param y_true: Ground truth labels of shape (N, 1, H, W)
:return: Computed MCC loss
"""

batch_size = y_true.shape[0]

y_true = y_true.view(batch_size, 1, -1)
y_pred = y_pred.view(batch_size, 1, -1)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here y_pred are logits (-inf,+inf) accroding to docstring, but underlying loss treat them as bounded (0,1) value (E.g .sigmoid() missing?)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this comment. You are right - in my usage, I always applied the loss after a sigmoid operation to the network output, but you are right, this should be an explicit choice. Updated the function to accommodate this choice.


tp = torch.sum(torch.mul(y_pred, y_true)) + self.eps
tn = torch.sum(torch.mul((1 - y_pred), (1 - y_true))) + self.eps
fp = torch.sum(torch.mul(y_pred, (1 - y_true))) + self.eps
fn = torch.sum(torch.mul((1 - y_pred), y_true)) + self.eps
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here tp,fn,fp,fn computed per the whole batch. Shouldn't we compute those per-sample? Given the shapes of y_true/y_pred it should be something like tp = torch.sum(..., dim=(1,2)) + eps).
This way, tp,fn,fp,fn will have shape of [B] and at very end you compute per-sample mean via loss.mean()

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you. I updated the implementation to perform either a sample-level or batch-level reduction. This should address your comment.


numerator = torch.mul(tp, tn) - torch.mul(fp, fn)
denominator = torch.sqrt(
torch.add(tp, fp)
* torch.add(tp, fn)
* torch.add(tn, fp)
* torch.add(tn, fn)
)

mcc = torch.div(numerator.sum(), denominator.sum())
loss = 1 - mcc

return loss