diff --git a/src/segmantic/detect/__init__.py b/src/segmantic/detect/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/segmantic/seg/boundary_loss.py b/src/segmantic/seg/boundary_loss.py new file mode 100644 index 0000000..e33855f --- /dev/null +++ b/src/segmantic/seg/boundary_loss.py @@ -0,0 +1,119 @@ +import warnings +from typing import List, Optional, Union + +import torch +from monai.networks import one_hot +from monai.utils import LossReduction +from torch.nn.modules.loss import _Loss + + +class BoundaryLoss(_Loss): + def __init__( + self, + include_background: bool = True, + argmax: bool = True, + threshold: Optional[float] = None, + to_onehot_y: bool = False, + reduction: Union[LossReduction, str] = LossReduction.MEAN, + ) -> None: + """ + Args: + include_background: if False, channel index 0 (background category) is excluded from the calculation. + if the non-background segmentations are small compared to the total image size they can get overwhelmed + by the signal from the background so excluding it in such cases helps convergence. + to_onehot_y: whether to convert `y` into the one-hot format. + reduction: {``"none"``, ``"mean"``, ``"sum"``} + Specifies the reduction to apply to the output. Defaults to ``"mean"``. + + - ``"none"``: no reduction will be applied. + - ``"mean"``: the sum of the output will be divided by the number of elements in the output. + - ``"sum"``: the output will be summed. + + Raises: + ValueError: When more than 1 of [``argmax=True``, ``threshold is not None``]. + Incompatible values. + + """ + super().__init__(reduction=LossReduction(reduction).value) + if argmax and (threshold is not None): + raise ValueError( + "Incompatible values: more than 1 of [argmax=True, threshold is not None]." + ) + self.include_background = include_background + self.argmax = argmax + self.threshold = threshold + self.to_onehot_y = to_onehot_y + + def forward( + self, pred: torch.Tensor, seg_gt: torch.Tensor, dist_gt: torch.Tensor + ) -> torch.Tensor: + """ + Args: + pred: the shape should be BNH[WD], where N is the number of classes. + seg_gt: the shape should be BNH[WD] or B1H[WD], where N is the number of classes. + dist_gt: the shape should be BNH[WD], where N is the number of classes. + + Raises: + AssertionError: When pred and seg_gt (after one hot transform if set) + have different shapes. + ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"]. + """ + num_classes = dist_gt.shape[1] + + if num_classes == 1: + warnings.warn("single channel prediction, `argmax=True` ignored.") + else: + if self.argmax: + pred = torch.argmax(pred, dim=1) + pred = one_hot(pred, num_classes=num_classes, dim=1) + if self.threshold is not None: + pred = pred >= self.threshold + + if self.to_onehot_y: + if num_classes == 1: + warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") + else: + seg_gt = one_hot(seg_gt, num_classes=num_classes, dim=1) + + if dist_gt.shape != pred.shape: + raise AssertionError( + f"boundary distance has different shape ({dist_gt.shape}) from input ({pred.shape})" + ) + if seg_gt.shape != pred.shape: + raise AssertionError( + f"ground truth has different shape ({seg_gt.shape}) from input ({pred.shape})" + ) + + if not self.include_background: + if num_classes == 1: + warnings.warn( + "single channel prediction, `include_background=False` ignored." + ) + else: + # if skipping background, removing first channel + dist_gt = dist_gt[:, 1:] + pred = pred[:, 1:] + seg_gt = seg_gt[:, 1:] + + # reducing only spatial dimensions (not batch nor channels) + reduce_axis: List[int] = torch.arange(2, len(pred.shape)).tolist() + + pred_gt_xor = torch.logical_xor(seg_gt, pred) + + f = torch.sum(dist_gt * pred_gt_xor, dim=reduce_axis) + + if self.reduction == LossReduction.MEAN.value: + f = torch.mean(f) # the batch and channel average + elif self.reduction == LossReduction.SUM.value: + f = torch.sum(f) # sum over the batch and channel dims + elif self.reduction == LossReduction.NONE.value: + # If we are not computing voxelwise loss components at least + # make sure a none reduction maintains a broadcastable shape + broadcast_shape = list(f.shape[0:2]) + [1] * (len(pred.shape) - 2) + f = f.view(broadcast_shape) + else: + raise ValueError( + f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].' + ) + + return f diff --git a/src/segmantic/seg/monai_unet.py b/src/segmantic/seg/monai_unet.py index c5e4ab1..0243dc0 100644 --- a/src/segmantic/seg/monai_unet.py +++ b/src/segmantic/seg/monai_unet.py @@ -1,4 +1,3 @@ -import json import os import subprocess as sp import sys @@ -65,7 +64,7 @@ from ..image.labels import load_tissue_list from ..seg.enum import EnsembleCombination -from ..seg.transforms import SelectBestEnsembled +from ..transforms.ensemble import SelectBestEnsembled from ..utils import config from .dataset import PairedDataSet from .evaluation import confusion_matrix @@ -540,16 +539,9 @@ def predict( gpu_ids: List[int] = [], ) -> None: # load trained model - model_settings_json = model_file.with_suffix(".json") - if model_settings_json.exists(): - print(f"WARNING: Loading legacy model settings from {model_settings_json}") - with model_settings_json.open() as json_file: - settings = json.load(json_file) - net = Net.load_from_checkpoint(f"{model_file}", **settings) - else: - net = Net.load_from_checkpoint( - f"{model_file}", channels=channels, strides=strides, dropout=dropout - ) + net = Net.load_from_checkpoint( + f"{model_file}", channels=channels, strides=strides, dropout=dropout + ) num_classes = net.num_classes net.freeze() diff --git a/src/segmantic/data/__init__.py b/src/segmantic/transforms/__init__.py similarity index 100% rename from src/segmantic/data/__init__.py rename to src/segmantic/transforms/__init__.py diff --git a/src/segmantic/detect/transforms.py b/src/segmantic/transforms/detect.py similarity index 100% rename from src/segmantic/detect/transforms.py rename to src/segmantic/transforms/detect.py diff --git a/src/segmantic/transforms/distance.py b/src/segmantic/transforms/distance.py new file mode 100644 index 0000000..dbcd146 --- /dev/null +++ b/src/segmantic/transforms/distance.py @@ -0,0 +1,127 @@ +from typing import Dict, Hashable, Optional, Sequence, Union + +import numpy as np +import torch +from monai.config import DtypeLike, KeysCollection +from monai.config.type_definitions import NdarrayOrTensor +from monai.transforms.transform import MapTransform, Transform +from monai.utils.enums import TransformBackends +from monai.utils.type_conversion import convert_to_dst_type, convert_to_numpy +from scipy.ndimage import binary_erosion, distance_transform_cdt, distance_transform_edt + + +def get_mask_edges(seg_gt, label_idx: int = 1) -> np.ndarray: + """Get edges of mask region + copied from monai.metrics.utils.get_mask_edges + """ + + if isinstance(seg_gt, torch.Tensor): + seg_gt = seg_gt.detach().cpu().numpy() + + # If not binary images, convert them + if seg_gt.dtype != bool: + seg_gt = seg_gt == label_idx + + # Do binary erosion and use XOR to get edges + edges_gt: np.ndarray = binary_erosion(seg_gt) ^ seg_gt + + return edges_gt + + +def get_boundary_distance( + labels: np.ndarray, + distance_metric: str = "euclidean", + num_classes: int = 2, + spacing: Optional[Union[float, Sequence[float]]] = None, +) -> np.ndarray: + """ + This function is used to compute the surface distances to `labels`. + Args: + labels: the edge of the ground truth. + distance_metric: : [``"euclidean"``, ``"chessboard"``, ``"taxicab"``] + the metric used to compute surface distance. Defaults to ``"euclidean"``. + - ``"euclidean"``, uses Exact Euclidean distance transform. + - ``"chessboard"``, uses `chessboard` metric in chamfer type of transform. + - ``"taxicab"``, uses `taxicab` metric in chamfer type of transform. + num_classes: number of classes including background (num_classes == max(labels) + 1) + spacing: the image spacing + Note: + If labels is all 0, may result in inf distance. (TODO: `inf` in boundary loss?) + """ + + masks: np.ndarray + if labels.dtype in (bool,): + masks = np.empty((2,) + labels.shape) + masks[0, ...] = ~labels + masks[1, ...] = labels + else: + masks = np.empty((num_classes,) + labels.shape) + for label_idx in range(num_classes): + masks[label_idx, ...] = labels == label_idx + + result = np.empty_like(masks, dtype=float) + for i, binary_mask in enumerate(masks): + if not np.any(binary_mask): + result[i, ...] = np.inf * np.ones(binary_mask.shape, dtype=float) + else: + edges = get_mask_edges(binary_mask) + if distance_metric == "euclidean": + result[i, ...] = distance_transform_edt(~edges, sampling=spacing) + elif distance_metric in {"chessboard", "taxicab"}: + result[i, ...] = distance_transform_cdt(~edges, metric=distance_metric) + else: + raise ValueError( + f"distance_metric {distance_metric} is not implemented." + ) + return result + + +class DistanceTransform(Transform): + def __init__( + self, + dtype: DtypeLike = np.float32, + num_classes: int = 2, + spacing: Optional[Union[float, Sequence[float]]] = None, + ) -> None: + super().__init__() + self.dtype = dtype + self.num_classes = num_classes + self.spacing = spacing + + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: + img_ = convert_to_numpy(img) + dist = get_boundary_distance( + img_, num_classes=self.num_classes, spacing=self.spacing + ) + ret = convert_to_dst_type(dist, dst=img, dtype=self.dtype or img.dtype)[0] + return ret + + +class DistanceTransformd(MapTransform): + def __init__( + self, + keys: KeysCollection, + output_keys: Union[str, Sequence[str]] = "dist", + dtype: DtypeLike = np.float32, + num_classes: int = 2, + spacing: Optional[Union[float, Sequence[float]]] = None, + ): + super().__init__(keys) + + self.output_keys = ( + (output_keys,) if isinstance(output_keys, str) else tuple(output_keys) + ) + if len(self.keys) != len(self.output_keys): + raise RuntimeError("Length of `output_keys` must match `keys`") + + self.dt = DistanceTransform(dtype, num_classes, spacing) + + def __call__( + self, data: Dict[Hashable, NdarrayOrTensor] + ) -> Dict[Hashable, NdarrayOrTensor]: + d = dict(data) + for key, output_key in zip(self.keys, self.output_keys): + d[output_key] = self.dt(d[key]) + return d diff --git a/src/segmantic/seg/transforms.py b/src/segmantic/transforms/ensemble.py similarity index 100% rename from src/segmantic/seg/transforms.py rename to src/segmantic/transforms/ensemble.py diff --git a/src/segmantic/data/transforms.py b/src/segmantic/transforms/io.py similarity index 100% rename from src/segmantic/data/transforms.py rename to src/segmantic/transforms/io.py diff --git a/tests/seg/test_boundary_loss.py b/tests/seg/test_boundary_loss.py new file mode 100644 index 0000000..7c36522 --- /dev/null +++ b/tests/seg/test_boundary_loss.py @@ -0,0 +1,37 @@ +import torch +from monai.networks import one_hot + +from segmantic.seg.boundary_loss import BoundaryLoss +from segmantic.transforms.distance import DistanceTransform + + +def test_BoundaryLoss_exact_match(): + mask = torch.zeros(8, 9, dtype=torch.bool) + mask[2:5, 3:6] = 1 + + distance_transform = DistanceTransform() + df = distance_transform(mask) + + # add batch dimension + mask = mask.unsqueeze(0).unsqueeze(0) + df = df.unsqueeze(0) + + # pseudo prediction + pred = one_hot(mask, num_classes=2, dim=1) + + print(pred.shape) + print(mask.shape) + print(df.shape) + + # loss for gt should be zero + loss_function = BoundaryLoss(to_onehot_y=True, argmax=False) + loss = loss_function(pred=pred, seg_gt=mask, dist_gt=df) + assert loss == 0.0 + + loss_function = BoundaryLoss(to_onehot_y=True, argmax=False, reduction="sum") + loss = loss_function(pred=pred, seg_gt=mask, dist_gt=df) + assert loss == 0.0 + + +if __name__ == "__main__": + test_BoundaryLoss_exact_match() diff --git a/tests/transforms/test_distance.py b/tests/transforms/test_distance.py new file mode 100644 index 0000000..fd1524c --- /dev/null +++ b/tests/transforms/test_distance.py @@ -0,0 +1,65 @@ +import numpy as np +import torch +from scipy.ndimage import distance_transform_edt + +from segmantic.transforms.distance import DistanceTransform, DistanceTransformd + + +def test_DistanceTransform_2d(): + mask = torch.zeros(8, 9, dtype=torch.bool) + mask[2:5, 3:6] = 1 + + distance_transform = DistanceTransform() + df = distance_transform(mask) + assert isinstance(df, torch.Tensor) + assert df.shape == ( + 2, + 8, + 9, + ) + + +def test_DistanceTransform_3d(): + mask = torch.zeros(8, 9, 7, dtype=torch.bool) + mask[2:5, 3:6, 3:5] = 1 + + distance_transform = DistanceTransform() + df = distance_transform(mask) + assert isinstance(df, torch.Tensor) + assert df.shape == ( + 2, + 8, + 9, + 7, + ) + + +def test_DistanceTransformd(): + mask = torch.zeros(8, 9, dtype=torch.bool) + mask[2:5, 3:6] = 1 + + distance_transform = DistanceTransformd(keys="label", output_keys="dist") + df = distance_transform({"label": mask}) + assert isinstance(df, dict) + assert "dist" in df + assert isinstance(df["dist"], torch.Tensor) + + +def test_DistanceTransform_MultiClass(): + mask = np.zeros((6, 7), dtype=int) + mask[2, 3] = 1 + mask[4, 5] = 2 + + spacing = (1.2, 1.7) + + distance_transform = DistanceTransform(num_classes=3, spacing=spacing) + df = distance_transform(mask) + assert df.shape == ( + 3, + 6, + 7, + ) + ref1 = distance_transform_edt(~(mask == 1), sampling=spacing) + ref2 = distance_transform_edt(~(mask == 2), sampling=spacing) + np.testing.assert_almost_equal(df[1, ...], ref1, decimal=6) + np.testing.assert_almost_equal(df[2, ...], ref2, decimal=6) diff --git a/tests/seg/test_transforms.py b/tests/transforms/test_ensemble.py similarity index 95% rename from tests/seg/test_transforms.py rename to tests/transforms/test_ensemble.py index 8c0b881..e98e5c5 100644 --- a/tests/seg/test_transforms.py +++ b/tests/transforms/test_ensemble.py @@ -2,7 +2,7 @@ from monai.networks import one_hot from torch.testing import assert_close -from segmantic.seg.transforms import SelectBestEnsembled +from segmantic.transforms.ensemble import SelectBestEnsembled def test_SelectBestEnsembled(): diff --git a/tests/data/test_iseg_saver.py b/tests/transforms/test_iseg_saver.py similarity index 96% rename from tests/data/test_iseg_saver.py rename to tests/transforms/test_iseg_saver.py index 6a96edd..9ea3791 100644 --- a/tests/data/test_iseg_saver.py +++ b/tests/transforms/test_iseg_saver.py @@ -5,7 +5,7 @@ import SimpleITK as sitk from monai.transforms import LoadImaged -from segmantic.data.transforms import LabelInfo, iSegSaver +from segmantic.transforms.io import LabelInfo, iSegSaver @pytest.fixture diff --git a/tests/detect/test_vert_transforms.py b/tests/transforms/test_vert_transforms.py similarity index 99% rename from tests/detect/test_vert_transforms.py rename to tests/transforms/test_vert_transforms.py index 3ff8d28..2005467 100644 --- a/tests/detect/test_vert_transforms.py +++ b/tests/transforms/test_vert_transforms.py @@ -8,14 +8,14 @@ from monai.transforms import AsDiscreted, Compose, EnsureChannelFirstd, LoadImaged from numpy.testing import assert_almost_equal -from segmantic.detect.transforms import ( +from segmantic.image.processing import make_image +from segmantic.transforms.detect import ( BoundingBoxd, EmbedVert, ExtractVertPosition, LoadVert, SaveVert, ) -from segmantic.image.processing import make_image KEY_0 = "point_0" KEY_1 = "point_1" diff --git a/tests/utils/test_cli.py b/tests/utils/test_cli.py index 7e04659..e171fb4 100644 --- a/tests/utils/test_cli.py +++ b/tests/utils/test_cli.py @@ -1,6 +1,7 @@ import json from inspect import signature from pathlib import Path +from typing import Optional import pytest import yaml @@ -12,7 +13,7 @@ def function1(path: Path, arg_int: int, arg_float: float = -1.5): pass -def function2(arg_int: int, path: Path = None): +def function2(arg_int: int, path: Optional[Path] = None): pass