diff --git a/miniscope_io/processing/video.py b/miniscope_io/processing/video.py index 1a5743fb..5bccd66d 100644 --- a/miniscope_io/processing/video.py +++ b/miniscope_io/processing/video.py @@ -2,44 +2,35 @@ This module contains functions for pre-processing video data. """ -import copy -from typing import Iterator, Optional, Tuple +from typing import Iterator, Tuple import cv2 import matplotlib.animation as animation import matplotlib.pyplot as plt import numpy as np from matplotlib.widgets import Button, Slider -from pydantic import BaseModel, Field from miniscope_io import init_logger logger = init_logger("video") -import cv2 -import numpy as np -import matplotlib.pyplot as plt -from matplotlib.widgets import Slider, Button -from matplotlib import animation - def plot_video_streams_with_controls( video_frames: list[list[np.ndarray] or np.ndarray], titles: list[str] = None, fps: int = 20 ) -> None: """ - Plot multiple video streams or static images side-by-side with controls to play/pause and navigate frames. + Plot multiple video streams or static images side-by-side. + Can play/pause and navigate frames. """ # Wrap static images in lists to handle them uniformly video_frames = [frame if isinstance(frame, list) else [frame] for frame in video_frames] num_streams = len(video_frames) - num_frames = max(len(stream) for stream in video_frames) # Use max to account for static images with 1 frame + num_frames = max(len(stream) for stream in video_frames) - # Initialize plots fig, axes = plt.subplots(1, num_streams, figsize=(20, 5)) - # Initial display of the first frame from each stream frame_displays = [] for idx, ax in enumerate(axes): # Adjust static images to display them consistently @@ -54,7 +45,6 @@ def plot_video_streams_with_controls( ax_slider = plt.axes([0.1, 0.1, 0.65, 0.05], facecolor='lightgoldenrodyellow') slider = Slider(ax=ax_slider, label='Frame', valmin=0, valmax=num_frames - 1, valinit=0, valstep=1) - # Define the play/pause button playing = [False] # Use a mutable object to track play state ax_button = plt.axes([0.8, 0.1, 0.1, 0.05]) button = Button(ax_button, 'Play/Pause') @@ -168,7 +158,7 @@ def show_frame(frame: np.ndarray) -> None: def gen_freq_mask( width: int = 200, height: int = 200, - center_radius: int = 15, + center_radius: int = 5, show_mask: bool = True ) -> np.ndarray: """ @@ -220,11 +210,18 @@ def __init__(self, """ Initialize the FrameProcessor object. Block size/buffer size will be set by dev config later. + + Parameters: + height (int): Height of the video frame. + width (int): Width of the video frame. + buffer_size (int): Size of the buffer to process. + block_size (int): Size of the blocks to process. Not used now. + """ self.height = height self.width = width self.buffer_size = buffer_size - self.block_size = block_size + self.buffer_split = 1 def split_by_length( self, @@ -233,6 +230,13 @@ def split_by_length( ) -> list[np.ndarray]: """ Split an array into sub-arrays of a specified length. + + Parameters: + array (np.ndarray): The array to split. + segment_length (int): The length of each sub-array. + + Returns: + list[np.ndarray]: A list of sub-arrays. """ num_segments = len(array) // segment_length @@ -255,12 +259,20 @@ def patch_noisy_buffer( ) -> Tuple[np.ndarray, np.ndarray]: """ Process the frame, replacing noisy blocks with those from the previous frame. + + Parameters: + current_frame (np.ndarray): The current frame to process. + previous_frame (np.ndarray): The previous frame to compare against. + noise_threshold (float): The threshold for mean error to consider a block noisy. + + Returns: + Tuple[np.ndarray, np.ndarray]: The processed frame and the noise patch """ serialized_current = current_frame.flatten().astype(np.int16) serialized_previous = previous_frame.flatten().astype(np.int16) - split_current = self.split_by_length(serialized_current, self.buffer_size) - split_previous = self.split_by_length(serialized_previous, self.buffer_size) + split_current = self.split_by_length(serialized_current, self.buffer_size//5) + split_previous = self.split_by_length(serialized_previous, self.buffer_size//5) split_output = split_current.copy() noisy_parts = split_current.copy() @@ -295,6 +307,13 @@ def remove_stripes( )-> np.ndarray: """ Perform FFT/IFFT to remove horizontal stripes from a single frame. + + Parameters: + img (np.ndarray): The image to process. + mask (np.ndarray): The frequency mask to apply. + + Returns: + np.ndarray: The filtered image """ f = np.fft.fft2(img) fshift = np.fft.fftshift(f) @@ -311,10 +330,65 @@ def remove_stripes( return np.uint8(img_back), np.uint8(magnitude_spectrum) -def get_minimum_projection(image_list): - stacked_images = np.stack(image_list, axis=0) - min_projection = np.min(stacked_images, axis=0) - return min_projection +class FrameListProcessor: + """ + A class to process a list of video frames. + """ + @staticmethod + def get_minimum_projection( + image_list: list[np.ndarray] + )-> np.ndarray: + """ + Get the minimum projection of a list of images. + + Parameters: + image_list (list[np.ndarray]): A list of images to project. + + Returns: + np.ndarray: The minimum projection of the images. + """ + stacked_images = np.stack(image_list, axis=0) + min_projection = np.min(stacked_images, axis=0) + return min_projection + + @staticmethod + def normalize_video_stack( + image_list: list[np.ndarray] + ) -> list[np.ndarray]: + """ + Normalize a stack of images to 0-255 using max and minimum values throughout the entire stack. + Return a list of images. + + Parameters: + image_list (list[np.ndarray]): A list of images to normalize. + + Returns: + list[np.ndarray]: The normalized images as a list. + """ + + # Stack images along a new axis (axis=0) + stacked_images = np.stack(image_list, axis=0) + + # Find the global min and max across the entire stack + global_min = stacked_images.min() + global_max = stacked_images.max() + + # Normalize each frame using the global min and max + normalized_images = [] + for i in range(stacked_images.shape[0]): + normalized_image = cv2.normalize( + stacked_images[i], + None, + 0, + 255, + cv2.NORM_MINMAX, + dtype=cv2.CV_32F + ) + # Apply global normalization + normalized_image = (stacked_images[i] - global_min) / (global_max - global_min) * 255 + normalized_images.append(normalized_image.astype(np.uint8)) + + return normalized_images if __name__ == "__main__": """ @@ -360,13 +434,12 @@ def get_minimum_projection(image_list): processed_frame, noise_patch = processor.patch_noisy_buffer( gray_frame, previous_frame, - noise_threshold=20 + noise_threshold=15 ) filtered_frame, freq_domain_frame = processor.remove_stripes( img=processed_frame, mask=freq_mask - ) - + ) diff_frame = cv2.absdiff(gray_frame, filtered_frame) gray_frames.append(gray_frame) processed_frames.append(processed_frame) @@ -378,10 +451,12 @@ def get_minimum_projection(image_list): reader.release() plt.close(fig) - minimum_projection = get_minimum_projection(filtered_frames) - show_frame(minimum_projection) + normalized_frames = FrameListProcessor.normalize_video_stack(filtered_frames) + minimum_projection = FrameListProcessor.get_minimum_projection(normalized_frames) + + subtract_minimum = [(frame - minimum_projection) for frame in normalized_frames] - subtract_minimum = [(frame - minimum_projection) for frame in filtered_frames] + subtract_minimum = FrameListProcessor.normalize_video_stack(subtract_minimum) if slider_plot: video_frames = [ @@ -392,6 +467,7 @@ def get_minimum_projection(image_list): freq_mask * 255, freq_domain_frames, filtered_frames, + normalized_frames, minimum_projection, subtract_minimum, ] @@ -405,6 +481,7 @@ def get_minimum_projection(image_list): 'Freq mask', 'Freq domain', 'Freq filtered', + 'Normalized', 'Min Proj', 'Subtracted', ]