Skip to content

Commit

Permalink
Add: init video preprocessing
Browse files Browse the repository at this point in the history
  • Loading branch information
t-sasatani committed Dec 5, 2024
1 parent 87487a1 commit 28fd048
Show file tree
Hide file tree
Showing 3 changed files with 221 additions and 22 deletions.
3 changes: 3 additions & 0 deletions miniscope_io/processing/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"""
Pre-processing module for miniscope data.
"""
196 changes: 196 additions & 0 deletions miniscope_io/processing/video.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
"""
This module contains functions for pre-processing video data.
"""

from typing import Iterator

import cv2
import matplotlib.pyplot as plt
import numpy as np

from miniscope_io import init_logger

logger = init_logger("video")

class VideoReader:
"""
A class to read video files.
"""
def __init__(self, video_path: str):
"""
Initialize the VideoReader object.
"""
self.video_path = video_path
self.cap = cv2.VideoCapture(str(video_path))

if not self.cap.isOpened():
raise ValueError(f"Could not open video at {video_path}")

logger.info(f"Opened video at {video_path}")

def read_frames(self) -> Iterator[np.ndarray]:
"""
Read frames from the video file.
"""
while self.cap.isOpened():
ret, frame = self.cap.read()
logger.debug(f"Reading frame {self.cap.get(cv2.CAP_PROP_POS_FRAMES)}")
if not ret:
break
yield frame

def release(self)-> None:
"""
Release the video capture object.
"""
self.cap.release()

def __del__(self):
self.release()

def make_mask(img):
"""
Create a mask to filter out horizontal and vertical frequencies except for a central circular region.
"""
rows, cols = img.shape
crow, ccol = rows // 2, cols // 2

# Create an initial mask filled with ones (allowing all frequencies)
mask = np.ones((rows, cols), np.uint8)

# Define band widths for vertical and horizontal suppression
vertical_band_width = 5
horizontal_band_width = 0

# Zero out a vertical stripe at the frequency center
mask[:, ccol - vertical_band_width:ccol + vertical_band_width] = 0

# Zero out a horizontal stripe at the frequency center
mask[crow - horizontal_band_width:crow + horizontal_band_width, :] = 0

# Define the radius of the circular region to retain at the center
radius = 6
y, x = np.ogrid[:rows, :cols]
center_mask = (x - ccol) ** 2 + (y - crow) ** 2 <= radius ** 2

# Restore the center circular area to allow low frequencies to pass
mask[center_mask] = 1

# Visualize the mask if needed
cv2.imshow('Mask', mask * 255)
while True:
if cv2.waitKey(1) == 27: # Press 'Esc' key to exit visualization
break
cv2.destroyAllWindows()

return mask

class FrameProcessor:
"""
A class to process video frames.
"""
def __init__(self, img, block_size=400):
"""
Initialize the FrameProcessor object.
"""
self.mask = make_mask(img)
self.previous_frame = img
self.block_size = block_size


def is_block_noisy(
self, block: np.ndarray, prev_block: np.ndarray, noise_threshold: float
) -> bool:
"""
Determine if a block is noisy by comparing it with the previous frame's block.
"""
# Calculate the mean squared difference between the current and previous blocks
block_diff = cv2.absdiff(block, prev_block)
mean_diff = np.mean(block_diff)

# Consider noisy if the mean difference exceeds the threshold
return mean_diff > noise_threshold

def process_frame(
self, current_frame: np.ndarray, noise_threshold: float
) -> np.ndarray:
"""
Process the frame, replacing noisy blocks with those from the previous frame.
Args:
current_frame: The current frame to process.
noise_threshold: The threshold for determining noisy blocks.
"""
processed_frame = np.copy(current_frame)
h, w = current_frame.shape

for y in range(0, h, self.block_size):
for x in range(0, w, self.block_size):
current_block = current_frame[y:y+self.block_size, x:x+self.block_size]
prev_block = self.previous_frame[y:y+self.block_size, x:x+self.block_size]

if self.is_block_noisy(current_block, prev_block, noise_threshold):
# Replace current block with previous block if noisy
processed_frame[y:y+self.block_size, x:x+self.block_size] = prev_block

self.previous_frame = processed_frame
return processed_frame

def remove_stripes(self, img):
"""Perform FFT/IFFT to remove horizontal stripes from a single frame."""
f = np.fft.fft2(img)
fshift = np.fft.fftshift(f)
magnitude_spectrum = np.log(np.abs(fshift) + 1) # Use log for better visualization

# Normalize the magnitude spectrum for visualization
magnitude_spectrum = cv2.normalize(magnitude_spectrum, None, 0, 255, cv2.NORM_MINMAX)

if False:
# Display the magnitude spectrum
cv2.imshow('Magnitude Spectrum', np.uint8(magnitude_spectrum))
while True:
if cv2.waitKey(1) == 27: # Press 'Esc' key to exit visualization
break
cv2.destroyAllWindows()
logger.debug(f"FFT shape: {fshift.shape}")

# Apply mask and inverse FFT
fshift *= self.mask
f_ishift = np.fft.ifftshift(fshift)
img_back = np.fft.ifft2(f_ishift)
img_back = np.abs(img_back)

# Normalize the result:
img_back = cv2.normalize(img_back, None, 0, 255, cv2.NORM_MINMAX)

return np.uint8(img_back) # Convert to 8-bit image for display and storage

if __name__ == "__main__":
video_path = 'output_001_test.avi'
reader = VideoReader(video_path)

frames = []
index = 0

try:
for frame in reader.read_frames():
gray_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
if index == 0:
processor = FrameProcessor(gray_frame)
if index > 100:
break
index += 1
logger.info(f"Processing frame {index}")
processed_frame = processor.process_frame(gray_frame, noise_threshold=10)
filtered_frame = processor.remove_stripes(processed_frame)
frames.append(filtered_frame)

# Display the results for visualization
for frame in frames:
cv2.imshow('Video', frame)
if cv2.waitKey(100) & 0xFF == ord('q'):
break

finally:
reader.release()
cv2.destroyAllWindows()
44 changes: 22 additions & 22 deletions pdm.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 28fd048

Please sign in to comment.