-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
87487a1
commit 28fd048
Showing
3 changed files
with
221 additions
and
22 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
""" | ||
Pre-processing module for miniscope data. | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,196 @@ | ||
""" | ||
This module contains functions for pre-processing video data. | ||
""" | ||
|
||
from typing import Iterator | ||
|
||
import cv2 | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
|
||
from miniscope_io import init_logger | ||
|
||
logger = init_logger("video") | ||
|
||
class VideoReader: | ||
""" | ||
A class to read video files. | ||
""" | ||
def __init__(self, video_path: str): | ||
""" | ||
Initialize the VideoReader object. | ||
""" | ||
self.video_path = video_path | ||
self.cap = cv2.VideoCapture(str(video_path)) | ||
|
||
if not self.cap.isOpened(): | ||
raise ValueError(f"Could not open video at {video_path}") | ||
|
||
logger.info(f"Opened video at {video_path}") | ||
|
||
def read_frames(self) -> Iterator[np.ndarray]: | ||
""" | ||
Read frames from the video file. | ||
""" | ||
while self.cap.isOpened(): | ||
ret, frame = self.cap.read() | ||
logger.debug(f"Reading frame {self.cap.get(cv2.CAP_PROP_POS_FRAMES)}") | ||
if not ret: | ||
break | ||
yield frame | ||
|
||
def release(self)-> None: | ||
""" | ||
Release the video capture object. | ||
""" | ||
self.cap.release() | ||
|
||
def __del__(self): | ||
self.release() | ||
|
||
def make_mask(img): | ||
""" | ||
Create a mask to filter out horizontal and vertical frequencies except for a central circular region. | ||
""" | ||
rows, cols = img.shape | ||
crow, ccol = rows // 2, cols // 2 | ||
|
||
# Create an initial mask filled with ones (allowing all frequencies) | ||
mask = np.ones((rows, cols), np.uint8) | ||
|
||
# Define band widths for vertical and horizontal suppression | ||
vertical_band_width = 5 | ||
horizontal_band_width = 0 | ||
|
||
# Zero out a vertical stripe at the frequency center | ||
mask[:, ccol - vertical_band_width:ccol + vertical_band_width] = 0 | ||
|
||
# Zero out a horizontal stripe at the frequency center | ||
mask[crow - horizontal_band_width:crow + horizontal_band_width, :] = 0 | ||
|
||
# Define the radius of the circular region to retain at the center | ||
radius = 6 | ||
y, x = np.ogrid[:rows, :cols] | ||
center_mask = (x - ccol) ** 2 + (y - crow) ** 2 <= radius ** 2 | ||
|
||
# Restore the center circular area to allow low frequencies to pass | ||
mask[center_mask] = 1 | ||
|
||
# Visualize the mask if needed | ||
cv2.imshow('Mask', mask * 255) | ||
while True: | ||
if cv2.waitKey(1) == 27: # Press 'Esc' key to exit visualization | ||
break | ||
cv2.destroyAllWindows() | ||
|
||
return mask | ||
|
||
class FrameProcessor: | ||
""" | ||
A class to process video frames. | ||
""" | ||
def __init__(self, img, block_size=400): | ||
""" | ||
Initialize the FrameProcessor object. | ||
""" | ||
self.mask = make_mask(img) | ||
self.previous_frame = img | ||
self.block_size = block_size | ||
|
||
|
||
def is_block_noisy( | ||
self, block: np.ndarray, prev_block: np.ndarray, noise_threshold: float | ||
) -> bool: | ||
""" | ||
Determine if a block is noisy by comparing it with the previous frame's block. | ||
""" | ||
# Calculate the mean squared difference between the current and previous blocks | ||
block_diff = cv2.absdiff(block, prev_block) | ||
mean_diff = np.mean(block_diff) | ||
|
||
# Consider noisy if the mean difference exceeds the threshold | ||
return mean_diff > noise_threshold | ||
|
||
def process_frame( | ||
self, current_frame: np.ndarray, noise_threshold: float | ||
) -> np.ndarray: | ||
""" | ||
Process the frame, replacing noisy blocks with those from the previous frame. | ||
Args: | ||
current_frame: The current frame to process. | ||
noise_threshold: The threshold for determining noisy blocks. | ||
""" | ||
processed_frame = np.copy(current_frame) | ||
h, w = current_frame.shape | ||
|
||
for y in range(0, h, self.block_size): | ||
for x in range(0, w, self.block_size): | ||
current_block = current_frame[y:y+self.block_size, x:x+self.block_size] | ||
prev_block = self.previous_frame[y:y+self.block_size, x:x+self.block_size] | ||
|
||
if self.is_block_noisy(current_block, prev_block, noise_threshold): | ||
# Replace current block with previous block if noisy | ||
processed_frame[y:y+self.block_size, x:x+self.block_size] = prev_block | ||
|
||
self.previous_frame = processed_frame | ||
return processed_frame | ||
|
||
def remove_stripes(self, img): | ||
"""Perform FFT/IFFT to remove horizontal stripes from a single frame.""" | ||
f = np.fft.fft2(img) | ||
fshift = np.fft.fftshift(f) | ||
magnitude_spectrum = np.log(np.abs(fshift) + 1) # Use log for better visualization | ||
|
||
# Normalize the magnitude spectrum for visualization | ||
magnitude_spectrum = cv2.normalize(magnitude_spectrum, None, 0, 255, cv2.NORM_MINMAX) | ||
|
||
if False: | ||
# Display the magnitude spectrum | ||
cv2.imshow('Magnitude Spectrum', np.uint8(magnitude_spectrum)) | ||
while True: | ||
if cv2.waitKey(1) == 27: # Press 'Esc' key to exit visualization | ||
break | ||
cv2.destroyAllWindows() | ||
logger.debug(f"FFT shape: {fshift.shape}") | ||
|
||
# Apply mask and inverse FFT | ||
fshift *= self.mask | ||
f_ishift = np.fft.ifftshift(fshift) | ||
img_back = np.fft.ifft2(f_ishift) | ||
img_back = np.abs(img_back) | ||
|
||
# Normalize the result: | ||
img_back = cv2.normalize(img_back, None, 0, 255, cv2.NORM_MINMAX) | ||
|
||
return np.uint8(img_back) # Convert to 8-bit image for display and storage | ||
|
||
if __name__ == "__main__": | ||
video_path = 'output_001_test.avi' | ||
reader = VideoReader(video_path) | ||
|
||
frames = [] | ||
index = 0 | ||
|
||
try: | ||
for frame in reader.read_frames(): | ||
gray_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) | ||
if index == 0: | ||
processor = FrameProcessor(gray_frame) | ||
if index > 100: | ||
break | ||
index += 1 | ||
logger.info(f"Processing frame {index}") | ||
processed_frame = processor.process_frame(gray_frame, noise_threshold=10) | ||
filtered_frame = processor.remove_stripes(processed_frame) | ||
frames.append(filtered_frame) | ||
|
||
# Display the results for visualization | ||
for frame in frames: | ||
cv2.imshow('Video', frame) | ||
if cv2.waitKey(100) & 0xFF == ord('q'): | ||
break | ||
|
||
finally: | ||
reader.release() | ||
cv2.destroyAllWindows() |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.