diff --git a/src/roiextractors/segmentationextractor.py b/src/roiextractors/segmentationextractor.py index 61229ba0..998d7cde 100644 --- a/src/roiextractors/segmentationextractor.py +++ b/src/roiextractors/segmentationextractor.py @@ -139,10 +139,10 @@ def get_roi_pixel_masks(self, roi_ids=None) -> np.array: ------- pixel_masks: list List of length number of rois, each element is a 2-D array with shape (number_of_non_zero_pixels, 3). - Columns 1 and 2 are the x and y coordinates of the pixel, while the third column represents the weight of + Columns 1 and 2 are the row and column coordinates of the pixel, while the third column represents the weight of the pixel. """ - return _pixel_mask_extractor(image_masks=self.get_roi_image_masks(roi_ids=roi_ids)) + return convert_image_masks_to_pixel_masks(image_masks=self.get_roi_image_masks(roi_ids=roi_ids)) @abstractmethod def get_roi_response_traces( @@ -225,7 +225,7 @@ def get_background_pixel_masks(self, background_ids=None) -> np.array: Columns 1 and 2 are the x and y coordinates of the pixel, while the third column represents the weight of the pixel. """ - return _pixel_mask_extractor(self.get_background_image_masks(background_ids=background_ids)) + return convert_image_masks_to_pixel_masks(self.get_background_image_masks(background_ids=background_ids)) @abstractmethod def get_background_response_traces( @@ -411,7 +411,7 @@ def get_summary_images(self, names: Optional[list[str]] = None) -> dict: return self._parent_segmentation.get_summary_images(names=names) -def _pixel_mask_extractor(image_masks: np.ndarray) -> list: +def convert_image_masks_to_pixel_masks(image_masks: np.ndarray) -> list: """Convert image masks to pixel masks. Pixel masks are an alternative data format for storage of image masks which relies on the sparsity of the images. @@ -426,7 +426,7 @@ def _pixel_mask_extractor(image_masks: np.ndarray) -> list: ------- pixel_masks: list List of length number of rois, each element is a 2-D array with shape (number_of_non_zero_pixels, 3). - Columns 1 and 2 are the x and y coordinates of the pixel, while the third column represents the weight of + Columns 1 and 2 are the row and column coordinates of the pixel, while the third column represents the weight of the pixel. """ pixel_mask_list = [] @@ -438,28 +438,27 @@ def _pixel_mask_extractor(image_masks: np.ndarray) -> list: return pixel_mask_list -def _image_mask_extractor(pixel_mask, _roi_ids, image_shape) -> np.ndarray: - """Convert a pixel mask to image mask. +def convert_pixel_masks_to_image_masks(pixel_masks: list[np.ndarray], image_shape: tuple) -> np.ndarray: + """Convert pixel masks to image masks. Parameters ---------- - pixel_mask: list - list of pixel masks (no pixels X 3) - _roi_ids: list - list of roi ids with length number_of_rois - image_shape: array_like - shape of the image (number_of_rows, number_of_columns) + pixel_masks: list[np.ndarray] + List of pixel mask arrays (number_of_non_zero_pixels X 3) for each ROI. + image_shape: tuple + Shape of the image (number_of_rows, number_of_columns). Returns ------- - image_mask: np.ndarray + image_masks: np.ndarray Dense representation of the ROIs with shape (number_of_rows, number_of_columns, number_of_rois). """ - image_mask = np.zeros(list(image_shape) + [len(_roi_ids)]) - for no, rois in enumerate(_roi_ids): - for y, x, wt in pixel_mask[rois]: - image_mask[int(y), int(x), no] = wt - return image_mask + shape = (*image_shape, len(pixel_masks)) + image_masks = np.zeros(shape=shape) + for i, pixel_mask in enumerate(pixel_masks): + for row, column, wt in pixel_mask: + image_masks[row, column, i] = wt + return image_masks def get_default_roi_locations_from_image_masks(image_masks: np.ndarray) -> np.ndarray: