Skip to content

Commit

Permalink
Merge pull request #248 from catalystneuro/volumetric
Browse files Browse the repository at this point in the history
VolumetricImagingExtractor
  • Loading branch information
pauladkisson authored Oct 25, 2023
2 parents 5bc3293 + 60e89ca commit 871552c
Show file tree
Hide file tree
Showing 4 changed files with 295 additions and 1 deletion.
2 changes: 2 additions & 0 deletions src/roiextractors/extractorlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from .extractors.miniscopeimagingextractor import MiniscopeImagingExtractor
from .multisegmentationextractor import MultiSegmentationExtractor
from .multiimagingextractor import MultiImagingExtractor
from .volumetricimagingextractor import VolumetricImagingExtractor

imaging_extractor_full_list = [
NumpyImagingExtractor,
Expand All @@ -39,6 +40,7 @@
SbxImagingExtractor,
NumpyMemmapImagingExtractor,
MemmapImagingExtractor,
VolumetricImagingExtractor,
]

segmentation_extractor_full_list = [
Expand Down
4 changes: 3 additions & 1 deletion src/roiextractors/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def generate_dummy_imaging_extractor(
num_channels: int = 1,
sampling_frequency: float = 30,
dtype: DtypeType = "uint16",
channel_names: Optional[list] = None,
):
"""Generate a dummy imaging extractor for testing.
Expand All @@ -78,7 +79,8 @@ def generate_dummy_imaging_extractor(
ImagingExtractor
An imaging extractor with random data fed into `NumpyImagingExtractor`.
"""
channel_names = [f"channel_num_{num}" for num in range(num_channels)]
if channel_names is None:
channel_names = [f"channel_num_{num}" for num in range(num_channels)]

size = (num_frames, num_rows, num_columns, num_channels)
video = generate_dummy_video(size=size, dtype=dtype)
Expand Down
169 changes: 169 additions & 0 deletions src/roiextractors/volumetricimagingextractor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
"""Base class definition for volumetric imaging extractors."""

from typing import Tuple, List, Iterable, Optional
import numpy as np

from .extraction_tools import ArrayType, DtypeType
from .imagingextractor import ImagingExtractor


class VolumetricImagingExtractor(ImagingExtractor):
"""Class to combine multiple ImagingExtractor objects by depth plane."""

extractor_name = "VolumetricImaging"
installed = True
installatiuon_mesage = ""

def __init__(self, imaging_extractors: List[ImagingExtractor]):
"""Initialize a VolumetricImagingExtractor object from a list of ImagingExtractors.
Parameters
----------
imaging_extractors: list of ImagingExtractor
list of imaging extractor objects
"""
super().__init__()
assert isinstance(imaging_extractors, list), "Enter a list of ImagingExtractor objects as argument"
assert all(isinstance(imaging_extractor, ImagingExtractor) for imaging_extractor in imaging_extractors)
self._check_consistency_between_imaging_extractors(imaging_extractors)
self._imaging_extractors = imaging_extractors
self._num_planes = len(imaging_extractors)

def _check_consistency_between_imaging_extractors(self, imaging_extractors: List[ImagingExtractor]):
"""Check that essential properties are consistent between extractors so that they can be combined appropriately.
Parameters
----------
imaging_extractors: list of ImagingExtractor
list of imaging extractor objects
Raises
------
AssertionError
If any of the properties are not consistent between extractors.
Notes
-----
This method checks the following properties:
- sampling frequency
- image size
- number of channels
- channel names
- data type
- num_frames
"""
properties_to_check = dict(
get_sampling_frequency="The sampling frequency",
get_image_size="The size of a frame",
get_num_channels="The number of channels",
get_channel_names="The name of the channels",
get_dtype="The data type",
get_num_frames="The number of frames",
)
for method, property_message in properties_to_check.items():
values = [getattr(extractor, method)() for extractor in imaging_extractors]
unique_values = set(tuple(v) if isinstance(v, Iterable) else v for v in values)
assert (
len(unique_values) == 1
), f"{property_message} is not consistent over the files (found {unique_values})."

def get_video(self, start_frame: Optional[int] = None, end_frame: Optional[int] = None) -> np.ndarray:
"""Get the video frames.
Parameters
----------
start_frame: int, optional
Start frame index (inclusive).
end_frame: int, optional
End frame index (exclusive).
Returns
-------
video: numpy.ndarray
The 3D video frames (num_frames, num_rows, num_columns, num_planes).
"""
if start_frame is None:
start_frame = 0
elif start_frame < 0:
start_frame = self.get_num_frames() + start_frame
elif start_frame >= self.get_num_frames():
raise ValueError(
f"start_frame {start_frame} is greater than or equal to the number of frames {self.get_num_frames()}"
)
if end_frame is None:
end_frame = self.get_num_frames()
elif end_frame < 0:
end_frame = self.get_num_frames() + end_frame
elif end_frame > self.get_num_frames():
raise ValueError(f"end_frame {end_frame} is greater than the number of frames {self.get_num_frames()}")
if end_frame <= start_frame:
raise ValueError(f"end_frame {end_frame} is less than or equal to start_frame {start_frame}")

video = np.zeros((end_frame - start_frame, *self.get_image_size()), self.get_dtype())
for i, imaging_extractor in enumerate(self._imaging_extractors):
video[..., i] = imaging_extractor.get_video(start_frame, end_frame)
return video

def get_frames(self, frame_idxs: ArrayType) -> np.ndarray:
"""Get specific video frames from indices (not necessarily continuous).
Parameters
----------
frame_idxs: array-like
Indices of frames to return.
Returns
-------
frames: numpy.ndarray
The 3D video frames (num_rows, num_columns, num_planes).
"""
if isinstance(frame_idxs, int):
frame_idxs = [frame_idxs]
for frame_idx in frame_idxs:
if frame_idx < -1 * self.get_num_frames() or frame_idx >= self.get_num_frames():
raise ValueError(f"frame_idx {frame_idx} is out of bounds")

# Note np.all([]) returns True so not all(np.diff(frame_idxs) == 1) returns False if frame_idxs is a single int
if not all(np.diff(frame_idxs) == 1):
frames = np.zeros((len(frame_idxs), *self.get_image_size()), self.get_dtype())
for i, imaging_extractor in enumerate(self._imaging_extractors):
frames[..., i] = imaging_extractor.get_frames(frame_idxs)
return frames
else:
return self.get_video(start_frame=frame_idxs[0], end_frame=frame_idxs[-1] + 1)

def get_image_size(self) -> Tuple:
"""Get the size of a single frame.
Returns
-------
image_size: tuple
The size of a single frame (num_rows, num_columns, num_planes).
"""
image_size = (*self._imaging_extractors[0].get_image_size(), self.get_num_planes())
return image_size

def get_num_planes(self) -> int:
"""Get the number of depth planes.
Returns
-------
_num_planes: int
The number of depth planes.
"""
return self._num_planes

def get_num_frames(self) -> int:
return self._imaging_extractors[0].get_num_frames()

def get_sampling_frequency(self) -> float:
return self._imaging_extractors[0].get_sampling_frequency()

def get_channel_names(self) -> list:
return self._imaging_extractors[0].get_channel_names()

def get_num_channels(self) -> int:
return self._imaging_extractors[0].get_num_channels()

def get_dtype(self) -> DtypeType:
return self._imaging_extractors[0].get_dtype()
121 changes: 121 additions & 0 deletions tests/test_volumetricimagingextractor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import pytest
import numpy as np
from roiextractors.testing import generate_dummy_imaging_extractor
from roiextractors import VolumetricImagingExtractor

num_frames = 10


@pytest.fixture(scope="module", params=[1, 2])
def imaging_extractors(request):
num_channels = request.param
return [generate_dummy_imaging_extractor(num_channels=num_channels, num_frames=num_frames) for _ in range(3)]


@pytest.fixture(scope="module")
def volumetric_imaging_extractor(imaging_extractors):
return VolumetricImagingExtractor(imaging_extractors)


@pytest.mark.parametrize(
"params",
[
[dict(sampling_frequency=1), dict(sampling_frequency=2)],
[dict(num_rows=1), dict(num_rows=2)],
[dict(num_channels=1), dict(num_channels=2)],
[dict(channel_names=["a"], num_channels=1), dict(channel_names=["b"], num_channels=1)],
[dict(dtype=np.int16), dict(dtype=np.float32)],
[dict(num_frames=1), dict(num_frames=2)],
],
)
def test_check_consistency_between_imaging_extractors(params):
imaging_extractors = [generate_dummy_imaging_extractor(**param) for param in params]
with pytest.raises(AssertionError):
VolumetricImagingExtractor(imaging_extractors=imaging_extractors)


@pytest.mark.parametrize("start_frame, end_frame", [(None, None), (0, num_frames), (3, 7), (-2, -1)])
def test_get_video(volumetric_imaging_extractor, start_frame, end_frame):
video = volumetric_imaging_extractor.get_video(start_frame=start_frame, end_frame=end_frame)
expected_video = []
for extractor in volumetric_imaging_extractor._imaging_extractors:
expected_video.append(extractor.get_video(start_frame=start_frame, end_frame=end_frame))
expected_video = np.array(expected_video)
expected_video = np.moveaxis(expected_video, 0, -1)
assert np.all(video == expected_video)


@pytest.mark.parametrize("start_frame, end_frame", [(num_frames + 1, None), (None, num_frames + 1), (2, 1)])
def test_get_video_invalid(volumetric_imaging_extractor, start_frame, end_frame):
with pytest.raises(ValueError):
volumetric_imaging_extractor.get_video(start_frame=start_frame, end_frame=end_frame)


@pytest.mark.parametrize("frame_idxs", [0, [0, 1, 2], [0, num_frames - 1], [-3, -1]])
def test_get_frames(volumetric_imaging_extractor, frame_idxs):
frames = volumetric_imaging_extractor.get_frames(frame_idxs=frame_idxs)
expected_frames = []
for extractor in volumetric_imaging_extractor._imaging_extractors:
expected_frames.append(extractor.get_frames(frame_idxs=frame_idxs))
expected_frames = np.array(expected_frames)
expected_frames = np.moveaxis(expected_frames, 0, -1)
assert np.all(frames == expected_frames)


@pytest.mark.parametrize("frame_idxs", [num_frames, [0, num_frames], [-num_frames - 1, -1]])
def test_get_frames_invalid(volumetric_imaging_extractor, frame_idxs):
with pytest.raises(ValueError):
volumetric_imaging_extractor.get_frames(frame_idxs=frame_idxs)


@pytest.mark.parametrize("num_rows, num_columns, num_planes", [(1, 2, 3), (2, 1, 3), (3, 2, 1)])
def test_get_image_size(num_rows, num_columns, num_planes):
imaging_extractors = [
generate_dummy_imaging_extractor(num_rows=num_rows, num_columns=num_columns) for _ in range(num_planes)
]
volumetric_imaging_extractor = VolumetricImagingExtractor(imaging_extractors=imaging_extractors)
assert volumetric_imaging_extractor.get_image_size() == (num_rows, num_columns, num_planes)


@pytest.mark.parametrize("num_planes", [1, 2, 3])
def test_get_num_planes(num_planes):
imaging_extractors = [generate_dummy_imaging_extractor() for _ in range(num_planes)]
volumetric_imaging_extractor = VolumetricImagingExtractor(imaging_extractors=imaging_extractors)
assert volumetric_imaging_extractor.get_num_planes() == num_planes


@pytest.mark.parametrize("num_frames", [1, 2, 3])
def test_get_num_frames(num_frames):
imaging_extractors = [generate_dummy_imaging_extractor(num_frames=num_frames)]
volumetric_imaging_extractor = VolumetricImagingExtractor(imaging_extractors=imaging_extractors)
assert volumetric_imaging_extractor.get_num_frames() == num_frames


@pytest.mark.parametrize("sampling_frequency", [1, 2, 3])
def test_get_sampling_frequency(sampling_frequency):
imaging_extractors = [generate_dummy_imaging_extractor(sampling_frequency=sampling_frequency)]
volumetric_imaging_extractor = VolumetricImagingExtractor(imaging_extractors=imaging_extractors)
assert volumetric_imaging_extractor.get_sampling_frequency() == sampling_frequency


@pytest.mark.parametrize("channel_names", [["Channel 1"], [" Channel 1 ", "Channel 2"]])
def test_get_channel_names(channel_names):
imaging_extractors = [
generate_dummy_imaging_extractor(channel_names=channel_names, num_channels=len(channel_names))
]
volumetric_imaging_extractor = VolumetricImagingExtractor(imaging_extractors=imaging_extractors)
assert volumetric_imaging_extractor.get_channel_names() == channel_names


@pytest.mark.parametrize("num_channels", [1, 2, 3])
def test_get_num_channels(num_channels):
imaging_extractors = [generate_dummy_imaging_extractor(num_channels=num_channels)]
volumetric_imaging_extractor = VolumetricImagingExtractor(imaging_extractors=imaging_extractors)
assert volumetric_imaging_extractor.get_num_channels() == num_channels


@pytest.mark.parametrize("dtype", [np.float64, np.int16, np.uint8])
def test_get_dtype(dtype):
imaging_extractors = [generate_dummy_imaging_extractor(dtype=dtype)]
volumetric_imaging_extractor = VolumetricImagingExtractor(imaging_extractors=imaging_extractors)
assert volumetric_imaging_extractor.get_dtype() == dtype

0 comments on commit 871552c

Please sign in to comment.