diff --git a/src/roiextractors/extractorlist.py b/src/roiextractors/extractorlist.py index 091265c2..7241de80 100644 --- a/src/roiextractors/extractorlist.py +++ b/src/roiextractors/extractorlist.py @@ -25,6 +25,7 @@ from .extractors.miniscopeimagingextractor import MiniscopeImagingExtractor from .multisegmentationextractor import MultiSegmentationExtractor from .multiimagingextractor import MultiImagingExtractor +from .volumetricimagingextractor import VolumetricImagingExtractor imaging_extractor_full_list = [ NumpyImagingExtractor, @@ -39,6 +40,7 @@ SbxImagingExtractor, NumpyMemmapImagingExtractor, MemmapImagingExtractor, + VolumetricImagingExtractor, ] segmentation_extractor_full_list = [ diff --git a/src/roiextractors/testing.py b/src/roiextractors/testing.py index d1720898..47d694d9 100644 --- a/src/roiextractors/testing.py +++ b/src/roiextractors/testing.py @@ -53,6 +53,7 @@ def generate_dummy_imaging_extractor( num_channels: int = 1, sampling_frequency: float = 30, dtype: DtypeType = "uint16", + channel_names: Optional[list] = None, ): """Generate a dummy imaging extractor for testing. @@ -78,7 +79,8 @@ def generate_dummy_imaging_extractor( ImagingExtractor An imaging extractor with random data fed into `NumpyImagingExtractor`. """ - channel_names = [f"channel_num_{num}" for num in range(num_channels)] + if channel_names is None: + channel_names = [f"channel_num_{num}" for num in range(num_channels)] size = (num_frames, num_rows, num_columns, num_channels) video = generate_dummy_video(size=size, dtype=dtype) diff --git a/src/roiextractors/volumetricimagingextractor.py b/src/roiextractors/volumetricimagingextractor.py new file mode 100644 index 00000000..2abf0c1a --- /dev/null +++ b/src/roiextractors/volumetricimagingextractor.py @@ -0,0 +1,169 @@ +"""Base class definition for volumetric imaging extractors.""" + +from typing import Tuple, List, Iterable, Optional +import numpy as np + +from .extraction_tools import ArrayType, DtypeType +from .imagingextractor import ImagingExtractor + + +class VolumetricImagingExtractor(ImagingExtractor): + """Class to combine multiple ImagingExtractor objects by depth plane.""" + + extractor_name = "VolumetricImaging" + installed = True + installatiuon_mesage = "" + + def __init__(self, imaging_extractors: List[ImagingExtractor]): + """Initialize a VolumetricImagingExtractor object from a list of ImagingExtractors. + + Parameters + ---------- + imaging_extractors: list of ImagingExtractor + list of imaging extractor objects + """ + super().__init__() + assert isinstance(imaging_extractors, list), "Enter a list of ImagingExtractor objects as argument" + assert all(isinstance(imaging_extractor, ImagingExtractor) for imaging_extractor in imaging_extractors) + self._check_consistency_between_imaging_extractors(imaging_extractors) + self._imaging_extractors = imaging_extractors + self._num_planes = len(imaging_extractors) + + def _check_consistency_between_imaging_extractors(self, imaging_extractors: List[ImagingExtractor]): + """Check that essential properties are consistent between extractors so that they can be combined appropriately. + + Parameters + ---------- + imaging_extractors: list of ImagingExtractor + list of imaging extractor objects + + Raises + ------ + AssertionError + If any of the properties are not consistent between extractors. + + Notes + ----- + This method checks the following properties: + - sampling frequency + - image size + - number of channels + - channel names + - data type + - num_frames + """ + properties_to_check = dict( + get_sampling_frequency="The sampling frequency", + get_image_size="The size of a frame", + get_num_channels="The number of channels", + get_channel_names="The name of the channels", + get_dtype="The data type", + get_num_frames="The number of frames", + ) + for method, property_message in properties_to_check.items(): + values = [getattr(extractor, method)() for extractor in imaging_extractors] + unique_values = set(tuple(v) if isinstance(v, Iterable) else v for v in values) + assert ( + len(unique_values) == 1 + ), f"{property_message} is not consistent over the files (found {unique_values})." + + def get_video(self, start_frame: Optional[int] = None, end_frame: Optional[int] = None) -> np.ndarray: + """Get the video frames. + + Parameters + ---------- + start_frame: int, optional + Start frame index (inclusive). + end_frame: int, optional + End frame index (exclusive). + + Returns + ------- + video: numpy.ndarray + The 3D video frames (num_frames, num_rows, num_columns, num_planes). + """ + if start_frame is None: + start_frame = 0 + elif start_frame < 0: + start_frame = self.get_num_frames() + start_frame + elif start_frame >= self.get_num_frames(): + raise ValueError( + f"start_frame {start_frame} is greater than or equal to the number of frames {self.get_num_frames()}" + ) + if end_frame is None: + end_frame = self.get_num_frames() + elif end_frame < 0: + end_frame = self.get_num_frames() + end_frame + elif end_frame > self.get_num_frames(): + raise ValueError(f"end_frame {end_frame} is greater than the number of frames {self.get_num_frames()}") + if end_frame <= start_frame: + raise ValueError(f"end_frame {end_frame} is less than or equal to start_frame {start_frame}") + + video = np.zeros((end_frame - start_frame, *self.get_image_size()), self.get_dtype()) + for i, imaging_extractor in enumerate(self._imaging_extractors): + video[..., i] = imaging_extractor.get_video(start_frame, end_frame) + return video + + def get_frames(self, frame_idxs: ArrayType) -> np.ndarray: + """Get specific video frames from indices (not necessarily continuous). + + Parameters + ---------- + frame_idxs: array-like + Indices of frames to return. + + Returns + ------- + frames: numpy.ndarray + The 3D video frames (num_rows, num_columns, num_planes). + """ + if isinstance(frame_idxs, int): + frame_idxs = [frame_idxs] + for frame_idx in frame_idxs: + if frame_idx < -1 * self.get_num_frames() or frame_idx >= self.get_num_frames(): + raise ValueError(f"frame_idx {frame_idx} is out of bounds") + + # Note np.all([]) returns True so not all(np.diff(frame_idxs) == 1) returns False if frame_idxs is a single int + if not all(np.diff(frame_idxs) == 1): + frames = np.zeros((len(frame_idxs), *self.get_image_size()), self.get_dtype()) + for i, imaging_extractor in enumerate(self._imaging_extractors): + frames[..., i] = imaging_extractor.get_frames(frame_idxs) + return frames + else: + return self.get_video(start_frame=frame_idxs[0], end_frame=frame_idxs[-1] + 1) + + def get_image_size(self) -> Tuple: + """Get the size of a single frame. + + Returns + ------- + image_size: tuple + The size of a single frame (num_rows, num_columns, num_planes). + """ + image_size = (*self._imaging_extractors[0].get_image_size(), self.get_num_planes()) + return image_size + + def get_num_planes(self) -> int: + """Get the number of depth planes. + + Returns + ------- + _num_planes: int + The number of depth planes. + """ + return self._num_planes + + def get_num_frames(self) -> int: + return self._imaging_extractors[0].get_num_frames() + + def get_sampling_frequency(self) -> float: + return self._imaging_extractors[0].get_sampling_frequency() + + def get_channel_names(self) -> list: + return self._imaging_extractors[0].get_channel_names() + + def get_num_channels(self) -> int: + return self._imaging_extractors[0].get_num_channels() + + def get_dtype(self) -> DtypeType: + return self._imaging_extractors[0].get_dtype() diff --git a/tests/test_volumetricimagingextractor.py b/tests/test_volumetricimagingextractor.py new file mode 100644 index 00000000..62424dae --- /dev/null +++ b/tests/test_volumetricimagingextractor.py @@ -0,0 +1,121 @@ +import pytest +import numpy as np +from roiextractors.testing import generate_dummy_imaging_extractor +from roiextractors import VolumetricImagingExtractor + +num_frames = 10 + + +@pytest.fixture(scope="module", params=[1, 2]) +def imaging_extractors(request): + num_channels = request.param + return [generate_dummy_imaging_extractor(num_channels=num_channels, num_frames=num_frames) for _ in range(3)] + + +@pytest.fixture(scope="module") +def volumetric_imaging_extractor(imaging_extractors): + return VolumetricImagingExtractor(imaging_extractors) + + +@pytest.mark.parametrize( + "params", + [ + [dict(sampling_frequency=1), dict(sampling_frequency=2)], + [dict(num_rows=1), dict(num_rows=2)], + [dict(num_channels=1), dict(num_channels=2)], + [dict(channel_names=["a"], num_channels=1), dict(channel_names=["b"], num_channels=1)], + [dict(dtype=np.int16), dict(dtype=np.float32)], + [dict(num_frames=1), dict(num_frames=2)], + ], +) +def test_check_consistency_between_imaging_extractors(params): + imaging_extractors = [generate_dummy_imaging_extractor(**param) for param in params] + with pytest.raises(AssertionError): + VolumetricImagingExtractor(imaging_extractors=imaging_extractors) + + +@pytest.mark.parametrize("start_frame, end_frame", [(None, None), (0, num_frames), (3, 7), (-2, -1)]) +def test_get_video(volumetric_imaging_extractor, start_frame, end_frame): + video = volumetric_imaging_extractor.get_video(start_frame=start_frame, end_frame=end_frame) + expected_video = [] + for extractor in volumetric_imaging_extractor._imaging_extractors: + expected_video.append(extractor.get_video(start_frame=start_frame, end_frame=end_frame)) + expected_video = np.array(expected_video) + expected_video = np.moveaxis(expected_video, 0, -1) + assert np.all(video == expected_video) + + +@pytest.mark.parametrize("start_frame, end_frame", [(num_frames + 1, None), (None, num_frames + 1), (2, 1)]) +def test_get_video_invalid(volumetric_imaging_extractor, start_frame, end_frame): + with pytest.raises(ValueError): + volumetric_imaging_extractor.get_video(start_frame=start_frame, end_frame=end_frame) + + +@pytest.mark.parametrize("frame_idxs", [0, [0, 1, 2], [0, num_frames - 1], [-3, -1]]) +def test_get_frames(volumetric_imaging_extractor, frame_idxs): + frames = volumetric_imaging_extractor.get_frames(frame_idxs=frame_idxs) + expected_frames = [] + for extractor in volumetric_imaging_extractor._imaging_extractors: + expected_frames.append(extractor.get_frames(frame_idxs=frame_idxs)) + expected_frames = np.array(expected_frames) + expected_frames = np.moveaxis(expected_frames, 0, -1) + assert np.all(frames == expected_frames) + + +@pytest.mark.parametrize("frame_idxs", [num_frames, [0, num_frames], [-num_frames - 1, -1]]) +def test_get_frames_invalid(volumetric_imaging_extractor, frame_idxs): + with pytest.raises(ValueError): + volumetric_imaging_extractor.get_frames(frame_idxs=frame_idxs) + + +@pytest.mark.parametrize("num_rows, num_columns, num_planes", [(1, 2, 3), (2, 1, 3), (3, 2, 1)]) +def test_get_image_size(num_rows, num_columns, num_planes): + imaging_extractors = [ + generate_dummy_imaging_extractor(num_rows=num_rows, num_columns=num_columns) for _ in range(num_planes) + ] + volumetric_imaging_extractor = VolumetricImagingExtractor(imaging_extractors=imaging_extractors) + assert volumetric_imaging_extractor.get_image_size() == (num_rows, num_columns, num_planes) + + +@pytest.mark.parametrize("num_planes", [1, 2, 3]) +def test_get_num_planes(num_planes): + imaging_extractors = [generate_dummy_imaging_extractor() for _ in range(num_planes)] + volumetric_imaging_extractor = VolumetricImagingExtractor(imaging_extractors=imaging_extractors) + assert volumetric_imaging_extractor.get_num_planes() == num_planes + + +@pytest.mark.parametrize("num_frames", [1, 2, 3]) +def test_get_num_frames(num_frames): + imaging_extractors = [generate_dummy_imaging_extractor(num_frames=num_frames)] + volumetric_imaging_extractor = VolumetricImagingExtractor(imaging_extractors=imaging_extractors) + assert volumetric_imaging_extractor.get_num_frames() == num_frames + + +@pytest.mark.parametrize("sampling_frequency", [1, 2, 3]) +def test_get_sampling_frequency(sampling_frequency): + imaging_extractors = [generate_dummy_imaging_extractor(sampling_frequency=sampling_frequency)] + volumetric_imaging_extractor = VolumetricImagingExtractor(imaging_extractors=imaging_extractors) + assert volumetric_imaging_extractor.get_sampling_frequency() == sampling_frequency + + +@pytest.mark.parametrize("channel_names", [["Channel 1"], [" Channel 1 ", "Channel 2"]]) +def test_get_channel_names(channel_names): + imaging_extractors = [ + generate_dummy_imaging_extractor(channel_names=channel_names, num_channels=len(channel_names)) + ] + volumetric_imaging_extractor = VolumetricImagingExtractor(imaging_extractors=imaging_extractors) + assert volumetric_imaging_extractor.get_channel_names() == channel_names + + +@pytest.mark.parametrize("num_channels", [1, 2, 3]) +def test_get_num_channels(num_channels): + imaging_extractors = [generate_dummy_imaging_extractor(num_channels=num_channels)] + volumetric_imaging_extractor = VolumetricImagingExtractor(imaging_extractors=imaging_extractors) + assert volumetric_imaging_extractor.get_num_channels() == num_channels + + +@pytest.mark.parametrize("dtype", [np.float64, np.int16, np.uint8]) +def test_get_dtype(dtype): + imaging_extractors = [generate_dummy_imaging_extractor(dtype=dtype)] + volumetric_imaging_extractor = VolumetricImagingExtractor(imaging_extractors=imaging_extractors) + assert volumetric_imaging_extractor.get_dtype() == dtype