-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into docstring_tests
- Loading branch information
Showing
4 changed files
with
295 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,169 @@ | ||
"""Base class definition for volumetric imaging extractors.""" | ||
|
||
from typing import Tuple, List, Iterable, Optional | ||
import numpy as np | ||
|
||
from .extraction_tools import ArrayType, DtypeType | ||
from .imagingextractor import ImagingExtractor | ||
|
||
|
||
class VolumetricImagingExtractor(ImagingExtractor): | ||
"""Class to combine multiple ImagingExtractor objects by depth plane.""" | ||
|
||
extractor_name = "VolumetricImaging" | ||
installed = True | ||
installatiuon_mesage = "" | ||
|
||
def __init__(self, imaging_extractors: List[ImagingExtractor]): | ||
"""Initialize a VolumetricImagingExtractor object from a list of ImagingExtractors. | ||
Parameters | ||
---------- | ||
imaging_extractors: list of ImagingExtractor | ||
list of imaging extractor objects | ||
""" | ||
super().__init__() | ||
assert isinstance(imaging_extractors, list), "Enter a list of ImagingExtractor objects as argument" | ||
assert all(isinstance(imaging_extractor, ImagingExtractor) for imaging_extractor in imaging_extractors) | ||
self._check_consistency_between_imaging_extractors(imaging_extractors) | ||
self._imaging_extractors = imaging_extractors | ||
self._num_planes = len(imaging_extractors) | ||
|
||
def _check_consistency_between_imaging_extractors(self, imaging_extractors: List[ImagingExtractor]): | ||
"""Check that essential properties are consistent between extractors so that they can be combined appropriately. | ||
Parameters | ||
---------- | ||
imaging_extractors: list of ImagingExtractor | ||
list of imaging extractor objects | ||
Raises | ||
------ | ||
AssertionError | ||
If any of the properties are not consistent between extractors. | ||
Notes | ||
----- | ||
This method checks the following properties: | ||
- sampling frequency | ||
- image size | ||
- number of channels | ||
- channel names | ||
- data type | ||
- num_frames | ||
""" | ||
properties_to_check = dict( | ||
get_sampling_frequency="The sampling frequency", | ||
get_image_size="The size of a frame", | ||
get_num_channels="The number of channels", | ||
get_channel_names="The name of the channels", | ||
get_dtype="The data type", | ||
get_num_frames="The number of frames", | ||
) | ||
for method, property_message in properties_to_check.items(): | ||
values = [getattr(extractor, method)() for extractor in imaging_extractors] | ||
unique_values = set(tuple(v) if isinstance(v, Iterable) else v for v in values) | ||
assert ( | ||
len(unique_values) == 1 | ||
), f"{property_message} is not consistent over the files (found {unique_values})." | ||
|
||
def get_video(self, start_frame: Optional[int] = None, end_frame: Optional[int] = None) -> np.ndarray: | ||
"""Get the video frames. | ||
Parameters | ||
---------- | ||
start_frame: int, optional | ||
Start frame index (inclusive). | ||
end_frame: int, optional | ||
End frame index (exclusive). | ||
Returns | ||
------- | ||
video: numpy.ndarray | ||
The 3D video frames (num_frames, num_rows, num_columns, num_planes). | ||
""" | ||
if start_frame is None: | ||
start_frame = 0 | ||
elif start_frame < 0: | ||
start_frame = self.get_num_frames() + start_frame | ||
elif start_frame >= self.get_num_frames(): | ||
raise ValueError( | ||
f"start_frame {start_frame} is greater than or equal to the number of frames {self.get_num_frames()}" | ||
) | ||
if end_frame is None: | ||
end_frame = self.get_num_frames() | ||
elif end_frame < 0: | ||
end_frame = self.get_num_frames() + end_frame | ||
elif end_frame > self.get_num_frames(): | ||
raise ValueError(f"end_frame {end_frame} is greater than the number of frames {self.get_num_frames()}") | ||
if end_frame <= start_frame: | ||
raise ValueError(f"end_frame {end_frame} is less than or equal to start_frame {start_frame}") | ||
|
||
video = np.zeros((end_frame - start_frame, *self.get_image_size()), self.get_dtype()) | ||
for i, imaging_extractor in enumerate(self._imaging_extractors): | ||
video[..., i] = imaging_extractor.get_video(start_frame, end_frame) | ||
return video | ||
|
||
def get_frames(self, frame_idxs: ArrayType) -> np.ndarray: | ||
"""Get specific video frames from indices (not necessarily continuous). | ||
Parameters | ||
---------- | ||
frame_idxs: array-like | ||
Indices of frames to return. | ||
Returns | ||
------- | ||
frames: numpy.ndarray | ||
The 3D video frames (num_rows, num_columns, num_planes). | ||
""" | ||
if isinstance(frame_idxs, int): | ||
frame_idxs = [frame_idxs] | ||
for frame_idx in frame_idxs: | ||
if frame_idx < -1 * self.get_num_frames() or frame_idx >= self.get_num_frames(): | ||
raise ValueError(f"frame_idx {frame_idx} is out of bounds") | ||
|
||
# Note np.all([]) returns True so not all(np.diff(frame_idxs) == 1) returns False if frame_idxs is a single int | ||
if not all(np.diff(frame_idxs) == 1): | ||
frames = np.zeros((len(frame_idxs), *self.get_image_size()), self.get_dtype()) | ||
for i, imaging_extractor in enumerate(self._imaging_extractors): | ||
frames[..., i] = imaging_extractor.get_frames(frame_idxs) | ||
return frames | ||
else: | ||
return self.get_video(start_frame=frame_idxs[0], end_frame=frame_idxs[-1] + 1) | ||
|
||
def get_image_size(self) -> Tuple: | ||
"""Get the size of a single frame. | ||
Returns | ||
------- | ||
image_size: tuple | ||
The size of a single frame (num_rows, num_columns, num_planes). | ||
""" | ||
image_size = (*self._imaging_extractors[0].get_image_size(), self.get_num_planes()) | ||
return image_size | ||
|
||
def get_num_planes(self) -> int: | ||
"""Get the number of depth planes. | ||
Returns | ||
------- | ||
_num_planes: int | ||
The number of depth planes. | ||
""" | ||
return self._num_planes | ||
|
||
def get_num_frames(self) -> int: | ||
return self._imaging_extractors[0].get_num_frames() | ||
|
||
def get_sampling_frequency(self) -> float: | ||
return self._imaging_extractors[0].get_sampling_frequency() | ||
|
||
def get_channel_names(self) -> list: | ||
return self._imaging_extractors[0].get_channel_names() | ||
|
||
def get_num_channels(self) -> int: | ||
return self._imaging_extractors[0].get_num_channels() | ||
|
||
def get_dtype(self) -> DtypeType: | ||
return self._imaging_extractors[0].get_dtype() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
import pytest | ||
import numpy as np | ||
from roiextractors.testing import generate_dummy_imaging_extractor | ||
from roiextractors import VolumetricImagingExtractor | ||
|
||
num_frames = 10 | ||
|
||
|
||
@pytest.fixture(scope="module", params=[1, 2]) | ||
def imaging_extractors(request): | ||
num_channels = request.param | ||
return [generate_dummy_imaging_extractor(num_channels=num_channels, num_frames=num_frames) for _ in range(3)] | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def volumetric_imaging_extractor(imaging_extractors): | ||
return VolumetricImagingExtractor(imaging_extractors) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"params", | ||
[ | ||
[dict(sampling_frequency=1), dict(sampling_frequency=2)], | ||
[dict(num_rows=1), dict(num_rows=2)], | ||
[dict(num_channels=1), dict(num_channels=2)], | ||
[dict(channel_names=["a"], num_channels=1), dict(channel_names=["b"], num_channels=1)], | ||
[dict(dtype=np.int16), dict(dtype=np.float32)], | ||
[dict(num_frames=1), dict(num_frames=2)], | ||
], | ||
) | ||
def test_check_consistency_between_imaging_extractors(params): | ||
imaging_extractors = [generate_dummy_imaging_extractor(**param) for param in params] | ||
with pytest.raises(AssertionError): | ||
VolumetricImagingExtractor(imaging_extractors=imaging_extractors) | ||
|
||
|
||
@pytest.mark.parametrize("start_frame, end_frame", [(None, None), (0, num_frames), (3, 7), (-2, -1)]) | ||
def test_get_video(volumetric_imaging_extractor, start_frame, end_frame): | ||
video = volumetric_imaging_extractor.get_video(start_frame=start_frame, end_frame=end_frame) | ||
expected_video = [] | ||
for extractor in volumetric_imaging_extractor._imaging_extractors: | ||
expected_video.append(extractor.get_video(start_frame=start_frame, end_frame=end_frame)) | ||
expected_video = np.array(expected_video) | ||
expected_video = np.moveaxis(expected_video, 0, -1) | ||
assert np.all(video == expected_video) | ||
|
||
|
||
@pytest.mark.parametrize("start_frame, end_frame", [(num_frames + 1, None), (None, num_frames + 1), (2, 1)]) | ||
def test_get_video_invalid(volumetric_imaging_extractor, start_frame, end_frame): | ||
with pytest.raises(ValueError): | ||
volumetric_imaging_extractor.get_video(start_frame=start_frame, end_frame=end_frame) | ||
|
||
|
||
@pytest.mark.parametrize("frame_idxs", [0, [0, 1, 2], [0, num_frames - 1], [-3, -1]]) | ||
def test_get_frames(volumetric_imaging_extractor, frame_idxs): | ||
frames = volumetric_imaging_extractor.get_frames(frame_idxs=frame_idxs) | ||
expected_frames = [] | ||
for extractor in volumetric_imaging_extractor._imaging_extractors: | ||
expected_frames.append(extractor.get_frames(frame_idxs=frame_idxs)) | ||
expected_frames = np.array(expected_frames) | ||
expected_frames = np.moveaxis(expected_frames, 0, -1) | ||
assert np.all(frames == expected_frames) | ||
|
||
|
||
@pytest.mark.parametrize("frame_idxs", [num_frames, [0, num_frames], [-num_frames - 1, -1]]) | ||
def test_get_frames_invalid(volumetric_imaging_extractor, frame_idxs): | ||
with pytest.raises(ValueError): | ||
volumetric_imaging_extractor.get_frames(frame_idxs=frame_idxs) | ||
|
||
|
||
@pytest.mark.parametrize("num_rows, num_columns, num_planes", [(1, 2, 3), (2, 1, 3), (3, 2, 1)]) | ||
def test_get_image_size(num_rows, num_columns, num_planes): | ||
imaging_extractors = [ | ||
generate_dummy_imaging_extractor(num_rows=num_rows, num_columns=num_columns) for _ in range(num_planes) | ||
] | ||
volumetric_imaging_extractor = VolumetricImagingExtractor(imaging_extractors=imaging_extractors) | ||
assert volumetric_imaging_extractor.get_image_size() == (num_rows, num_columns, num_planes) | ||
|
||
|
||
@pytest.mark.parametrize("num_planes", [1, 2, 3]) | ||
def test_get_num_planes(num_planes): | ||
imaging_extractors = [generate_dummy_imaging_extractor() for _ in range(num_planes)] | ||
volumetric_imaging_extractor = VolumetricImagingExtractor(imaging_extractors=imaging_extractors) | ||
assert volumetric_imaging_extractor.get_num_planes() == num_planes | ||
|
||
|
||
@pytest.mark.parametrize("num_frames", [1, 2, 3]) | ||
def test_get_num_frames(num_frames): | ||
imaging_extractors = [generate_dummy_imaging_extractor(num_frames=num_frames)] | ||
volumetric_imaging_extractor = VolumetricImagingExtractor(imaging_extractors=imaging_extractors) | ||
assert volumetric_imaging_extractor.get_num_frames() == num_frames | ||
|
||
|
||
@pytest.mark.parametrize("sampling_frequency", [1, 2, 3]) | ||
def test_get_sampling_frequency(sampling_frequency): | ||
imaging_extractors = [generate_dummy_imaging_extractor(sampling_frequency=sampling_frequency)] | ||
volumetric_imaging_extractor = VolumetricImagingExtractor(imaging_extractors=imaging_extractors) | ||
assert volumetric_imaging_extractor.get_sampling_frequency() == sampling_frequency | ||
|
||
|
||
@pytest.mark.parametrize("channel_names", [["Channel 1"], [" Channel 1 ", "Channel 2"]]) | ||
def test_get_channel_names(channel_names): | ||
imaging_extractors = [ | ||
generate_dummy_imaging_extractor(channel_names=channel_names, num_channels=len(channel_names)) | ||
] | ||
volumetric_imaging_extractor = VolumetricImagingExtractor(imaging_extractors=imaging_extractors) | ||
assert volumetric_imaging_extractor.get_channel_names() == channel_names | ||
|
||
|
||
@pytest.mark.parametrize("num_channels", [1, 2, 3]) | ||
def test_get_num_channels(num_channels): | ||
imaging_extractors = [generate_dummy_imaging_extractor(num_channels=num_channels)] | ||
volumetric_imaging_extractor = VolumetricImagingExtractor(imaging_extractors=imaging_extractors) | ||
assert volumetric_imaging_extractor.get_num_channels() == num_channels | ||
|
||
|
||
@pytest.mark.parametrize("dtype", [np.float64, np.int16, np.uint8]) | ||
def test_get_dtype(dtype): | ||
imaging_extractors = [generate_dummy_imaging_extractor(dtype=dtype)] | ||
volumetric_imaging_extractor = VolumetricImagingExtractor(imaging_extractors=imaging_extractors) | ||
assert volumetric_imaging_extractor.get_dtype() == dtype |