Skip to content

Commit

Permalink
Model for frame data handling
Browse files Browse the repository at this point in the history
  • Loading branch information
t-sasatani committed Dec 6, 2024
1 parent 0c07c52 commit 1eb6688
Show file tree
Hide file tree
Showing 4 changed files with 146 additions and 57 deletions.
85 changes: 85 additions & 0 deletions miniscope_io/models/frames.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""
Pydantic models for storing frames and videos.
"""

from typing import List, Optional, TypeVar

import numpy as np
from pydantic import BaseModel, Field, model_validator

T = TypeVar("T", np.ndarray, List[np.ndarray], List[List[np.ndarray]])


class NamedFrame(BaseModel):
"""
Pydantic model to store an array (frame/video/video list) together with a name.
"""

name: str = Field(
...,
description="Name of the video.",
)
static_frame: Optional[np.ndarray] = Field(
None,
description="Frame data, if provided.",
)
video_frame: Optional[List[np.ndarray]] = Field(
None,
description="Video data, if provided.",
)
video_list_frame: Optional[List[List[np.ndarray]]] = Field(
None,
description="List of video data, if provided.",
)
frame_type: Optional[str] = Field(
None,
description="Type of frame data.",
)

@model_validator(mode="before")
def check_frame_type(cls, values: dict) -> dict:
"""
Ensure that exactly one of static_frame, video_frame, or video_list_frame is provided.
"""
static = values.get("static_frame")
video = values.get("video_frame")
video_list = values.get("video_list_frame")

# Identify which fields are present
present_fields = [
(field_name, field_value)
for field_name, field_value in zip(
("static_frame", "video_frame", "video_list_frame"), (static, video, video_list)
)
if field_value is not None
]

if len(present_fields) != 1:
raise ValueError(
"Exactly one of static_frame, video_frame, or video_list_frame must be provided."
)

# Record which frame type is present
values["frame_type"] = present_fields[0][0]

return values

@property
def data(self) -> T:
"""Return the content of the populated field."""
if self.frame_type == "static_frame":
return self.static_frame
elif self.frame_type == "video_frame":
return self.video_frame
elif self.frame_type == "video_list_frame":
return self.video_list_frame
else:
raise ValueError("Unknown frame type or no frame data provided.")

class Config:
"""
Pydantic config for allowing np.ndarray types.
Could be an Numpydantic situation so will look into it later.
"""

arbitrary_types_allowed = True
27 changes: 16 additions & 11 deletions miniscope_io/plots/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,40 +2,45 @@
Plotting functions for video streams and frames.
"""

from typing import Union
from typing import List

import matplotlib.pyplot as plt
import numpy as np
from matplotlib import animation
from matplotlib.backend_bases import KeyEvent
from matplotlib.widgets import Button, Slider

from miniscope_io.models.frames import NamedFrame


class VideoPlotter:
"""
Class to display video streams and static images.
"""

@staticmethod
def show_video_with_controls(
video_frames: Union[list[np.ndarray], np.ndarray], titles: list[str] = None, fps: int = 20
) -> None:
def show_video_with_controls(videos: List[NamedFrame], fps: int = 20) -> None:
"""
Plot multiple video streams or static images side-by-side.
Can play/pause and navigate frames.
Parameters
----------
video_frames : list[np.ndarray] or np.ndarray
List of video streams or static images to display.
Each element of the list should be a list of frames or a single frame.
titles : list[str], optional
List of titles for each stream, by default None
videos : NamedFrame
NamedFrame object containing video data and names.
fps : int, optional
Frames per second for playback, by default 20
Frames per second for the video, by default 20
"""

if any(frame.frame_type == "video_list_frame" for frame in videos):
raise NotImplementedError("Only single videos or frames are supported for now.")

# Wrap static images in lists to handle them uniformly
video_frames = [frame if isinstance(frame, list) else [frame] for frame in video_frames]
video_frames = [
frame.data if frame.frame_type == "video_frame" else [frame.data] for frame in videos
]

titles = [video.name for video in videos]

num_streams = len(video_frames)
num_frames = max(len(stream) for stream in video_frames)
Expand Down
File renamed without changes.
91 changes: 45 additions & 46 deletions miniscope_io/processing/video.py → miniscope_io/process/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import numpy as np

from miniscope_io import init_logger
from miniscope_io.models.frames import NamedFrame
from miniscope_io.plots.video import VideoPlotter

logger = init_logger("video")
Expand Down Expand Up @@ -42,6 +43,9 @@ def __init__(self, video_path: str):
def read_frames(self) -> Iterator[np.ndarray]:
"""
Read frames from the video file.
Yields:
np.ndarray: The next frame in the video.
"""
while self.cap.isOpened():
ret, frame = self.cap.read()
Expand All @@ -60,18 +64,6 @@ def __del__(self):
self.release()


def show_frame(frame: np.ndarray) -> None:
"""
Display a single frame using OpenCV.
"""
cv2.imshow("Mask", frame * np.iinfo(np.uint8).max)
while True:
if cv2.waitKey(1) == 27: # Press 'Esc' key to exit visualization
break

cv2.destroyAllWindows()


def gen_freq_mask(
width: int,
height: int,
Expand All @@ -96,14 +88,13 @@ def gen_freq_mask(
mask[crow - horizontal_BEF : crow + horizontal_BEF, :] = 0

# Define spacial low pass filter
radius = center_LPF
y, x = np.ogrid[:height, :width]
center_mask = (x - ccol) ** 2 + (y - crow) ** 2 <= radius**2
center_mask = (x - ccol) ** 2 + (y - crow) ** 2 <= center_LPF**2

# Restore the center circular area to allow low frequencies to pass
mask[center_mask] = 1

# Visualize the mask if needed
# Visualize the mask if needed. Might delete later.
if show_mask:
cv2.imshow("Mask", mask * np.iinfo(np.uint8).max)
while True:
Expand All @@ -118,7 +109,7 @@ class FrameProcessor:
A class to process video frames.
"""

def __init__(self, height: int, width: int, buffer_size: int = 5032, block_size: int = 32):
def __init__(self, height: int, width: int, buffer_size: int = 5032, buffer_split: int = 1):
"""
Initialize the FrameProcessor object.
Block size/buffer size will be set by dev config later.
Expand All @@ -133,7 +124,7 @@ def __init__(self, height: int, width: int, buffer_size: int = 5032, block_size:
self.height = height
self.width = width
self.buffer_size = buffer_size
self.buffer_split = 1
self.buffer_split = buffer_split

def split_by_length(self, array: np.ndarray, segment_length: int) -> list[np.ndarray]:
"""
Expand Down Expand Up @@ -176,8 +167,12 @@ def patch_noisy_buffer(
serialized_current = current_frame.flatten().astype(np.int16)
serialized_previous = previous_frame.flatten().astype(np.int16)

split_current = self.split_by_length(serialized_current, self.buffer_size // 5)
split_previous = self.split_by_length(serialized_previous, self.buffer_size // 5)
split_current = self.split_by_length(
serialized_current, self.buffer_size // self.buffer_split
)
split_previous = self.split_by_length(
serialized_previous, self.buffer_size // self.buffer_split
)

split_output = split_current.copy()
noisy_parts = split_current.copy()
Expand Down Expand Up @@ -301,10 +296,12 @@ def denoise(
video_path: str,
slider_plot: bool = True,
end_frame: int = 100,
noise_threshold: float = 10,
spatial_LPF: int = 5,
noise_threshold: float = 20,
spatial_LPF: int = 10,
vertical_BEF: int = 2,
horizontal_BEF: int = 0,
diff_mag: int = 10,
buffer_split: int = 1,
) -> None:
"""
Process a video file and display the results.
Expand All @@ -324,6 +321,7 @@ def denoise(
processor = FrameProcessor(
height=reader.height,
width=reader.width,
buffer_split=buffer_split,
)

freq_mask = gen_freq_mask(
Expand Down Expand Up @@ -360,7 +358,7 @@ def denoise(
freq_domain_frames.append(frame_freq_domain)
noise_patchs.append(noise_patch * np.iinfo(np.uint8).max)
freq_filtered_frames.append(freq_filtered_frame)
diff_frames.append(diff_frame * 10)
diff_frames.append(diff_frame * diff_mag)

index += 1
finally:
Expand All @@ -374,31 +372,32 @@ def denoise(

subtract_minimum = FrameListProcessor.normalize_video_stack(subtract_minimum)

raw_video = NamedFrame(name="RAW", video_frame=raw_frames)
patched_video = NamedFrame(name="Patched", video_frame=patched_frames)
diff_video = NamedFrame(name=f"Diff {diff_mag}x", video_frame=diff_frames)
noise_patch = NamedFrame(name="Noisy area", video_frame=noise_patchs)
freq_mask_frame = NamedFrame(
name="Freq mask", static_frame=freq_mask * np.iinfo(np.uint8).max
)
freq_domain_video = NamedFrame(name="Freq domain", video_frame=freq_domain_frames)
freq_filtered_video = NamedFrame(name="Freq filtered", video_frame=freq_filtered_frames)
normalized_video = NamedFrame(name="Normalized", video_frame=normalized_frames)
min_proj_frame = NamedFrame(name="Min Proj", static_frame=minimum_projection)
subtract_video = NamedFrame(name="Subtracted", video_frame=subtract_minimum)

if slider_plot:
video_frames = [
raw_frames,
patched_frames,
diff_frames,
noise_patchs,
freq_mask * np.iinfo(np.uint8).max,
freq_domain_frames,
freq_filtered_frames,
normalized_frames,
minimum_projection,
subtract_minimum,
videos = [
raw_video,
patched_video,
diff_video,
noise_patch,
freq_mask_frame,
freq_domain_video,
freq_filtered_video,
normalized_video,
min_proj_frame,
subtract_video,
]
VideoPlotter.show_video_with_controls(
video_frames,
titles=[
"RAW",
"Patched",
"Diff",
"Noisy area",
"Freq mask",
"Freq domain",
"Freq filtered",
"Normalized",
"Min Proj",
"Subtracted",
],
videos,
)

0 comments on commit 1eb6688

Please sign in to comment.