-
Notifications
You must be signed in to change notification settings - Fork 126
Add MCCLoss implementation for binary image segmentation #108
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,3 +14,4 @@ | |
from .wing_loss import * | ||
from .logcosh import * | ||
from .quality_focal_loss import * | ||
from .mcc import * |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
from typing import Optional | ||
|
||
import torch | ||
from torch import Tensor | ||
from torch.nn.modules.loss import _Loss | ||
|
||
__all__ = ["MCCLoss"] | ||
|
||
|
||
class MCCLoss(_Loss): | ||
""" | ||
Implementation of Matthews Correlation Coefficient (MCC) loss for image segmentation task. | ||
It supports binary cases. | ||
Reference: https://github.com/kakumarabhishek/MCC-Loss | ||
Paper: https://doi.org/10.1109/ISBI48211.2021.9433782 | ||
""" | ||
|
||
def __init__(self, eps: Optional[float] = 1e-7): | ||
""" | ||
Initializes the MCCLoss class. | ||
|
||
:param eps: Small epsilon for numerical stability | ||
""" | ||
super().__init__() | ||
self.eps = eps | ||
|
||
def forward(self, y_pred: Tensor, y_true: Tensor) -> Tensor: | ||
""" | ||
Computes the Matthews Correlation Coefficient (MCC) loss. | ||
MCC = (TP.TN - FP.FN) / sqrt((TP+FP) . (TP+FN) . (TN+FP) . (TN+FN)) | ||
where TP, TN, FP, and FN are elements in the confusion matrix. | ||
|
||
:param y_pred: Predicted probabilities (logits) of shape (N, 1, H, W) | ||
:param y_true: Ground truth labels of shape (N, 1, H, W) | ||
:return: Computed MCC loss | ||
""" | ||
|
||
batch_size = y_true.shape[0] | ||
|
||
y_true = y_true.view(batch_size, 1, -1) | ||
y_pred = y_pred.view(batch_size, 1, -1) | ||
|
||
tp = torch.sum(torch.mul(y_pred, y_true)) + self.eps | ||
tn = torch.sum(torch.mul((1 - y_pred), (1 - y_true))) + self.eps | ||
fp = torch.sum(torch.mul(y_pred, (1 - y_true))) + self.eps | ||
fn = torch.sum(torch.mul((1 - y_pred), y_true)) + self.eps | ||
|
||
|
||
numerator = torch.mul(tp, tn) - torch.mul(fp, fn) | ||
denominator = torch.sqrt( | ||
torch.add(tp, fp) | ||
* torch.add(tp, fn) | ||
* torch.add(tn, fp) | ||
* torch.add(tn, fn) | ||
) | ||
|
||
mcc = torch.div(numerator.sum(), denominator.sum()) | ||
loss = 1 - mcc | ||
|
||
return loss |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here
y_pred
are logits (-inf,+inf) accroding to docstring, but underlying loss treat them as bounded (0,1) value (E.g.sigmoid()
missing?)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this comment. You are right - in my usage, I always applied the loss after a sigmoid operation to the network output, but you are right, this should be an explicit choice. Updated the function to accommodate this choice.