Skip to content

Conversation

kakumarabhishek
Copy link

Added loss function for binary image segmentation with the Matthews Correlation Coefficient (MCC). Based on the reference implementation and the paper.

@BloodAxe
Copy link
Owner

@kakumarabhishek Hi, and thanks for your PR!
I'm not familiar with MMC loss but I went quickly over the paper attached (kudos for that!)
and I have few questions. Can you help check my comments to ensure we don't miss anything? TIA!

batch_size = y_true.shape[0]

y_true = y_true.view(batch_size, 1, -1)
y_pred = y_pred.view(batch_size, 1, -1)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here y_pred are logits (-inf,+inf) accroding to docstring, but underlying loss treat them as bounded (0,1) value (E.g .sigmoid() missing?)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this comment. You are right - in my usage, I always applied the loss after a sigmoid operation to the network output, but you are right, this should be an explicit choice. Updated the function to accommodate this choice.

tp = torch.sum(torch.mul(y_pred, y_true)) + self.eps
tn = torch.sum(torch.mul((1 - y_pred), (1 - y_true))) + self.eps
fp = torch.sum(torch.mul(y_pred, (1 - y_true))) + self.eps
fn = torch.sum(torch.mul((1 - y_pred), y_true)) + self.eps
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here tp,fn,fp,fn computed per the whole batch. Shouldn't we compute those per-sample? Given the shapes of y_true/y_pred it should be something like tp = torch.sum(..., dim=(1,2)) + eps).
This way, tp,fn,fp,fn will have shape of [B] and at very end you compute per-sample mean via loss.mean()

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you. I updated the implementation to perform either a sample-level or batch-level reduction. This should address your comment.

Add option to calculate loss from either logits or predictions. Add option to perform sample-wise or batch-wise reduction. Add tests for `MCCLoss`.
@kakumarabhishek
Copy link
Author

Thanks @BloodAxe for reviewing the code. Both your comments are valid and I appreciate your inputs. Here are the changes:

losses/mcc.py

  • Added a from_logits parameter (bool) to specify if the loss is being calculated on logits or final predictions. If the former, I perform a sigmoid operation (logsigmoid + exp for stability).
  • Added a reduction parameter ([sample, batch]; default: batch) to calculate the MCC at a sample-level or batch-level.

tests/test_losses.py

  • Added tests for the MCCLoss.

Please let me know if you have any other concerns.

@kakumarabhishek
Copy link
Author

@BloodAxe , can you please take a look at this? Thank you.

@kakumarabhishek
Copy link
Author

@BloodAxe , any updates?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants