Skip to content

Commit

Permalink
Add: stack editing class, chunked error detection units
Browse files Browse the repository at this point in the history
  • Loading branch information
t-sasatani committed Dec 6, 2024
1 parent 0b40004 commit 13fe13d
Showing 1 changed file with 105 additions and 28 deletions.
133 changes: 105 additions & 28 deletions miniscope_io/processing/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,44 +2,35 @@
This module contains functions for pre-processing video data.
"""

import copy
from typing import Iterator, Optional, Tuple
from typing import Iterator, Tuple

import cv2
import matplotlib.animation as animation
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.widgets import Button, Slider
from pydantic import BaseModel, Field

from miniscope_io import init_logger

logger = init_logger("video")

import cv2
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider, Button
from matplotlib import animation

def plot_video_streams_with_controls(
video_frames: list[list[np.ndarray] or np.ndarray],
titles: list[str] = None,
fps: int = 20
) -> None:
"""
Plot multiple video streams or static images side-by-side with controls to play/pause and navigate frames.
Plot multiple video streams or static images side-by-side.
Can play/pause and navigate frames.
"""
# Wrap static images in lists to handle them uniformly
video_frames = [frame if isinstance(frame, list) else [frame] for frame in video_frames]

num_streams = len(video_frames)
num_frames = max(len(stream) for stream in video_frames) # Use max to account for static images with 1 frame
num_frames = max(len(stream) for stream in video_frames)

# Initialize plots
fig, axes = plt.subplots(1, num_streams, figsize=(20, 5))

# Initial display of the first frame from each stream
frame_displays = []
for idx, ax in enumerate(axes):
# Adjust static images to display them consistently
Expand All @@ -54,7 +45,6 @@ def plot_video_streams_with_controls(
ax_slider = plt.axes([0.1, 0.1, 0.65, 0.05], facecolor='lightgoldenrodyellow')
slider = Slider(ax=ax_slider, label='Frame', valmin=0, valmax=num_frames - 1, valinit=0, valstep=1)

# Define the play/pause button
playing = [False] # Use a mutable object to track play state
ax_button = plt.axes([0.8, 0.1, 0.1, 0.05])
button = Button(ax_button, 'Play/Pause')
Expand Down Expand Up @@ -168,7 +158,7 @@ def show_frame(frame: np.ndarray) -> None:
def gen_freq_mask(
width: int = 200,
height: int = 200,
center_radius: int = 15,
center_radius: int = 5,
show_mask: bool = True
) -> np.ndarray:
"""
Expand Down Expand Up @@ -220,11 +210,18 @@ def __init__(self,
"""
Initialize the FrameProcessor object.
Block size/buffer size will be set by dev config later.
Parameters:
height (int): Height of the video frame.
width (int): Width of the video frame.
buffer_size (int): Size of the buffer to process.
block_size (int): Size of the blocks to process. Not used now.
"""
self.height = height
self.width = width
self.buffer_size = buffer_size
self.block_size = block_size
self.buffer_split = 1

def split_by_length(
self,
Expand All @@ -233,6 +230,13 @@ def split_by_length(
) -> list[np.ndarray]:
"""
Split an array into sub-arrays of a specified length.
Parameters:
array (np.ndarray): The array to split.
segment_length (int): The length of each sub-array.
Returns:
list[np.ndarray]: A list of sub-arrays.
"""
num_segments = len(array) // segment_length

Expand All @@ -255,12 +259,20 @@ def patch_noisy_buffer(
) -> Tuple[np.ndarray, np.ndarray]:
"""
Process the frame, replacing noisy blocks with those from the previous frame.
Parameters:
current_frame (np.ndarray): The current frame to process.
previous_frame (np.ndarray): The previous frame to compare against.
noise_threshold (float): The threshold for mean error to consider a block noisy.
Returns:
Tuple[np.ndarray, np.ndarray]: The processed frame and the noise patch
"""
serialized_current = current_frame.flatten().astype(np.int16)
serialized_previous = previous_frame.flatten().astype(np.int16)

split_current = self.split_by_length(serialized_current, self.buffer_size)
split_previous = self.split_by_length(serialized_previous, self.buffer_size)
split_current = self.split_by_length(serialized_current, self.buffer_size//5)
split_previous = self.split_by_length(serialized_previous, self.buffer_size//5)

split_output = split_current.copy()
noisy_parts = split_current.copy()
Expand Down Expand Up @@ -295,6 +307,13 @@ def remove_stripes(
)-> np.ndarray:
"""
Perform FFT/IFFT to remove horizontal stripes from a single frame.
Parameters:
img (np.ndarray): The image to process.
mask (np.ndarray): The frequency mask to apply.
Returns:
np.ndarray: The filtered image
"""
f = np.fft.fft2(img)
fshift = np.fft.fftshift(f)
Expand All @@ -311,10 +330,65 @@ def remove_stripes(

return np.uint8(img_back), np.uint8(magnitude_spectrum)

def get_minimum_projection(image_list):
stacked_images = np.stack(image_list, axis=0)
min_projection = np.min(stacked_images, axis=0)
return min_projection
class FrameListProcessor:
"""
A class to process a list of video frames.
"""
@staticmethod
def get_minimum_projection(
image_list: list[np.ndarray]
)-> np.ndarray:
"""
Get the minimum projection of a list of images.
Parameters:
image_list (list[np.ndarray]): A list of images to project.
Returns:
np.ndarray: The minimum projection of the images.
"""
stacked_images = np.stack(image_list, axis=0)
min_projection = np.min(stacked_images, axis=0)
return min_projection

@staticmethod
def normalize_video_stack(
image_list: list[np.ndarray]
) -> list[np.ndarray]:
"""
Normalize a stack of images to 0-255 using max and minimum values throughout the entire stack.
Return a list of images.
Parameters:
image_list (list[np.ndarray]): A list of images to normalize.
Returns:
list[np.ndarray]: The normalized images as a list.
"""

# Stack images along a new axis (axis=0)
stacked_images = np.stack(image_list, axis=0)

# Find the global min and max across the entire stack
global_min = stacked_images.min()
global_max = stacked_images.max()

# Normalize each frame using the global min and max
normalized_images = []
for i in range(stacked_images.shape[0]):
normalized_image = cv2.normalize(
stacked_images[i],
None,
0,
255,
cv2.NORM_MINMAX,
dtype=cv2.CV_32F
)
# Apply global normalization
normalized_image = (stacked_images[i] - global_min) / (global_max - global_min) * 255
normalized_images.append(normalized_image.astype(np.uint8))

return normalized_images

if __name__ == "__main__":
"""
Expand Down Expand Up @@ -360,13 +434,12 @@ def get_minimum_projection(image_list):
processed_frame, noise_patch = processor.patch_noisy_buffer(
gray_frame,
previous_frame,
noise_threshold=20
noise_threshold=15
)
filtered_frame, freq_domain_frame = processor.remove_stripes(
img=processed_frame,
mask=freq_mask
)

)
diff_frame = cv2.absdiff(gray_frame, filtered_frame)
gray_frames.append(gray_frame)
processed_frames.append(processed_frame)
Expand All @@ -378,10 +451,12 @@ def get_minimum_projection(image_list):
reader.release()
plt.close(fig)

minimum_projection = get_minimum_projection(filtered_frames)
show_frame(minimum_projection)
normalized_frames = FrameListProcessor.normalize_video_stack(filtered_frames)
minimum_projection = FrameListProcessor.get_minimum_projection(normalized_frames)

subtract_minimum = [(frame - minimum_projection) for frame in normalized_frames]

subtract_minimum = [(frame - minimum_projection) for frame in filtered_frames]
subtract_minimum = FrameListProcessor.normalize_video_stack(subtract_minimum)

if slider_plot:
video_frames = [
Expand All @@ -392,6 +467,7 @@ def get_minimum_projection(image_list):
freq_mask * 255,
freq_domain_frames,
filtered_frames,
normalized_frames,
minimum_projection,
subtract_minimum,
]
Expand All @@ -405,6 +481,7 @@ def get_minimum_projection(image_list):
'Freq mask',
'Freq domain',
'Freq filtered',
'Normalized',
'Min Proj',
'Subtracted',
]
Expand Down

0 comments on commit 13fe13d

Please sign in to comment.