Open
Description
🚀 Feature
when we compute IOU
import torch
_ = torch.manual_seed(0)
from torchmetrics.segmentation import MeanIoU
miou = MeanIoU(num_classes=3)
preds = torch.randint(0, 2, (5,))
target = torch.as_tensor((0, 1, 2, 0, 255)) # An index of 255 is a tag to be ignored.
miou(preds, target)
>>> This will result in an error
Motivation
When I generate the sample pairs, the opposite mask (assuming 3 classes), but not all pixels in the entire mask should be classified into a particular class, so I set these pixels to 255. The pixel is then ignored in the loss calculation using torch.nn.CrossEntropyLoss(ignore_index=255)
. However, the IOU calculation does not have this feature, which leads to errors in the IOU calculation, so I wondered if it could be made to support the ignore_index parameter as well, to ignore certain pixels.
Pitch
import torch
_ = torch.manual_seed(0)
from torchmetrics.segmentation import MeanIoU
miou = MeanIoU(num_classes=3, ignore_index=255) # support ignore_index param to ignore index 255
preds = torch.randint(0, 2, (5,))
target = torch.as_tensor((0, 1, 2, 0, 255)) # An index of 255 is a tag to be ignored.
miou(preds, target)
Alternatives
torchmetrics/src/torchmetrics/functional/segmentation/mean_iou.py
Lines 52 to 55 in 62d9d32
def _mean_iou_update(
preds: Tensor,
target: Tensor,
num_classes: int,
include_background: bool = False,
input_format: Literal["one-hot", "index"] = "one-hot",
ignore_index=255
) -> Tuple[Tensor, Tensor]:
...
if input_format == "index":
preds = torch.nn.functional.one_hot(preds, num_classes=num_classes).movedim(-1, 1)
mask = torch.where(target == ignore_index) # Add removal of ignored labels
target[mask] = 0
target = torch.nn.functional.one_hot(target, num_classes=num_classes)
target[mask] = 0 # set ont-hot to zero-hot from ignored labels
target = target.movedim(-1, 1)
...