diff --git a/configs/vision/radiology/online/segmentation/btcv.yaml b/configs/vision/radiology/online/segmentation/btcv.yaml index d8e83ce21..5afdd2164 100644 --- a/configs/vision/radiology/online/segmentation/btcv.yaml +++ b/configs/vision/radiology/online/segmentation/btcv.yaml @@ -76,21 +76,19 @@ model: common: - class_path: eva.metrics.AverageLoss evaluation: - - class_path: torchmetrics.segmentation.DiceScore + - class_path: eva.vision.metrics.defaults.MulticlassSegmentationMetricsV2 init_args: num_classes: *NUM_CLASSES - include_background: false - average: macro input_format: one-hot - class_path: torchmetrics.ClasswiseWrapper init_args: metric: - class_path: eva.vision.metrics.MonaiDiceScore + class_path: torchmetrics.segmentation.DiceScore init_args: include_background: true num_classes: *NUM_CLASSES + average: none input_format: one-hot - reduction: none prefix: DiceScore_ labels: - "0_background" diff --git a/configs/vision/radiology/online/segmentation/lits17.yaml b/configs/vision/radiology/online/segmentation/lits17.yaml index 2978a9595..b48a29a12 100644 --- a/configs/vision/radiology/online/segmentation/lits17.yaml +++ b/configs/vision/radiology/online/segmentation/lits17.yaml @@ -76,21 +76,19 @@ model: common: - class_path: eva.metrics.AverageLoss evaluation: - - class_path: torchmetrics.segmentation.DiceScore + - class_path: eva.vision.metrics.defaults.MulticlassSegmentationMetricsV2 init_args: num_classes: *NUM_CLASSES - include_background: false - average: macro input_format: one-hot - class_path: torchmetrics.ClasswiseWrapper init_args: metric: - class_path: eva.vision.metrics.MonaiDiceScore + class_path: torchmetrics.segmentation.DiceScore init_args: include_background: true num_classes: *NUM_CLASSES + average: none input_format: one-hot - reduction: none prefix: DiceScore_ labels: - "0_background" diff --git a/configs/vision/radiology/online/segmentation/lits17_2d.yaml b/configs/vision/radiology/online/segmentation/lits17_2d.yaml index 5532fb672..9b175fcc6 100644 --- a/configs/vision/radiology/online/segmentation/lits17_2d.yaml +++ b/configs/vision/radiology/online/segmentation/lits17_2d.yaml @@ -81,21 +81,19 @@ model: common: - class_path: eva.metrics.AverageLoss evaluation: - - class_path: torchmetrics.segmentation.DiceScore + - class_path: eva.vision.metrics.defaults.MulticlassSegmentationMetricsV2 init_args: num_classes: *NUM_CLASSES - include_background: false - average: macro input_format: one-hot - class_path: torchmetrics.ClasswiseWrapper init_args: metric: - class_path: eva.vision.metrics.MonaiDiceScore + class_path: torchmetrics.segmentation.DiceScore init_args: include_background: true num_classes: *NUM_CLASSES + average: none input_format: one-hot - reduction: none prefix: DiceScore_ labels: - "0_background" diff --git a/configs/vision/radiology/online/segmentation/msd_task7_pancreas.yaml b/configs/vision/radiology/online/segmentation/msd_task7_pancreas.yaml index 5ff36ae6e..be2a9b1bb 100644 --- a/configs/vision/radiology/online/segmentation/msd_task7_pancreas.yaml +++ b/configs/vision/radiology/online/segmentation/msd_task7_pancreas.yaml @@ -76,21 +76,19 @@ model: common: - class_path: eva.metrics.AverageLoss evaluation: - - class_path: torchmetrics.segmentation.DiceScore + - class_path: eva.vision.metrics.defaults.MulticlassSegmentationMetricsV2 init_args: num_classes: *NUM_CLASSES - include_background: false - average: macro input_format: one-hot - class_path: torchmetrics.ClasswiseWrapper init_args: metric: - class_path: eva.vision.metrics.MonaiDiceScore + class_path: torchmetrics.segmentation.DiceScore init_args: include_background: true num_classes: *NUM_CLASSES + average: none input_format: one-hot - reduction: none prefix: DiceScore_ labels: - "0_background" diff --git a/src/eva/vision/metrics/defaults/__init__.py b/src/eva/vision/metrics/defaults/__init__.py index 14bcecda0..f8196a75f 100644 --- a/src/eva/vision/metrics/defaults/__init__.py +++ b/src/eva/vision/metrics/defaults/__init__.py @@ -1,7 +1,11 @@ """Default metric collections API.""" -from eva.vision.metrics.defaults.segmentation import MulticlassSegmentationMetrics +from eva.vision.metrics.defaults.segmentation import ( + MulticlassSegmentationMetrics, + MulticlassSegmentationMetricsV2, +) __all__ = [ "MulticlassSegmentationMetrics", + "MulticlassSegmentationMetricsV2", ] diff --git a/src/eva/vision/metrics/defaults/segmentation/__init__.py b/src/eva/vision/metrics/defaults/segmentation/__init__.py index 34d11a381..f38b4e6c5 100644 --- a/src/eva/vision/metrics/defaults/segmentation/__init__.py +++ b/src/eva/vision/metrics/defaults/segmentation/__init__.py @@ -1,5 +1,8 @@ """Default segmentation metric collections API.""" -from eva.vision.metrics.defaults.segmentation.multiclass import MulticlassSegmentationMetrics +from eva.vision.metrics.defaults.segmentation.multiclass import ( + MulticlassSegmentationMetrics, + MulticlassSegmentationMetricsV2, +) -__all__ = ["MulticlassSegmentationMetrics"] +__all__ = ["MulticlassSegmentationMetrics", "MulticlassSegmentationMetricsV2"] diff --git a/src/eva/vision/metrics/defaults/segmentation/multiclass.py b/src/eva/vision/metrics/defaults/segmentation/multiclass.py index 26b2b0f53..92c54a7b5 100644 --- a/src/eva/vision/metrics/defaults/segmentation/multiclass.py +++ b/src/eva/vision/metrics/defaults/segmentation/multiclass.py @@ -1,11 +1,14 @@ """Default metric collection for multiclass semantic segmentation tasks.""" +from typing import Literal + from eva.core.metrics import structs +from eva.core.utils import requirements from eva.vision.metrics import segmentation class MulticlassSegmentationMetrics(structs.MetricCollection): - """Default metrics for multi-class semantic segmentation tasks.""" + """Metrics for multi-class semantic segmentation tasks.""" def __init__( self, @@ -66,3 +69,54 @@ def __init__( prefix=prefix, postfix=postfix, ) + + +class MulticlassSegmentationMetricsV2(structs.MetricCollection): + """Metrics for multi-class semantic segmentation tasks. + + In torchmetrics 1.8.0, the DiceScore implementation has been + improved, and should now provide enough signal. Therefore, + removing the monai implementation and iou for simplicity and + computational efficiency. + """ + + def __init__( + self, + num_classes: int, + include_background: bool = False, + prefix: str | None = None, + postfix: str | None = None, + input_format: Literal["one-hot", "index"] = "one-hot", + ) -> None: + """Initializes the multi-class semantic segmentation metrics. + + Args: + num_classes: Integer specifying the number of classes. + include_background: Whether to include the background class in the metrics computation. + prefix: A string to add before the keys in the output dictionary. + postfix: A string to add after the keys in the output dictionary. + input_format: Input tensor format. Options are `"one-hot"` for one-hot encoded tensors, + `"index"` for index tensors. + """ + requirements.check_dependencies(requirements={"torchmetrics": "1.8.0"}) + super().__init__( + metrics={ + "DiceScore (macro)": segmentation.DiceScore( + num_classes=num_classes, + include_background=include_background, + average="macro", + aggregation_level="samplewise", + input_format=input_format, + ), + "DiceScore (macro/global)": segmentation.DiceScore( + num_classes=num_classes, + include_background=include_background, + average="macro", + aggregation_level="global", + input_format=input_format, + ), + }, + prefix=prefix, + postfix=postfix, + ) + self.num_classes = num_classes diff --git a/src/eva/vision/metrics/segmentation/dice.py b/src/eva/vision/metrics/segmentation/dice.py index 485524af6..1853553b6 100644 --- a/src/eva/vision/metrics/segmentation/dice.py +++ b/src/eva/vision/metrics/segmentation/dice.py @@ -1,20 +1,20 @@ -"""Generalized Dice Score metric for semantic segmentation.""" +"""Dice Score metric for semantic segmentation.""" from typing import Any, Literal import torch from torchmetrics import segmentation -from torchmetrics.functional.segmentation.dice import _dice_score_update from typing_extensions import override from eva.vision.metrics.segmentation import _utils class DiceScore(segmentation.DiceScore): - """Defines the Generalized Dice Score. + """Dice Score metric for semantic segmentation tasks. - It expands the `torchmetrics` class by including an `ignore_index` - functionality and converting tensors to one-hot format. + This implementation expands the `torchmetrics` class by including + an `ignore_index` functionality and converting tensors to one-hot + format on the fly. """ def __init__( @@ -22,6 +22,7 @@ def __init__( num_classes: int, include_background: bool = True, average: Literal["micro", "macro", "weighted", "none"] | None = "micro", + input_format: Literal["one-hot", "index", "auto"] = "auto", ignore_index: int | None = None, **kwargs: Any, ) -> None: @@ -32,6 +33,9 @@ def __init__( include_background: Whether to include the background class in the computation average: The method to average the dice score accross the different classes. Options are `"micro"`, `"macro"`, `"weighted"`, `"none"` or `None`. + input_format: Input tensor format. Options are `"one-hot"` for one-hot encoded tensors, + `"index"` for index tensors, or `"auto"` to automatically convert the format + to one-hot. ignore_index: Integer specifying a target class to ignore. If given, this class index does not contribute to the returned score, regardless of reduction method. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. @@ -42,28 +46,23 @@ def __init__( + (ignore_index == 0 and not include_background), include_background=include_background, average=average, - input_format="one-hot", + input_format=input_format if input_format == "index" else "one-hot", **kwargs, ) self.orig_num_classes = num_classes self.ignore_index = ignore_index + self.input_format = input_format @override def update(self, preds: torch.Tensor, target: torch.Tensor) -> None: - preds = _utils.index_to_one_hot(preds, num_classes=self.orig_num_classes) - target = _utils.index_to_one_hot(target, num_classes=self.orig_num_classes) + if self.input_format == "auto": + preds = _utils.index_to_one_hot(preds, num_classes=self.orig_num_classes) + target = _utils.index_to_one_hot(target, num_classes=self.orig_num_classes) if self.ignore_index is not None: + if self.input_format == "index": + raise ValueError( + "When `ignore_index` is set, `input_format` must be 'one-hot' or 'auto'." + ) preds, target = _utils.apply_ignore_index(preds, target, self.ignore_index) - # TODO: Replace _update by super.update() once the following issue is fixed: - # https://github.com/Lightning-AI/torchmetrics/issues/2847 - self._update(preds.long(), target.long()) - # super().update(preds=preds.long(), target=target.long()) - - def _update(self, preds: torch.Tensor, target: torch.Tensor) -> None: - numerator, denominator, support = _dice_score_update( - preds, target, self.num_classes, self.include_background, self.input_format # type: ignore - ) - self.numerator.append(numerator) - self.denominator.append(denominator) - self.support.append(support) + super().update(preds.long(), target.long()) diff --git a/tests/eva/core/metrics/core/test_metric_module.py b/tests/eva/core/metrics/core/test_metric_module.py index faccf2853..59bbe5912 100644 --- a/tests/eva/core/metrics/core/test_metric_module.py +++ b/tests/eva/core/metrics/core/test_metric_module.py @@ -3,19 +3,39 @@ from typing import List import pytest -import torchmetrics +import torchmetrics.segmentation from eva.core.metrics import structs +NUM_CLASSES = 3 + @pytest.mark.parametrize( "schema, expected", [ - (structs.MetricsSchema(train=torchmetrics.Dice()), [1, 0, 0]), - (structs.MetricsSchema(evaluation=torchmetrics.Dice()), [0, 1, 1]), - (structs.MetricsSchema(common=torchmetrics.Dice()), [1, 1, 1]), ( - structs.MetricsSchema(train=torchmetrics.Dice(), evaluation=torchmetrics.Dice()), + structs.MetricsSchema( + train=torchmetrics.segmentation.DiceScore(num_classes=NUM_CLASSES) + ), + [1, 0, 0], + ), + ( + structs.MetricsSchema( + evaluation=torchmetrics.segmentation.DiceScore(num_classes=NUM_CLASSES) + ), + [0, 1, 1], + ), + ( + structs.MetricsSchema( + common=torchmetrics.segmentation.DiceScore(num_classes=NUM_CLASSES) + ), + [1, 1, 1], + ), + ( + structs.MetricsSchema( + train=torchmetrics.segmentation.DiceScore(num_classes=NUM_CLASSES), + evaluation=torchmetrics.segmentation.DiceScore(num_classes=NUM_CLASSES), + ), [1, 1, 1], ), ], diff --git a/tests/eva/core/metrics/core/test_schemas.py b/tests/eva/core/metrics/core/test_schemas.py index 29d0c2a84..3f7125ffc 100644 --- a/tests/eva/core/metrics/core/test_schemas.py +++ b/tests/eva/core/metrics/core/test_schemas.py @@ -2,6 +2,7 @@ import pytest import torchmetrics +import torchmetrics.segmentation from eva.core.metrics import structs from eva.core.metrics.structs.typings import MetricModuleType @@ -33,23 +34,23 @@ ), ( torchmetrics.Accuracy("binary"), - torchmetrics.Dice(), + torchmetrics.segmentation.DiceScore(num_classes=2), None, - "[BinaryAccuracy(), Dice()]", + "[BinaryAccuracy(), DiceScore()]", "BinaryAccuracy()", ), ( torchmetrics.Accuracy("binary"), None, - torchmetrics.Dice(), + torchmetrics.segmentation.DiceScore(num_classes=2), "BinaryAccuracy()", - "[BinaryAccuracy(), Dice()]", + "[BinaryAccuracy(), DiceScore()]", ), ( torchmetrics.Accuracy("binary"), - torchmetrics.Dice(), + torchmetrics.segmentation.DiceScore(num_classes=2), torchmetrics.AUROC("binary"), - "[BinaryAccuracy(), Dice()]", + "[BinaryAccuracy(), DiceScore()]", "[BinaryAccuracy(), BinaryAUROC()]", ), ], diff --git a/tests/eva/vision/metrics/defaults/segmentation/test_multiclass.py b/tests/eva/vision/metrics/defaults/segmentation/test_multiclass.py index 9ab370382..e487044fb 100644 --- a/tests/eva/vision/metrics/defaults/segmentation/test_multiclass.py +++ b/tests/eva/vision/metrics/defaults/segmentation/test_multiclass.py @@ -3,15 +3,14 @@ import pytest import torch +from eva.core.metrics import structs from eva.vision.metrics import defaults NUM_BATCHES = 2 BATCH_SIZE = 4 -"""Test parameters.""" - -NUM_CLASSES_ONE = 3 -PREDS_ONE = torch.randint(0, NUM_CLASSES_ONE, (NUM_BATCHES, BATCH_SIZE, 32, 32)) -TARGET_ONE = torch.randint(0, NUM_CLASSES_ONE, (NUM_BATCHES, BATCH_SIZE, 32, 32)) +NUM_CLASSES = 3 +PREDS = torch.randint(0, NUM_CLASSES, (NUM_BATCHES, BATCH_SIZE, 32, 32)) +TARGET = torch.randint(0, NUM_CLASSES, (NUM_BATCHES, BATCH_SIZE, 32, 32)) EXPECTED_ONE = { "MonaiDiceScore": torch.tensor(0.34805023670196533), "MonaiDiceScore (ignore_empty=False)": torch.tensor(0.34805023670196533), @@ -20,18 +19,25 @@ "DiceScore (weighted)": torch.tensor(0.3484232723712921), "MeanIoU": torch.tensor(0.2109210342168808), } -"""Test features.""" assert EXPECTED_ONE["MonaiDiceScore (ignore_empty=False)"] == EXPECTED_ONE["DiceScore (macro)"] +PREDS_ONE_HOT = torch.nn.functional.one_hot(PREDS, num_classes=NUM_CLASSES).movedim(-1, -3) +TARGET_ONE_HOT = torch.nn.functional.one_hot(TARGET, num_classes=NUM_CLASSES).movedim(-1, -3) +EXPECTED_TWO = { + "DiceScore (macro)": torch.tensor(0.34805023670196533), + "DiceScore (macro/global)": torch.tensor(0.3483232259750366), +} @pytest.mark.parametrize( - "num_classes, preds, target, expected", + "metrics_collection, preds, target, expected", [ - (NUM_CLASSES_ONE, PREDS_ONE, TARGET_ONE, EXPECTED_ONE), + ("multiclass_segmentation_metrics", PREDS, TARGET, EXPECTED_ONE), + ("multiclass_segmentation_metrics_v2", PREDS_ONE_HOT, TARGET_ONE_HOT, EXPECTED_TWO), ], + indirect=["metrics_collection"], ) def test_multiclass_segmentation_metrics( - multiclass_segmentation_metrics: defaults.MulticlassSegmentationMetrics, + metrics_collection: structs.MetricCollection, preds: torch.Tensor, target: torch.Tensor, expected: torch.Tensor, @@ -40,16 +46,21 @@ def test_multiclass_segmentation_metrics( def _calculate_metric() -> None: for batch_preds, batch_target in zip(preds, target, strict=False): - multiclass_segmentation_metrics.update(preds=batch_preds, target=batch_target) # type: ignore - actual = multiclass_segmentation_metrics.compute() + metrics_collection.update(preds=batch_preds, target=batch_target) # type: ignore + actual = metrics_collection.compute() torch.testing.assert_close(actual, expected, rtol=1e-04, atol=1e-04) _calculate_metric() - multiclass_segmentation_metrics.reset() + metrics_collection.reset() _calculate_metric() @pytest.fixture(scope="function") -def multiclass_segmentation_metrics(num_classes: int) -> defaults.MulticlassSegmentationMetrics: - """MulticlassSegmentationMetrics fixture.""" - return defaults.MulticlassSegmentationMetrics(num_classes=num_classes) +def metrics_collection(request) -> structs.MetricCollection: + """Indirect fixture that returns the appropriate metrics class.""" + if request.param == "multiclass_segmentation_metrics": + return defaults.MulticlassSegmentationMetrics(num_classes=NUM_CLASSES) + elif request.param == "multiclass_segmentation_metrics_v2": + return defaults.MulticlassSegmentationMetricsV2(num_classes=NUM_CLASSES) + else: + raise ValueError(f"Unknown metrics fixture: {request.param}")