Skip to content

Commit

Permalink
updated pixel_mask_extractor and image_mask_extractor
Browse files Browse the repository at this point in the history
  • Loading branch information
pauladkisson committed Oct 4, 2024
1 parent 8d0b51e commit e6d2bca
Showing 1 changed file with 18 additions and 19 deletions.
37 changes: 18 additions & 19 deletions src/roiextractors/segmentationextractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,10 @@ def get_roi_pixel_masks(self, roi_ids=None) -> np.array:
-------
pixel_masks: list
List of length number of rois, each element is a 2-D array with shape (number_of_non_zero_pixels, 3).
Columns 1 and 2 are the x and y coordinates of the pixel, while the third column represents the weight of
Columns 1 and 2 are the row and column coordinates of the pixel, while the third column represents the weight of
the pixel.
"""
return _pixel_mask_extractor(image_masks=self.get_roi_image_masks(roi_ids=roi_ids))
return convert_image_masks_to_pixel_masks(image_masks=self.get_roi_image_masks(roi_ids=roi_ids))

@abstractmethod
def get_roi_response_traces(
Expand Down Expand Up @@ -225,7 +225,7 @@ def get_background_pixel_masks(self, background_ids=None) -> np.array:
Columns 1 and 2 are the x and y coordinates of the pixel, while the third column represents the weight of
the pixel.
"""
return _pixel_mask_extractor(self.get_background_image_masks(background_ids=background_ids))
return convert_image_masks_to_pixel_masks(self.get_background_image_masks(background_ids=background_ids))

@abstractmethod
def get_background_response_traces(
Expand Down Expand Up @@ -411,7 +411,7 @@ def get_summary_images(self, names: Optional[list[str]] = None) -> dict:
return self._parent_segmentation.get_summary_images(names=names)


def _pixel_mask_extractor(image_masks: np.ndarray) -> list:
def convert_image_masks_to_pixel_masks(image_masks: np.ndarray) -> list:
"""Convert image masks to pixel masks.
Pixel masks are an alternative data format for storage of image masks which relies on the sparsity of the images.
Expand All @@ -426,7 +426,7 @@ def _pixel_mask_extractor(image_masks: np.ndarray) -> list:
-------
pixel_masks: list
List of length number of rois, each element is a 2-D array with shape (number_of_non_zero_pixels, 3).
Columns 1 and 2 are the x and y coordinates of the pixel, while the third column represents the weight of
Columns 1 and 2 are the row and column coordinates of the pixel, while the third column represents the weight of
the pixel.
"""
pixel_mask_list = []
Expand All @@ -438,28 +438,27 @@ def _pixel_mask_extractor(image_masks: np.ndarray) -> list:
return pixel_mask_list


def _image_mask_extractor(pixel_mask, _roi_ids, image_shape) -> np.ndarray:
"""Convert a pixel mask to image mask.
def convert_pixel_masks_to_image_masks(pixel_masks: list[np.ndarray], image_shape: tuple) -> np.ndarray:
"""Convert pixel masks to image masks.
Parameters
----------
pixel_mask: list
list of pixel masks (no pixels X 3)
_roi_ids: list
list of roi ids with length number_of_rois
image_shape: array_like
shape of the image (number_of_rows, number_of_columns)
pixel_masks: list[np.ndarray]
List of pixel mask arrays (number_of_non_zero_pixels X 3) for each ROI.
image_shape: tuple
Shape of the image (number_of_rows, number_of_columns).
Returns
-------
image_mask: np.ndarray
image_masks: np.ndarray
Dense representation of the ROIs with shape (number_of_rows, number_of_columns, number_of_rois).
"""
image_mask = np.zeros(list(image_shape) + [len(_roi_ids)])
for no, rois in enumerate(_roi_ids):
for y, x, wt in pixel_mask[rois]:
image_mask[int(y), int(x), no] = wt
return image_mask
shape = (*image_shape, len(pixel_masks))
image_masks = np.zeros(shape=shape)
for i, pixel_mask in enumerate(pixel_masks):
for row, column, wt in pixel_mask:
image_masks[row, column, i] = wt
return image_masks


def get_default_roi_locations_from_image_masks(image_masks: np.ndarray) -> np.ndarray:
Expand Down

0 comments on commit e6d2bca

Please sign in to comment.