Skip to content

Segmentation IOU compute Ignore some tagged values that don't need to be recorded (such as 255) #2747

Open
@woldier

Description

@woldier

🚀 Feature

when we compute IOU

import torch

_ = torch.manual_seed(0)
from torchmetrics.segmentation import MeanIoU

miou = MeanIoU(num_classes=3)
preds = torch.randint(0, 2, (5,))
target = torch.as_tensor((0, 1, 2, 0, 255))  # An index of 255 is a tag to be ignored.
miou(preds, target)
>>> This will result in an error

Motivation

When I generate the sample pairs, the opposite mask (assuming 3 classes), but not all pixels in the entire mask should be classified into a particular class, so I set these pixels to 255. The pixel is then ignored in the loss calculation using torch.nn.CrossEntropyLoss(ignore_index=255). However, the IOU calculation does not have this feature, which leads to errors in the IOU calculation, so I wondered if it could be made to support the ignore_index parameter as well, to ignore certain pixels.

Pitch

import torch
_ = torch.manual_seed(0)
from torchmetrics.segmentation import MeanIoU
miou = MeanIoU(num_classes=3, ignore_index=255)  # support ignore_index param to ignore index 255
preds = torch.randint(0, 2, (5,))
target = torch.as_tensor((0, 1, 2, 0, 255))  # An index of 255 is a tag to be ignored.
miou(preds, target)

Alternatives

if input_format == "index":
preds = torch.nn.functional.one_hot(preds, num_classes=num_classes).movedim(-1, 1)
target = torch.nn.functional.one_hot(target, num_classes=num_classes).movedim(-1, 1)

def _mean_iou_update(
    preds: Tensor,
    target: Tensor,
    num_classes: int,
    include_background: bool = False,
    input_format: Literal["one-hot", "index"] = "one-hot",
    ignore_index=255

) -> Tuple[Tensor, Tensor]:
    ...

    if input_format == "index":
        preds = torch.nn.functional.one_hot(preds, num_classes=num_classes).movedim(-1, 1)
        mask = torch.where(target == ignore_index)  #  Add removal of ignored labels
        target[mask] = 0
        target = torch.nn.functional.one_hot(target, num_classes=num_classes)
        target[mask] = 0  # set ont-hot to zero-hot from ignored labels
        target = target.movedim(-1, 1)
  ...

Additional context

Metadata

Metadata

Assignees

Labels

enhancementNew feature or request

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions