Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨🔨 add distance transform #57

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file removed src/segmantic/detect/__init__.py
Empty file.
22 changes: 7 additions & 15 deletions src/segmantic/seg/monai_unet.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
import os
import subprocess as sp
import sys
Expand Down Expand Up @@ -65,7 +64,7 @@

from ..image.labels import load_tissue_list
from ..seg.enum import EnsembleCombination
from ..seg.transforms import SelectBestEnsembled
from ..transforms.ensemble import SelectBestEnsembled
from ..utils import config
from .dataset import PairedDataSet
from .evaluation import confusion_matrix
Expand Down Expand Up @@ -538,19 +537,12 @@ def predict(
gpu_ids: List[int] = [],
) -> None:
# load trained model
model_settings_json = model_file.with_suffix(".json")
if model_settings_json.exists():
print(f"WARNING: Loading legacy model settings from {model_settings_json}")
with model_settings_json.open() as json_file:
settings = json.load(json_file)
net = cast(Net, Net.load_from_checkpoint(f"{model_file}", **settings))
else:
net = cast(
Net,
Net.load_from_checkpoint(
f"{model_file}", channels=channels, strides=strides, dropout=dropout
),
)
net = cast(
Net,
Net.load_from_checkpoint(
f"{model_file}", channels=channels, strides=strides, dropout=dropout
),
)
num_classes = net.num_classes

net.freeze()
Expand Down
File renamed without changes.
File renamed without changes.
125 changes: 125 additions & 0 deletions src/segmantic/transforms/distance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
from typing import Dict, Hashable, Optional, Sequence, Union

import numpy as np
import torch
from monai.config import DtypeLike, KeysCollection
from monai.config.type_definitions import NdarrayOrTensor
from monai.transforms.transform import MapTransform, Transform
from monai.utils.enums import TransformBackends
from monai.utils.type_conversion import convert_to_dst_type, convert_to_numpy
from scipy.ndimage import binary_erosion, distance_transform_cdt, distance_transform_edt


def get_mask_edges(seg_gt, label_idx: int = 1) -> np.ndarray:
"""Get edges of mask region
copied from monai.metrics.utils.get_mask_edges
"""

if isinstance(seg_gt, torch.Tensor):
seg_gt = seg_gt.detach().cpu().numpy()

# If not binary images, convert them
if seg_gt.dtype != bool:
seg_gt = seg_gt == label_idx

# Do binary erosion and use XOR to get edges
edges_gt: np.ndarray = binary_erosion(seg_gt) ^ seg_gt

return edges_gt


def get_boundary_distance(
labels: np.ndarray,
distance_metric: str = "euclidean",
num_classes: int = 2,
spacing: Optional[Union[float, Sequence[float]]] = None,
) -> np.ndarray:
"""
This function is used to compute the surface distances to `labels`.
Args:
labels: the edge of the ground truth.
distance_metric: : [``"euclidean"``, ``"chessboard"``, ``"taxicab"``]
the metric used to compute surface distance. Defaults to ``"euclidean"``.
- ``"euclidean"``, uses Exact Euclidean distance transform.
- ``"chessboard"``, uses `chessboard` metric in chamfer type of transform.
- ``"taxicab"``, uses `taxicab` metric in chamfer type of transform.
num_classes: number of classes including background (num_classes == max(labels) + 1)
spacing: the image spacing
Note:
If labels is all 0, may result in inf distance. (TODO: `inf` in boundary loss?)
"""

masks: np.ndarray
if labels.dtype in (bool,):
masks = np.expand_dims(labels, axis=0)
else:
masks = np.empty((num_classes - 1,) + labels.shape)
for label_idx in range(1, num_classes):
masks[label_idx - 1, ...] = labels == label_idx

result = np.empty_like(masks, dtype=float)
for i, binary_mask in enumerate(masks):
if not np.any(binary_mask):
result[i, ...] = np.inf * np.ones(binary_mask.shape, dtype=float)
else:
edges = get_mask_edges(binary_mask)
if distance_metric == "euclidean":
result[i, ...] = distance_transform_edt(~edges, sampling=spacing)
elif distance_metric in {"chessboard", "taxicab"}:
result[i, ...] = distance_transform_cdt(~edges, metric=distance_metric)
else:
raise ValueError(
f"distance_metric {distance_metric} is not implemented."
)
return result


class DistanceTransform(Transform):
def __init__(
self,
dtype: DtypeLike = np.float32,
num_classes: int = 2,
spacing: Optional[Union[float, Sequence[float]]] = None,
) -> None:
super().__init__()
self.dtype = dtype
self.num_classes = num_classes
self.spacing = spacing

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
img_ = convert_to_numpy(img)
dist = get_boundary_distance(
img_, num_classes=self.num_classes, spacing=self.spacing
)
ret = convert_to_dst_type(dist, dst=img, dtype=self.dtype or img.dtype)[0]
return ret


class DistanceTransformd(MapTransform):
def __init__(
self,
keys: KeysCollection,
output_keys: Union[str, Sequence[str]] = "dist",
dtype: DtypeLike = np.float32,
num_classes: int = 2,
spacing: Optional[Union[float, Sequence[float]]] = None,
):
super().__init__(keys)

self.output_keys = (
(output_keys,) if isinstance(output_keys, str) else tuple(output_keys)
)
if len(self.keys) != len(self.output_keys):
raise RuntimeError("Length of `output_keys` must match `keys`")

self.dt = DistanceTransform(dtype, num_classes, spacing)

def __call__(
self, data: Dict[Hashable, NdarrayOrTensor]
) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
for key, output_key in zip(self.keys, self.output_keys):
d[output_key] = self.dt(d[key])
return d
File renamed without changes.
File renamed without changes.
69 changes: 69 additions & 0 deletions tests/transforms/test_distance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import numpy as np
import torch
from scipy.ndimage import distance_transform_edt

from segmantic.transforms.distance import DistanceTransform, DistanceTransformd


def test_DistanceTransform_2d():
mask = torch.zeros(8, 9, dtype=torch.bool)
mask[2:5, 3:6] = 1

distance_transform = DistanceTransform()
df = distance_transform(mask)
assert isinstance(df, torch.Tensor)
assert df.shape == (
1,
8,
9,
)


def test_DistanceTransform_3d():
mask = torch.zeros(8, 9, 7, dtype=torch.bool)
mask[2:5, 3:6, 3:5] = 1

distance_transform = DistanceTransform()
df = distance_transform(mask)
assert isinstance(df, torch.Tensor)
assert df.shape == (
1,
8,
9,
7,
)


def test_DistanceTransformd():
mask = torch.zeros(8, 9, dtype=torch.bool)
mask[2:5, 3:6] = 1

distance_transform = DistanceTransformd(keys="label", output_keys="dist")
df = distance_transform({"label": mask})
assert isinstance(df, dict)
assert "dist" in df
assert isinstance(df["dist"], torch.Tensor)


def test_DistanceTransform_MultiClass():
mask = np.zeros((6, 7), dtype=int)
mask[2, 3] = 1
mask[4, 5] = 2

spacing = (1.2, 1.7)

distance_transform = DistanceTransform(num_classes=3, spacing=spacing)
df = distance_transform(mask)
assert df.shape == (
2,
6,
7,
)
ref1 = distance_transform_edt(~(mask == 1), sampling=spacing)
ref2 = distance_transform_edt(~(mask == 2), sampling=spacing)
np.testing.assert_almost_equal(df[0, ...], ref1, decimal=6)
np.testing.assert_almost_equal(df[1, ...], ref2, decimal=6)


if __name__ == "__main__":
test_DistanceTransform_MultiClass()
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from monai.networks import one_hot
from torch.testing import assert_close

from segmantic.seg.transforms import SelectBestEnsembled
from segmantic.transforms.ensemble import SelectBestEnsembled


def test_SelectBestEnsembled():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import SimpleITK as sitk
from monai.transforms import LoadImaged

from segmantic.data.transforms import LabelInfo, iSegSaver
from segmantic.transforms.io import LabelInfo, iSegSaver


@pytest.fixture
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@
from monai.transforms import AsDiscreted, Compose, EnsureChannelFirstd, LoadImaged
from numpy.testing import assert_almost_equal

from segmantic.detect.transforms import (
from segmantic.image.processing import make_image
from segmantic.transforms.detect import (
BoundingBoxd,
EmbedVert,
ExtractVertPosition,
LoadVert,
SaveVert,
)
from segmantic.image.processing import make_image

KEY_0 = "point_0"
KEY_1 = "point_1"
Expand Down
3 changes: 2 additions & 1 deletion tests/utils/test_cli.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
from inspect import signature
from pathlib import Path
from typing import Optional

import pytest
import yaml
Expand All @@ -12,7 +13,7 @@ def function1(path: Path, arg_int: int, arg_float: float = -1.5):
pass


def function2(arg_int: int, path: Path = None):
def function2(arg_int: int, path: Optional[Path] = None):
pass


Expand Down