From 6e0e83da820f5ce1ffca5012614632d41f9acd47 Mon Sep 17 00:00:00 2001 From: pauladkisson Date: Wed, 20 Sep 2023 13:46:58 -0700 Subject: [PATCH 1/8] added VolumetricImagingExtractor to its own file --- .../volumetricimagingextractor.py | 148 ++++++++++++++++++ 1 file changed, 148 insertions(+) create mode 100644 src/roiextractors/volumetricimagingextractor.py diff --git a/src/roiextractors/volumetricimagingextractor.py b/src/roiextractors/volumetricimagingextractor.py new file mode 100644 index 00000000..8361c43b --- /dev/null +++ b/src/roiextractors/volumetricimagingextractor.py @@ -0,0 +1,148 @@ +from typing import Tuple, List, Iterable, Optional +import numpy as np + +from .extraction_tools import ArrayType, DtypeType +from .imagingextractor import ImagingExtractor + + +class VolumetricImagingExtractor(ImagingExtractor): + """Class to combine multiple ImagingExtractor objects by depth plane.""" + + extractor_name = "VolumetricImaging" + installed = True + installatiuon_mesage = "" + + def __init__(self, imaging_extractors: List[ImagingExtractor]): + """Initialize a VolumetricImagingExtractor object from a list of ImagingExtractors. + + Parameters + ---------- + imaging_extractors: list of ImagingExtractor + list of imaging extractor objects + """ + super().__init__() + assert isinstance(imaging_extractors, list), "Enter a list of ImagingExtractor objects as argument" + assert all(isinstance(imaging_extractor, ImagingExtractor) for imaging_extractor in imaging_extractors) + self._check_consistency_between_imaging_extractors(imaging_extractors) + self._imaging_extractors = imaging_extractors + self._num_planes = len(imaging_extractors) + + def _check_consistency_between_imaging_extractors(self, imaging_extractors: List[ImagingExtractor]): + """Check that essential properties are consistent between extractors so that they can be combined appropriately. + + Parameters + ---------- + imaging_extractors: list of ImagingExtractor + list of imaging extractor objects + + Raises + ------ + AssertionError + If any of the properties are not consistent between extractors. + + Notes + ----- + This method checks the following properties: + - sampling frequency + - image size + - number of channels + - channel names + - data type + - num_frames + """ + properties_to_check = dict( + get_sampling_frequency="The sampling frequency", + get_image_size="The size of a frame", + get_num_channels="The number of channels", + get_channel_names="The name of the channels", + get_dtype="The data type", + get_num_frames="The number of frames", + ) + for method, property_message in properties_to_check.items(): + values = [getattr(extractor, method)() for extractor in imaging_extractors] + unique_values = set(tuple(v) if isinstance(v, Iterable) else v for v in values) + assert ( + len(unique_values) == 1 + ), f"{property_message} is not consistent over the files (found {unique_values})." + + def get_video(self, start_frame: Optional[int] = None, end_frame: Optional[int] = None) -> np.ndarray: + """Get the video frames. + + Parameters + ---------- + start_frame: int, optional + Start frame index (inclusive). + end_frame: int, optional + End frame index (exclusive). + + Returns + ------- + video: numpy.ndarray + The 3D video frames (num_rows, num_columns, num_planes). + """ + start_frame = start_frame if start_frame is not None else 0 + end_frame = end_frame if end_frame is not None else self.get_num_frames() + + video = np.zeros((end_frame - start_frame, *self.get_image_size()), self.get_dtype()) + for i, imaging_extractor in enumerate(self._imaging_extractors): + video[..., i] = imaging_extractor.get_video(start_frame, end_frame) + return video + + def get_frames(self, frame_idxs: ArrayType) -> np.ndarray: + """Get specific video frames from indices (not necessarily continuous). + + Parameters + ---------- + frame_idxs: array-like + Indices of frames to return. + + Returns + ------- + frames: numpy.ndarray + The 3D video frames (num_rows, num_columns, num_planes). + """ + if isinstance(frame_idxs, int): + frame_idxs = [frame_idxs] + + if not all(np.diff(frame_idxs) == 1): + frames = np.zeros((len(frame_idxs), *self.get_image_size()), self.get_dtype()) + for i, imaging_extractor in enumerate(self._imaging_extractors): + frames[..., i] = imaging_extractor.get_frames(frame_idxs) + else: + return self.get_video(start_frame=frame_idxs[0], end_frame=frame_idxs[-1] + 1) + + def get_image_size(self) -> Tuple: + """Get the size of a single frame. + + Returns + ------- + image_size: tuple + The size of a single frame (num_rows, num_columns, num_planes). + """ + image_size = (*self._imaging_extractors[0].get_image_size(), self.get_num_planes()) + return image_size + + def get_num_planes(self) -> int: + """Get the number of depth planes. + + Returns + ------- + _num_planes: int + The number of depth planes. + """ + return self._num_planes + + def get_num_frames(self) -> int: + return self._imaging_extractors[0].get_num_frames() + + def get_sampling_frequency(self) -> float: + return self._imaging_extractors[0].get_sampling_frequency() + + def get_channel_names(self) -> list: + return self._imaging_extractors[0].get_channel_names() + + def get_num_channels(self) -> int: + return self._imaging_extractors[0].get_num_channels() + + def get_dtype(self) -> DtypeType: + return self._imaging_extractors[0].get_dtype() From 7532f1127dc235878f15e282c5218082a15b9d6f Mon Sep 17 00:00:00 2001 From: pauladkisson Date: Thu, 21 Sep 2023 19:15:58 -0700 Subject: [PATCH 2/8] added first set of tests to volumeImagingExtractor and caught some bugs --- src/roiextractors/extractorlist.py | 2 + src/roiextractors/testing.py | 4 +- .../volumetricimagingextractor.py | 27 +- tests/test_scanimagetiffimagingextractor.py | 308 ++++++++++++++++++ tests/test_volumetricimagingextractor.py | 69 ++++ 5 files changed, 407 insertions(+), 3 deletions(-) create mode 100644 tests/test_scanimagetiffimagingextractor.py create mode 100644 tests/test_volumetricimagingextractor.py diff --git a/src/roiextractors/extractorlist.py b/src/roiextractors/extractorlist.py index 091265c2..7241de80 100644 --- a/src/roiextractors/extractorlist.py +++ b/src/roiextractors/extractorlist.py @@ -25,6 +25,7 @@ from .extractors.miniscopeimagingextractor import MiniscopeImagingExtractor from .multisegmentationextractor import MultiSegmentationExtractor from .multiimagingextractor import MultiImagingExtractor +from .volumetricimagingextractor import VolumetricImagingExtractor imaging_extractor_full_list = [ NumpyImagingExtractor, @@ -39,6 +40,7 @@ SbxImagingExtractor, NumpyMemmapImagingExtractor, MemmapImagingExtractor, + VolumetricImagingExtractor, ] segmentation_extractor_full_list = [ diff --git a/src/roiextractors/testing.py b/src/roiextractors/testing.py index d1720898..47d694d9 100644 --- a/src/roiextractors/testing.py +++ b/src/roiextractors/testing.py @@ -53,6 +53,7 @@ def generate_dummy_imaging_extractor( num_channels: int = 1, sampling_frequency: float = 30, dtype: DtypeType = "uint16", + channel_names: Optional[list] = None, ): """Generate a dummy imaging extractor for testing. @@ -78,7 +79,8 @@ def generate_dummy_imaging_extractor( ImagingExtractor An imaging extractor with random data fed into `NumpyImagingExtractor`. """ - channel_names = [f"channel_num_{num}" for num in range(num_channels)] + if channel_names is None: + channel_names = [f"channel_num_{num}" for num in range(num_channels)] size = (num_frames, num_rows, num_columns, num_channels) video = generate_dummy_video(size=size, dtype=dtype) diff --git a/src/roiextractors/volumetricimagingextractor.py b/src/roiextractors/volumetricimagingextractor.py index 8361c43b..bef728e8 100644 --- a/src/roiextractors/volumetricimagingextractor.py +++ b/src/roiextractors/volumetricimagingextractor.py @@ -80,8 +80,22 @@ def get_video(self, start_frame: Optional[int] = None, end_frame: Optional[int] video: numpy.ndarray The 3D video frames (num_rows, num_columns, num_planes). """ - start_frame = start_frame if start_frame is not None else 0 - end_frame = end_frame if end_frame is not None else self.get_num_frames() + if start_frame is None: + start_frame = 0 + elif start_frame < 0: + start_frame = self.get_num_frames() + start_frame + elif start_frame >= self.get_num_frames(): + raise ValueError( + f"start_frame {start_frame} is greater than or equal to the number of frames {self.get_num_frames()}" + ) + if end_frame is None: + end_frame = self.get_num_frames() + elif end_frame < 0: + end_frame = self.get_num_frames() + end_frame + elif end_frame > self.get_num_frames(): + raise ValueError(f"end_frame {end_frame} is greater than the number of frames {self.get_num_frames()}") + if end_frame <= start_frame: + raise ValueError(f"end_frame {end_frame} is less than or equal to start_frame {start_frame}") video = np.zeros((end_frame - start_frame, *self.get_image_size()), self.get_dtype()) for i, imaging_extractor in enumerate(self._imaging_extractors): @@ -101,13 +115,22 @@ def get_frames(self, frame_idxs: ArrayType) -> np.ndarray: frames: numpy.ndarray The 3D video frames (num_rows, num_columns, num_planes). """ + squeeze_data = False if isinstance(frame_idxs, int): frame_idxs = [frame_idxs] + squeeze_data = True + for frame_idx in frame_idxs: + if frame_idx < -1 * self.get_num_frames() or frame_idx >= self.get_num_frames(): + raise ValueError(f"frame_idx {frame_idx} is out of bounds") if not all(np.diff(frame_idxs) == 1): frames = np.zeros((len(frame_idxs), *self.get_image_size()), self.get_dtype()) for i, imaging_extractor in enumerate(self._imaging_extractors): frames[..., i] = imaging_extractor.get_frames(frame_idxs) + if squeeze_data: + return frames.squeeze() + else: + return frames else: return self.get_video(start_frame=frame_idxs[0], end_frame=frame_idxs[-1] + 1) diff --git a/tests/test_scanimagetiffimagingextractor.py b/tests/test_scanimagetiffimagingextractor.py new file mode 100644 index 00000000..efb8604d --- /dev/null +++ b/tests/test_scanimagetiffimagingextractor.py @@ -0,0 +1,308 @@ +import pytest +from pathlib import Path +from tempfile import mkdtemp +from shutil import rmtree, copy +from numpy.testing import assert_array_equal + +from ScanImageTiffReader import ScanImageTiffReader +from roiextractors import ScanImageTiffSinglePlaneImagingExtractor, ScanImageTiffMultiPlaneImagingExtractor +from roiextractors.extractors.tiffimagingextractors.scanimagetiffimagingextractor import ( + extract_extra_metadata, + parse_metadata, + parse_metadata_v3_8, +) + +from .setup_paths import OPHYS_DATA_PATH + +scan_image_path = OPHYS_DATA_PATH / "imaging_datasets" / "ScanImage" +test_files = [ + "scanimage_20220801_volume.tif", + "scanimage_20220801_multivolume.tif", + "scanimage_20230119_adesnik_00001.tif", +] +file_paths = [scan_image_path / test_file for test_file in test_files] + + +def metadata_string_to_dict(metadata_string): + metadata_dict = { + x.split("=")[0].strip(): x.split("=")[1].strip() + for x in metadata_string.replace("\n", "\r").split("\r") + if "=" in x + } + return metadata_dict + + +@pytest.fixture(scope="module", params=file_paths) +def scan_image_tiff_single_plane_imaging_extractor(request): + return ScanImageTiffSinglePlaneImagingExtractor(file_path=request.param, channel_name="Channel 1", plane_name="0") + + +@pytest.fixture( + scope="module", + params=[ + dict(channel_name="Channel 1", plane_name="0"), + dict(channel_name="Channel 1", plane_name="1"), + dict(channel_name="Channel 1", plane_name="2"), + dict(channel_name="Channel 2", plane_name="0"), + dict(channel_name="Channel 2", plane_name="1"), + dict(channel_name="Channel 2", plane_name="2"), + ], +) # Only the adesnik file has many (>2) frames per plane and multiple (2) channels. +def scan_image_tiff_single_plane_imaging_extractor_adesnik(request): + file_path = scan_image_path / "scanimage_20230119_adesnik_00001.tif" + return ScanImageTiffSinglePlaneImagingExtractor(file_path=file_path, **request.param) + + +@pytest.fixture(scope="module") +def num_planes_adesnik(): + return 3 + + +@pytest.fixture(scope="module") +def num_channels_adesnik(): + return 2 + + +@pytest.mark.parametrize("frame_idxs", (0, [0])) +def test_get_frames(scan_image_tiff_single_plane_imaging_extractor, frame_idxs): + frames = scan_image_tiff_single_plane_imaging_extractor.get_frames(frame_idxs=frame_idxs) + file_path = str(scan_image_tiff_single_plane_imaging_extractor.file_path) + with ScanImageTiffReader(file_path) as io: + if isinstance(frame_idxs, int): + frame_idxs = [frame_idxs] + assert_array_equal(frames, io.data()[frame_idxs]) + + +@pytest.mark.parametrize("frame_idxs", ([0, 1, 2], [1, 3, 31])) # 31 is the last frame in the adesnik file +def test_get_frames_adesnik( + scan_image_tiff_single_plane_imaging_extractor_adesnik, num_planes_adesnik, num_channels_adesnik, frame_idxs +): + frames = scan_image_tiff_single_plane_imaging_extractor_adesnik.get_frames(frame_idxs=frame_idxs) + file_path = str(scan_image_tiff_single_plane_imaging_extractor_adesnik.file_path) + plane = scan_image_tiff_single_plane_imaging_extractor_adesnik.plane + channel = scan_image_tiff_single_plane_imaging_extractor_adesnik.channel + raw_idxs = [ + idx * num_planes_adesnik * num_channels_adesnik + plane * num_channels_adesnik + channel for idx in frame_idxs + ] + with ScanImageTiffReader(file_path) as io: + assert_array_equal(frames, io.data()[raw_idxs]) + + +@pytest.mark.parametrize("frame_idxs", ([-1], [50])) +def test_get_frames_adesnik_invalid(scan_image_tiff_single_plane_imaging_extractor_adesnik, frame_idxs): + with pytest.raises(ValueError): + scan_image_tiff_single_plane_imaging_extractor_adesnik.get_frames(frame_idxs=frame_idxs) + + +def test_get_single_frame(scan_image_tiff_single_plane_imaging_extractor): + frame = scan_image_tiff_single_plane_imaging_extractor._get_single_frame(frame=0) + file_path = str(scan_image_tiff_single_plane_imaging_extractor.file_path) + with ScanImageTiffReader(file_path) as io: + assert_array_equal(frame, io.data()[:1]) + + +@pytest.mark.parametrize("frame_idx", (5, 10, 31)) +def test_get_single_frame_adesnik( + scan_image_tiff_single_plane_imaging_extractor_adesnik, num_planes_adesnik, num_channels_adesnik, frame_idx +): + frame = scan_image_tiff_single_plane_imaging_extractor_adesnik._get_single_frame(frame=frame_idx) + file_path = str(scan_image_tiff_single_plane_imaging_extractor_adesnik.file_path) + plane = scan_image_tiff_single_plane_imaging_extractor_adesnik.plane + channel = scan_image_tiff_single_plane_imaging_extractor_adesnik.channel + raw_idx = frame_idx * num_planes_adesnik * num_channels_adesnik + plane * num_channels_adesnik + channel + print(raw_idx) + with ScanImageTiffReader(file_path) as io: + assert_array_equal(frame, io.data()[raw_idx : raw_idx + 1]) + + +@pytest.mark.parametrize("frame", (-1, 50)) +def test_get_single_frame_adesnik_invalid(scan_image_tiff_single_plane_imaging_extractor_adesnik, frame): + with pytest.raises(ValueError): + scan_image_tiff_single_plane_imaging_extractor_adesnik._get_single_frame(frame=frame) + + +def test_get_video(scan_image_tiff_single_plane_imaging_extractor): + video = scan_image_tiff_single_plane_imaging_extractor.get_video() + file_path = str(scan_image_tiff_single_plane_imaging_extractor.file_path) + num_channels = scan_image_tiff_single_plane_imaging_extractor.get_num_channels() + num_planes = scan_image_tiff_single_plane_imaging_extractor.get_num_planes() + with ScanImageTiffReader(file_path) as io: + assert_array_equal(video, io.data()[:: num_planes * num_channels]) + + +@pytest.mark.parametrize("start_frame, end_frame", [(0, 2), (5, 10), (20, 32)]) +def test_get_video_adesnik( + scan_image_tiff_single_plane_imaging_extractor_adesnik, + num_planes_adesnik, + num_channels_adesnik, + start_frame, + end_frame, +): + video = scan_image_tiff_single_plane_imaging_extractor_adesnik.get_video( + start_frame=start_frame, end_frame=end_frame + ) + file_path = str(scan_image_tiff_single_plane_imaging_extractor_adesnik.file_path) + plane = scan_image_tiff_single_plane_imaging_extractor_adesnik.plane + channel = scan_image_tiff_single_plane_imaging_extractor_adesnik.channel + raw_idxs = [ + idx * num_planes_adesnik * num_channels_adesnik + plane * num_channels_adesnik + channel + for idx in range(start_frame, end_frame) + ] + with ScanImageTiffReader(file_path) as io: + assert_array_equal(video, io.data()[raw_idxs]) + + +@pytest.mark.parametrize("start_frame, end_frame", [(-1, 2), (0, 50)]) +def test_get_video_adesnik_invalid( + scan_image_tiff_single_plane_imaging_extractor_adesnik, + start_frame, + end_frame, +): + with pytest.raises(ValueError): + scan_image_tiff_single_plane_imaging_extractor_adesnik.get_video(start_frame=start_frame, end_frame=end_frame) + + +def test_get_image_size(scan_image_tiff_single_plane_imaging_extractor): + image_size = scan_image_tiff_single_plane_imaging_extractor.get_image_size() + file_path = str(scan_image_tiff_single_plane_imaging_extractor.file_path) + with ScanImageTiffReader(file_path) as io: + assert image_size == tuple(io.shape()[1:]) + + +def test_get_num_frames(scan_image_tiff_single_plane_imaging_extractor): + num_frames = scan_image_tiff_single_plane_imaging_extractor.get_num_frames() + num_channels = scan_image_tiff_single_plane_imaging_extractor.get_num_channels() + num_planes = scan_image_tiff_single_plane_imaging_extractor.get_num_planes() + file_path = str(scan_image_tiff_single_plane_imaging_extractor.file_path) + with ScanImageTiffReader(file_path) as io: + assert num_frames == io.shape()[0] // (num_channels * num_planes) + + +def test_get_sampling_frequency(scan_image_tiff_single_plane_imaging_extractor): + sampling_frequency = scan_image_tiff_single_plane_imaging_extractor.get_sampling_frequency() + file_path = str(scan_image_tiff_single_plane_imaging_extractor.file_path) + with ScanImageTiffReader(file_path) as io: + metadata_string = io.metadata() + metadata_dict = metadata_string_to_dict(metadata_string) + assert sampling_frequency == float(metadata_dict["SI.hRoiManager.scanVolumeRate"]) + + +def test_get_num_channels(scan_image_tiff_single_plane_imaging_extractor): + num_channels = scan_image_tiff_single_plane_imaging_extractor.get_num_channels() + file_path = str(scan_image_tiff_single_plane_imaging_extractor.file_path) + with ScanImageTiffReader(file_path) as io: + metadata_string = io.metadata() + metadata_dict = metadata_string_to_dict(metadata_string) + assert num_channels == len(metadata_dict["SI.hChannels.channelsActive"].split(";")) + + +def test_get_channel_names(scan_image_tiff_single_plane_imaging_extractor): + channel_names = scan_image_tiff_single_plane_imaging_extractor.get_channel_names() + num_channels = scan_image_tiff_single_plane_imaging_extractor.get_num_channels() + file_path = str(scan_image_tiff_single_plane_imaging_extractor.file_path) + with ScanImageTiffReader(file_path) as io: + metadata_string = io.metadata() + metadata_dict = metadata_string_to_dict(metadata_string) + assert channel_names == metadata_dict["SI.hChannels.channelName"].split("'")[1::2][:num_channels] + + +def test_get_num_planes(scan_image_tiff_single_plane_imaging_extractor): + num_planes = scan_image_tiff_single_plane_imaging_extractor.get_num_planes() + file_path = str(scan_image_tiff_single_plane_imaging_extractor.file_path) + with ScanImageTiffReader(file_path) as io: + metadata_string = io.metadata() + metadata_dict = metadata_string_to_dict(metadata_string) + assert num_planes == int(metadata_dict["SI.hStackManager.numSlices"]) + + +def test_get_dtype(scan_image_tiff_single_plane_imaging_extractor): + dtype = scan_image_tiff_single_plane_imaging_extractor.get_dtype() + file_path = str(scan_image_tiff_single_plane_imaging_extractor.file_path) + with ScanImageTiffReader(file_path) as io: + assert dtype == io.data().dtype + + +def test_check_frame_inputs_valid(scan_image_tiff_single_plane_imaging_extractor): + scan_image_tiff_single_plane_imaging_extractor.check_frame_inputs(frame=0) + + +def test_check_frame_inputs_invalid(scan_image_tiff_single_plane_imaging_extractor): + num_frames = scan_image_tiff_single_plane_imaging_extractor.get_num_frames() + with pytest.raises(ValueError): + scan_image_tiff_single_plane_imaging_extractor.check_frame_inputs(frame=num_frames + 1) + + +@pytest.mark.parametrize("frame", (0, 10, 31)) +def test_frame_to_raw_index_adesnik( + scan_image_tiff_single_plane_imaging_extractor_adesnik, num_channels_adesnik, num_planes_adesnik, frame +): + raw_index = scan_image_tiff_single_plane_imaging_extractor_adesnik.frame_to_raw_index(frame=frame) + plane = scan_image_tiff_single_plane_imaging_extractor_adesnik.plane + channel = scan_image_tiff_single_plane_imaging_extractor_adesnik.channel + assert raw_index == (frame * num_planes_adesnik * num_channels_adesnik) + (plane * num_channels_adesnik) + channel + + +@pytest.mark.parametrize("file_path", file_paths) +def test_extract_extra_metadata(file_path): + metadata = extract_extra_metadata(file_path) + io = ScanImageTiffReader(str(file_path)) + extra_metadata = {} + for metadata_string in (io.description(iframe=0), io.metadata()): + metadata_dict = { + x.split("=")[0].strip(): x.split("=")[1].strip() + for x in metadata_string.replace("\n", "\r").split("\r") + if "=" in x + } + extra_metadata = dict(**extra_metadata, **metadata_dict) + assert metadata == extra_metadata + + +@pytest.mark.parametrize("file_path", file_paths) +def test_parse_metadata(file_path): + metadata = extract_extra_metadata(file_path) + parsed_metadata = parse_metadata(metadata) + sampling_frequency = float(metadata["SI.hRoiManager.scanVolumeRate"]) + num_channels = len(metadata["SI.hChannels.channelsActive"].split(";")) + num_planes = int(metadata["SI.hStackManager.numSlices"]) + frames_per_slice = int(metadata["SI.hStackManager.framesPerSlice"]) + channel_names = metadata["SI.hChannels.channelName"].split("'")[1::2][:num_channels] + assert parsed_metadata == dict( + sampling_frequency=sampling_frequency, + num_channels=num_channels, + num_planes=num_planes, + frames_per_slice=frames_per_slice, + channel_names=channel_names, + ) + + +def test_parse_metadata_v3_8(): + file_path = scan_image_path / "sample_scanimage_version_3_8.tiff" + metadata = extract_extra_metadata(file_path) + parsed_metadata = parse_metadata_v3_8(metadata) + sampling_frequency = float(metadata["state.acq.frameRate"]) + num_channels = int(metadata["state.acq.numberOfChannelsSave"]) + num_planes = int(metadata["state.acq.numberOfZSlices"]) + assert parsed_metadata == dict( + sampling_frequency=sampling_frequency, + num_channels=num_channels, + num_planes=num_planes, + ) + + +@pytest.mark.parametrize("file_path", file_paths) +def test_ScanImageTiffMultiPlaneImagingExtractor__init__(file_path): + extractor = ScanImageTiffMultiPlaneImagingExtractor(file_path=file_path) + assert extractor.file_path == file_path + + +@pytest.mark.parametrize("channel_name, plane_name", [("Invalid Channel", "0"), ("Channel 1", "Invalid Plane")]) +def test_ScanImageTiffSinglePlaneImagingExtractor__init__invalid(channel_name, plane_name): + with pytest.raises(ValueError): + ScanImageTiffSinglePlaneImagingExtractor( + file_path=file_paths[0], channel_name=channel_name, plane_name=plane_name + ) + + +def test_ScanImageTiffMultiPlaneImagingExtractor__init__invalid(): + with pytest.raises(ValueError): + ScanImageTiffMultiPlaneImagingExtractor(file_path=file_paths[0], channel_name="Invalid Channel") diff --git a/tests/test_volumetricimagingextractor.py b/tests/test_volumetricimagingextractor.py new file mode 100644 index 00000000..47d8f444 --- /dev/null +++ b/tests/test_volumetricimagingextractor.py @@ -0,0 +1,69 @@ +import pytest +import numpy as np +from roiextractors.testing import generate_dummy_imaging_extractor +from roiextractors import VolumetricImagingExtractor + +num_frames = 10 + + +@pytest.fixture(scope="module", params=[1, 2]) +def imaging_extractors(request): + num_channels = request.param + return [generate_dummy_imaging_extractor(num_channels=num_channels, num_frames=num_frames) for _ in range(3)] + + +@pytest.fixture(scope="module") +def volumetric_imaging_extractor(imaging_extractors): + return VolumetricImagingExtractor(imaging_extractors) + + +@pytest.mark.parametrize( + "params", + [ + [dict(sampling_frequency=1), dict(sampling_frequency=2)], + [dict(num_rows=1), dict(num_rows=2)], + [dict(num_channels=1), dict(num_channels=2)], + [dict(channel_names=["a"], num_channels=1), dict(channel_names=["b"], num_channels=1)], + [dict(dtype=np.int16), dict(dtype=np.float32)], + [dict(num_frames=1), dict(num_frames=2)], + ], +) +def test_check_consistency_between_imaging_extractors(params): + imaging_extractors = [generate_dummy_imaging_extractor(**param) for param in params] + with pytest.raises(AssertionError): + VolumetricImagingExtractor(imaging_extractors=imaging_extractors) + + +@pytest.mark.parametrize("start_frame, end_frame", [(None, None), (0, num_frames), (3, 7), (-2, -1)]) +def test_get_video(volumetric_imaging_extractor, start_frame, end_frame): + video = volumetric_imaging_extractor.get_video(start_frame=start_frame, end_frame=end_frame) + expected_video = [] + for extractor in volumetric_imaging_extractor._imaging_extractors: + expected_video.append(extractor.get_video(start_frame=start_frame, end_frame=end_frame)) + expected_video = np.array(expected_video) + expected_video = np.moveaxis(expected_video, 0, -1) + assert np.all(video == expected_video) + + +@pytest.mark.parametrize("start_frame, end_frame", [(num_frames + 1, None), (None, num_frames + 1), (2, 1)]) +def test_get_video_invalid(volumetric_imaging_extractor, start_frame, end_frame): + with pytest.raises(ValueError): + volumetric_imaging_extractor.get_video(start_frame=start_frame, end_frame=end_frame) + + +@pytest.mark.parametrize("frame_idxs", [0, [0, 1, 2], [0, num_frames - 1], [-3, -1]]) +def test_get_frames(volumetric_imaging_extractor, frame_idxs): + frames = volumetric_imaging_extractor.get_frames(frame_idxs=frame_idxs) + expected_frames = [] + for extractor in volumetric_imaging_extractor._imaging_extractors: + gotten_frames = extractor.get_frames(frame_idxs=frame_idxs) + expected_frames.append(extractor.get_frames(frame_idxs=frame_idxs)) + expected_frames = np.array(expected_frames) + expected_frames = np.moveaxis(expected_frames, 0, -1) + assert np.all(frames == expected_frames) + + +@pytest.mark.parametrize("frame_idxs", [num_frames, [0, num_frames], [-num_frames - 1, -1]]) +def test_get_frames_invalid(volumetric_imaging_extractor, frame_idxs): + with pytest.raises(ValueError): + volumetric_imaging_extractor.get_frames(frame_idxs=frame_idxs) From af4ec855ce8a411a9150ad3a6fb059625229015f Mon Sep 17 00:00:00 2001 From: pauladkisson Date: Fri, 22 Sep 2023 10:28:00 -0700 Subject: [PATCH 3/8] removed a debugging variable --- tests/test_volumetricimagingextractor.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_volumetricimagingextractor.py b/tests/test_volumetricimagingextractor.py index 47d8f444..e72005bb 100644 --- a/tests/test_volumetricimagingextractor.py +++ b/tests/test_volumetricimagingextractor.py @@ -56,7 +56,6 @@ def test_get_frames(volumetric_imaging_extractor, frame_idxs): frames = volumetric_imaging_extractor.get_frames(frame_idxs=frame_idxs) expected_frames = [] for extractor in volumetric_imaging_extractor._imaging_extractors: - gotten_frames = extractor.get_frames(frame_idxs=frame_idxs) expected_frames.append(extractor.get_frames(frame_idxs=frame_idxs)) expected_frames = np.array(expected_frames) expected_frames = np.moveaxis(expected_frames, 0, -1) From 7707ad7e488aabeac5673f268fa3ea9a8598b8f3 Mon Sep 17 00:00:00 2001 From: pauladkisson Date: Thu, 28 Sep 2023 15:39:40 -0700 Subject: [PATCH 4/8] removed scanimage testing file --- tests/test_scanimagetiffimagingextractor.py | 308 -------------------- 1 file changed, 308 deletions(-) delete mode 100644 tests/test_scanimagetiffimagingextractor.py diff --git a/tests/test_scanimagetiffimagingextractor.py b/tests/test_scanimagetiffimagingextractor.py deleted file mode 100644 index efb8604d..00000000 --- a/tests/test_scanimagetiffimagingextractor.py +++ /dev/null @@ -1,308 +0,0 @@ -import pytest -from pathlib import Path -from tempfile import mkdtemp -from shutil import rmtree, copy -from numpy.testing import assert_array_equal - -from ScanImageTiffReader import ScanImageTiffReader -from roiextractors import ScanImageTiffSinglePlaneImagingExtractor, ScanImageTiffMultiPlaneImagingExtractor -from roiextractors.extractors.tiffimagingextractors.scanimagetiffimagingextractor import ( - extract_extra_metadata, - parse_metadata, - parse_metadata_v3_8, -) - -from .setup_paths import OPHYS_DATA_PATH - -scan_image_path = OPHYS_DATA_PATH / "imaging_datasets" / "ScanImage" -test_files = [ - "scanimage_20220801_volume.tif", - "scanimage_20220801_multivolume.tif", - "scanimage_20230119_adesnik_00001.tif", -] -file_paths = [scan_image_path / test_file for test_file in test_files] - - -def metadata_string_to_dict(metadata_string): - metadata_dict = { - x.split("=")[0].strip(): x.split("=")[1].strip() - for x in metadata_string.replace("\n", "\r").split("\r") - if "=" in x - } - return metadata_dict - - -@pytest.fixture(scope="module", params=file_paths) -def scan_image_tiff_single_plane_imaging_extractor(request): - return ScanImageTiffSinglePlaneImagingExtractor(file_path=request.param, channel_name="Channel 1", plane_name="0") - - -@pytest.fixture( - scope="module", - params=[ - dict(channel_name="Channel 1", plane_name="0"), - dict(channel_name="Channel 1", plane_name="1"), - dict(channel_name="Channel 1", plane_name="2"), - dict(channel_name="Channel 2", plane_name="0"), - dict(channel_name="Channel 2", plane_name="1"), - dict(channel_name="Channel 2", plane_name="2"), - ], -) # Only the adesnik file has many (>2) frames per plane and multiple (2) channels. -def scan_image_tiff_single_plane_imaging_extractor_adesnik(request): - file_path = scan_image_path / "scanimage_20230119_adesnik_00001.tif" - return ScanImageTiffSinglePlaneImagingExtractor(file_path=file_path, **request.param) - - -@pytest.fixture(scope="module") -def num_planes_adesnik(): - return 3 - - -@pytest.fixture(scope="module") -def num_channels_adesnik(): - return 2 - - -@pytest.mark.parametrize("frame_idxs", (0, [0])) -def test_get_frames(scan_image_tiff_single_plane_imaging_extractor, frame_idxs): - frames = scan_image_tiff_single_plane_imaging_extractor.get_frames(frame_idxs=frame_idxs) - file_path = str(scan_image_tiff_single_plane_imaging_extractor.file_path) - with ScanImageTiffReader(file_path) as io: - if isinstance(frame_idxs, int): - frame_idxs = [frame_idxs] - assert_array_equal(frames, io.data()[frame_idxs]) - - -@pytest.mark.parametrize("frame_idxs", ([0, 1, 2], [1, 3, 31])) # 31 is the last frame in the adesnik file -def test_get_frames_adesnik( - scan_image_tiff_single_plane_imaging_extractor_adesnik, num_planes_adesnik, num_channels_adesnik, frame_idxs -): - frames = scan_image_tiff_single_plane_imaging_extractor_adesnik.get_frames(frame_idxs=frame_idxs) - file_path = str(scan_image_tiff_single_plane_imaging_extractor_adesnik.file_path) - plane = scan_image_tiff_single_plane_imaging_extractor_adesnik.plane - channel = scan_image_tiff_single_plane_imaging_extractor_adesnik.channel - raw_idxs = [ - idx * num_planes_adesnik * num_channels_adesnik + plane * num_channels_adesnik + channel for idx in frame_idxs - ] - with ScanImageTiffReader(file_path) as io: - assert_array_equal(frames, io.data()[raw_idxs]) - - -@pytest.mark.parametrize("frame_idxs", ([-1], [50])) -def test_get_frames_adesnik_invalid(scan_image_tiff_single_plane_imaging_extractor_adesnik, frame_idxs): - with pytest.raises(ValueError): - scan_image_tiff_single_plane_imaging_extractor_adesnik.get_frames(frame_idxs=frame_idxs) - - -def test_get_single_frame(scan_image_tiff_single_plane_imaging_extractor): - frame = scan_image_tiff_single_plane_imaging_extractor._get_single_frame(frame=0) - file_path = str(scan_image_tiff_single_plane_imaging_extractor.file_path) - with ScanImageTiffReader(file_path) as io: - assert_array_equal(frame, io.data()[:1]) - - -@pytest.mark.parametrize("frame_idx", (5, 10, 31)) -def test_get_single_frame_adesnik( - scan_image_tiff_single_plane_imaging_extractor_adesnik, num_planes_adesnik, num_channels_adesnik, frame_idx -): - frame = scan_image_tiff_single_plane_imaging_extractor_adesnik._get_single_frame(frame=frame_idx) - file_path = str(scan_image_tiff_single_plane_imaging_extractor_adesnik.file_path) - plane = scan_image_tiff_single_plane_imaging_extractor_adesnik.plane - channel = scan_image_tiff_single_plane_imaging_extractor_adesnik.channel - raw_idx = frame_idx * num_planes_adesnik * num_channels_adesnik + plane * num_channels_adesnik + channel - print(raw_idx) - with ScanImageTiffReader(file_path) as io: - assert_array_equal(frame, io.data()[raw_idx : raw_idx + 1]) - - -@pytest.mark.parametrize("frame", (-1, 50)) -def test_get_single_frame_adesnik_invalid(scan_image_tiff_single_plane_imaging_extractor_adesnik, frame): - with pytest.raises(ValueError): - scan_image_tiff_single_plane_imaging_extractor_adesnik._get_single_frame(frame=frame) - - -def test_get_video(scan_image_tiff_single_plane_imaging_extractor): - video = scan_image_tiff_single_plane_imaging_extractor.get_video() - file_path = str(scan_image_tiff_single_plane_imaging_extractor.file_path) - num_channels = scan_image_tiff_single_plane_imaging_extractor.get_num_channels() - num_planes = scan_image_tiff_single_plane_imaging_extractor.get_num_planes() - with ScanImageTiffReader(file_path) as io: - assert_array_equal(video, io.data()[:: num_planes * num_channels]) - - -@pytest.mark.parametrize("start_frame, end_frame", [(0, 2), (5, 10), (20, 32)]) -def test_get_video_adesnik( - scan_image_tiff_single_plane_imaging_extractor_adesnik, - num_planes_adesnik, - num_channels_adesnik, - start_frame, - end_frame, -): - video = scan_image_tiff_single_plane_imaging_extractor_adesnik.get_video( - start_frame=start_frame, end_frame=end_frame - ) - file_path = str(scan_image_tiff_single_plane_imaging_extractor_adesnik.file_path) - plane = scan_image_tiff_single_plane_imaging_extractor_adesnik.plane - channel = scan_image_tiff_single_plane_imaging_extractor_adesnik.channel - raw_idxs = [ - idx * num_planes_adesnik * num_channels_adesnik + plane * num_channels_adesnik + channel - for idx in range(start_frame, end_frame) - ] - with ScanImageTiffReader(file_path) as io: - assert_array_equal(video, io.data()[raw_idxs]) - - -@pytest.mark.parametrize("start_frame, end_frame", [(-1, 2), (0, 50)]) -def test_get_video_adesnik_invalid( - scan_image_tiff_single_plane_imaging_extractor_adesnik, - start_frame, - end_frame, -): - with pytest.raises(ValueError): - scan_image_tiff_single_plane_imaging_extractor_adesnik.get_video(start_frame=start_frame, end_frame=end_frame) - - -def test_get_image_size(scan_image_tiff_single_plane_imaging_extractor): - image_size = scan_image_tiff_single_plane_imaging_extractor.get_image_size() - file_path = str(scan_image_tiff_single_plane_imaging_extractor.file_path) - with ScanImageTiffReader(file_path) as io: - assert image_size == tuple(io.shape()[1:]) - - -def test_get_num_frames(scan_image_tiff_single_plane_imaging_extractor): - num_frames = scan_image_tiff_single_plane_imaging_extractor.get_num_frames() - num_channels = scan_image_tiff_single_plane_imaging_extractor.get_num_channels() - num_planes = scan_image_tiff_single_plane_imaging_extractor.get_num_planes() - file_path = str(scan_image_tiff_single_plane_imaging_extractor.file_path) - with ScanImageTiffReader(file_path) as io: - assert num_frames == io.shape()[0] // (num_channels * num_planes) - - -def test_get_sampling_frequency(scan_image_tiff_single_plane_imaging_extractor): - sampling_frequency = scan_image_tiff_single_plane_imaging_extractor.get_sampling_frequency() - file_path = str(scan_image_tiff_single_plane_imaging_extractor.file_path) - with ScanImageTiffReader(file_path) as io: - metadata_string = io.metadata() - metadata_dict = metadata_string_to_dict(metadata_string) - assert sampling_frequency == float(metadata_dict["SI.hRoiManager.scanVolumeRate"]) - - -def test_get_num_channels(scan_image_tiff_single_plane_imaging_extractor): - num_channels = scan_image_tiff_single_plane_imaging_extractor.get_num_channels() - file_path = str(scan_image_tiff_single_plane_imaging_extractor.file_path) - with ScanImageTiffReader(file_path) as io: - metadata_string = io.metadata() - metadata_dict = metadata_string_to_dict(metadata_string) - assert num_channels == len(metadata_dict["SI.hChannels.channelsActive"].split(";")) - - -def test_get_channel_names(scan_image_tiff_single_plane_imaging_extractor): - channel_names = scan_image_tiff_single_plane_imaging_extractor.get_channel_names() - num_channels = scan_image_tiff_single_plane_imaging_extractor.get_num_channels() - file_path = str(scan_image_tiff_single_plane_imaging_extractor.file_path) - with ScanImageTiffReader(file_path) as io: - metadata_string = io.metadata() - metadata_dict = metadata_string_to_dict(metadata_string) - assert channel_names == metadata_dict["SI.hChannels.channelName"].split("'")[1::2][:num_channels] - - -def test_get_num_planes(scan_image_tiff_single_plane_imaging_extractor): - num_planes = scan_image_tiff_single_plane_imaging_extractor.get_num_planes() - file_path = str(scan_image_tiff_single_plane_imaging_extractor.file_path) - with ScanImageTiffReader(file_path) as io: - metadata_string = io.metadata() - metadata_dict = metadata_string_to_dict(metadata_string) - assert num_planes == int(metadata_dict["SI.hStackManager.numSlices"]) - - -def test_get_dtype(scan_image_tiff_single_plane_imaging_extractor): - dtype = scan_image_tiff_single_plane_imaging_extractor.get_dtype() - file_path = str(scan_image_tiff_single_plane_imaging_extractor.file_path) - with ScanImageTiffReader(file_path) as io: - assert dtype == io.data().dtype - - -def test_check_frame_inputs_valid(scan_image_tiff_single_plane_imaging_extractor): - scan_image_tiff_single_plane_imaging_extractor.check_frame_inputs(frame=0) - - -def test_check_frame_inputs_invalid(scan_image_tiff_single_plane_imaging_extractor): - num_frames = scan_image_tiff_single_plane_imaging_extractor.get_num_frames() - with pytest.raises(ValueError): - scan_image_tiff_single_plane_imaging_extractor.check_frame_inputs(frame=num_frames + 1) - - -@pytest.mark.parametrize("frame", (0, 10, 31)) -def test_frame_to_raw_index_adesnik( - scan_image_tiff_single_plane_imaging_extractor_adesnik, num_channels_adesnik, num_planes_adesnik, frame -): - raw_index = scan_image_tiff_single_plane_imaging_extractor_adesnik.frame_to_raw_index(frame=frame) - plane = scan_image_tiff_single_plane_imaging_extractor_adesnik.plane - channel = scan_image_tiff_single_plane_imaging_extractor_adesnik.channel - assert raw_index == (frame * num_planes_adesnik * num_channels_adesnik) + (plane * num_channels_adesnik) + channel - - -@pytest.mark.parametrize("file_path", file_paths) -def test_extract_extra_metadata(file_path): - metadata = extract_extra_metadata(file_path) - io = ScanImageTiffReader(str(file_path)) - extra_metadata = {} - for metadata_string in (io.description(iframe=0), io.metadata()): - metadata_dict = { - x.split("=")[0].strip(): x.split("=")[1].strip() - for x in metadata_string.replace("\n", "\r").split("\r") - if "=" in x - } - extra_metadata = dict(**extra_metadata, **metadata_dict) - assert metadata == extra_metadata - - -@pytest.mark.parametrize("file_path", file_paths) -def test_parse_metadata(file_path): - metadata = extract_extra_metadata(file_path) - parsed_metadata = parse_metadata(metadata) - sampling_frequency = float(metadata["SI.hRoiManager.scanVolumeRate"]) - num_channels = len(metadata["SI.hChannels.channelsActive"].split(";")) - num_planes = int(metadata["SI.hStackManager.numSlices"]) - frames_per_slice = int(metadata["SI.hStackManager.framesPerSlice"]) - channel_names = metadata["SI.hChannels.channelName"].split("'")[1::2][:num_channels] - assert parsed_metadata == dict( - sampling_frequency=sampling_frequency, - num_channels=num_channels, - num_planes=num_planes, - frames_per_slice=frames_per_slice, - channel_names=channel_names, - ) - - -def test_parse_metadata_v3_8(): - file_path = scan_image_path / "sample_scanimage_version_3_8.tiff" - metadata = extract_extra_metadata(file_path) - parsed_metadata = parse_metadata_v3_8(metadata) - sampling_frequency = float(metadata["state.acq.frameRate"]) - num_channels = int(metadata["state.acq.numberOfChannelsSave"]) - num_planes = int(metadata["state.acq.numberOfZSlices"]) - assert parsed_metadata == dict( - sampling_frequency=sampling_frequency, - num_channels=num_channels, - num_planes=num_planes, - ) - - -@pytest.mark.parametrize("file_path", file_paths) -def test_ScanImageTiffMultiPlaneImagingExtractor__init__(file_path): - extractor = ScanImageTiffMultiPlaneImagingExtractor(file_path=file_path) - assert extractor.file_path == file_path - - -@pytest.mark.parametrize("channel_name, plane_name", [("Invalid Channel", "0"), ("Channel 1", "Invalid Plane")]) -def test_ScanImageTiffSinglePlaneImagingExtractor__init__invalid(channel_name, plane_name): - with pytest.raises(ValueError): - ScanImageTiffSinglePlaneImagingExtractor( - file_path=file_paths[0], channel_name=channel_name, plane_name=plane_name - ) - - -def test_ScanImageTiffMultiPlaneImagingExtractor__init__invalid(): - with pytest.raises(ValueError): - ScanImageTiffMultiPlaneImagingExtractor(file_path=file_paths[0], channel_name="Invalid Channel") From f6d1896f16194fb284251d5d534d4f996160adb2 Mon Sep 17 00:00:00 2001 From: pauladkisson Date: Thu, 28 Sep 2023 16:05:30 -0700 Subject: [PATCH 5/8] added tests for the rest of the volumetric extractor methods --- tests/test_volumetricimagingextractor.py | 53 ++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/tests/test_volumetricimagingextractor.py b/tests/test_volumetricimagingextractor.py index e72005bb..62424dae 100644 --- a/tests/test_volumetricimagingextractor.py +++ b/tests/test_volumetricimagingextractor.py @@ -66,3 +66,56 @@ def test_get_frames(volumetric_imaging_extractor, frame_idxs): def test_get_frames_invalid(volumetric_imaging_extractor, frame_idxs): with pytest.raises(ValueError): volumetric_imaging_extractor.get_frames(frame_idxs=frame_idxs) + + +@pytest.mark.parametrize("num_rows, num_columns, num_planes", [(1, 2, 3), (2, 1, 3), (3, 2, 1)]) +def test_get_image_size(num_rows, num_columns, num_planes): + imaging_extractors = [ + generate_dummy_imaging_extractor(num_rows=num_rows, num_columns=num_columns) for _ in range(num_planes) + ] + volumetric_imaging_extractor = VolumetricImagingExtractor(imaging_extractors=imaging_extractors) + assert volumetric_imaging_extractor.get_image_size() == (num_rows, num_columns, num_planes) + + +@pytest.mark.parametrize("num_planes", [1, 2, 3]) +def test_get_num_planes(num_planes): + imaging_extractors = [generate_dummy_imaging_extractor() for _ in range(num_planes)] + volumetric_imaging_extractor = VolumetricImagingExtractor(imaging_extractors=imaging_extractors) + assert volumetric_imaging_extractor.get_num_planes() == num_planes + + +@pytest.mark.parametrize("num_frames", [1, 2, 3]) +def test_get_num_frames(num_frames): + imaging_extractors = [generate_dummy_imaging_extractor(num_frames=num_frames)] + volumetric_imaging_extractor = VolumetricImagingExtractor(imaging_extractors=imaging_extractors) + assert volumetric_imaging_extractor.get_num_frames() == num_frames + + +@pytest.mark.parametrize("sampling_frequency", [1, 2, 3]) +def test_get_sampling_frequency(sampling_frequency): + imaging_extractors = [generate_dummy_imaging_extractor(sampling_frequency=sampling_frequency)] + volumetric_imaging_extractor = VolumetricImagingExtractor(imaging_extractors=imaging_extractors) + assert volumetric_imaging_extractor.get_sampling_frequency() == sampling_frequency + + +@pytest.mark.parametrize("channel_names", [["Channel 1"], [" Channel 1 ", "Channel 2"]]) +def test_get_channel_names(channel_names): + imaging_extractors = [ + generate_dummy_imaging_extractor(channel_names=channel_names, num_channels=len(channel_names)) + ] + volumetric_imaging_extractor = VolumetricImagingExtractor(imaging_extractors=imaging_extractors) + assert volumetric_imaging_extractor.get_channel_names() == channel_names + + +@pytest.mark.parametrize("num_channels", [1, 2, 3]) +def test_get_num_channels(num_channels): + imaging_extractors = [generate_dummy_imaging_extractor(num_channels=num_channels)] + volumetric_imaging_extractor = VolumetricImagingExtractor(imaging_extractors=imaging_extractors) + assert volumetric_imaging_extractor.get_num_channels() == num_channels + + +@pytest.mark.parametrize("dtype", [np.float64, np.int16, np.uint8]) +def test_get_dtype(dtype): + imaging_extractors = [generate_dummy_imaging_extractor(dtype=dtype)] + volumetric_imaging_extractor = VolumetricImagingExtractor(imaging_extractors=imaging_extractors) + assert volumetric_imaging_extractor.get_dtype() == dtype From 4596212a4282439f9fc9c51e3bdb884fad90744a Mon Sep 17 00:00:00 2001 From: pauladkisson Date: Fri, 29 Sep 2023 11:50:31 -0700 Subject: [PATCH 6/8] removed unnecessary squeezing logic and a clarifying comment --- src/roiextractors/volumetricimagingextractor.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/roiextractors/volumetricimagingextractor.py b/src/roiextractors/volumetricimagingextractor.py index bef728e8..394e11a5 100644 --- a/src/roiextractors/volumetricimagingextractor.py +++ b/src/roiextractors/volumetricimagingextractor.py @@ -115,22 +115,18 @@ def get_frames(self, frame_idxs: ArrayType) -> np.ndarray: frames: numpy.ndarray The 3D video frames (num_rows, num_columns, num_planes). """ - squeeze_data = False if isinstance(frame_idxs, int): frame_idxs = [frame_idxs] - squeeze_data = True for frame_idx in frame_idxs: if frame_idx < -1 * self.get_num_frames() or frame_idx >= self.get_num_frames(): raise ValueError(f"frame_idx {frame_idx} is out of bounds") + # Note np.all([]) returns True so not all(np.diff(frame_idxs) == 1) returns False if frame_idxs is a single int if not all(np.diff(frame_idxs) == 1): frames = np.zeros((len(frame_idxs), *self.get_image_size()), self.get_dtype()) for i, imaging_extractor in enumerate(self._imaging_extractors): frames[..., i] = imaging_extractor.get_frames(frame_idxs) - if squeeze_data: - return frames.squeeze() - else: - return frames + return frames else: return self.get_video(start_frame=frame_idxs[0], end_frame=frame_idxs[-1] + 1) From bb692f569e0e8cbec3fb2516cd5e07d3e6cab153 Mon Sep 17 00:00:00 2001 From: Paul Adkisson Date: Mon, 23 Oct 2023 18:54:14 -0400 Subject: [PATCH 7/8] fix docstring typo Co-authored-by: Alessandra Trapani <55453048+alessandratrapani@users.noreply.github.com> --- src/roiextractors/volumetricimagingextractor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/roiextractors/volumetricimagingextractor.py b/src/roiextractors/volumetricimagingextractor.py index 394e11a5..759caf8f 100644 --- a/src/roiextractors/volumetricimagingextractor.py +++ b/src/roiextractors/volumetricimagingextractor.py @@ -78,7 +78,7 @@ def get_video(self, start_frame: Optional[int] = None, end_frame: Optional[int] Returns ------- video: numpy.ndarray - The 3D video frames (num_rows, num_columns, num_planes). + The 3D video frames (num_frames, num_rows, num_columns, num_planes). """ if start_frame is None: start_frame = 0 From cdf71925de6a490eebc7557aed215a7d9bf5ced5 Mon Sep 17 00:00:00 2001 From: pauladkisson Date: Mon, 23 Oct 2023 15:57:57 -0700 Subject: [PATCH 8/8] added module docstring --- src/roiextractors/volumetricimagingextractor.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/roiextractors/volumetricimagingextractor.py b/src/roiextractors/volumetricimagingextractor.py index 759caf8f..2abf0c1a 100644 --- a/src/roiextractors/volumetricimagingextractor.py +++ b/src/roiextractors/volumetricimagingextractor.py @@ -1,3 +1,5 @@ +"""Base class definition for volumetric imaging extractors.""" + from typing import Tuple, List, Iterable, Optional import numpy as np