Skip to content

Commit

Permalink
added tests for segmentation_extractor_functions
Browse files Browse the repository at this point in the history
  • Loading branch information
pauladkisson committed Oct 7, 2024
1 parent e6d2bca commit 5c71966
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from ...tools.typing import PathType
from ...multisegmentationextractor import MultiSegmentationExtractor
from ...segmentationextractor import SegmentationExtractor, _image_mask_extractor
from ...segmentationextractor import SegmentationExtractor, convert_pixel_masks_to_image_masks


class Suite2pSegmentationExtractor(SegmentationExtractor):
Expand Down Expand Up @@ -180,7 +180,7 @@ def __init__(
image_mean_name = "meanImg" if channel_name == "chan1" else f"meanImg_chan2"
self._image_mean = self.options[image_mean_name] if image_mean_name in self.options else None
roi_indices = list(range(self.get_num_rois()))
self._image_masks = _image_mask_extractor(
self._image_masks = convert_pixel_masks_to_image_masks(
self.get_roi_pixel_masks(),
roi_indices,
self.get_image_size(),
Expand Down
18 changes: 9 additions & 9 deletions src/roiextractors/segmentationextractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,13 +429,13 @@ def convert_image_masks_to_pixel_masks(image_masks: np.ndarray) -> list:
Columns 1 and 2 are the row and column coordinates of the pixel, while the third column represents the weight of
the pixel.
"""
pixel_mask_list = []
pixel_masks = []
for i in range(image_masks.shape[2]):
image_mask = image_masks[:, :, i]
locs = np.where(image_mask > 0)
pix_values = image_mask[image_mask > 0]
pixel_mask_list.append(np.vstack((locs[0], locs[1], pix_values)).T)
return pixel_mask_list
pixel_masks.append(np.vstack((locs[0], locs[1], pix_values)).T)
return pixel_masks


def convert_pixel_masks_to_image_masks(pixel_masks: list[np.ndarray], image_shape: tuple) -> np.ndarray:
Expand All @@ -457,16 +457,16 @@ def convert_pixel_masks_to_image_masks(pixel_masks: list[np.ndarray], image_shap
image_masks = np.zeros(shape=shape)
for i, pixel_mask in enumerate(pixel_masks):
for row, column, wt in pixel_mask:
image_masks[row, column, i] = wt
image_masks[int(row), int(column), i] = wt
return image_masks


def get_default_roi_locations_from_image_masks(image_masks: np.ndarray) -> np.ndarray:
"""Calculate the default ROI locations from given image masks.
This function takes a 3D numpy array of image masks and computes the median
coordinates of the maximum values in each 2D mask. The result is a 2D numpy
array where each column represents the (x, y) coordinates of the ROI for
This function takes a 3D numpy array of image masks and computes the coordinates (row, column)
of the maximum values in each 2D mask. In the case of a tie, the integer median of the coordinates is used.
The result is a 2D numpy array where each column represents the (row, column) coordinates of the ROI for
each mask.
Parameters
Expand All @@ -478,12 +478,12 @@ def get_default_roi_locations_from_image_masks(image_masks: np.ndarray) -> np.nd
-------
np.ndarray
A 2D numpy array of shape (2, num_rois) where each column contains the
(x, y) coordinates of the ROI for each mask.
(row, column) coordinates of the ROI for each mask.
"""
num_rois = image_masks.shape[2]
roi_locations = np.zeros([2, num_rois], dtype="int")
for i in range(num_rois):
image_mask = image_masks[:, :, i]
max_value_indices = np.where(image_mask == np.amax(image_mask))
roi_locations[:, i] = np.array([np.median(max_value_indices[0]), np.median(max_value_indices[1])]).T
roi_locations[:, i] = np.array([int(np.median(max_value_indices[0])), int(np.median(max_value_indices[1]))]).T
return roi_locations
132 changes: 132 additions & 0 deletions tests/test_minimal/test_segmentation_extractor_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import pytest
import numpy as np

from roiextractors.segmentationextractor import (
convert_image_masks_to_pixel_masks,
convert_pixel_masks_to_image_masks,
get_default_roi_locations_from_image_masks,
)


@pytest.fixture(scope="module")
def rng():
seed = 1728084845 # int(datetime.datetime.now().timestamp()) at the time of writing
return np.random.default_rng(seed=seed)


@pytest.fixture(scope="function")
def image_masks(rng):
return rng.random((3, 3, 3))


def test_convert_image_masks_to_pixel_masks(image_masks):
pixel_masks = convert_image_masks_to_pixel_masks(image_masks=image_masks)
for i, pixel_mask in enumerate(pixel_masks):
assert pixel_mask.shape == (image_masks.shape[0] * image_masks.shape[1], 3)
for row, column, wt in pixel_mask:
assert row == int(row)
assert column == int(column)
assert image_masks[int(row), int(column), i] == wt


def test_convert_image_masks_to_pixel_masks_with_zeros(image_masks):
image_masks[0, 0, 0] = 0
pixel_masks = convert_image_masks_to_pixel_masks(image_masks=image_masks)
assert pixel_masks[0].shape == (image_masks.shape[0] * image_masks.shape[1] - 1, 3)
for i, pixel_mask in enumerate(pixel_masks):
for row, column, wt in pixel_mask:
assert row == int(row)
assert column == int(column)
assert image_masks[int(row), int(column), i] == wt


def test_convert_image_masks_to_pixel_masks_all_zeros(image_masks):
image_masks = np.zeros(image_masks.shape)
pixel_masks = convert_image_masks_to_pixel_masks(image_masks=image_masks)
for pixel_mask in pixel_masks:
assert pixel_mask.shape == (0, 3)


def test_convert_pixel_masks_to_image_masks(image_masks):
pixel_masks = []
for i in range(image_masks.shape[2]):
image_mask = image_masks[:, :, i]
locs = np.where(image_mask > 0)
pix_values = image_mask[image_mask > 0]
pixel_masks.append(np.vstack((locs[0], locs[1], pix_values)).T)

image_masks = convert_pixel_masks_to_image_masks(pixel_masks=pixel_masks, image_shape=image_masks.shape[:2])
for i in range(image_masks.shape[2]):
image_mask = image_masks[:, :, i]
indices = np.ndindex(image_mask.shape)
for row, column in indices:
pixel_mask_mask = np.logical_and(pixel_masks[i][:, 0] == row, pixel_masks[i][:, 1] == column)
assert image_mask[row, column] == pixel_masks[i][pixel_mask_mask, 2]


def test_convert_pixel_masks_to_image_masks_with_zeros(image_masks):
pixel_masks = []
for i in range(image_masks.shape[2]):
image_mask = image_masks[:, :, i]
locs = np.where(image_mask > 0)
pix_values = image_mask[image_mask > 0]
pixel_masks.append(np.vstack((locs[0], locs[1], pix_values)).T)

pixel_masks[0] = pixel_masks[0][1:]
image_masks = convert_pixel_masks_to_image_masks(pixel_masks=pixel_masks, image_shape=image_masks.shape[:2])
for i in range(image_masks.shape[2]):
image_mask = image_masks[:, :, i]
indices = np.ndindex(image_mask.shape)
for row, column in indices:
pixel_mask_mask = np.logical_and(pixel_masks[i][:, 0] == row, pixel_masks[i][:, 1] == column)
if i == 0 and row == 0 and column == 0:
assert np.all(np.logical_not(pixel_mask_mask))
else:
assert image_mask[row, column] == pixel_masks[i][pixel_mask_mask, 2]


def test_convert_pixel_masks_to_image_masks_all_zeros(image_masks):
pixel_masks = [np.zeros((0, 0)) for _ in range(image_masks.shape[2])]
output_image_masks = convert_pixel_masks_to_image_masks(pixel_masks=pixel_masks, image_shape=image_masks.shape[:2])
assert output_image_masks.shape == image_masks.shape
for image_mask in output_image_masks:
assert np.all(image_mask == 0)


def test_convert_masks_roundtrip(image_masks):
pixel_masks = convert_image_masks_to_pixel_masks(image_masks=image_masks)
output_image_masks = convert_pixel_masks_to_image_masks(pixel_masks=pixel_masks, image_shape=image_masks.shape[:2])
np.testing.assert_array_equal(image_masks, output_image_masks)


def test_get_default_roi_locations_from_image_masks():
image_masks = np.zeros((3, 3, 3))
image_masks[0, 0, 0] = 1
image_masks[1, 1, 1] = 1
image_masks[2, 2, 2] = 1
roi_locations = get_default_roi_locations_from_image_masks(image_masks=image_masks)
expected_roi_locations = np.array([[0, 0], [1, 1], [2, 2]]).T
np.testing.assert_array_equal(roi_locations, expected_roi_locations)


def test_get_default_roi_locations_from_image_masks_tie1():
image_masks = np.zeros((3, 3, 3))
image_masks[0, 0, 0] = 1
image_masks[0, 1, 0] = 1
image_masks[1, 1, 1] = 1
image_masks[2, 2, 2] = 1
roi_locations = get_default_roi_locations_from_image_masks(image_masks=image_masks)
expected_roi_locations = np.array([[0, 0], [1, 1], [2, 2]]).T
np.testing.assert_array_equal(roi_locations, expected_roi_locations)


def test_get_default_roi_locations_from_image_masks_tie2():
image_masks = np.zeros((3, 3, 3))
image_masks[0, 0, 0] = 1
image_masks[0, 1, 0] = 1
image_masks[1, 1, 0] = 1
image_masks[1, 1, 1] = 1
image_masks[2, 2, 2] = 1
roi_locations = get_default_roi_locations_from_image_masks(image_masks=image_masks)
expected_roi_locations = np.array([[0, 1], [1, 1], [2, 2]]).T
np.testing.assert_array_equal(roi_locations, expected_roi_locations)

0 comments on commit 5c71966

Please sign in to comment.