diff --git a/requirements-minimal.txt b/requirements-minimal.txt index 6a8d39d7..c99106be 100644 --- a/requirements-minimal.txt +++ b/requirements-minimal.txt @@ -8,3 +8,4 @@ psutil>=5.8.0 PyYAML lxml packaging +matplotlib diff --git a/src/roiextractors/__init__.py b/src/roiextractors/__init__.py index 1f7ec5be..b0ab0f7d 100644 --- a/src/roiextractors/__init__.py +++ b/src/roiextractors/__init__.py @@ -4,8 +4,6 @@ __version__ = version("roiextractors") -from .example_datasets import toy_example -from .extraction_tools import show_video from .extractorlist import * from .imagingextractor import ImagingExtractor from .segmentationextractor import SegmentationExtractor diff --git a/src/roiextractors/baseextractor.py b/src/roiextractors/baseextractor.py index 6a66f70b..5c3cb0ac 100644 --- a/src/roiextractors/baseextractor.py +++ b/src/roiextractors/baseextractor.py @@ -2,7 +2,7 @@ from typing import Union, Tuple from copy import deepcopy import numpy as np -from .extraction_tools import ArrayType, FloatType +from .tools.typing import ArrayType, FloatType class BaseExtractor(ABC): diff --git a/src/roiextractors/example_datasets/__init__.py b/src/roiextractors/example_datasets/__init__.py deleted file mode 100644 index 4532ba9d..00000000 --- a/src/roiextractors/example_datasets/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -"""Toy example ImagingExtractor and SegmentationExtractor for testing. - -Modules -------- -toy_example - Create a toy example of an ImagingExtractor and a SegmentationExtractor. - -Functions ---------- -toy_example - Create a toy example of an ImagingExtractor and a SegmentationExtractor. -""" - -from .toy_example import toy_example diff --git a/src/roiextractors/example_datasets/toy_example.py b/src/roiextractors/example_datasets/toy_example.py deleted file mode 100644 index dd9997d9..00000000 --- a/src/roiextractors/example_datasets/toy_example.py +++ /dev/null @@ -1,218 +0,0 @@ -"""Toy example ImagingExtractor and SegmentationExtractor for testing. - -Functions ---------- -toy_example - Create a toy example of an ImagingExtractor and a SegmentationExtractor. -""" - -import numpy as np - -from ..extractors.numpyextractors import ( - NumpyImagingExtractor, - NumpySegmentationExtractor, -) - - -def _gaussian(x, mu, sigma): - """Compute classical gaussian with parameters x, mu, sigma.""" - return 1 / np.sqrt(2 * np.pi * sigma) * np.exp(-((x - mu) ** 2) / sigma) - - -def _generate_rois( - num_units=10, size_x=100, size_y=100, roi_size=4, min_dist=5, mode="uniform" -): # TODO: mode --> literal type - """Generate ROIs with given parameters. - - Parameters - ---------- - num_units: int - Number of ROIs - size_x: int - Size of x dimension (pixels) - size_y: int - Size of y dimension (pixels) - roi_size: int - Siz of ROI in x and y dimension (pixels) - min_dist: int - Minimum distance between ROI centers (pixels) - mode: str - 'uniform' or 'gaussian'. - If 'uniform', ROI values are uniform and equal to 1. - If 'gaussian', ROI values are gaussian modulated - - Returns - ------- - roi_pixels: list - List of pixel coordinates for each ROI - image: np.ndarray - Image with ROIs - means: list - List of mean coordinates for each ROI - """ - image = np.zeros((size_x, size_y)) - max_iter = 1000 - - count = 0 - it = 0 - means = [] - - while count < num_units: - mean_x = np.random.randint(0, size_x - 1) - mean_y = np.random.randint(0, size_y - 1) - - mean_ = np.array([mean_x, mean_y]) - - if len(means) == 0: - means.append(mean_) - count += 1 - else: - dists = np.array([np.linalg.norm(mean_ - m) for m in means]) - - if np.all(dists > min_dist): - means.append(mean_) - count += 1 - - it += 1 - - if it >= max_iter: - raise Exception("Could not fit ROIs given 'min_dist'") - - roi_pixels = [] - - for m, mean in enumerate(means): - # print(f"ROI {m + 1}/{num_units}") - pixels = [] - for i in np.arange(size_x): - for j in np.arange(size_y): - p = np.array([i, j]) - - if np.linalg.norm(p - mean) < roi_size: - pixels.append(p) - if mode == "uniform": - image[i, j] = 1 - elif mode == "gaussian": - image[i, j] = _gaussian(i, mean[0], roi_size) + _gaussian(j, mean[1], roi_size) - else: - raise Exception("'mode' can be 'uniform' or 'gaussian'") - roi_pixels.append(np.array(pixels)) - - return roi_pixels, image, means - - -def toy_example( - duration=10, - num_rois=10, - size_x=100, - size_y=100, - roi_size=4, - min_dist=5, - mode="uniform", - sampling_frequency=30.0, - decay_time=0.5, - noise_std=0.05, -): - """Create a toy example of an ImagingExtractor and a SegmentationExtractor. - - Parameters - ---------- - duration: float - Duration in s - num_rois: int - Number of ROIs - size_x: int - Size of x dimension (pixels) - size_y: int - Size of y dimension (pixels) - roi_size: int - Size of ROI in x and y dimension (pixels) - min_dist: int - Minimum distance between ROI centers (pixels) - mode: str - 'uniform' or 'gaussian'. - If 'uniform', ROI values are uniform and equal to 1. - If 'gaussian', ROI values are gaussian modulated - sampling_frequency: float - The sampling rate - decay_time: float - Decay time of fluorescence reponse - noise_std: float - Standard deviation of added gaussian noise - - Returns - ------- - imag: NumpyImagingExtractor - The output imaging extractor - seg: NumpySegmentationExtractor - The output segmentation extractor - """ - # generate ROIs - num_rois = int(num_rois) - roi_pixels, im, means = _generate_rois( - num_units=num_rois, - size_x=size_x, - size_y=size_y, - roi_size=roi_size, - min_dist=min_dist, - mode=mode, - ) - - from spikeinterface.core import generate_sorting - - sort = generate_sorting(durations=[duration], num_units=num_rois, sampling_frequency=sampling_frequency) - - # create decaying response - resp_samples = int(decay_time * sampling_frequency) - resp_tau = resp_samples / 5 - tresp = np.arange(resp_samples) - resp = np.exp(-tresp / resp_tau) - - num_frames = sampling_frequency * duration - - # convolve response with ROIs - raw = np.zeros(num_rois, num_frames) # TODO Change to new standard formating with time in first axis - deconvolved = np.zeros(num_rois, num_frames) # TODO Change to new standard formating with time in first axis - neuropil = noise_std * np.random.randn( - num_rois, num_frames - ) # TODO Change to new standard formating with time in first axis - frames = num_frames - for u_i, unit in range(num_rois): - unit = u_i + 1 # spikeextractor toy example has unit ids starting at 1 - for s in sort.get_unit_spike_train(unit): # TODO build a local function that generates frames with spikes - if s < num_frames: - if s + len(resp) < frames: - raw[u_i, s : s + len(resp)] += resp - else: - raw[u_i, s:] = resp[: frames - s] - deconvolved[u_i, s] = 1 - - # generate video - video = np.zeros((frames, size_x, size_y)) - for rp, t in zip(roi_pixels, raw): - for r in rp: - video[:, r[0], r[1]] += t * im[r[0], r[1]] - - # normalize video - video /= np.max(video) - - # add noise - video += noise_std * np.abs(np.random.randn(*video.shape)) - - # instantiate imaging and segmentation extractors - imag = NumpyImagingExtractor(timeseries=video, sampling_frequency=30) - - # create image masks - image_masks = np.zeros((size_x, size_y, num_rois)) - for rois_i, roi in enumerate(roi_pixels): - for r in roi: - image_masks[r[0], r[1], rois_i] += im[r[0], r[1]] - - seg = NumpySegmentationExtractor( - image_masks=image_masks, - raw=raw, - deconvolved=deconvolved, - neuropil=neuropil, - sampling_frequency=float(sampling_frequency), - ) - - return imag, seg diff --git a/src/roiextractors/extraction_tools.py b/src/roiextractors/extraction_tools.py deleted file mode 100644 index dfe2a50d..00000000 --- a/src/roiextractors/extraction_tools.py +++ /dev/null @@ -1,713 +0,0 @@ -"""Various tools for extraction of ROIs from imaging data. - -Classes -------- -VideoStructure - A data class for specifying the structure of a video. -""" - -import sys -import importlib.util -from functools import wraps -from pathlib import Path -from typing import Union, Tuple, Optional, Dict, List -from types import ModuleType -from dataclasses import dataclass -from platform import python_version - -import lazy_ops -import numpy as np -from numpy.typing import ArrayLike, DTypeLike -from tqdm import tqdm -from packaging import version - - -import h5py -import zarr - -ArrayType = ArrayLike -PathType = Union[str, Path] -NumpyArray = np.ndarray -DtypeType = DTypeLike -IntType = Union[int, np.integer] -FloatType = float -NoneType = type(None) - - -def raise_multi_channel_or_depth_not_implemented(extractor_name: str): - """Raise a NotImplementedError for an extractor that does not support multiple channels or depth (z-axis).""" - raise NotImplementedError( - f"The {extractor_name}Extractor does not currently support multiple color channels or 3-dimensional depth." - "If you with to request either of these features, please do so by raising an issue at " - "https://github.com/catalystneuro/roiextractors/issues" - ) - - -@dataclass -class VideoStructure: - """A data class for specifying the structure of a video. - - The role of the data class is to ensure consistency in naming and provide some initial - consistency checks to ensure the validity of the sturcture. - - Attributes - ---------- - num_rows : int - The number of rows of each frame as a matrix. - num_columns : int - The number of columns of each frame as a matrix. - num_channels : int - The number of channels (1 for grayscale, 3 for color). - rows_axis : int - The axis or dimension corresponding to the rows. - columns_axis : int - The axis or dimension corresponding to the columns. - channels_axis : int - The axis or dimension corresponding to the channels. - frame_axis : int - The axis or dimension corresponding to the frames in the video. - - As an example if you wanted to build the structure for a video with gray (n_channels=1) frames of 10 x 5 - where the video is to have the following shape (num_frames, num_rows, num_columns, num_channels) you - could define the class this way: - - >>> from roiextractors.extraction_tools import VideoStructure - >>> num_rows = 10 - >>> num_columns = 5 - >>> num_channels = 1 - >>> frame_axis = 0 - >>> rows_axis = 1 - >>> columns_axis = 2 - >>> channels_axis = 3 - >>> video_structure = VideoStructure( - num_rows=num_rows, - num_columns=num_columns, - num_channels=num_channels, - rows_axis=rows_axis, - columns_axis=columns_axis, - channels_axis=channels_axis, - frame_axis=frame_axis, - ) - """ - - num_rows: int - num_columns: int - num_channels: int - rows_axis: int - columns_axis: int - channels_axis: int - frame_axis: int - - def __post_init__(self) -> None: - """Validate the structure of the video and initialize the shape of the frame.""" - self._validate_video_structure() - self._initialize_frame_shape() - self.number_of_pixels_per_frame = np.prod(self.frame_shape) - - def _initialize_frame_shape(self) -> None: - """Initialize the shape of the frame.""" - self.frame_shape = [None, None, None, None] - self.frame_shape[self.rows_axis] = self.num_rows - self.frame_shape[self.columns_axis] = self.num_columns - self.frame_shape[self.channels_axis] = self.num_channels - self.frame_shape.pop(self.frame_axis) - self.frame_shape = tuple(self.frame_shape) - - def _validate_video_structure(self) -> None: - """Validate the structure of the video.""" - exception_message = ( - "Invalid structure: " - f"{self.__repr__()}, " - "each property axis should be unique value between 0 and 3 (inclusive)" - ) - - axis_values = {self.rows_axis, self.columns_axis, self.channels_axis, self.frame_axis} - axis_values_are_not_unique = len(axis_values) != 4 - if axis_values_are_not_unique: - raise ValueError(exception_message) - - values_out_of_range = any([axis < 0 or axis > 4 for axis in axis_values]) - if values_out_of_range: - raise ValueError(exception_message) - - def build_video_shape(self, n_frames: int) -> Tuple[int, int, int, int]: - """Build the shape of the video from class attributes. - - Parameters - ---------- - n_frames : int - The number of frames in the video. - - Returns - ------- - Tuple[int, int, int, int] - The shape of the video. - - Notes - ----- - The class attributes frame_axis, rows_axis, columns_axis and channels_axis are used to determine the order of the - dimensions in the returned tuple. - """ - video_shape = [None] * 4 - video_shape[self.frame_axis] = n_frames - video_shape[self.rows_axis] = self.num_rows - video_shape[self.columns_axis] = self.num_columns - video_shape[self.channels_axis] = self.num_channels - - return tuple(video_shape) - - def transform_video_to_canonical_form(self, video: np.ndarray) -> np.ndarray: - """Transform a video to the canonical internal format of roiextractors (num_frames, num_rows, num_columns, num_channels). - - Parameters - ---------- - video : numpy.ndarray - The video to be transformed - Returns - ------- - numpy.ndarray - The reshaped video - - Raises - ------ - KeyError - If the video is not in a format that can be transformed. - """ - canonical_shape = (self.frame_axis, self.rows_axis, self.columns_axis, self.channels_axis) - if isinstance(video, (h5py.Dataset, zarr.core.Array)): - re_mapped_video = lazy_ops.DatasetView(video).lazy_transpose(canonical_shape) - elif isinstance(video, np.ndarray): - re_mapped_video = video.transpose(canonical_shape) - else: - raise KeyError(f"Function not implemented for specific format {type(video)}") - - return re_mapped_video - - -def read_numpy_memmap_video( - file_path: PathType, video_structure: VideoStructure, dtype: DtypeType, offset: int = 0 -) -> np.array: - """Auxiliary function to read videos from binary files. - - Parameters - ---------- - file_path : PathType - the file_path where the data resides. - video_structure : VideoStructure - A VideoStructure instance describing the structure of the video to read. This includes parameters - such as the number of rows, columns and channels plus which axis (i.e. dimension) of the - image corresponds to each of them. - - As an example you create one of these structures in the following way: - - from roiextractors.extraction_tools import VideoStructure - - num_rows = 10 - num_columns = 5 - num_channels = 3 - frame_axis = 0 - rows_axis = 1 - columns_axis = 2 - channels_axis = 3 - - video_structure = VideoStructure( - num_rows=num_rows, - num_columns=num_columns, - num_channels=num_channels, - rows_axis=rows_axis, - columns_axis=columns_axis, - channels_axis=channels_axis, - frame_axis=frame_axis, - ) - - dtype : DtypeType - The type of the data to be loaded (int, float, etc.) - offset : int, optional - The offset in bytes. Usually corresponds to the number of bytes occupied by the header. 0 by default. - - Returns - ------- - video_memap: np.array - A numpy memmap pointing to the video. - """ - file_size_bytes = Path(file_path).stat().st_size - - pixels_per_frame = video_structure.number_of_pixels_per_frame - type_size = np.dtype(dtype).itemsize - frame_size_bytes = pixels_per_frame * type_size - - bytes_available = file_size_bytes - offset - number_of_frames = bytes_available // frame_size_bytes - - memmap_shape = video_structure.build_video_shape(n_frames=number_of_frames) - video_memap = np.memmap(file_path, offset=offset, dtype=dtype, mode="r", shape=memmap_shape) - - return video_memap - - -def _pixel_mask_extractor(image_masks: np.ndarray) -> list: - """Convert image masks to pixel masks. - - Pixel masks are an alternative data format for storage of image masks which relies on the sparsity of the images. - The location and weight of each non-zero pixel is stored for each mask. - - Parameters - ---------- - image_masks: numpy.ndarray - Dense representation of the ROIs with shape (number_of_rows, number_of_columns, number_of_rois). - - Returns - ------- - pixel_masks: list - List of length number of rois, each element is a 2-D array with shape (number_of_non_zero_pixels, 3). - Columns 1 and 2 are the x and y coordinates of the pixel, while the third column represents the weight of - the pixel. - """ - pixel_mask_list = [] - for i in range(image_masks.shape[2]): - image_mask = image_masks[:, :, i] - locs = np.where(image_mask > 0) - pix_values = image_mask[image_mask > 0] - pixel_mask_list.append(np.vstack((locs[0], locs[1], pix_values)).T) - return pixel_mask_list - - -def _image_mask_extractor(pixel_mask, _roi_ids, image_shape) -> np.ndarray: - """Convert a pixel mask to image mask. - - Parameters - ---------- - pixel_mask: list - list of pixel masks (no pixels X 3) - _roi_ids: list - list of roi ids with length number_of_rois - image_shape: array_like - shape of the image (number_of_rows, number_of_columns) - - Returns - ------- - image_mask: np.ndarray - Dense representation of the ROIs with shape (number_of_rows, number_of_columns, number_of_rois). - """ - image_mask = np.zeros(list(image_shape) + [len(_roi_ids)]) - for no, rois in enumerate(_roi_ids): - for y, x, wt in pixel_mask[rois]: - image_mask[int(y), int(x), no] = wt - return image_mask - - -def check_get_frames_args(func): - """Check the arguments of the get_frames function. - - This decorator allows the get_frames function to be queried with either - an integer, slice or an array and handles a common return. [I think that np.take can be used instead of this] - - Parameters - ---------- - func: function - The get_frames function. - - Returns - ------- - corrected_args: function - The get_frames function with corrected arguments. - - Raises - ------ - AssertionError - If 'frame_idxs' exceed the number of frames. - """ - - @wraps(func) - def corrected_args(imaging, frame_idxs, channel=0): - channel = int(channel) - if isinstance(frame_idxs, (int, np.integer)): - frame_idxs = [frame_idxs] - if not isinstance(frame_idxs, slice): - frame_idxs = np.array(frame_idxs) - assert np.all(frame_idxs < imaging.get_num_frames()), "'frame_idxs' exceed number of frames" - get_frames_correct_arg = func(imaging, frame_idxs, channel) - - if len(frame_idxs) == 1: - return get_frames_correct_arg[0] - else: - return get_frames_correct_arg - - return corrected_args - - -def _cast_start_end_frame(start_frame, end_frame): - """Cast start and end frame to int or None. - - Parameters - ---------- - start_frame: int, float, None - The start frame. - end_frame: int, float, None - The end frame. - - Returns - ------- - start_frame: int, None - The start frame. - end_frame: int, None - The end frame. - - Raises - ------ - ValueError - If start_frame is not an int, float or None. - ValueError - If end_frame is not an int, float or None. - """ - if isinstance(start_frame, float): - start_frame = int(start_frame) - elif isinstance(start_frame, (int, np.integer, type(None))): - start_frame = start_frame - else: - raise ValueError("start_frame must be an int, float (not infinity), or None") - if isinstance(end_frame, float) and np.isfinite(end_frame): - end_frame = int(end_frame) - elif isinstance(end_frame, (int, np.integer, type(None))): - end_frame = end_frame - # else end_frame is infinity (accepted for get_unit_spike_train) - if start_frame is not None: - start_frame = int(start_frame) - if end_frame is not None and np.isfinite(end_frame): - end_frame = int(end_frame) - return start_frame, end_frame - - -def check_get_videos_args(func): - """Check the arguments of the get_videos function. - - This decorator allows the get_videos function to be queried with either - an integer or slice and handles a common return. - - Parameters - ---------- - func: function - The get_videos function. - - Returns - ------- - corrected_args: function - The get_videos function with corrected arguments. - - Raises - ------ - AssertionError - If 'start_frame' exceeds the number of frames. - AssertionError - If 'end_frame' exceeds the number of frames. - AssertionError - If 'start_frame' is greater than 'end_frame'. - """ - - @wraps(func) - def corrected_args(imaging, start_frame=None, end_frame=None, channel=0): - if start_frame is not None: - if start_frame > imaging.get_num_frames(): - raise Exception(f"'start_frame' exceeds number of frames {imaging.get_num_frames()}!") - elif start_frame < 0: - start_frame = imaging.get_num_frames() + start_frame - else: - start_frame = 0 - if end_frame is not None: - if end_frame > imaging.get_num_frames(): - raise Exception(f"'end_frame' exceeds number of frames {imaging.get_num_frames()}!") - elif end_frame < 0: - end_frame = imaging.get_num_frames() + end_frame - else: - end_frame = imaging.get_num_frames() - assert end_frame - start_frame > 0, "'start_frame' must be less than 'end_frame'!" - - start_frame, end_frame = _cast_start_end_frame(start_frame, end_frame) - channel = int(channel) - get_videos_correct_arg = func(imaging, start_frame=start_frame, end_frame=end_frame, channel=channel) - - return get_videos_correct_arg - - return corrected_args - - -def write_to_h5_dataset_format( - imaging, - dataset_path, - save_path=None, - file_handle=None, - dtype=None, - chunk_size=None, - chunk_mb=1000, - verbose=False, -): - """Save the video of an imaging extractor in an h5 dataset. - - Parameters - ---------- - imaging: ImagingExtractor - The imaging extractor object to be saved in the .h5 file - dataset_path: str - Path to dataset in h5 file (e.g. '/dataset') - save_path: str - The path to the file. - file_handle: file handle - The file handle to dump data. This can be used to append data to an header. In case file_handle is given, - the file is NOT closed after writing the binary data. - dtype: dtype - Type of the saved data. Default float32. - chunk_size: None or int - Number of chunks to save the file in. This avoid to much memory consumption for big files. - If None and 'chunk_mb' is given, the file is saved in chunks of 'chunk_mb' Mb (default 500Mb) - chunk_mb: None or int - Chunk size in Mb (default 1000Mb) - verbose: bool - If True, output is verbose (when chunks are used) - - Returns - ------- - save_path: str - The path to the file. - - Raises - ------ - AssertionError - If neither 'save_path' nor 'file_handle' are given. - """ - assert save_path is not None or file_handle is not None, "Provide 'save_path' or 'file handle'" - - if save_path is not None: - save_path = Path(save_path) - if save_path.suffix == "": - # when suffix is already raw/bin/dat do not change it. - save_path = save_path.parent / (save_path.name + ".h5") - num_channels = imaging.get_num_channels() - num_frames = imaging.get_num_frames() - size_x, size_y = imaging.get_image_size() - - if file_handle is not None: - assert isinstance(file_handle, h5py.File) - else: - file_handle = h5py.File(save_path, "w") - if dtype is None: - dtype_file = imaging.get_dtype() - else: - dtype_file = dtype - dset = file_handle.create_dataset(dataset_path, shape=(num_channels, num_frames, size_x, size_y), dtype=dtype_file) - - # set chunk size - if chunk_size is not None: - chunk_size = int(chunk_size) - elif chunk_mb is not None: - n_bytes = np.dtype(imaging.get_dtype()).itemsize - max_size = int(chunk_mb * 1e6) # set Mb per chunk - chunk_size = max_size // (size_x * size_y * n_bytes) - # writ one channel at a time - for ch in range(num_channels): - if chunk_size is None: - video = imaging.get_video(channel=ch) - if dtype is not None: - video = video.astype(dtype_file) - dset[ch, ...] = np.squeeze(video) - else: - chunk_start = 0 - # chunk size is not None - n_chunk = num_frames // chunk_size - if num_frames % chunk_size > 0: - n_chunk += 1 - if verbose: - chunks = tqdm(range(n_chunk), ascii=True, desc="Writing to .h5 file") - else: - chunks = range(n_chunk) - for i in chunks: - video = imaging.get_video( - start_frame=i * chunk_size, - end_frame=min((i + 1) * chunk_size, num_frames), - channel=ch, - ) - chunk_frames = np.squeeze(video).shape[0] - if dtype is not None: - video = video.astype(dtype_file) - dset[ch, chunk_start : chunk_start + chunk_frames, ...] = np.squeeze(video) - chunk_start += chunk_frames - if save_path is not None: - file_handle.close() - return save_path - - -# TODO will be moved eventually, but for now it's very handy :) -def show_video(imaging, ax=None): - """Show video as animation. - - Parameters - ---------- - imaging: ImagingExtractor - The imaging extractor object to be saved in the .h5 file - ax: matplotlib axis - Axis to plot the video. If None, a new axis is created. - - Returns - ------- - anim: matplotlib.animation.FuncAnimation - Animation of the video. - """ - import matplotlib.pyplot as plt - import matplotlib.animation as animation - - def animate_func(i, imaging, im, ax): - ax.set_title(f"{i}") - im.set_array(imaging.get_frames(i)) - return [im] - - if ax is None: - fig = plt.figure(figsize=(5, 5)) - ax = fig.add_subplot(111) - im0 = imaging.get_frames(0) - im = ax.imshow(im0, interpolation="none", aspect="auto", vmin=0, vmax=1) - interval = 1 / imaging.get_sampling_frequency() * 1000 - anim = animation.FuncAnimation( - fig, - animate_func, - frames=imaging.get_num_frames(), - fargs=(imaging, im, ax), - interval=interval, - blit=False, - ) - return anim - - -def check_keys(dict_: dict) -> dict: - """Check keys of dictionary for mat-objects. - - Checks if entries in dictionary are mat-objects. If yes - todict is called to change them to nested dictionaries. - - Parameters - ---------- - dict_: dict - Dictionary to check. - - Returns - ------- - dict: dict - Dictionary with mat-objects converted to nested dictionaries. - - Raises - ------ - AssertionError - If scipy is not installed. - """ - from scipy.io.matlab.mio5_params import mat_struct - - for key in dict_: - if isinstance(dict_[key], mat_struct): - dict_[key] = todict(dict_[key]) - return dict_ - - -def todict(matobj): - """Recursively construct nested dictionaries from matobjects. - - Parameters - ---------- - matobj: mat_struct - Matlab object to convert to nested dictionary. - - Returns - ------- - dict: dict - Dictionary with mat-objects converted to nested dictionaries. - """ - from scipy.io.matlab.mio5_params import mat_struct - - dict_ = {} - from scipy.io.matlab.mio5_params import mat_struct - - for strg in matobj._fieldnames: - elem = matobj.__dict__[strg] - if isinstance(elem, mat_struct): - dict_[strg] = todict(elem) - else: - dict_[strg] = elem - return dict_ - - -def get_package( - package_name: str, - installation_instructions: Optional[str] = None, - excluded_platforms_and_python_versions: Optional[Dict[str, List[str]]] = None, -) -> ModuleType: - """Check if package is installed and return module if so. - - Otherwise, raise informative error describing how to perform the installation. - Inspired by https://docs.python.org/3/library/importlib.html#checking-if-a-module-can-be-imported. - - Parameters - ---------- - package_name : str - Name of the package to be imported. - installation_instructions : str, optional - String describing the source, options, and alias of package name (if needed) for installation. - For example, - >>> installation_source = "conda install -c conda-forge my-package-name" - Defaults to f"pip install {package_name}". - excluded_platforms_and_python_versions : dict mapping string platform names to a list of string versions, optional - In case some combinations of platforms or Python versions are not allowed for the given package, specify - this dictionary to raise a more specific error to that issue. - For example, `excluded_platforms_and_python_versions = dict(darwin=["3.7"])` will raise an informative error - when running on MacOS with Python version 3.7. - Allows all platforms and Python versions used by default. - - Raises - ------ - ModuleNotFoundError - If the package is not installed. - """ - installation_instructions = installation_instructions or f"pip install {package_name}" - excluded_platforms_and_python_versions = excluded_platforms_and_python_versions or dict() - - if package_name in sys.modules: - return sys.modules[package_name] - - if importlib.util.find_spec(package_name) is not None: - return importlib.import_module(name=package_name) - - for excluded_version in excluded_platforms_and_python_versions.get(sys.platform, list()): - if version.parse(python_version()).minor == version.parse(excluded_version).minor: - raise ModuleNotFoundError( - f"\nThe package '{package_name}' is not available on the {sys.platform} platform for " - f"Python version {excluded_version}!" - ) - - raise ModuleNotFoundError( - f"\nThe required package'{package_name}' is not installed!\n" - f"To install this package, please run\n\n\t{installation_instructions}\n" - ) - - -def get_default_roi_locations_from_image_masks(image_masks: np.ndarray) -> np.ndarray: - """Calculate the default ROI locations from given image masks. - - This function takes a 3D numpy array of image masks and computes the median - coordinates of the maximum values in each 2D mask. The result is a 2D numpy - array where each column represents the (x, y) coordinates of the ROI for - each mask. - - Parameters - ---------- - image_masks : np.ndarray - A 3D numpy array of shape (height, width, num_rois) containing the image masks. - - Returns - ------- - np.ndarray - A 2D numpy array of shape (2, num_rois) where each column contains the - (x, y) coordinates of the ROI for each mask. - """ - num_rois = image_masks.shape[2] - roi_locations = np.zeros([2, num_rois], dtype="int") - for i in range(num_rois): - image_mask = image_masks[:, :, i] - max_value_indices = np.where(image_mask == np.amax(image_mask)) - roi_locations[:, i] = np.array([np.median(max_value_indices[0]), np.median(max_value_indices[1])]).T - return roi_locations diff --git a/src/roiextractors/extractors/caiman/caimansegmentationextractor.py b/src/roiextractors/extractors/caiman/caimansegmentationextractor.py index 387eac26..fe06cf8c 100644 --- a/src/roiextractors/extractors/caiman/caimansegmentationextractor.py +++ b/src/roiextractors/extractors/caiman/caimansegmentationextractor.py @@ -13,7 +13,8 @@ from scipy.sparse import csc_matrix import numpy as np -from ...extraction_tools import PathType, get_package +from ...tools.typing import PathType +from ...tools.importing import get_package from ...multisegmentationextractor import MultiSegmentationExtractor from ...segmentationextractor import SegmentationExtractor diff --git a/src/roiextractors/extractors/hdf5imagingextractor/hdf5imagingextractor.py b/src/roiextractors/extractors/hdf5imagingextractor/hdf5imagingextractor.py index 16e165c6..7c0a7389 100644 --- a/src/roiextractors/extractors/hdf5imagingextractor/hdf5imagingextractor.py +++ b/src/roiextractors/extractors/hdf5imagingextractor/hdf5imagingextractor.py @@ -12,8 +12,7 @@ import numpy as np -from ...extraction_tools import PathType, FloatType, ArrayType -from ...extraction_tools import write_to_h5_dataset_format +from ...tools.typing import PathType, FloatType, ArrayType from ...imagingextractor import ImagingExtractor from lazy_ops import DatasetView @@ -187,3 +186,107 @@ def write_imaging( write_to_h5_dataset_format(imaging=imaging, dataset_path=mov_field, file_handle=f, **kwargs) dset = f[mov_field] dset.attrs["fr"] = imaging.get_sampling_frequency() + + +def write_to_h5_dataset_format( + imaging, + dataset_path, + save_path=None, + file_handle=None, + dtype=None, + chunk_size=None, + chunk_mb=1000, + verbose=False, +): + """Save the video of an imaging extractor in an h5 dataset. + + Parameters + ---------- + imaging: ImagingExtractor + The imaging extractor object to be saved in the .h5 file + dataset_path: str + Path to dataset in h5 file (e.g. '/dataset') + save_path: str + The path to the file. + file_handle: file handle + The file handle to dump data. This can be used to append data to an header. In case file_handle is given, + the file is NOT closed after writing the binary data. + dtype: dtype + Type of the saved data. Default float32. + chunk_size: None or int + Number of chunks to save the file in. This avoid to much memory consumption for big files. + If None and 'chunk_mb' is given, the file is saved in chunks of 'chunk_mb' Mb (default 500Mb) + chunk_mb: None or int + Chunk size in Mb (default 1000Mb) + verbose: bool + If True, output is verbose (when chunks are used) + + Returns + ------- + save_path: str + The path to the file. + + Raises + ------ + AssertionError + If neither 'save_path' nor 'file_handle' are given. + """ + assert save_path is not None or file_handle is not None, "Provide 'save_path' or 'file handle'" + + if save_path is not None: + save_path = Path(save_path) + if save_path.suffix == "": + # when suffix is already raw/bin/dat do not change it. + save_path = save_path.parent / (save_path.name + ".h5") + num_channels = imaging.get_num_channels() + num_frames = imaging.get_num_frames() + size_x, size_y = imaging.get_image_size() + + if file_handle is not None: + assert isinstance(file_handle, h5py.File) + else: + file_handle = h5py.File(save_path, "w") + if dtype is None: + dtype_file = imaging.get_dtype() + else: + dtype_file = dtype + dset = file_handle.create_dataset(dataset_path, shape=(num_channels, num_frames, size_x, size_y), dtype=dtype_file) + + # set chunk size + if chunk_size is not None: + chunk_size = int(chunk_size) + elif chunk_mb is not None: + n_bytes = np.dtype(imaging.get_dtype()).itemsize + max_size = int(chunk_mb * 1e6) # set Mb per chunk + chunk_size = max_size // (size_x * size_y * n_bytes) + # writ one channel at a time + for ch in range(num_channels): + if chunk_size is None: + video = imaging.get_video(channel=ch) + if dtype is not None: + video = video.astype(dtype_file) + dset[ch, ...] = np.squeeze(video) + else: + chunk_start = 0 + # chunk size is not None + n_chunk = num_frames // chunk_size + if num_frames % chunk_size > 0: + n_chunk += 1 + if verbose: + chunks = tqdm(range(n_chunk), ascii=True, desc="Writing to .h5 file") + else: + chunks = range(n_chunk) + for i in chunks: + video = imaging.get_video( + start_frame=i * chunk_size, + end_frame=min((i + 1) * chunk_size, num_frames), + channel=ch, + ) + chunk_frames = np.squeeze(video).shape[0] + if dtype is not None: + video = video.astype(dtype_file) + dset[ch, chunk_start : chunk_start + chunk_frames, ...] = np.squeeze(video) + chunk_start += chunk_frames + if save_path is not None: + file_handle.close() + return save_path diff --git a/src/roiextractors/extractors/inscopixextractors/inscopiximagingextractor.py b/src/roiextractors/extractors/inscopixextractors/inscopiximagingextractor.py index ccd83189..410133ff 100644 --- a/src/roiextractors/extractors/inscopixextractors/inscopiximagingextractor.py +++ b/src/roiextractors/extractors/inscopixextractors/inscopiximagingextractor.py @@ -6,7 +6,7 @@ import numpy as np from ...imagingextractor import ImagingExtractor -from ...extraction_tools import PathType +from ...tools.typing import PathType class InscopixImagingExtractor(ImagingExtractor): diff --git a/src/roiextractors/extractors/memmapextractors/memmapextractors.py b/src/roiextractors/extractors/memmapextractors/memmapextractors.py index 88b7e52d..d9de3820 100644 --- a/src/roiextractors/extractors/memmapextractors/memmapextractors.py +++ b/src/roiextractors/extractors/memmapextractors/memmapextractors.py @@ -15,7 +15,7 @@ from ...imagingextractor import ImagingExtractor from typing import Tuple, Optional -from ...extraction_tools import PathType, DtypeType +from ...tools.typing import PathType, DtypeType class MemmapImagingExtractor(ImagingExtractor): diff --git a/src/roiextractors/extractors/memmapextractors/numpymemampextractor.py b/src/roiextractors/extractors/memmapextractors/numpymemampextractor.py index bd25b33a..d5af9512 100644 --- a/src/roiextractors/extractors/memmapextractors/numpymemampextractor.py +++ b/src/roiextractors/extractors/memmapextractors/numpymemampextractor.py @@ -6,12 +6,214 @@ The class for reading optical imaging data stored in a binary format with numpy.memmap. """ -import os from pathlib import Path +from typing import Tuple +from dataclasses import dataclass +import lazy_ops +import numpy as np +import h5py +from ...tools.typing import PathType, DtypeType +from .memmapextractors import MemmapImagingExtractor -from roiextractors.extraction_tools import read_numpy_memmap_video, VideoStructure, DtypeType, PathType -from .memmapextractors import MemmapImagingExtractor +@dataclass +class VideoStructure: + """A data class for specifying the structure of a video. + + The role of the data class is to ensure consistency in naming and provide some initial + consistency checks to ensure the validity of the sturcture. + + Attributes + ---------- + num_rows : int + The number of rows of each frame as a matrix. + num_columns : int + The number of columns of each frame as a matrix. + num_channels : int + The number of channels (1 for grayscale, 3 for color). + rows_axis : int + The axis or dimension corresponding to the rows. + columns_axis : int + The axis or dimension corresponding to the columns. + channels_axis : int + The axis or dimension corresponding to the channels. + frame_axis : int + The axis or dimension corresponding to the frames in the video. + + As an example if you wanted to build the structure for a video with gray (n_channels=1) frames of 10 x 5 + where the video is to have the following shape (num_frames, num_rows, num_columns, num_channels) you + could define the class this way: + + >>> num_rows = 10 + >>> num_columns = 5 + >>> num_channels = 1 + >>> frame_axis = 0 + >>> rows_axis = 1 + >>> columns_axis = 2 + >>> channels_axis = 3 + >>> video_structure = VideoStructure( + num_rows=num_rows, + num_columns=num_columns, + num_channels=num_channels, + rows_axis=rows_axis, + columns_axis=columns_axis, + channels_axis=channels_axis, + frame_axis=frame_axis, + ) + """ + + num_rows: int + num_columns: int + num_channels: int + rows_axis: int + columns_axis: int + channels_axis: int + frame_axis: int + + def __post_init__(self) -> None: + """Validate the structure of the video and initialize the shape of the frame.""" + self._validate_video_structure() + self._initialize_frame_shape() + self.number_of_pixels_per_frame = np.prod(self.frame_shape) + + def _initialize_frame_shape(self) -> None: + """Initialize the shape of the frame.""" + self.frame_shape = [None, None, None, None] + self.frame_shape[self.rows_axis] = self.num_rows + self.frame_shape[self.columns_axis] = self.num_columns + self.frame_shape[self.channels_axis] = self.num_channels + self.frame_shape.pop(self.frame_axis) + self.frame_shape = tuple(self.frame_shape) + + def _validate_video_structure(self) -> None: + """Validate the structure of the video.""" + exception_message = ( + "Invalid structure: " + f"{self.__repr__()}, " + "each property axis should be unique value between 0 and 3 (inclusive)" + ) + + axis_values = {self.rows_axis, self.columns_axis, self.channels_axis, self.frame_axis} + axis_values_are_not_unique = len(axis_values) != 4 + if axis_values_are_not_unique: + raise ValueError(exception_message) + + values_out_of_range = any([axis < 0 or axis > 4 for axis in axis_values]) + if values_out_of_range: + raise ValueError(exception_message) + + def build_video_shape(self, n_frames: int) -> Tuple[int, int, int, int]: + """Build the shape of the video from class attributes. + + Parameters + ---------- + n_frames : int + The number of frames in the video. + + Returns + ------- + Tuple[int, int, int, int] + The shape of the video. + + Notes + ----- + The class attributes frame_axis, rows_axis, columns_axis and channels_axis are used to determine the order of the + dimensions in the returned tuple. + """ + video_shape = [None] * 4 + video_shape[self.frame_axis] = n_frames + video_shape[self.rows_axis] = self.num_rows + video_shape[self.columns_axis] = self.num_columns + video_shape[self.channels_axis] = self.num_channels + + return tuple(video_shape) + + def transform_video_to_canonical_form(self, video: np.ndarray) -> np.ndarray: + """Transform a video to the canonical internal format of roiextractors (num_frames, num_rows, num_columns, num_channels). + + Parameters + ---------- + video : numpy.ndarray + The video to be transformed + Returns + ------- + numpy.ndarray + The reshaped video + + Raises + ------ + KeyError + If the video is not in a format that can be transformed. + """ + canonical_shape = (self.frame_axis, self.rows_axis, self.columns_axis, self.channels_axis) + if isinstance(video, (h5py.Dataset, zarr.core.Array)): + re_mapped_video = lazy_ops.DatasetView(video).lazy_transpose(canonical_shape) + elif isinstance(video, np.ndarray): + re_mapped_video = video.transpose(canonical_shape) + else: + raise KeyError(f"Function not implemented for specific format {type(video)}") + + return re_mapped_video + + +def read_numpy_memmap_video( + file_path: PathType, video_structure: VideoStructure, dtype: DtypeType, offset: int = 0 +) -> np.array: + """Auxiliary function to read videos from binary files. + + Parameters + ---------- + file_path : PathType + the file_path where the data resides. + video_structure : VideoStructure + A VideoStructure instance describing the structure of the video to read. This includes parameters + such as the number of rows, columns and channels plus which axis (i.e. dimension) of the + image corresponds to each of them. + + As an example you create one of these structures in the following way: + + + num_rows = 10 + num_columns = 5 + num_channels = 3 + frame_axis = 0 + rows_axis = 1 + columns_axis = 2 + channels_axis = 3 + + video_structure = VideoStructure( + num_rows=num_rows, + num_columns=num_columns, + num_channels=num_channels, + rows_axis=rows_axis, + columns_axis=columns_axis, + channels_axis=channels_axis, + frame_axis=frame_axis, + ) + + dtype : DtypeType + The type of the data to be loaded (int, float, etc.) + offset : int, optional + The offset in bytes. Usually corresponds to the number of bytes occupied by the header. 0 by default. + + Returns + ------- + video_memap: np.array + A numpy memmap pointing to the video. + """ + file_size_bytes = Path(file_path).stat().st_size + + pixels_per_frame = video_structure.number_of_pixels_per_frame + type_size = np.dtype(dtype).itemsize + frame_size_bytes = pixels_per_frame * type_size + + bytes_available = file_size_bytes - offset + number_of_frames = bytes_available // frame_size_bytes + + memmap_shape = video_structure.build_video_shape(n_frames=number_of_frames) + video_memap = np.memmap(file_path, offset=offset, dtype=dtype, mode="r", shape=memmap_shape) + + return video_memap class NumpyMemmapImagingExtractor(MemmapImagingExtractor): @@ -40,7 +242,6 @@ def __init__( As an example you create one of these structures in the following way: - from roiextractors.extraction_tools import VideoStructure num_rows = 10 num_columns = 5 diff --git a/src/roiextractors/extractors/miniscopeimagingextractor/miniscopeimagingextractor.py b/src/roiextractors/extractors/miniscopeimagingextractor/miniscopeimagingextractor.py index e619f8b2..cb274fab 100644 --- a/src/roiextractors/extractors/miniscopeimagingextractor/miniscopeimagingextractor.py +++ b/src/roiextractors/extractors/miniscopeimagingextractor/miniscopeimagingextractor.py @@ -15,7 +15,8 @@ from ...imagingextractor import ImagingExtractor from ...multiimagingextractor import MultiImagingExtractor -from ...extraction_tools import PathType, DtypeType, get_package +from ...tools.typing import PathType, DtypeType +from ...tools.importing import get_package class MiniscopeImagingExtractor(MultiImagingExtractor): # TODO: rename to MiniscopeMultiImagingExtractor diff --git a/src/roiextractors/extractors/numpyextractors/numpyextractors.py b/src/roiextractors/extractors/numpyextractors/numpyextractors.py index 1d82ad39..16037f6f 100644 --- a/src/roiextractors/extractors/numpyextractors/numpyextractors.py +++ b/src/roiextractors/extractors/numpyextractors/numpyextractors.py @@ -13,16 +13,14 @@ import numpy as np -from ...extraction_tools import ( +from ...tools.typing import ( PathType, FloatType, ArrayType, IntType, - NoneType, - get_default_roi_locations_from_image_masks, ) from ...imagingextractor import ImagingExtractor -from ...segmentationextractor import SegmentationExtractor +from ...segmentationextractor import SegmentationExtractor, get_default_roi_locations_from_image_masks class NumpyImagingExtractor(ImagingExtractor): diff --git a/src/roiextractors/extractors/nwbextractors/nwbextractors.py b/src/roiextractors/extractors/nwbextractors/nwbextractors.py index 4f6ec2ca..7d0b3849 100644 --- a/src/roiextractors/extractors/nwbextractors/nwbextractors.py +++ b/src/roiextractors/extractors/nwbextractors/nwbextractors.py @@ -16,12 +16,11 @@ from pynwb import NWBHDF5IO from pynwb.ophys import TwoPhotonSeries, OnePhotonSeries -from ...extraction_tools import ( +from ...tools.typing import ( PathType, FloatType, IntType, ArrayType, - raise_multi_channel_or_depth_not_implemented, ) from ...imagingextractor import ImagingExtractor from ...segmentationextractor import SegmentationExtractor diff --git a/src/roiextractors/extractors/sbximagingextractor/sbximagingextractor.py b/src/roiextractors/extractors/sbximagingextractor/sbximagingextractor.py index 9b85dd37..eb999bca 100644 --- a/src/roiextractors/extractors/sbximagingextractor/sbximagingextractor.py +++ b/src/roiextractors/extractors/sbximagingextractor/sbximagingextractor.py @@ -13,7 +13,7 @@ import numpy as np -from ...extraction_tools import PathType, ArrayType, raise_multi_channel_or_depth_not_implemented, check_keys +from ...tools.typing import PathType, ArrayType from ...imagingextractor import ImagingExtractor import scipy.io as spio @@ -206,3 +206,59 @@ def write_imaging(imaging, save_path: PathType, overwrite: bool = False): This function is not implemented yet. """ raise NotImplementedError + + +# TODO: Check if we still need these functions +def check_keys(dict_: dict) -> dict: + """Check keys of dictionary for mat-objects. + + Checks if entries in dictionary are mat-objects. If yes + todict is called to change them to nested dictionaries. + + Parameters + ---------- + dict_: dict + Dictionary to check. + + Returns + ------- + dict: dict + Dictionary with mat-objects converted to nested dictionaries. + + Raises + ------ + AssertionError + If scipy is not installed. + """ + from scipy.io.matlab.mio5_params import mat_struct + + for key in dict_: + if isinstance(dict_[key], mat_struct): + dict_[key] = todict(dict_[key]) + return dict_ + + +def todict(matobj): + """Recursively construct nested dictionaries from matobjects. + + Parameters + ---------- + matobj: mat_struct + Matlab object to convert to nested dictionary. + + Returns + ------- + dict: dict + Dictionary with mat-objects converted to nested dictionaries. + """ + from scipy.io.matlab.mio5_params import mat_struct + + dict_ = {} + + for strg in matobj._fieldnames: + elem = matobj.__dict__[strg] + if isinstance(elem, mat_struct): + dict_[strg] = todict(elem) + else: + dict_[strg] = elem + return dict_ diff --git a/src/roiextractors/extractors/schnitzerextractor/cnmfesegmentationextractor.py b/src/roiextractors/extractors/schnitzerextractor/cnmfesegmentationextractor.py index ae5fc9ba..ae951116 100644 --- a/src/roiextractors/extractors/schnitzerextractor/cnmfesegmentationextractor.py +++ b/src/roiextractors/extractors/schnitzerextractor/cnmfesegmentationextractor.py @@ -13,7 +13,7 @@ from lazy_ops import DatasetView from scipy.sparse import csc_matrix -from ...extraction_tools import PathType +from ...tools.typing import PathType from ...multisegmentationextractor import MultiSegmentationExtractor from ...segmentationextractor import SegmentationExtractor diff --git a/src/roiextractors/extractors/schnitzerextractor/extractsegmentationextractor.py b/src/roiextractors/extractors/schnitzerextractor/extractsegmentationextractor.py index 53e63f8c..0754204b 100644 --- a/src/roiextractors/extractors/schnitzerextractor/extractsegmentationextractor.py +++ b/src/roiextractors/extractors/schnitzerextractor/extractsegmentationextractor.py @@ -20,7 +20,7 @@ import h5py -from ...extraction_tools import PathType, ArrayType +from ...tools.typing import PathType, ArrayType from ...segmentationextractor import SegmentationExtractor diff --git a/src/roiextractors/extractors/simaextractor/simasegmentationextractor.py b/src/roiextractors/extractors/simaextractor/simasegmentationextractor.py index f4310090..9ac4ca1c 100644 --- a/src/roiextractors/extractors/simaextractor/simasegmentationextractor.py +++ b/src/roiextractors/extractors/simaextractor/simasegmentationextractor.py @@ -13,7 +13,7 @@ import numpy as np -from ...extraction_tools import PathType +from ...tools.typing import PathType from ...segmentationextractor import SegmentationExtractor diff --git a/src/roiextractors/extractors/suite2p/suite2psegmentationextractor.py b/src/roiextractors/extractors/suite2p/suite2psegmentationextractor.py index eb55b98c..77242b75 100644 --- a/src/roiextractors/extractors/suite2p/suite2psegmentationextractor.py +++ b/src/roiextractors/extractors/suite2p/suite2psegmentationextractor.py @@ -12,10 +12,9 @@ from warnings import warn import numpy as np -from ...extraction_tools import PathType -from ...extraction_tools import _image_mask_extractor +from ...tools.typing import PathType from ...multisegmentationextractor import MultiSegmentationExtractor -from ...segmentationextractor import SegmentationExtractor +from ...segmentationextractor import SegmentationExtractor, convert_pixel_masks_to_image_masks class Suite2pSegmentationExtractor(SegmentationExtractor): @@ -181,7 +180,7 @@ def __init__( image_mean_name = "meanImg" if channel_name == "chan1" else f"meanImg_chan2" self._image_mean = self.options[image_mean_name] if image_mean_name in self.options else None roi_indices = list(range(self.get_num_rois())) - self._image_masks = _image_mask_extractor( + self._image_masks = convert_pixel_masks_to_image_masks( self.get_roi_pixel_masks(), roi_indices, self.get_image_size(), diff --git a/src/roiextractors/extractors/tiffimagingextractors/brukertiffimagingextractor.py b/src/roiextractors/extractors/tiffimagingextractors/brukertiffimagingextractor.py index 5f908560..9f4376d8 100644 --- a/src/roiextractors/extractors/tiffimagingextractors/brukertiffimagingextractor.py +++ b/src/roiextractors/extractors/tiffimagingextractors/brukertiffimagingextractor.py @@ -22,7 +22,8 @@ from ...multiimagingextractor import MultiImagingExtractor from ...imagingextractor import ImagingExtractor -from ...extraction_tools import PathType, get_package, DtypeType, ArrayType +from ...tools.typing import PathType, DtypeType, ArrayType +from ...tools.importing import get_package def filter_read_uic_tag_warnings(record): diff --git a/src/roiextractors/extractors/tiffimagingextractors/micromanagertiffimagingextractor.py b/src/roiextractors/extractors/tiffimagingextractors/micromanagertiffimagingextractor.py index 1df9f3d7..ec50b849 100644 --- a/src/roiextractors/extractors/tiffimagingextractors/micromanagertiffimagingextractor.py +++ b/src/roiextractors/extractors/tiffimagingextractors/micromanagertiffimagingextractor.py @@ -19,7 +19,8 @@ import numpy as np from ...imagingextractor import ImagingExtractor -from ...extraction_tools import PathType, get_package, DtypeType +from ...tools.typing import PathType, DtypeType +from ...tools.importing import get_package from ...multiimagingextractor import MultiImagingExtractor diff --git a/src/roiextractors/extractors/tiffimagingextractors/scanimagetiff_utils.py b/src/roiextractors/extractors/tiffimagingextractors/scanimagetiff_utils.py index bf559573..5a987f74 100644 --- a/src/roiextractors/extractors/tiffimagingextractors/scanimagetiff_utils.py +++ b/src/roiextractors/extractors/tiffimagingextractors/scanimagetiff_utils.py @@ -2,7 +2,8 @@ import numpy as np import json -from ...extraction_tools import PathType, get_package +from ...tools.typing import PathType +from ...tools.importing import get_package def _get_scanimage_reader() -> type: diff --git a/src/roiextractors/extractors/tiffimagingextractors/scanimagetiffimagingextractor.py b/src/roiextractors/extractors/tiffimagingextractors/scanimagetiffimagingextractor.py index 155ace8f..0da27458 100644 --- a/src/roiextractors/extractors/tiffimagingextractors/scanimagetiffimagingextractor.py +++ b/src/roiextractors/extractors/tiffimagingextractors/scanimagetiffimagingextractor.py @@ -11,7 +11,7 @@ from warnings import warn import numpy as np -from ...extraction_tools import PathType, FloatType, ArrayType, DtypeType, get_package +from ...tools.typing import PathType, FloatType, ArrayType, DtypeType from ...imagingextractor import ImagingExtractor from ...volumetricimagingextractor import VolumetricImagingExtractor from ...multiimagingextractor import MultiImagingExtractor diff --git a/src/roiextractors/extractors/tiffimagingextractors/tiffimagingextractor.py b/src/roiextractors/extractors/tiffimagingextractors/tiffimagingextractor.py index 2d77ae49..9c156755 100644 --- a/src/roiextractors/extractors/tiffimagingextractors/tiffimagingextractor.py +++ b/src/roiextractors/extractors/tiffimagingextractors/tiffimagingextractor.py @@ -15,12 +15,8 @@ from tqdm import tqdm from ...imagingextractor import ImagingExtractor -from ...extraction_tools import ( - PathType, - FloatType, - raise_multi_channel_or_depth_not_implemented, - get_package, -) +from ...tools.typing import PathType, FloatType +from ...tools.importing import get_package class TiffImagingExtractor(ImagingExtractor): diff --git a/src/roiextractors/imagingextractor.py b/src/roiextractors/imagingextractor.py index 987eea63..1c370271 100644 --- a/src/roiextractors/imagingextractor.py +++ b/src/roiextractors/imagingextractor.py @@ -15,7 +15,7 @@ import numpy as np from .baseextractor import BaseExtractor -from .extraction_tools import ArrayType, PathType, DtypeType, FloatType, IntType +from .tools.typing import ArrayType, PathType, DtypeType, FloatType, IntType class ImagingExtractor(BaseExtractor): @@ -137,31 +137,6 @@ def _validate_get_frames_arguments(self, frame_idxs: ArrayType) -> Tuple[int, in return start_frame, end_frame - def __eq__(self, imaging_extractor2): - image_size_equal = self.get_image_size() == imaging_extractor2.get_image_size() - num_frames_equal = self.get_num_frames() == imaging_extractor2.get_num_frames() - sampling_frequency_equal = np.isclose( - self.get_sampling_frequency(), imaging_extractor2.get_sampling_frequency() - ) - dtype_equal = self.get_dtype() == imaging_extractor2.get_dtype() - video_equal = np.array_equal(self.get_video(), imaging_extractor2.get_video()) - times_equal = np.allclose( - self.frame_to_time(np.arange(self.get_num_frames())), - imaging_extractor2.frame_to_time(np.arange(imaging_extractor2.get_num_frames())), - ) - imaging_extractors_equal = all( - [ - image_size_equal, - num_frames_equal, - sampling_frequency_equal, - dtype_equal, - video_equal, - times_equal, - ] - ) - - return imaging_extractors_equal - def frame_slice(self, start_frame: Optional[int] = None, end_frame: Optional[int] = None): """Return a new ImagingExtractor ranging from the start_frame to the end_frame. diff --git a/src/roiextractors/multiimagingextractor.py b/src/roiextractors/multiimagingextractor.py index 294daa24..1352a11d 100644 --- a/src/roiextractors/multiimagingextractor.py +++ b/src/roiextractors/multiimagingextractor.py @@ -11,7 +11,7 @@ import numpy as np -from .extraction_tools import ArrayType, NumpyArray +from .tools.typing import ArrayType from .imagingextractor import ImagingExtractor @@ -99,7 +99,7 @@ def _get_times(self) -> np.ndarray: return times - def _get_frames_from_an_imaging_extractor(self, extractor_index: int, frame_idxs: ArrayType) -> NumpyArray: + def _get_frames_from_an_imaging_extractor(self, extractor_index: int, frame_idxs: ArrayType) -> np.ndarray: """Get frames from a single imaging extractor. Parameters @@ -121,7 +121,7 @@ def _get_frames_from_an_imaging_extractor(self, extractor_index: int, frame_idxs def get_dtype(self): return self._imaging_extractors[0].get_dtype() - def get_frames(self, frame_idxs: ArrayType, channel: Optional[int] = 0) -> NumpyArray: + def get_frames(self, frame_idxs: ArrayType, channel: Optional[int] = 0) -> np.ndarray: if isinstance(frame_idxs, (int, np.integer)): frame_idxs = [frame_idxs] frame_idxs = np.array(frame_idxs) diff --git a/src/roiextractors/segmentationextractor.py b/src/roiextractors/segmentationextractor.py index 6ae3fab1..e74fc369 100644 --- a/src/roiextractors/segmentationextractor.py +++ b/src/roiextractors/segmentationextractor.py @@ -17,8 +17,7 @@ from numpy.typing import ArrayLike from .baseextractor import BaseExtractor -from .extraction_tools import ArrayType, IntType, FloatType -from .extraction_tools import _pixel_mask_extractor +from .tools.typing import ArrayType, IntType, FloatType class SegmentationExtractor(BaseExtractor): @@ -140,10 +139,10 @@ def get_roi_pixel_masks(self, roi_ids=None) -> np.array: ------- pixel_masks: list List of length number of rois, each element is a 2-D array with shape (number_of_non_zero_pixels, 3). - Columns 1 and 2 are the x and y coordinates of the pixel, while the third column represents the weight of + Columns 1 and 2 are the row and column coordinates of the pixel, while the third column represents the weight of the pixel. """ - return _pixel_mask_extractor(image_masks=self.get_roi_image_masks(roi_ids=roi_ids)) + return convert_image_masks_to_pixel_masks(image_masks=self.get_roi_image_masks(roi_ids=roi_ids)) @abstractmethod def get_roi_response_traces( @@ -226,7 +225,7 @@ def get_background_pixel_masks(self, background_ids=None) -> np.array: Columns 1 and 2 are the x and y coordinates of the pixel, while the third column represents the weight of the pixel. """ - return _pixel_mask_extractor(self.get_background_image_masks(background_ids=background_ids)) + return convert_image_masks_to_pixel_masks(self.get_background_image_masks(background_ids=background_ids)) @abstractmethod def get_background_response_traces( @@ -410,3 +409,81 @@ def get_background_response_traces( def get_summary_images(self, names: Optional[list[str]] = None) -> dict: return self._parent_segmentation.get_summary_images(names=names) + + +def convert_image_masks_to_pixel_masks(image_masks: np.ndarray) -> list: + """Convert image masks to pixel masks. + + Pixel masks are an alternative data format for storage of image masks which relies on the sparsity of the images. + The location and weight of each non-zero pixel is stored for each mask. + + Parameters + ---------- + image_masks: numpy.ndarray + Dense representation of the ROIs with shape (number_of_rows, number_of_columns, number_of_rois). + + Returns + ------- + pixel_masks: list + List of length number of rois, each element is a 2-D array with shape (number_of_non_zero_pixels, 3). + Columns 1 and 2 are the row and column coordinates of the pixel, while the third column represents the weight of + the pixel. + """ + pixel_masks = [] + for i in range(image_masks.shape[2]): + image_mask = image_masks[:, :, i] + locs = np.where(image_mask > 0) + pix_values = image_mask[image_mask > 0] + pixel_masks.append(np.vstack((locs[0], locs[1], pix_values)).T) + return pixel_masks + + +def convert_pixel_masks_to_image_masks(pixel_masks: list[np.ndarray], image_shape: tuple) -> np.ndarray: + """Convert pixel masks to image masks. + + Parameters + ---------- + pixel_masks: list[np.ndarray] + List of pixel mask arrays (number_of_non_zero_pixels X 3) for each ROI. + image_shape: tuple + Shape of the image (number_of_rows, number_of_columns). + + Returns + ------- + image_masks: np.ndarray + Dense representation of the ROIs with shape (number_of_rows, number_of_columns, number_of_rois). + """ + shape = (*image_shape, len(pixel_masks)) + image_masks = np.zeros(shape=shape) + for i, pixel_mask in enumerate(pixel_masks): + for row, column, wt in pixel_mask: + image_masks[int(row), int(column), i] = wt + return image_masks + + +def get_default_roi_locations_from_image_masks(image_masks: np.ndarray) -> np.ndarray: + """Calculate the default ROI locations from given image masks. + + This function takes a 3D numpy array of image masks and computes the coordinates (row, column) + of the maximum values in each 2D mask. In the case of a tie, the integer median of the coordinates is used. + The result is a 2D numpy array where each column represents the (row, column) coordinates of the ROI for + each mask. + + Parameters + ---------- + image_masks : np.ndarray + A 3D numpy array of shape (height, width, num_rois) containing the image masks. + + Returns + ------- + np.ndarray + A 2D numpy array of shape (2, num_rois) where each column contains the + (row, column) coordinates of the ROI for each mask. + """ + num_rois = image_masks.shape[2] + roi_locations = np.zeros([2, num_rois], dtype="int") + for i in range(num_rois): + image_mask = image_masks[:, :, i] + max_value_indices = np.where(image_mask == np.amax(image_mask)) + roi_locations[:, i] = np.array([int(np.median(max_value_indices[0])), int(np.median(max_value_indices[1]))]).T + return roi_locations diff --git a/src/roiextractors/testing.py b/src/roiextractors/testing.py deleted file mode 100644 index 3ea0d707..00000000 --- a/src/roiextractors/testing.py +++ /dev/null @@ -1,423 +0,0 @@ -"""Testing utilities for the roiextractors package.""" - -from collections.abc import Iterable -from typing import Tuple, Optional, List - -import numpy as np -from numpy.testing import assert_array_equal, assert_array_almost_equal - -from .segmentationextractor import SegmentationExtractor -from .imagingextractor import ImagingExtractor - -from roiextractors import NumpyImagingExtractor, NumpySegmentationExtractor - -from roiextractors.extraction_tools import DtypeType - -NoneType = type(None) -floattype = (float, np.floating) -inttype = (int, np.integer) - - -def generate_dummy_video(size: Tuple[int], dtype: DtypeType = "uint16", seed: int = 0): - """Generate a dummy video of a given size and dtype. - - Parameters - ---------- - size : Tuple[int] - Size of the video to generate. - dtype : DtypeType, optional - Dtype of the video to generate, by default "uint16". - seed : int, default 0 - seed for the random number generator, by default 0. - - Returns - ------- - video : np.ndarray - A dummy video of the given size and dtype. - """ - dtype = np.dtype(dtype) - number_of_bytes = dtype.itemsize - - rng = np.random.default_rng(seed) - - low = 0 if "u" in dtype.name else 2 ** (number_of_bytes - 1) - 2**number_of_bytes - high = 2**number_of_bytes - 1 if "u" in dtype.name else 2**number_of_bytes - 2 ** (number_of_bytes - 1) - 1 - video = rng.random(size=size) if "float" in dtype.name else rng.integers(low=low, high=high, size=size, dtype=dtype) - - return video - - -def generate_dummy_imaging_extractor( - num_frames: int = 30, - num_rows: int = 10, - num_columns: int = 10, - sampling_frequency: float = 30.0, - dtype: DtypeType = "uint16", - seed: int = 0, -): - """Generate a dummy imaging extractor for testing. - - The imaging extractor is built by feeding random data into the `NumpyImagingExtractor`. - - Parameters - ---------- - num_frames : int, optional - number of frames in the video, by default 30. - num_rows : int, optional - number of rows in the video, by default 10. - num_columns : int, optional - number of columns in the video, by default 10. - sampling_frequency : float, optional - sampling frequency of the video, by default 30. - dtype : DtypeType, optional - dtype of the video, by default "uint16". - seed : int, default 0 - seed for the random number generator, by default 0. - - Returns - ------- - ImagingExtractor - An imaging extractor with random data fed into `NumpyImagingExtractor`. - """ - size = (num_frames, num_rows, num_columns) - video = generate_dummy_video(size=size, dtype=dtype, seed=seed) - imaging_extractor = NumpyImagingExtractor(timeseries=video, sampling_frequency=sampling_frequency) - - return imaging_extractor - - -def generate_dummy_segmentation_extractor( - num_rois: int = 10, - num_frames: int = 30, - num_rows: int = 25, - num_columns: int = 25, - sampling_frequency: float = 30.0, - has_summary_images: bool = True, - has_raw_signal: bool = True, - has_dff_signal: bool = True, - has_deconvolved_signal: bool = True, - has_neuropil_signal: bool = True, - rejected_list: Optional[list] = None, - seed: int = 0, -) -> SegmentationExtractor: - """Generate a dummy segmentation extractor for testing. - - The segmentation extractor is built by feeding random data into the - `NumpySegmentationExtractor`. - - Parameters - ---------- - num_rois : int, optional - number of regions of interest, by default 10. - num_frames : int, optional - Number of frames in the recording, by default 30. - num_rows : int, optional - number of rows in the hypothetical video from which the data was extracted, by default 25. - num_columns : int, optional - number of columns in the hypothetical video from which the data was extracted, by default 25. - sampling_frequency : float, optional - sampling frequency of the hypothetical video from which the data was extracted, by default 30.0. - has_summary_images : bool, optional - whether the dummy segmentation extractor has summary images or not (mean and correlation). - has_raw_signal : bool, optional - whether a raw fluorescence signal is desired in the object, by default True. - has_dff_signal : bool, optional - whether a relative (df/f) fluorescence signal is desired in the object, by default True. - has_deconvolved_signal : bool, optional - whether a deconvolved signal is desired in the object, by default True. - has_neuropil_signal : bool, optional - whether a neuropil signal is desired in the object, by default True. - rejected_list: list, optional - A list of rejected rois, None by default. - seed : int, default 0 - seed for the random number generator, by default 0. - - Returns - ------- - SegmentationExtractor - A segmentation extractor with random data fed into `NumpySegmentationExtractor` - - Notes - ----- - Note that this dummy example is meant to be a mock object with the right shape, structure and objects but does not - contain meaningful content. That is, the image masks matrices are not plausible image mask for a roi, the raw signal - is not a meaningful biological signal and is not related appropriately to the deconvolved signal , etc. - """ - rng = np.random.default_rng(seed) - - # Create dummy image masks - image_masks = rng.random((num_rows, num_columns, num_rois)) - movie_dims = (num_rows, num_columns) - - # Create signals - raw = rng.random((num_frames, num_rois)) if has_raw_signal else None - dff = rng.random((num_frames, num_rois)) if has_dff_signal else None - deconvolved = rng.random((num_frames, num_rois)) if has_deconvolved_signal else None - neuropil = rng.random((num_frames, num_rois)) if has_neuropil_signal else None - - # Summary images - mean_image = rng.random((num_rows, num_columns)) if has_summary_images else None - correlation_image = rng.random((num_rows, num_columns)) if has_summary_images else None - - # Rois - roi_ids = [id for id in range(num_rois)] - roi_locations_rows = rng.integers(low=0, high=num_rows, size=num_rois) - roi_locations_columns = rng.integers(low=0, high=num_columns, size=num_rois) - roi_locations = np.vstack((roi_locations_rows, roi_locations_columns)) - - rejected_list = rejected_list if rejected_list else None - - accepeted_list = roi_ids - if rejected_list is not None: - accepeted_list = list(set(accepeted_list).difference(rejected_list)) - - dummy_segmentation_extractor = NumpySegmentationExtractor( - sampling_frequency=sampling_frequency, - image_masks=image_masks, - raw=raw, - dff=dff, - deconvolved=deconvolved, - neuropil=neuropil, - mean_image=mean_image, - correlation_image=correlation_image, - roi_ids=roi_ids, - roi_locations=roi_locations, - accepted_roi_ids=accepeted_list, - rejected_roi_ids=rejected_list, - movie_dims=movie_dims, - channel_names=["channel_num_0"], - ) - - return dummy_segmentation_extractor - - -def _assert_iterable_shape(iterable, shape): - """Assert that the iterable has the given shape. If the iterable is a numpy array, the shape is checked directly.""" - ar = iterable if isinstance(iterable, np.ndarray) else np.array(iterable) - for ar_shape, given_shape in zip(ar.shape, shape): - if isinstance(given_shape, int): - assert ar_shape == given_shape, f"Expected {given_shape}, received {ar_shape}!" - - -def _assert_iterable_shape_max(iterable, shape_max): - """Assert that the iterable has a shape less than or equal to the given maximum shape.""" - ar = iterable if isinstance(iterable, np.ndarray) else np.array(iterable) - for ar_shape, given_shape in zip(ar.shape, shape_max): - if isinstance(given_shape, int): - assert ar_shape <= given_shape - - -def _assert_iterable_element_dtypes(iterable, dtypes): - """Assert that the iterable has elements of the given dtypes.""" - if isinstance(iterable, Iterable) and not isinstance(iterable, str): - for iter in iterable: - _assert_iterable_element_dtypes(iter, dtypes) - else: - assert isinstance(iterable, dtypes), f"array is none of the types {dtypes}" - - -def _assert_iterable_complete(iterable, dtypes=None, element_dtypes=None, shape=None, shape_max=None): - """Assert that the iterable is complete, i.e. it is not None and has the given dtypes, element_dtypes, shape and shape_max.""" - assert isinstance(iterable, dtypes), f"iterable {type(iterable)} is none of the types {dtypes}" - if not isinstance(iterable, NoneType): - if shape is not None: - _assert_iterable_shape(iterable, shape=shape) - if shape_max is not None: - _assert_iterable_shape_max(iterable, shape_max=shape_max) - if element_dtypes is not None: - _assert_iterable_element_dtypes(iterable, element_dtypes) - - -def check_segmentations_equal( - segmentation_extractor1: SegmentationExtractor, segmentation_extractor2: SegmentationExtractor -): - """Check that two segmentation extractors have equal fields.""" - check_segmentation_return_types(segmentation_extractor1) - check_segmentation_return_types(segmentation_extractor2) - # assert equality: - assert segmentation_extractor1.get_num_rois() == segmentation_extractor2.get_num_rois() - assert segmentation_extractor1.get_num_frames() == segmentation_extractor2.get_num_frames() - assert segmentation_extractor1.get_num_channels() == segmentation_extractor2.get_num_channels() - assert np.isclose( - segmentation_extractor1.get_sampling_frequency(), segmentation_extractor2.get_sampling_frequency() - ) - assert_array_equal(segmentation_extractor1.get_channel_names(), segmentation_extractor2.get_channel_names()) - assert_array_equal(segmentation_extractor1.get_image_size(), segmentation_extractor2.get_image_size()) - assert_array_equal( - segmentation_extractor1.get_roi_image_masks(roi_ids=segmentation_extractor1.get_roi_ids()[:1]), - segmentation_extractor2.get_roi_image_masks(roi_ids=segmentation_extractor2.get_roi_ids()[:1]), - ) - assert set( - segmentation_extractor1.get_roi_pixel_masks(roi_ids=segmentation_extractor1.get_roi_ids()[:1])[0].flatten() - ) == set( - segmentation_extractor2.get_roi_pixel_masks(roi_ids=segmentation_extractor1.get_roi_ids()[:1])[0].flatten() - ) - - check_segmentations_images(segmentation_extractor1, segmentation_extractor2) - - assert_array_equal(segmentation_extractor1.get_accepted_list(), segmentation_extractor2.get_accepted_list()) - assert_array_equal(segmentation_extractor1.get_rejected_list(), segmentation_extractor2.get_rejected_list()) - assert_array_equal(segmentation_extractor1.get_roi_locations(), segmentation_extractor2.get_roi_locations()) - assert_array_equal(segmentation_extractor1.get_roi_ids(), segmentation_extractor2.get_roi_ids()) - assert_array_equal(segmentation_extractor1.get_traces(), segmentation_extractor2.get_traces()) - - assert_array_equal( - segmentation_extractor1.frame_to_time(np.arange(segmentation_extractor1.get_num_frames())), - segmentation_extractor2.frame_to_time(np.arange(segmentation_extractor2.get_num_frames())), - ) - - -def check_segmentations_images( - segmentation_extractor1: SegmentationExtractor, - segmentation_extractor2: SegmentationExtractor, -): - """Check that the segmentation images are equal for the given segmentation extractors.""" - images_in_extractor1 = segmentation_extractor1.get_images_dict() - images_in_extractor2 = segmentation_extractor2.get_images_dict() - - assert len(images_in_extractor1) == len(images_in_extractor2) - - image_names_are_equal = all(image_name in images_in_extractor1.keys() for image_name in images_in_extractor2.keys()) - assert image_names_are_equal, "The names of segmentation images in the segmentation extractors are not the same." - - for image_name in images_in_extractor1.keys(): - assert_array_equal( - images_in_extractor1[image_name], - images_in_extractor2[image_name], - ), f"The segmentation images for {image_name} are not equal." - - -def check_segmentation_return_types(seg: SegmentationExtractor): - """Check that the return types of the segmentation extractor are correct.""" - assert isinstance(seg.get_num_rois(), int) - assert isinstance(seg.get_num_frames(), int) - assert isinstance(seg.get_num_channels(), int) - assert isinstance(seg.get_sampling_frequency(), (NoneType, floattype)) - _assert_iterable_complete( - seg.get_channel_names(), - dtypes=list, - element_dtypes=str, - shape_max=(seg.get_num_channels(),), - ) - _assert_iterable_complete(seg.get_image_size(), dtypes=Iterable, element_dtypes=inttype, shape=(2,)) - _assert_iterable_complete( - seg.get_roi_image_masks(roi_ids=seg.get_roi_ids()[:1]), - dtypes=(np.ndarray,), - element_dtypes=floattype, - shape=(*seg.get_image_size(), 1), - ) - _assert_iterable_complete( - seg.get_roi_ids(), - dtypes=(list,), - shape=(seg.get_num_rois(),), - element_dtypes=inttype, - ) - assert isinstance(seg.get_roi_pixel_masks(roi_ids=seg.get_roi_ids()[:2]), list) - _assert_iterable_complete( - seg.get_roi_pixel_masks(roi_ids=seg.get_roi_ids()[:1])[0], - dtypes=(np.ndarray,), - element_dtypes=floattype, - shape_max=(np.prod(seg.get_image_size()), 3), - ) - for image_name in seg.get_images_dict(): - _assert_iterable_complete( - seg.get_image(image_name), - dtypes=(np.ndarray, NoneType), - element_dtypes=floattype, - shape_max=(*seg.get_image_size(),), - ) - _assert_iterable_complete( - seg.get_accepted_list(), - dtypes=(list, NoneType), - element_dtypes=inttype, - shape_max=(seg.get_num_rois(),), - ) - _assert_iterable_complete( - seg.get_rejected_list(), - dtypes=(list, NoneType), - element_dtypes=inttype, - shape_max=(seg.get_num_rois(),), - ) - _assert_iterable_complete( - seg.get_roi_locations(), - dtypes=(np.ndarray,), - shape=(2, seg.get_num_rois()), - element_dtypes=inttype, - ) - _assert_iterable_complete( - seg.get_traces(), - dtypes=(np.ndarray, NoneType), - element_dtypes=floattype, - shape=(np.prod(seg.get_num_rois()), None), - ) - assert isinstance(seg.get_traces_dict(), dict) - assert isinstance(seg.get_images_dict(), dict) - assert {"raw", "dff", "neuropil", "deconvolved", "denoised"} == set(seg.get_traces_dict().keys()) - - -def check_imaging_equal(imaging_extractor1: ImagingExtractor, imaging_extractor2: ImagingExtractor): - """Check that two imaging extractors have equal fields.""" - assert imaging_extractor1.get_image_size() == imaging_extractor2.get_image_size() - assert imaging_extractor1.get_num_frames() == imaging_extractor2.get_num_frames() - assert np.close(imaging_extractor1.get_sampling_frequency(), imaging_extractor2.get_sampling_frequency()) - assert imaging_extractor1.get_dtype() == imaging_extractor2.get_dtype() - assert_array_equal(imaging_extractor1.get_video(), imaging_extractor2.get_video()) - assert_array_almost_equal( - imaging_extractor1.frame_to_time(np.arange(imaging_extractor1.get_num_frames())), - imaging_extractor2.frame_to_time(np.arange(imaging_extractor2.get_num_frames())), - ) - - -def assert_get_frames_return_shape(imaging_extractor: ImagingExtractor): - """Check whether an ImagingExtractor get_frames function behaves as expected. - - We aim for the function to behave as numpy slicing and indexing as much as possible. - """ - image_size = imaging_extractor.get_image_size() - - frame_idxs = 0 - frames_with_scalar = imaging_extractor.get_frames(frame_idxs=frame_idxs, channel=0) - assert frames_with_scalar.shape == image_size, "get_frames does not work correctly with frame_idxs=0" - - frame_idxs = [0] - frames_with_single_element_list = imaging_extractor.get_frames(frame_idxs=frame_idxs, channel=0) - assert_msg = "get_frames does not work correctly with frame_idxs=[0]" - assert frames_with_single_element_list.shape == (1, image_size[0], image_size[1]), assert_msg - - frame_idxs = [0, 1] - frames_with_list = imaging_extractor.get_frames(frame_idxs=frame_idxs, channel=0) - assert_msg = "get_frames does not work correctly with frame_idxs=[0, 1]" - assert frames_with_list.shape == (2, image_size[0], image_size[1]), assert_msg - - frame_idxs = np.array([0, 1]) - frames_with_array = imaging_extractor.get_frames(frame_idxs=frame_idxs, channel=0) - assert_msg = "get_frames does not work correctly with frame_idxs=np.arrray([0, 1])" - assert frames_with_array.shape == (2, image_size[0], image_size[1]), assert_msg - - frame_idxs = [0, 2] - frames_with_array = imaging_extractor.get_frames(frame_idxs=frame_idxs, channel=0) - assert_msg = "get_frames does not work correctly with frame_idxs=[0, 2]" - assert frames_with_array.shape == (2, image_size[0], image_size[1]), assert_msg - - -def check_imaging_return_types(img_ex: ImagingExtractor): - """Check that the return types of the imaging extractor are correct.""" - assert isinstance(img_ex.get_num_frames(), inttype) - assert isinstance(img_ex.get_num_channels(), inttype) - assert isinstance(img_ex.get_sampling_frequency(), floattype) - _assert_iterable_complete( - iterable=img_ex.get_channel_names(), - dtypes=(list, NoneType), - element_dtypes=str, - shape_max=(img_ex.get_num_channels(),), - ) - _assert_iterable_complete(iterable=img_ex.get_image_size(), dtypes=Iterable, element_dtypes=inttype, shape=(2,)) - - # This needs a method for getting frame shape not image size. It only works for n_channel==1 - # two_first_frames = img_ex.get_frames(frame_idxs=[0, 1]) - # _assert_iterable_complete( - # iterable=two_first_frames, - # dtypes=(np.ndarray,), - # element_dtypes=inttype + floattype, - # shape=(2, *img_ex.get_image_size()), - # ) diff --git a/src/roiextractors/tools/__init__.py b/src/roiextractors/tools/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/roiextractors/tools/importing.py b/src/roiextractors/tools/importing.py new file mode 100644 index 00000000..56f8e6bd --- /dev/null +++ b/src/roiextractors/tools/importing.py @@ -0,0 +1,59 @@ +import sys +import importlib.util +from typing import Optional, Dict, List +from types import ModuleType +from platform import python_version +from packaging import version + + +def get_package( + package_name: str, + installation_instructions: Optional[str] = None, + excluded_platforms_and_python_versions: Optional[Dict[str, List[str]]] = None, +) -> ModuleType: + """Check if package is installed and return module if so. + + Otherwise, raise informative error describing how to perform the installation. + Inspired by https://docs.python.org/3/library/importlib.html#checking-if-a-module-can-be-imported. + + Parameters + ---------- + package_name : str + Name of the package to be imported. + installation_instructions : str, optional + String describing the source, options, and alias of package name (if needed) for installation. + For example, + >>> installation_source = "conda install -c conda-forge my-package-name" + Defaults to f"pip install {package_name}". + excluded_platforms_and_python_versions : dict mapping string platform names to a list of string versions, optional + In case some combinations of platforms or Python versions are not allowed for the given package, specify + this dictionary to raise a more specific error to that issue. + For example, `excluded_platforms_and_python_versions = dict(darwin=["3.7"])` will raise an informative error + when running on MacOS with Python version 3.7. + Allows all platforms and Python versions used by default. + + Raises + ------ + ModuleNotFoundError + If the package is not installed. + """ + installation_instructions = installation_instructions or f"pip install {package_name}" + excluded_platforms_and_python_versions = excluded_platforms_and_python_versions or dict() + + if package_name in sys.modules: + return sys.modules[package_name] + + if importlib.util.find_spec(package_name) is not None: + return importlib.import_module(name=package_name) + + for excluded_version in excluded_platforms_and_python_versions.get(sys.platform, list()): + if version.parse(python_version()).minor == version.parse(excluded_version).minor: + raise ModuleNotFoundError( + f"\nThe package '{package_name}' is not available on the {sys.platform} platform for " + f"Python version {excluded_version}!" + ) + + raise ModuleNotFoundError( + f"\nThe required package'{package_name}' is not installed!\n" + f"To install this package, please run\n\n\t{installation_instructions}\n" + ) diff --git a/src/roiextractors/tools/plotting.py b/src/roiextractors/tools/plotting.py new file mode 100644 index 00000000..9c440bae --- /dev/null +++ b/src/roiextractors/tools/plotting.py @@ -0,0 +1,38 @@ +def show_video(imaging, ax=None): + """Show video as animation. + + Parameters + ---------- + imaging: ImagingExtractor + The imaging extractor object to be saved in the .h5 file + ax: matplotlib axis + Axis to plot the video. If None, a new axis is created. + + Returns + ------- + anim: matplotlib.animation.FuncAnimation + Animation of the video. + """ + import matplotlib.pyplot as plt + import matplotlib.animation as animation + + def animate_func(i, imaging, im, ax): + ax.set_title(f"{i}") + im.set_array(imaging.get_frames([i])[0]) + return [im] + + if ax is None: + fig = plt.figure(figsize=(5, 5)) + ax = fig.add_subplot(111) + im0 = imaging.get_frames([0])[0] + im = ax.imshow(im0, interpolation="none", aspect="auto", vmin=0, vmax=1) + interval = 1 / imaging.get_sampling_frequency() * 1000 + anim = animation.FuncAnimation( + fig, + animate_func, + frames=imaging.get_num_frames(), + fargs=(imaging, im, ax), + interval=interval, + blit=False, + ) + return anim diff --git a/src/roiextractors/tools/testing.py b/src/roiextractors/tools/testing.py new file mode 100644 index 00000000..dd1057ea --- /dev/null +++ b/src/roiextractors/tools/testing.py @@ -0,0 +1,496 @@ +"""Testing utilities for the roiextractors package.""" + +from collections.abc import Iterable +from typing import Tuple, Optional, List + +import numpy as np +from numpy.testing import assert_array_equal, assert_array_almost_equal + +from ..segmentationextractor import SegmentationExtractor +from ..imagingextractor import ImagingExtractor + +from roiextractors import NumpyImagingExtractor, NumpySegmentationExtractor + +from roiextractors.tools.typing import DtypeType, ArrayType + + +def generate_mock_video(size: Tuple[int], dtype: DtypeType = "uint16", seed: int = 0): + """Generate a mock video of a given size and dtype. + + Parameters + ---------- + size : Tuple[int] + Size of the video to generate. + dtype : DtypeType, optional + Dtype of the video to generate, by default "uint16". + seed : int, default 0 + seed for the random number generator, by default 0. + + Returns + ------- + video : np.ndarray + A mock video of the given size and dtype. + """ + dtype = np.dtype(dtype) + number_of_bytes = dtype.itemsize + + rng = np.random.default_rng(seed) + + low = 0 if "u" in dtype.name else 2 ** (number_of_bytes - 1) - 2**number_of_bytes + high = 2**number_of_bytes - 1 if "u" in dtype.name else 2**number_of_bytes - 2 ** (number_of_bytes - 1) - 1 + video = ( + rng.random(size=size, dtype=dtype) + if "float" in dtype.name + else rng.integers(low=low, high=high, size=size, dtype=dtype) + ) + + return video + + +def generate_mock_imaging_extractor( + num_frames: int = 30, + num_rows: int = 10, + num_columns: int = 10, + sampling_frequency: float = 30.0, + dtype: DtypeType = "uint16", + seed: int = 0, +): + """Generate a mock imaging extractor for testing. + + The imaging extractor is built by feeding random data into the `NumpyImagingExtractor`. + + Parameters + ---------- + num_frames : int, optional + number of frames in the video, by default 30. + num_rows : int, optional + number of rows in the video, by default 10. + num_columns : int, optional + number of columns in the video, by default 10. + sampling_frequency : float, optional + sampling frequency of the video, by default 30. + dtype : DtypeType, optional + dtype of the video, by default "uint16". + seed : int, default 0 + seed for the random number generator, by default 0. + + Returns + ------- + NumpyImagingExtractor + An imaging extractor with random data fed into `NumpyImagingExtractor`. + """ + size = (num_frames, num_rows, num_columns) + video = generate_mock_video(size=size, dtype=dtype, seed=seed) + imaging_extractor = NumpyImagingExtractor(timeseries=video, sampling_frequency=sampling_frequency) + + return imaging_extractor + + +def assert_imaging_equal(imaging_extractor1: ImagingExtractor, imaging_extractor2: ImagingExtractor): + """Assert that two ImagingExtractor objects are equal by comparing their attributes and data. + + Parameters + ---------- + imaging_extractor1 : ImagingExtractor + The first ImagingExtractor object to compare. + imaging_extractor2 : ImagingExtractor + The second ImagingExtractor object to compare. + + Raises + ------ + AssertionError + If any of the following attributes or data do not match between the two ImagingExtractor objects: + - Image size + - Number of frames + - Sampling frequency + - Data type (dtype) + - Video data + - Time points (_times) + """ + assert ( + imaging_extractor1.get_image_size() == imaging_extractor2.get_image_size() + ), "ImagingExtractors are not equal: image_sizes do not match." + assert ( + imaging_extractor1.get_num_frames() == imaging_extractor2.get_num_frames() + ), "ImagingExtractors are not equal: num_frames do not match." + assert np.isclose( + imaging_extractor1.get_sampling_frequency(), imaging_extractor2.get_sampling_frequency() + ), "ImagingExtractors are not equal: sampling_frequencies do not match." + assert ( + imaging_extractor1.get_dtype() == imaging_extractor2.get_dtype() + ), "ImagingExtractors are not equal: dtypes do not match." + assert_array_equal( + imaging_extractor1.get_video(), + imaging_extractor2.get_video(), + err_msg="ImagingExtractors are not equal: videos do not match.", + ) + assert_array_equal( + imaging_extractor1._times, + imaging_extractor2._times, + err_msg="ImagingExtractors are not equal: _times do not match.", + ) + + +def imaging_equal(imaging_extractor1: ImagingExtractor, imaging_extractor2: ImagingExtractor) -> bool: + """Return True if two ImagingExtractors are equal, False otherwise. + + Parameters + ---------- + imaging_extractor1 : ImagingExtractor + The first ImagingExtractor object to compare. + imaging_extractor2 : ImagingExtractor + The second ImagingExtractor object to compare. + + Returns + ------- + bool + True if all of the following fields match between the two ImagingExtractor objects: + - Image size + - Number of frames + - Sampling frequency + - Data type (dtype) + - Video data + - Time points (_times) + """ + try: + assert_imaging_equal(imaging_extractor1, imaging_extractor2) + return True + except AssertionError: + return False + + +def generate_mock_segmentation_extractor( + num_rois: int = 10, + num_frames: int = 30, + num_rows: int = 25, + num_columns: int = 25, + num_background_components: int = 2, + sampling_frequency: float = 30.0, + summary_image_names: List[str] = ["mean", "correlation"], + roi_response_names: List[str] = ["raw", "dff", "deconvolved", "denoised"], + background_response_names: List[str] = ["background"], + rejected_roi_ids: Optional[list] = None, + roi_locations: Optional[ArrayType] = None, + image_masks: Optional[ArrayType] = None, + roi_response_traces: Optional[dict] = None, + background_image_masks: Optional[ArrayType] = None, + background_response_traces: Optional[dict] = None, + summary_images: Optional[dict] = None, + seed: int = 0, +) -> NumpySegmentationExtractor: + """Generate a mock segmentation extractor for testing. + + The segmentation extractor is built by feeding random data into the + `NumpySegmentationExtractor`. + + Parameters + ---------- + num_rois : int, optional + number of regions of interest, by default 10. + num_frames : int, optional + Number of frames in the recording, by default 30. + num_rows : int, optional + number of rows in the hypothetical video from which the data was extracted, by default 25. + num_columns : int, optional + number of columns in the hypothetical video from which the data was extracted, by default 25. + num_background_components : int, optional + number of background components, by default 2. + sampling_frequency : float, optional + sampling frequency of the hypothetical video from which the data was extracted, by default 30.0. + summary_image_names : List[str], optional + names of summary images, by default ["mean", "correlation"]. + roi_response_names : List[str], optional + names of roi response traces, by default ["raw", "dff", "deconvolved", "denoised"]. + background_response_names : List[str], optional + names of background response traces, by default ["background"]. + rejected_roi_ids: Optional[list], optional + A list of rejected rois, None by default. + roi_locations : Optional[ArrayType], optional + A 2D array of shape (2, num_rois) containing the locations of the rois, None by default. + seed : int, default 0 + seed for the random number generator, by default 0. + + Returns + ------- + NumpySegmentationExtractor + A segmentation extractor with random data fed into `NumpySegmentationExtractor` + + Notes + ----- + Note that this dummy example is meant to be a mock object with the right shape, structure and objects but does not + contain meaningful content. That is, the image masks matrices are not plausible image mask for a roi, the raw signal + is not a meaningful biological signal and is not related appropriately to the deconvolved signal , etc. + """ + rng = np.random.default_rng(seed) + + # Create dummy image masks + if image_masks is None: + image_masks = rng.random((num_rows, num_columns, num_rois)) + else: + assert image_masks.shape == ( + num_rows, + num_columns, + num_rois, + ), f"image_masks should have shape (num_rows, num_columns, num_rois) but got {image_masks.shape}." + if background_image_masks is None: + background_image_masks = rng.random((num_rows, num_columns, num_background_components)) + else: + assert background_image_masks.shape == ( + num_rows, + num_columns, + num_background_components, + ), f"background_image_masks should have shape (num_rows, num_columns, num_background_components) but got {background_image_masks.shape}." + + # Create signals + if roi_response_traces is None: + roi_response_traces = {name: rng.random((num_frames, num_rois)) for name in roi_response_names} + else: + for name, trace in roi_response_traces.items(): + assert trace.shape == ( + num_frames, + num_rois, + ), f"roi_response_traces[{name}] should have shape (num_frames, num_rois) but got {trace.shape}." + if background_response_traces is None: + background_response_traces = { + name: rng.random((num_frames, num_background_components)) for name in background_response_names + } + else: + for name, trace in background_response_traces.items(): + assert trace.shape == ( + num_frames, + num_background_components, + ), f"background_response_traces[{name}] should have shape (num_frames, num_background_components) but got {trace.shape}." + + # Summary images + if summary_images is None: + summary_images = {name: rng.random((num_rows, num_columns)) for name in summary_image_names} + else: + for name, image in summary_images.items(): + assert image.shape == ( + num_rows, + num_columns, + ), f"summary_images[{name}] should have shape (num_rows, num_columns) but got {image.shape}." + + # Rois + roi_ids = [id for id in range(num_rois)] + if roi_locations is None: + roi_locations_rows = rng.integers(low=0, high=num_rows, size=num_rois) + roi_locations_columns = rng.integers(low=0, high=num_columns, size=num_rois) + roi_locations = np.vstack((roi_locations_rows, roi_locations_columns)) + else: + assert roi_locations.shape == ( + 2, + num_rois, + ), f"roi_locations should have shape (2, num_rois) but got {roi_locations.shape}." + background_ids = [i for i in range(num_background_components)] + + rejected_roi_ids = rejected_roi_ids if rejected_roi_ids else None + + accepted_roi_ids = roi_ids + if rejected_roi_ids is not None: + accepted_roi_ids = list(set(accepted_roi_ids).difference(rejected_roi_ids)) + + dummy_segmentation_extractor = NumpySegmentationExtractor( + image_masks=image_masks, + roi_response_traces=roi_response_traces, + sampling_frequency=sampling_frequency, + roi_ids=roi_ids, + accepted_roi_ids=accepted_roi_ids, + rejected_roi_ids=rejected_roi_ids, + roi_locations=roi_locations, + summary_images=summary_images, + background_image_masks=background_image_masks, + background_response_traces=background_response_traces, + background_ids=background_ids, + ) + + return dummy_segmentation_extractor + + +def assert_segmentation_equal( + segmentation_extractor1: SegmentationExtractor, segmentation_extractor2: SegmentationExtractor +): + """Assert that two segmentation extractors have equal fields. + + Parameters + ---------- + segmentation_extractor1 : SegmentationExtractor + First segmentation extractor to compare. + segmentation_extractor2 : SegmentationExtractor + Second segmentation extractor to compare. + + Raises + ------ + AssertionError + If any of the following attributes or data do not match between the two SegmentationExtractor objects: + - image_size + - num_frames + - sampling_frequency + - _times + - roi_ids + - num_rois + - accepted_roi_ids + - rejected_roi_ids + - roi_locations + - roi_image_masks + - roi_pixel_masks + - roi_response_traces + - background_ids + - num_background_components + - background_image_masks + - background_response_traces + - summary_images + """ + assert ( + segmentation_extractor1.get_image_size() == segmentation_extractor2.get_image_size() + ), "SegmentationExtractors are not equal: image_sizes do not match." + assert ( + segmentation_extractor1.get_num_frames() == segmentation_extractor2.get_num_frames() + ), "SegmentationExtractors are not equal: num_frames do not match." + assert ( + segmentation_extractor1.get_sampling_frequency() == segmentation_extractor2.get_sampling_frequency() + ), "SegmentationExtractors are not equal: sampling_frequencies do not match." + assert_array_equal( + segmentation_extractor1._times, + segmentation_extractor2._times, + err_msg="SegmentationExtractors are not equal: _times do not match.", + ) + assert_array_equal( + segmentation_extractor1.get_roi_ids(), + segmentation_extractor2.get_roi_ids(), + err_msg="SegmentationExtractors are not equal: roi_ids do not match.", + ) + assert ( + segmentation_extractor1.get_num_rois() == segmentation_extractor2.get_num_rois() + ), "SegmentationExtractors are not equal: num_rois do not match." + assert_array_equal( + segmentation_extractor1.get_accepted_roi_ids(), + segmentation_extractor2.get_accepted_roi_ids(), + err_msg="SegmentationExtractors are not equal: accepted_roi_ids do not match.", + ) + assert_array_equal( + segmentation_extractor1.get_rejected_roi_ids(), + segmentation_extractor2.get_rejected_roi_ids(), + err_msg="SegmentationExtractors are not equal: rejected_roi_ids do not match.", + ) + assert_array_equal( + segmentation_extractor1.get_roi_locations(), + segmentation_extractor2.get_roi_locations(), + err_msg="SegmentationExtractors are not equal: roi_locations do not match.", + ) + assert_array_equal( + segmentation_extractor1.get_roi_image_masks(), + segmentation_extractor2.get_roi_image_masks(), + err_msg="SegmentationExtractors are not equal: roi_image_masks do not match.", + ) + for pixel_mask1, pixel_mask2 in zip( + segmentation_extractor1.get_roi_pixel_masks(), segmentation_extractor2.get_roi_pixel_masks() + ): + assert_array_equal( + pixel_mask1, + pixel_mask2, + err_msg="SegmentationExtractors are not equal: roi_pixel_masks do not match.", + ) + roi_response_traces1 = segmentation_extractor1.get_roi_response_traces() + roi_response_traces2 = segmentation_extractor2.get_roi_response_traces() + for name, trace1 in roi_response_traces1.items(): + assert ( + name in roi_response_traces2 + ), f"SegmentationExtractors are not equal: SegmentationExtractor1 has roi_response_trace {name} but SegmentationExtractor2 does not." + trace2 = roi_response_traces2[name] + assert_array_equal( + trace1, + trace2, + err_msg=f"SegmentationExtractors are not equal: roi_response_trace {name} does not match.", + ) + for name2 in roi_response_traces2: + assert ( + name2 in roi_response_traces1 + ), f"SegmentationExtractors are not equal: SegmentationExtractor2 has roi_response_trace {name2} but SegmentationExtractor1 does not." + assert_array_equal( + segmentation_extractor1.get_background_ids(), + segmentation_extractor2.get_background_ids(), + err_msg="SegmentationExtractors are not equal: background_ids do not match.", + ) + assert ( + segmentation_extractor1.get_num_background_components() + == segmentation_extractor2.get_num_background_components() + ), "SegmentationExtractors are not equal: num_background_components do not match." + assert_array_equal( + segmentation_extractor1.get_background_image_masks(), + segmentation_extractor2.get_background_image_masks(), + err_msg="SegmentationExtractors are not equal: background_image_masks do not match.", + ) + background_response_traces1 = segmentation_extractor1.get_background_response_traces() + background_response_traces2 = segmentation_extractor2.get_background_response_traces() + for name, trace1 in background_response_traces1.items(): + assert ( + name in background_response_traces2 + ), f"SegmentationExtractors are not equal: SegmentationExtractor1 has background_response_trace {name} but SegmentationExtractor2 does not." + trace2 = background_response_traces2[name] + assert_array_equal( + trace1, + trace2, + err_msg=f"SegmentationExtractors are not equal: background_response_trace {name} does not match.", + ) + for name2 in background_response_traces2: + assert ( + name2 in background_response_traces1 + ), f"SegmentationExtractors are not equal: SegmentationExtractor2 has background_response_trace {name2} but SegmentationExtractor1 does not." + summary_images1 = segmentation_extractor1.get_summary_images() + summary_images2 = segmentation_extractor2.get_summary_images() + for name, image1 in summary_images1.items(): + assert ( + name in summary_images2 + ), f"SegmentationExtractors are not equal: SegmentationExtractor1 has summary_image {name} but SegmentationExtractor2 does not." + image2 = summary_images2[name] + assert_array_equal( + image1, + image2, + err_msg=f"SegmentationExtractors are not equal: summary_image {name} does not match.", + ) + for name2 in summary_images2: + assert ( + name2 in summary_images1 + ), f"SegmentationExtractors are not equal: SegmentationExtractor2 has summary_image {name2} but SegmentationExtractor1 does not." + + +def segmentation_equal( + segmentation_extractor1: SegmentationExtractor, segmentation_extractor2: SegmentationExtractor +) -> bool: + """Return True if two SegmentationExtractors have equal fields, False otherwise. + + Parameters + ---------- + segmentation_extractor1 : SegmentationExtractor + First segmentation extractor to compare. + segmentation_extractor2 : SegmentationExtractor + Second segmentation extractor to compare. + + Returns + ------- + bool + True if all of the following fields match between the two SegmentationExtractor objects: + - image_size + - num_frames + - sampling_frequency + - _times + - roi_ids + - num_rois + - accepted_roi_ids + - rejected_roi_ids + - roi_locations + - roi_image_masks + - roi_pixel_masks + - roi_response_traces + - background_ids + - num_background_components + - background_image_masks + - background_response_traces + - summary_images + """ + try: + assert_segmentation_equal(segmentation_extractor1, segmentation_extractor2) + return True + except AssertionError: + return False diff --git a/src/roiextractors/tools/typing.py b/src/roiextractors/tools/typing.py new file mode 100644 index 00000000..35b5bae6 --- /dev/null +++ b/src/roiextractors/tools/typing.py @@ -0,0 +1,11 @@ +from typing import Union +from pathlib import Path +import numpy as np +from numpy.typing import ArrayLike, DTypeLike + +ArrayType = ArrayLike +PathType = Union[str, Path] +DtypeType = DTypeLike +IntType = Union[int, np.integer] +FloatType = Union[float, np.floating] +NoneType = type(None) diff --git a/src/roiextractors/volumetricimagingextractor.py b/src/roiextractors/volumetricimagingextractor.py index 893ae039..2b53b7ee 100644 --- a/src/roiextractors/volumetricimagingextractor.py +++ b/src/roiextractors/volumetricimagingextractor.py @@ -3,7 +3,7 @@ from typing import Tuple, List, Iterable, Optional import numpy as np -from .extraction_tools import ArrayType, DtypeType +from .tools.typing import ArrayType, DtypeType from .imagingextractor import ImagingExtractor diff --git a/tests/mixins/imaging_extractor_mixin.py b/tests/mixins/imaging_extractor_mixin.py index aa2c2c59..b7901c65 100644 --- a/tests/mixins/imaging_extractor_mixin.py +++ b/tests/mixins/imaging_extractor_mixin.py @@ -65,9 +65,6 @@ def test_get_frames_invalid_frame_idxs(self, imaging_extractor): with pytest.raises(AssertionError): imaging_extractor.get_frames(frame_idxs=[0.5]) - def test_eq(self, imaging_extractor, imaging_extractor2): - assert imaging_extractor == imaging_extractor2 - @pytest.mark.parametrize("start_frame, end_frame", [(None, None), (1, 3), (0, 1)]) def test_frame_slice(self, imaging_extractor, start_frame, end_frame): frame_slice_imaging_extractor = imaging_extractor.frame_slice(start_frame=start_frame, end_frame=end_frame) @@ -218,9 +215,6 @@ def test_copy_times_frame_slice(self, frame_slice_imaging_extractor, frame_slice assert np.array_equal(frame_slice_imaging_extractor2._times, expected_times) assert frame_slice_imaging_extractor2._times is not expected_times - def test_eq_frame_slice(self, frame_slice_imaging_extractor, frame_slice_imaging_extractor2): - assert frame_slice_imaging_extractor == frame_slice_imaging_extractor2 - @pytest.mark.parametrize("start_frame, end_frame", [(None, None), (1, 2), (0, 1)]) def test_frame_slice_on_frame_slice(self, frame_slice_imaging_extractor, start_frame, end_frame): twice_sliced_imaging_extractor = frame_slice_imaging_extractor.frame_slice( diff --git a/tests/test_minimal/test_numpy_imaging_extractor.py b/tests/test_minimal/test_numpy_imaging_extractor.py index 332032f7..3dd8833f 100644 --- a/tests/test_minimal/test_numpy_imaging_extractor.py +++ b/tests/test_minimal/test_numpy_imaging_extractor.py @@ -1,6 +1,6 @@ from ..mixins.imaging_extractor_mixin import ImagingExtractorMixin, FrameSliceImagingExtractorMixin from roiextractors import NumpyImagingExtractor -from roiextractors.testing import generate_dummy_video +from roiextractors.tools.testing import generate_mock_video import pytest import numpy as np @@ -8,7 +8,7 @@ class TestNumpyImagingExtractor(ImagingExtractorMixin, FrameSliceImagingExtractorMixin): @pytest.fixture(scope="class") def expected_video(self): - return generate_dummy_video(size=(3, 2, 4)) + return generate_mock_video(size=(3, 2, 4)) @pytest.fixture(scope="class") def expected_sampling_frequency(self): @@ -26,7 +26,7 @@ def imaging_extractor2(self, expected_video, expected_sampling_frequency): class TestNumpyImagingExtractorFromFile(ImagingExtractorMixin, FrameSliceImagingExtractorMixin): @pytest.fixture(scope="class") def expected_video(self): - return generate_dummy_video(size=(3, 2, 4)) + return generate_mock_video(size=(3, 2, 4)) @pytest.fixture(scope="class") def expected_sampling_frequency(self): diff --git a/tests/test_minimal/test_numpy_segmentation_extractor.py b/tests/test_minimal/test_numpy_segmentation_extractor.py index 0b92e77e..743982a6 100644 --- a/tests/test_minimal/test_numpy_segmentation_extractor.py +++ b/tests/test_minimal/test_numpy_segmentation_extractor.py @@ -232,7 +232,7 @@ def segmentation_extractor2( ) name_to_file_path = {} for name, ndarray in name_to_ndarray.items(): - file_path = tmp_path / f"{name}.npy" + file_path = tmp_path / f"{name}2.npy" file_path.parent.mkdir(parents=True, exist_ok=True) np.save(file_path, ndarray) name_to_file_path[name] = file_path @@ -245,7 +245,7 @@ def segmentation_extractor2( for name, dict_of_ndarrays in name_to_dict_of_ndarrays.items(): name_to_dict_of_file_paths[name] = {} for key, ndarray in dict_of_ndarrays.items(): - file_path = tmp_path / f"{name}_{key}.npy" + file_path = tmp_path / f"{name}_{key}2.npy" np.save(file_path, ndarray) name_to_dict_of_file_paths[name][key] = file_path diff --git a/tests/test_minimal/test_segmentation_extractor_functions.py b/tests/test_minimal/test_segmentation_extractor_functions.py new file mode 100644 index 00000000..30403236 --- /dev/null +++ b/tests/test_minimal/test_segmentation_extractor_functions.py @@ -0,0 +1,132 @@ +import pytest +import numpy as np + +from roiextractors.segmentationextractor import ( + convert_image_masks_to_pixel_masks, + convert_pixel_masks_to_image_masks, + get_default_roi_locations_from_image_masks, +) + + +@pytest.fixture(scope="module") +def rng(): + seed = 1728084845 # int(datetime.datetime.now().timestamp()) at the time of writing + return np.random.default_rng(seed=seed) + + +@pytest.fixture(scope="function") +def image_masks(rng): + return rng.random((3, 3, 3)) + + +def test_convert_image_masks_to_pixel_masks(image_masks): + pixel_masks = convert_image_masks_to_pixel_masks(image_masks=image_masks) + for i, pixel_mask in enumerate(pixel_masks): + assert pixel_mask.shape == (image_masks.shape[0] * image_masks.shape[1], 3) + for row, column, wt in pixel_mask: + assert row == int(row) + assert column == int(column) + assert image_masks[int(row), int(column), i] == wt + + +def test_convert_image_masks_to_pixel_masks_with_zeros(image_masks): + image_masks[0, 0, 0] = 0 + pixel_masks = convert_image_masks_to_pixel_masks(image_masks=image_masks) + assert pixel_masks[0].shape == (image_masks.shape[0] * image_masks.shape[1] - 1, 3) + for i, pixel_mask in enumerate(pixel_masks): + for row, column, wt in pixel_mask: + assert row == int(row) + assert column == int(column) + assert image_masks[int(row), int(column), i] == wt + + +def test_convert_image_masks_to_pixel_masks_all_zeros(image_masks): + image_masks = np.zeros(image_masks.shape) + pixel_masks = convert_image_masks_to_pixel_masks(image_masks=image_masks) + for pixel_mask in pixel_masks: + assert pixel_mask.shape == (0, 3) + + +def test_convert_pixel_masks_to_image_masks(image_masks): + pixel_masks = [] + for i in range(image_masks.shape[2]): + image_mask = image_masks[:, :, i] + locs = np.where(image_mask > 0) + pix_values = image_mask[image_mask > 0] + pixel_masks.append(np.vstack((locs[0], locs[1], pix_values)).T) + + image_masks = convert_pixel_masks_to_image_masks(pixel_masks=pixel_masks, image_shape=image_masks.shape[:2]) + for i in range(image_masks.shape[2]): + image_mask = image_masks[:, :, i] + indices = np.ndindex(image_mask.shape) + for row, column in indices: + pixel_mask_mask = np.logical_and(pixel_masks[i][:, 0] == row, pixel_masks[i][:, 1] == column) + assert image_mask[row, column] == pixel_masks[i][pixel_mask_mask, 2] + + +def test_convert_pixel_masks_to_image_masks_with_zeros(image_masks): + pixel_masks = [] + for i in range(image_masks.shape[2]): + image_mask = image_masks[:, :, i] + locs = np.where(image_mask > 0) + pix_values = image_mask[image_mask > 0] + pixel_masks.append(np.vstack((locs[0], locs[1], pix_values)).T) + + pixel_masks[0] = pixel_masks[0][1:] + image_masks = convert_pixel_masks_to_image_masks(pixel_masks=pixel_masks, image_shape=image_masks.shape[:2]) + for i in range(image_masks.shape[2]): + image_mask = image_masks[:, :, i] + indices = np.ndindex(image_mask.shape) + for row, column in indices: + pixel_mask_mask = np.logical_and(pixel_masks[i][:, 0] == row, pixel_masks[i][:, 1] == column) + if i == 0 and row == 0 and column == 0: + assert np.all(np.logical_not(pixel_mask_mask)) + else: + assert image_mask[row, column] == pixel_masks[i][pixel_mask_mask, 2] + + +def test_convert_pixel_masks_to_image_masks_all_zeros(image_masks): + pixel_masks = [np.zeros((0, 0)) for _ in range(image_masks.shape[2])] + output_image_masks = convert_pixel_masks_to_image_masks(pixel_masks=pixel_masks, image_shape=image_masks.shape[:2]) + assert output_image_masks.shape == image_masks.shape + for image_mask in output_image_masks: + assert np.all(image_mask == 0) + + +def test_convert_masks_roundtrip(image_masks): + pixel_masks = convert_image_masks_to_pixel_masks(image_masks=image_masks) + output_image_masks = convert_pixel_masks_to_image_masks(pixel_masks=pixel_masks, image_shape=image_masks.shape[:2]) + np.testing.assert_array_equal(image_masks, output_image_masks) + + +def test_get_default_roi_locations_from_image_masks(): + image_masks = np.zeros((3, 3, 3)) + image_masks[0, 0, 0] = 1 + image_masks[1, 1, 1] = 1 + image_masks[2, 2, 2] = 1 + roi_locations = get_default_roi_locations_from_image_masks(image_masks=image_masks) + expected_roi_locations = np.array([[0, 0], [1, 1], [2, 2]]).T + np.testing.assert_array_equal(roi_locations, expected_roi_locations) + + +def test_get_default_roi_locations_from_image_masks_tie1(): + image_masks = np.zeros((3, 3, 3)) + image_masks[0, 0, 0] = 1 + image_masks[0, 1, 0] = 1 + image_masks[1, 1, 1] = 1 + image_masks[2, 2, 2] = 1 + roi_locations = get_default_roi_locations_from_image_masks(image_masks=image_masks) + expected_roi_locations = np.array([[0, 0], [1, 1], [2, 2]]).T + np.testing.assert_array_equal(roi_locations, expected_roi_locations) + + +def test_get_default_roi_locations_from_image_masks_tie2(): + image_masks = np.zeros((3, 3, 3)) + image_masks[0, 0, 0] = 1 + image_masks[0, 1, 0] = 1 + image_masks[1, 1, 0] = 1 + image_masks[1, 1, 1] = 1 + image_masks[2, 2, 2] = 1 + roi_locations = get_default_roi_locations_from_image_masks(image_masks=image_masks) + expected_roi_locations = np.array([[0, 1], [1, 1], [2, 2]]).T + np.testing.assert_array_equal(roi_locations, expected_roi_locations) diff --git a/tests/test_minimal/test_tools/__init__.py b/tests/test_minimal/test_tools/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_minimal/test_tools/test_importing.py b/tests/test_minimal/test_tools/test_importing.py new file mode 100644 index 00000000..b0042a4e --- /dev/null +++ b/tests/test_minimal/test_tools/test_importing.py @@ -0,0 +1,46 @@ +import pytest +from roiextractors.tools.importing import get_package +import sys +from platform import python_version + + +def test_get_package_in_sys_modules(): + package_name = "roiextractors" + package = get_package(package_name=package_name) + assert package.__name__ == package_name + + +def test_get_package_not_in_sys_modules(): + package_name = "roiextractors" + del sys.modules[package_name] + package = get_package(package_name=package_name) + assert package.__name__ == package_name + + +def test_get_package_excluded_versions(): + package_name = "invalid-package" + excluded_version = python_version() + excluded_platform = sys.platform + excluded_platforms_and_python_versions = {excluded_platform: [excluded_version]} + expected_error = ( + f"\nThe package '{package_name}' is not available on the {excluded_platform} platform for " + f"Python version {excluded_version}!" + ) + with pytest.raises(ModuleNotFoundError, match=expected_error): + package = get_package( + package_name=package_name, excluded_platforms_and_python_versions=excluded_platforms_and_python_versions + ) + + +@pytest.mark.parametrize( + "installation_instructions", [None, "conda install -c conda-forge my-package-name", "pip install my-package-name"] +) +def test_get_package_not_installed(installation_instructions): + package_name = "invalid-package" + expected_installation_instructions = installation_instructions or f"pip install {package_name}" + expected_error = ( + f"\nThe required package'{package_name}' is not installed!\n" + f"To install this package, please run\n\n\t{expected_installation_instructions}\n" + ) + with pytest.raises(ModuleNotFoundError, match=expected_error): + package = get_package(package_name=package_name, installation_instructions=installation_instructions) diff --git a/tests/test_minimal/test_tools/test_plotting.py b/tests/test_minimal/test_tools/test_plotting.py new file mode 100644 index 00000000..d4d58e48 --- /dev/null +++ b/tests/test_minimal/test_tools/test_plotting.py @@ -0,0 +1,7 @@ +from roiextractors.tools.plotting import show_video +from roiextractors.tools.testing import generate_mock_imaging_extractor + + +def test_show_video(): + imaging_extractor = generate_mock_imaging_extractor() + anim = show_video(imaging=imaging_extractor) diff --git a/tests/test_minimal/test_tools/test_testing.py b/tests/test_minimal/test_tools/test_testing.py new file mode 100644 index 00000000..42e4c45c --- /dev/null +++ b/tests/test_minimal/test_tools/test_testing.py @@ -0,0 +1,407 @@ +from roiextractors.tools.testing import ( + generate_mock_video, + generate_mock_imaging_extractor, + generate_mock_segmentation_extractor, + assert_imaging_equal, + imaging_equal, + assert_segmentation_equal, + segmentation_equal, +) +import pytest +import numpy as np +from numpy.testing import assert_array_equal + + +@pytest.mark.parametrize("size", [(1, 2, 3), (3, 2, 4), (5, 3, 2)]) +def test_generate_mock_video_size(size): + video = generate_mock_video(size=size) + assert video.shape == size + + +@pytest.mark.parametrize("dtype", [np.uint8, np.float32, "uint8", "float32"]) +def test_generate_mock_video_dtype(dtype): + video = generate_mock_video(size=(3, 2, 4), dtype=dtype) + assert video.dtype == np.dtype(dtype) + + +def test_generate_mock_video_seed(): + size = (1, 2, 3) + video1 = generate_mock_video(size=size, seed=0) + video2 = generate_mock_video(size=size, seed=0) + video3 = generate_mock_video(size=size, seed=1) + assert_array_equal(video1, video2) + assert not np.array_equal(video1, video3) + + +@pytest.mark.parametrize("num_frames, num_rows, num_columns", [(1, 2, 3), (3, 2, 4), (5, 3, 2)]) +def test_generate_mock_imaging_extractor_shape(num_frames, num_rows, num_columns): + imaging_extractor = generate_mock_imaging_extractor( + num_frames=num_frames, num_rows=num_rows, num_columns=num_columns + ) + video = imaging_extractor.get_video() + assert video.shape == (num_frames, num_rows, num_columns) + + +@pytest.mark.parametrize("sampling_frequency", [10.0, 20.0, 30.0]) +def test_generate_mock_imaging_extractor_sampling_frequency(sampling_frequency): + imaging_extractor = generate_mock_imaging_extractor(sampling_frequency=sampling_frequency) + assert imaging_extractor.get_sampling_frequency() == sampling_frequency + + +@pytest.mark.parametrize("dtype", [np.uint8, np.float32, "uint8", "float32"]) +def test_generate_mock_imaging_extractor_dtype(dtype): + imaging_extractor = generate_mock_imaging_extractor(dtype=dtype) + assert imaging_extractor.get_dtype() == np.dtype(dtype) + + +def test_generate_mock_imaging_extractor_seed(): + imaging_extractor1 = generate_mock_imaging_extractor(seed=0) + imaging_extractor2 = generate_mock_imaging_extractor(seed=0) + imaging_extractor3 = generate_mock_imaging_extractor(seed=1) + assert_imaging_equal(imaging_extractor1, imaging_extractor2) + assert not imaging_equal(imaging_extractor1, imaging_extractor3) + + +# def assert_imaging_equal(imaging_extractor1: ImagingExtractor, imaging_extractor2: ImagingExtractor): +# """Assert that two ImagingExtractor objects are equal by comparing their attributes and data. + +# Parameters +# ---------- +# imaging_extractor1 : ImagingExtractor +# The first ImagingExtractor object to compare. +# imaging_extractor2 : ImagingExtractor +# The second ImagingExtractor object to compare. + +# Raises +# ------ +# AssertionError +# If any of the following attributes or data do not match between the two ImagingExtractor objects: +# - Image size +# - Number of frames +# - Sampling frequency +# - Data type (dtype) +# - Video data +# - Time points (_times) +# """ +# assert ( +# imaging_extractor1.get_image_size() == imaging_extractor2.get_image_size() +# ), "ImagingExtractors are not equal: image_sizes do not match." +# assert ( +# imaging_extractor1.get_num_frames() == imaging_extractor2.get_num_frames() +# ), "ImagingExtractors are not equal: num_frames do not match." +# assert np.isclose( +# imaging_extractor1.get_sampling_frequency(), imaging_extractor2.get_sampling_frequency() +# ), "ImagingExtractors are not equal: sampling_frequencies do not match." +# assert ( +# imaging_extractor1.get_dtype() == imaging_extractor2.get_dtype() +# ), "ImagingExtractors are not equal: dtypes do not match." +# assert_array_equal( +# imaging_extractor1.get_video(), +# imaging_extractor2.get_video(), +# err_msg="ImagingExtractors are not equal: videos do not match.", +# ) +# assert_array_equal( +# imaging_extractor1._times, +# imaging_extractor2._times, +# err_msg="ImagingExtractors are not equal: _times do not match.", +# ) + + +def test_assert_imaging_equal_image_size(): + imaging_extractor1 = generate_mock_imaging_extractor(num_rows=1) + imaging_extractor2 = generate_mock_imaging_extractor(num_rows=1) + imaging_extractor3 = generate_mock_imaging_extractor(num_rows=2) + assert_imaging_equal(imaging_extractor1, imaging_extractor2) + with pytest.raises(AssertionError): + assert_imaging_equal(imaging_extractor1, imaging_extractor3) + + +def test_assert_imaging_equal_num_frames(): + imaging_extractor1 = generate_mock_imaging_extractor(num_frames=1) + imaging_extractor2 = generate_mock_imaging_extractor(num_frames=1) + imaging_extractor3 = generate_mock_imaging_extractor(num_frames=2) + assert_imaging_equal(imaging_extractor1, imaging_extractor2) + with pytest.raises(AssertionError): + assert_imaging_equal(imaging_extractor1, imaging_extractor3) + + +def test_assert_imaging_equal_sampling_frequency(): + imaging_extractor1 = generate_mock_imaging_extractor(sampling_frequency=30.0) + imaging_extractor2 = generate_mock_imaging_extractor(sampling_frequency=30.0) + imaging_extractor3 = generate_mock_imaging_extractor(sampling_frequency=20.0) + assert_imaging_equal(imaging_extractor1, imaging_extractor2) + with pytest.raises(AssertionError): + assert_imaging_equal(imaging_extractor1, imaging_extractor3) + + +def test_assert_imaging_equal_dtype(): + imaging_extractor1 = generate_mock_imaging_extractor(dtype="uint16") + imaging_extractor2 = generate_mock_imaging_extractor(dtype="uint16") + imaging_extractor3 = generate_mock_imaging_extractor(dtype="float32") + assert_imaging_equal(imaging_extractor1, imaging_extractor2) + with pytest.raises(AssertionError): + assert_imaging_equal(imaging_extractor1, imaging_extractor3) + + +def test_assert_imaging_equal_video(): + imaging_extractor1 = generate_mock_imaging_extractor(seed=0) + imaging_extractor2 = generate_mock_imaging_extractor(seed=0) + imaging_extractor3 = generate_mock_imaging_extractor(seed=1) + assert_imaging_equal(imaging_extractor1, imaging_extractor2) + with pytest.raises(AssertionError): + assert_imaging_equal(imaging_extractor1, imaging_extractor3) + + +def test_assert_imaging_equal_times(): + imaging_extractor1 = generate_mock_imaging_extractor() + imaging_extractor2 = generate_mock_imaging_extractor() + imaging_extractor3 = generate_mock_imaging_extractor() + imaging_extractor1._times = np.array([0, 1, 2]) + imaging_extractor2._times = np.array([0, 1, 2]) + imaging_extractor3._times = np.array([0, 1, 3]) + assert_imaging_equal(imaging_extractor1, imaging_extractor2) + with pytest.raises(AssertionError): + assert_imaging_equal(imaging_extractor1, imaging_extractor3) + + +def test_imaging_equal(): + imaging_extractor1 = generate_mock_imaging_extractor(seed=0) + imaging_extractor2 = generate_mock_imaging_extractor(seed=0) + imaging_extractor3 = generate_mock_imaging_extractor(seed=1) + assert imaging_equal(imaging_extractor1, imaging_extractor2) + assert not imaging_equal(imaging_extractor1, imaging_extractor3) + + +@pytest.mark.parametrize( + "num_rois, num_frames, num_rows, num_columns, num_background_components", + [(1, 2, 3, 4, 5), (3, 2, 4, 5, 6), (5, 3, 2, 1, 0)], +) +def test_generate_mock_segmentation_extractor_shape( + num_rois, num_frames, num_rows, num_columns, num_background_components +): + segmentation_extractor = generate_mock_segmentation_extractor( + num_rois=num_rois, + num_frames=num_frames, + num_rows=num_rows, + num_columns=num_columns, + num_background_components=num_background_components, + ) + assert segmentation_extractor.get_num_rois() == num_rois + assert segmentation_extractor.get_num_frames() == num_frames + assert segmentation_extractor.get_image_size() == (num_rows, num_columns) + assert segmentation_extractor.get_num_background_components() == num_background_components + + +@pytest.mark.parametrize("sampling_frequency", [10.0, 20.0, 30.0]) +def test_generate_mock_segmentation_extractor_sampling_frequency(sampling_frequency): + segmentation_extractor = generate_mock_segmentation_extractor(sampling_frequency=sampling_frequency) + assert segmentation_extractor.get_sampling_frequency() == sampling_frequency + + +@pytest.mark.parametrize( + "summary_image_names, roi_response_names, background_response_names", + [ + ([], ["denoised"], []), + (["mean"], ["raw", "dff"], ["background"]), + (["correlation"], ["deconvolved"], ["background"]), + ], +) +def test_generate_mock_segmentation_extractor_names(summary_image_names, roi_response_names, background_response_names): + segmentation_extractor = generate_mock_segmentation_extractor( + summary_image_names=summary_image_names, + roi_response_names=roi_response_names, + background_response_names=background_response_names, + ) + assert list(segmentation_extractor.get_summary_images().keys()) == summary_image_names + assert list(segmentation_extractor.get_roi_response_traces().keys()) == roi_response_names + assert list(segmentation_extractor.get_background_response_traces().keys()) == background_response_names + + +@pytest.mark.parametrize("rejected_roi_ids", [[], [0, 1], [1, 2, 3]]) +def test_generate_mock_segmentation_extractor_rejected_list(rejected_roi_ids): + segmentation_extractor = generate_mock_segmentation_extractor(rejected_roi_ids=rejected_roi_ids) + assert segmentation_extractor.get_rejected_roi_ids() == rejected_roi_ids + + +def test_generate_mock_segmentation_extractor_seed(): + segmentation_extractor1 = generate_mock_segmentation_extractor(seed=0) + segmentation_extractor2 = generate_mock_segmentation_extractor(seed=0) + segmentation_extractor3 = generate_mock_segmentation_extractor(seed=1) + assert_segmentation_equal(segmentation_extractor1, segmentation_extractor2) + assert not segmentation_equal(segmentation_extractor1, segmentation_extractor3) + + +def test_assert_segmentation_equal_image_size(): + segmentation_extractor1 = generate_mock_segmentation_extractor(num_rows=1) + segmentation_extractor2 = generate_mock_segmentation_extractor(num_rows=1) + segmentation_extractor3 = generate_mock_segmentation_extractor(num_rows=2) + assert_segmentation_equal(segmentation_extractor1, segmentation_extractor2) + with pytest.raises(AssertionError): + assert_segmentation_equal(segmentation_extractor1, segmentation_extractor3) + + +def test_assert_segmentation_equal_num_frames(): + segmentation_extractor1 = generate_mock_segmentation_extractor(num_frames=1) + segmentation_extractor2 = generate_mock_segmentation_extractor(num_frames=1) + segmentation_extractor3 = generate_mock_segmentation_extractor(num_frames=2) + assert_segmentation_equal(segmentation_extractor1, segmentation_extractor2) + with pytest.raises(AssertionError): + assert_segmentation_equal(segmentation_extractor1, segmentation_extractor3) + + +def test_assert_segmentation_equal_sampling_frequency(): + segmentation_extractor1 = generate_mock_segmentation_extractor(sampling_frequency=30.0) + segmentation_extractor2 = generate_mock_segmentation_extractor(sampling_frequency=30.0) + segmentation_extractor3 = generate_mock_segmentation_extractor(sampling_frequency=20.0) + assert_segmentation_equal(segmentation_extractor1, segmentation_extractor2) + with pytest.raises(AssertionError): + assert_segmentation_equal(segmentation_extractor1, segmentation_extractor3) + + +def test_assert_segmentation_equal_times(): + segmentation_extractor1 = generate_mock_segmentation_extractor() + segmentation_extractor2 = generate_mock_segmentation_extractor() + segmentation_extractor3 = generate_mock_segmentation_extractor() + segmentation_extractor1._times = np.array([0, 1, 2]) + segmentation_extractor2._times = np.array([0, 1, 2]) + segmentation_extractor3._times = np.array([0, 1, 3]) + assert_segmentation_equal(segmentation_extractor1, segmentation_extractor2) + with pytest.raises(AssertionError): + assert_segmentation_equal(segmentation_extractor1, segmentation_extractor3) + + +def test_assert_segmentation_equal_roi_ids(): + segmentation_extractor1 = generate_mock_segmentation_extractor(num_rois=3) + segmentation_extractor2 = generate_mock_segmentation_extractor(num_rois=3) + segmentation_extractor3 = generate_mock_segmentation_extractor(num_rois=4) + assert_segmentation_equal(segmentation_extractor1, segmentation_extractor2) + with pytest.raises(AssertionError): + assert_segmentation_equal(segmentation_extractor1, segmentation_extractor3) + + +def test_assert_segmentation_equal_accepted_rejected_roi_ids(): + segmentation_extractor1 = generate_mock_segmentation_extractor(rejected_roi_ids=[1]) + segmentation_extractor2 = generate_mock_segmentation_extractor(rejected_roi_ids=[1]) + segmentation_extractor3 = generate_mock_segmentation_extractor(rejected_roi_ids=[2]) + assert_segmentation_equal(segmentation_extractor1, segmentation_extractor2) + with pytest.raises(AssertionError): + assert_segmentation_equal(segmentation_extractor1, segmentation_extractor3) + + +def test_assert_segmentation_equal_roi_locations(): + roi_locations1 = np.array([[0, 1], [1, 2]]) + roi_locations2 = np.array([[0, 1], [1, 2]]) + roi_locations3 = np.array([[0, 1], [1, 3]]) + segmentation_extractor1 = generate_mock_segmentation_extractor(num_rois=2, roi_locations=roi_locations1) + segmentation_extractor2 = generate_mock_segmentation_extractor(num_rois=2, roi_locations=roi_locations2) + segmentation_extractor3 = generate_mock_segmentation_extractor(num_rois=2, roi_locations=roi_locations3) + assert_segmentation_equal(segmentation_extractor1, segmentation_extractor2) + with pytest.raises(AssertionError): + assert_segmentation_equal(segmentation_extractor1, segmentation_extractor3) + + +def test_assert_segmentation_equal_roi_image_pixel_masks(): + image_masks1 = np.array([[[0, 1], [1, 2]], [[0, 1], [1, 2]]]) + image_masks2 = np.array([[[0, 1], [1, 2]], [[0, 1], [1, 2]]]) + image_masks3 = np.array([[[0, 1], [1, 2]], [[0, 1], [1, 3]]]) + segmentation_extractor1 = generate_mock_segmentation_extractor( + image_masks=image_masks1, num_rois=2, num_rows=2, num_columns=2 + ) + segmentation_extractor2 = generate_mock_segmentation_extractor( + image_masks=image_masks2, num_rois=2, num_rows=2, num_columns=2 + ) + segmentation_extractor3 = generate_mock_segmentation_extractor( + image_masks=image_masks3, num_rois=2, num_rows=2, num_columns=2 + ) + assert_segmentation_equal(segmentation_extractor1, segmentation_extractor2) + with pytest.raises(AssertionError): + assert_segmentation_equal(segmentation_extractor1, segmentation_extractor3) + + +def test_assert_segmentation_equal_roi_response_traces(): + trace1 = np.array([[0, 1], [1, 2]]) + trace2 = np.array([[0, 1], [1, 2]]) + trace3 = np.array([[0, 1], [1, 3]]) + response_traces1 = {"raw": trace1} + response_traces2 = {"raw": trace2} + response_traces3 = {"raw": trace3} + response_traces4 = {"raw": trace1, "dff": trace1} + segmentation_extractor1 = generate_mock_segmentation_extractor( + roi_response_traces=response_traces1, num_rois=2, num_frames=2 + ) + segmentation_extractor2 = generate_mock_segmentation_extractor( + roi_response_traces=response_traces2, num_rois=2, num_frames=2 + ) + segmentation_extractor3 = generate_mock_segmentation_extractor( + roi_response_traces=response_traces3, num_rois=2, num_frames=2 + ) + segmentation_extractor4 = generate_mock_segmentation_extractor( + roi_response_traces=response_traces4, num_rois=2, num_frames=2 + ) + assert_segmentation_equal(segmentation_extractor1, segmentation_extractor2) + with pytest.raises(AssertionError): + assert_segmentation_equal(segmentation_extractor1, segmentation_extractor3) + with pytest.raises(AssertionError): + assert_segmentation_equal(segmentation_extractor1, segmentation_extractor4) + + +def test_assert_segmentation_equal_background_ids(): + segmentation_extractor1 = generate_mock_segmentation_extractor(num_background_components=2) + segmentation_extractor2 = generate_mock_segmentation_extractor(num_background_components=2) + segmentation_extractor3 = generate_mock_segmentation_extractor(num_background_components=3) + assert_segmentation_equal(segmentation_extractor1, segmentation_extractor2) + with pytest.raises(AssertionError): + assert_segmentation_equal(segmentation_extractor1, segmentation_extractor3) + + +def test_assert_segmentation_equal_background_image_masks(): + image_masks1 = np.array([[[0, 1], [1, 2]], [[0, 1], [1, 2]]]) + image_masks2 = np.array([[[0, 1], [1, 2]], [[0, 1], [1, 2]]]) + image_masks3 = np.array([[[0, 1], [1, 2]], [[0, 1], [1, 3]]]) + segmentation_extractor1 = generate_mock_segmentation_extractor( + background_image_masks=image_masks1, num_background_components=2, num_rows=2, num_columns=2 + ) + segmentation_extractor2 = generate_mock_segmentation_extractor( + background_image_masks=image_masks2, num_background_components=2, num_rows=2, num_columns=2 + ) + segmentation_extractor3 = generate_mock_segmentation_extractor( + background_image_masks=image_masks3, num_background_components=2, num_rows=2, num_columns=2 + ) + assert_segmentation_equal(segmentation_extractor1, segmentation_extractor2) + with pytest.raises(AssertionError): + assert_segmentation_equal(segmentation_extractor1, segmentation_extractor3) + + +def test_assert_segmentation_equal_background_response_traces(): + trace1 = np.array([[0, 1], [1, 2]]) + trace2 = np.array([[0, 1], [1, 2]]) + trace3 = np.array([[0, 1], [1, 3]]) + response_traces1 = {"background": trace1} + response_traces2 = {"background": trace2} + response_traces3 = {"background": trace3} + response_traces4 = {"background": trace1, "dff": trace1} + segmentation_extractor1 = generate_mock_segmentation_extractor( + background_response_traces=response_traces1, num_background_components=2, num_frames=2 + ) + segmentation_extractor2 = generate_mock_segmentation_extractor( + background_response_traces=response_traces2, num_background_components=2, num_frames=2 + ) + segmentation_extractor3 = generate_mock_segmentation_extractor( + background_response_traces=response_traces3, num_background_components=2, num_frames=2 + ) + segmentation_extractor4 = generate_mock_segmentation_extractor( + background_response_traces=response_traces4, num_background_components=2, num_frames=2 + ) + assert_segmentation_equal(segmentation_extractor1, segmentation_extractor2) + with pytest.raises(AssertionError): + assert_segmentation_equal(segmentation_extractor1, segmentation_extractor3) + with pytest.raises(AssertionError): + assert_segmentation_equal(segmentation_extractor1, segmentation_extractor4) + + +def test_segmentation_equal(): + segmentation_extractor1 = generate_mock_segmentation_extractor(seed=0) + segmentation_extractor2 = generate_mock_segmentation_extractor(seed=0) + segmentation_extractor3 = generate_mock_segmentation_extractor(seed=1) + assert segmentation_equal(segmentation_extractor1, segmentation_extractor2) + assert not segmentation_equal(segmentation_extractor1, segmentation_extractor3) diff --git a/tests/test_minimal/test_tools/test_typing.py b/tests/test_minimal/test_tools/test_typing.py new file mode 100644 index 00000000..0fdf2fb2 --- /dev/null +++ b/tests/test_minimal/test_tools/test_typing.py @@ -0,0 +1,36 @@ +from roiextractors.tools.typing import ( + ArrayType, + PathType, + DtypeType, + IntType, + FloatType, + NoneType, +) +from numpy.typing import ArrayLike, DTypeLike +import numpy as np +from typing import Union +from pathlib import Path + + +def test_ArrayType(): + assert ArrayType == ArrayLike + + +def test_PathType(): + assert PathType == Union[str, Path] + + +def test_DtypeType(): + assert DtypeType == DTypeLike + + +def test_IntType(): + assert IntType == Union[int, np.integer] + + +def test_FloatType(): + assert FloatType == Union[float, np.floating] + + +def test_NoneType(): + assert NoneType == type(None)