diff --git a/miniscope_io/models/frames.py b/miniscope_io/models/frames.py new file mode 100644 index 00000000..990e8436 --- /dev/null +++ b/miniscope_io/models/frames.py @@ -0,0 +1,85 @@ +""" +Pydantic models for storing frames and videos. +""" + +from typing import List, Optional, TypeVar + +import numpy as np +from pydantic import BaseModel, Field, model_validator + +T = TypeVar("T", np.ndarray, List[np.ndarray], List[List[np.ndarray]]) + + +class NamedFrame(BaseModel): + """ + Pydantic model to store an array (frame/video/video list) together with a name. + """ + + name: str = Field( + ..., + description="Name of the video.", + ) + static_frame: Optional[np.ndarray] = Field( + None, + description="Frame data, if provided.", + ) + video_frame: Optional[List[np.ndarray]] = Field( + None, + description="Video data, if provided.", + ) + video_list_frame: Optional[List[List[np.ndarray]]] = Field( + None, + description="List of video data, if provided.", + ) + frame_type: Optional[str] = Field( + None, + description="Type of frame data.", + ) + + @model_validator(mode="before") + def check_frame_type(cls, values: dict) -> dict: + """ + Ensure that exactly one of static_frame, video_frame, or video_list_frame is provided. + """ + static = values.get("static_frame") + video = values.get("video_frame") + video_list = values.get("video_list_frame") + + # Identify which fields are present + present_fields = [ + (field_name, field_value) + for field_name, field_value in zip( + ("static_frame", "video_frame", "video_list_frame"), (static, video, video_list) + ) + if field_value is not None + ] + + if len(present_fields) != 1: + raise ValueError( + "Exactly one of static_frame, video_frame, or video_list_frame must be provided." + ) + + # Record which frame type is present + values["frame_type"] = present_fields[0][0] + + return values + + @property + def data(self) -> T: + """Return the content of the populated field.""" + if self.frame_type == "static_frame": + return self.static_frame + elif self.frame_type == "video_frame": + return self.video_frame + elif self.frame_type == "video_list_frame": + return self.video_list_frame + else: + raise ValueError("Unknown frame type or no frame data provided.") + + class Config: + """ + Pydantic config for allowing np.ndarray types. + Could be an Numpydantic situation so will look into it later. + """ + + arbitrary_types_allowed = True diff --git a/miniscope_io/plots/video.py b/miniscope_io/plots/video.py index 60ccd03b..541f35b6 100644 --- a/miniscope_io/plots/video.py +++ b/miniscope_io/plots/video.py @@ -2,7 +2,7 @@ Plotting functions for video streams and frames. """ -from typing import Union +from typing import List import matplotlib.pyplot as plt import numpy as np @@ -10,6 +10,8 @@ from matplotlib.backend_bases import KeyEvent from matplotlib.widgets import Button, Slider +from miniscope_io.models.frames import NamedFrame + class VideoPlotter: """ @@ -17,25 +19,28 @@ class VideoPlotter: """ @staticmethod - def show_video_with_controls( - video_frames: Union[list[np.ndarray], np.ndarray], titles: list[str] = None, fps: int = 20 - ) -> None: + def show_video_with_controls(videos: List[NamedFrame], fps: int = 20) -> None: """ Plot multiple video streams or static images side-by-side. Can play/pause and navigate frames. Parameters ---------- - video_frames : list[np.ndarray] or np.ndarray - List of video streams or static images to display. - Each element of the list should be a list of frames or a single frame. - titles : list[str], optional - List of titles for each stream, by default None + videos : NamedFrame + NamedFrame object containing video data and names. fps : int, optional - Frames per second for playback, by default 20 + Frames per second for the video, by default 20 """ + + if any(frame.frame_type == "video_list_frame" for frame in videos): + raise NotImplementedError("Only single videos or frames are supported for now.") + # Wrap static images in lists to handle them uniformly - video_frames = [frame if isinstance(frame, list) else [frame] for frame in video_frames] + video_frames = [ + frame.data if frame.frame_type == "video_frame" else [frame.data] for frame in videos + ] + + titles = [video.name for video in videos] num_streams = len(video_frames) num_frames = max(len(stream) for stream in video_frames) diff --git a/miniscope_io/processing/__init__.py b/miniscope_io/process/__init__.py similarity index 100% rename from miniscope_io/processing/__init__.py rename to miniscope_io/process/__init__.py diff --git a/miniscope_io/processing/video.py b/miniscope_io/process/video.py similarity index 84% rename from miniscope_io/processing/video.py rename to miniscope_io/process/video.py index a8c6bfb0..4b013f17 100644 --- a/miniscope_io/processing/video.py +++ b/miniscope_io/process/video.py @@ -9,6 +9,7 @@ import numpy as np from miniscope_io import init_logger +from miniscope_io.models.frames import NamedFrame from miniscope_io.plots.video import VideoPlotter logger = init_logger("video") @@ -42,6 +43,9 @@ def __init__(self, video_path: str): def read_frames(self) -> Iterator[np.ndarray]: """ Read frames from the video file. + + Yields: + np.ndarray: The next frame in the video. """ while self.cap.isOpened(): ret, frame = self.cap.read() @@ -60,18 +64,6 @@ def __del__(self): self.release() -def show_frame(frame: np.ndarray) -> None: - """ - Display a single frame using OpenCV. - """ - cv2.imshow("Mask", frame * np.iinfo(np.uint8).max) - while True: - if cv2.waitKey(1) == 27: # Press 'Esc' key to exit visualization - break - - cv2.destroyAllWindows() - - def gen_freq_mask( width: int, height: int, @@ -96,14 +88,13 @@ def gen_freq_mask( mask[crow - horizontal_BEF : crow + horizontal_BEF, :] = 0 # Define spacial low pass filter - radius = center_LPF y, x = np.ogrid[:height, :width] - center_mask = (x - ccol) ** 2 + (y - crow) ** 2 <= radius**2 + center_mask = (x - ccol) ** 2 + (y - crow) ** 2 <= center_LPF**2 # Restore the center circular area to allow low frequencies to pass mask[center_mask] = 1 - # Visualize the mask if needed + # Visualize the mask if needed. Might delete later. if show_mask: cv2.imshow("Mask", mask * np.iinfo(np.uint8).max) while True: @@ -118,7 +109,7 @@ class FrameProcessor: A class to process video frames. """ - def __init__(self, height: int, width: int, buffer_size: int = 5032, block_size: int = 32): + def __init__(self, height: int, width: int, buffer_size: int = 5032, buffer_split: int = 1): """ Initialize the FrameProcessor object. Block size/buffer size will be set by dev config later. @@ -133,7 +124,7 @@ def __init__(self, height: int, width: int, buffer_size: int = 5032, block_size: self.height = height self.width = width self.buffer_size = buffer_size - self.buffer_split = 1 + self.buffer_split = buffer_split def split_by_length(self, array: np.ndarray, segment_length: int) -> list[np.ndarray]: """ @@ -176,8 +167,12 @@ def patch_noisy_buffer( serialized_current = current_frame.flatten().astype(np.int16) serialized_previous = previous_frame.flatten().astype(np.int16) - split_current = self.split_by_length(serialized_current, self.buffer_size // 5) - split_previous = self.split_by_length(serialized_previous, self.buffer_size // 5) + split_current = self.split_by_length( + serialized_current, self.buffer_size // self.buffer_split + ) + split_previous = self.split_by_length( + serialized_previous, self.buffer_size // self.buffer_split + ) split_output = split_current.copy() noisy_parts = split_current.copy() @@ -301,10 +296,12 @@ def denoise( video_path: str, slider_plot: bool = True, end_frame: int = 100, - noise_threshold: float = 10, - spatial_LPF: int = 5, + noise_threshold: float = 20, + spatial_LPF: int = 10, vertical_BEF: int = 2, horizontal_BEF: int = 0, + diff_mag: int = 10, + buffer_split: int = 1, ) -> None: """ Process a video file and display the results. @@ -324,6 +321,7 @@ def denoise( processor = FrameProcessor( height=reader.height, width=reader.width, + buffer_split=buffer_split, ) freq_mask = gen_freq_mask( @@ -360,7 +358,7 @@ def denoise( freq_domain_frames.append(frame_freq_domain) noise_patchs.append(noise_patch * np.iinfo(np.uint8).max) freq_filtered_frames.append(freq_filtered_frame) - diff_frames.append(diff_frame * 10) + diff_frames.append(diff_frame * diff_mag) index += 1 finally: @@ -374,31 +372,32 @@ def denoise( subtract_minimum = FrameListProcessor.normalize_video_stack(subtract_minimum) + raw_video = NamedFrame(name="RAW", video_frame=raw_frames) + patched_video = NamedFrame(name="Patched", video_frame=patched_frames) + diff_video = NamedFrame(name=f"Diff {diff_mag}x", video_frame=diff_frames) + noise_patch = NamedFrame(name="Noisy area", video_frame=noise_patchs) + freq_mask_frame = NamedFrame( + name="Freq mask", static_frame=freq_mask * np.iinfo(np.uint8).max + ) + freq_domain_video = NamedFrame(name="Freq domain", video_frame=freq_domain_frames) + freq_filtered_video = NamedFrame(name="Freq filtered", video_frame=freq_filtered_frames) + normalized_video = NamedFrame(name="Normalized", video_frame=normalized_frames) + min_proj_frame = NamedFrame(name="Min Proj", static_frame=minimum_projection) + subtract_video = NamedFrame(name="Subtracted", video_frame=subtract_minimum) + if slider_plot: - video_frames = [ - raw_frames, - patched_frames, - diff_frames, - noise_patchs, - freq_mask * np.iinfo(np.uint8).max, - freq_domain_frames, - freq_filtered_frames, - normalized_frames, - minimum_projection, - subtract_minimum, + videos = [ + raw_video, + patched_video, + diff_video, + noise_patch, + freq_mask_frame, + freq_domain_video, + freq_filtered_video, + normalized_video, + min_proj_frame, + subtract_video, ] VideoPlotter.show_video_with_controls( - video_frames, - titles=[ - "RAW", - "Patched", - "Diff", - "Noisy area", - "Freq mask", - "Freq domain", - "Freq filtered", - "Normalized", - "Min Proj", - "Subtracted", - ], + videos, )