From 28fd04869007d325b2bb82a297d4e2e7ba7b0292 Mon Sep 17 00:00:00 2001 From: t-sasatani <33111879+t-sasatani@users.noreply.github.com> Date: Wed, 4 Dec 2024 19:04:18 -0800 Subject: [PATCH] Add: init video preprocessing --- miniscope_io/processing/__init__.py | 3 + miniscope_io/processing/video.py | 196 ++++++++++++++++++++++++++++ pdm.lock | 44 +++---- 3 files changed, 221 insertions(+), 22 deletions(-) create mode 100644 miniscope_io/processing/__init__.py create mode 100644 miniscope_io/processing/video.py diff --git a/miniscope_io/processing/__init__.py b/miniscope_io/processing/__init__.py new file mode 100644 index 00000000..032d6d9e --- /dev/null +++ b/miniscope_io/processing/__init__.py @@ -0,0 +1,3 @@ +""" +Pre-processing module for miniscope data. +""" diff --git a/miniscope_io/processing/video.py b/miniscope_io/processing/video.py new file mode 100644 index 00000000..2cdce59f --- /dev/null +++ b/miniscope_io/processing/video.py @@ -0,0 +1,196 @@ +""" +This module contains functions for pre-processing video data. +""" + +from typing import Iterator + +import cv2 +import matplotlib.pyplot as plt +import numpy as np + +from miniscope_io import init_logger + +logger = init_logger("video") + +class VideoReader: + """ + A class to read video files. + """ + def __init__(self, video_path: str): + """ + Initialize the VideoReader object. + """ + self.video_path = video_path + self.cap = cv2.VideoCapture(str(video_path)) + + if not self.cap.isOpened(): + raise ValueError(f"Could not open video at {video_path}") + + logger.info(f"Opened video at {video_path}") + + def read_frames(self) -> Iterator[np.ndarray]: + """ + Read frames from the video file. + """ + while self.cap.isOpened(): + ret, frame = self.cap.read() + logger.debug(f"Reading frame {self.cap.get(cv2.CAP_PROP_POS_FRAMES)}") + if not ret: + break + yield frame + + def release(self)-> None: + """ + Release the video capture object. + """ + self.cap.release() + + def __del__(self): + self.release() + +def make_mask(img): + """ + Create a mask to filter out horizontal and vertical frequencies except for a central circular region. + """ + rows, cols = img.shape + crow, ccol = rows // 2, cols // 2 + + # Create an initial mask filled with ones (allowing all frequencies) + mask = np.ones((rows, cols), np.uint8) + + # Define band widths for vertical and horizontal suppression + vertical_band_width = 5 + horizontal_band_width = 0 + + # Zero out a vertical stripe at the frequency center + mask[:, ccol - vertical_band_width:ccol + vertical_band_width] = 0 + + # Zero out a horizontal stripe at the frequency center + mask[crow - horizontal_band_width:crow + horizontal_band_width, :] = 0 + + # Define the radius of the circular region to retain at the center + radius = 6 + y, x = np.ogrid[:rows, :cols] + center_mask = (x - ccol) ** 2 + (y - crow) ** 2 <= radius ** 2 + + # Restore the center circular area to allow low frequencies to pass + mask[center_mask] = 1 + + # Visualize the mask if needed + cv2.imshow('Mask', mask * 255) + while True: + if cv2.waitKey(1) == 27: # Press 'Esc' key to exit visualization + break + cv2.destroyAllWindows() + + return mask + +class FrameProcessor: + """ + A class to process video frames. + """ + def __init__(self, img, block_size=400): + """ + Initialize the FrameProcessor object. + """ + self.mask = make_mask(img) + self.previous_frame = img + self.block_size = block_size + + + def is_block_noisy( + self, block: np.ndarray, prev_block: np.ndarray, noise_threshold: float + ) -> bool: + """ + Determine if a block is noisy by comparing it with the previous frame's block. + """ + # Calculate the mean squared difference between the current and previous blocks + block_diff = cv2.absdiff(block, prev_block) + mean_diff = np.mean(block_diff) + + # Consider noisy if the mean difference exceeds the threshold + return mean_diff > noise_threshold + + def process_frame( + self, current_frame: np.ndarray, noise_threshold: float + ) -> np.ndarray: + """ + Process the frame, replacing noisy blocks with those from the previous frame. + + Args: + current_frame: The current frame to process. + noise_threshold: The threshold for determining noisy blocks. + """ + processed_frame = np.copy(current_frame) + h, w = current_frame.shape + + for y in range(0, h, self.block_size): + for x in range(0, w, self.block_size): + current_block = current_frame[y:y+self.block_size, x:x+self.block_size] + prev_block = self.previous_frame[y:y+self.block_size, x:x+self.block_size] + + if self.is_block_noisy(current_block, prev_block, noise_threshold): + # Replace current block with previous block if noisy + processed_frame[y:y+self.block_size, x:x+self.block_size] = prev_block + + self.previous_frame = processed_frame + return processed_frame + + def remove_stripes(self, img): + """Perform FFT/IFFT to remove horizontal stripes from a single frame.""" + f = np.fft.fft2(img) + fshift = np.fft.fftshift(f) + magnitude_spectrum = np.log(np.abs(fshift) + 1) # Use log for better visualization + + # Normalize the magnitude spectrum for visualization + magnitude_spectrum = cv2.normalize(magnitude_spectrum, None, 0, 255, cv2.NORM_MINMAX) + + if False: + # Display the magnitude spectrum + cv2.imshow('Magnitude Spectrum', np.uint8(magnitude_spectrum)) + while True: + if cv2.waitKey(1) == 27: # Press 'Esc' key to exit visualization + break + cv2.destroyAllWindows() + logger.debug(f"FFT shape: {fshift.shape}") + + # Apply mask and inverse FFT + fshift *= self.mask + f_ishift = np.fft.ifftshift(fshift) + img_back = np.fft.ifft2(f_ishift) + img_back = np.abs(img_back) + + # Normalize the result: + img_back = cv2.normalize(img_back, None, 0, 255, cv2.NORM_MINMAX) + + return np.uint8(img_back) # Convert to 8-bit image for display and storage + +if __name__ == "__main__": + video_path = 'output_001_test.avi' + reader = VideoReader(video_path) + + frames = [] + index = 0 + + try: + for frame in reader.read_frames(): + gray_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) + if index == 0: + processor = FrameProcessor(gray_frame) + if index > 100: + break + index += 1 + logger.info(f"Processing frame {index}") + processed_frame = processor.process_frame(gray_frame, noise_threshold=10) + filtered_frame = processor.remove_stripes(processed_frame) + frames.append(filtered_frame) + + # Display the results for visualization + for frame in frames: + cv2.imshow('Video', frame) + if cv2.waitKey(100) & 0xFF == ord('q'): + break + + finally: + reader.release() + cv2.destroyAllWindows() \ No newline at end of file diff --git a/pdm.lock b/pdm.lock index 3008efc3..4a533fdb 100644 --- a/pdm.lock +++ b/pdm.lock @@ -1722,29 +1722,29 @@ files = [ [[package]] name = "ruff" -version = "0.7.2" +version = "0.7.4" requires_python = ">=3.7" summary = "An extremely fast Python linter and code formatter, written in Rust." groups = ["all", "dev"] files = [ - {file = "ruff-0.7.2-py3-none-linux_armv6l.whl", hash = "sha256:b73f873b5f52092e63ed540adefc3c36f1f803790ecf2590e1df8bf0a9f72cb8"}, - {file = "ruff-0.7.2-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:5b813ef26db1015953daf476202585512afd6a6862a02cde63f3bafb53d0b2d4"}, - {file = "ruff-0.7.2-py3-none-macosx_11_0_arm64.whl", hash = "sha256:853277dbd9675810c6826dad7a428d52a11760744508340e66bf46f8be9701d9"}, - {file = "ruff-0.7.2-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:21aae53ab1490a52bf4e3bf520c10ce120987b047c494cacf4edad0ba0888da2"}, - {file = "ruff-0.7.2-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ccc7e0fc6e0cb3168443eeadb6445285abaae75142ee22b2b72c27d790ab60ba"}, - {file = "ruff-0.7.2-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fd77877a4e43b3a98e5ef4715ba3862105e299af0c48942cc6d51ba3d97dc859"}, - {file = "ruff-0.7.2-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:e00163fb897d35523c70d71a46fbaa43bf7bf9af0f4534c53ea5b96b2e03397b"}, - {file = "ruff-0.7.2-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f3c54b538633482dc342e9b634d91168fe8cc56b30a4b4f99287f4e339103e88"}, - {file = "ruff-0.7.2-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7b792468e9804a204be221b14257566669d1db5c00d6bb335996e5cd7004ba80"}, - {file = "ruff-0.7.2-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dba53ed84ac19ae4bfb4ea4bf0172550a2285fa27fbb13e3746f04c80f7fa088"}, - {file = "ruff-0.7.2-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:b19fafe261bf741bca2764c14cbb4ee1819b67adb63ebc2db6401dcd652e3748"}, - {file = "ruff-0.7.2-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:28bd8220f4d8f79d590db9e2f6a0674f75ddbc3847277dd44ac1f8d30684b828"}, - {file = "ruff-0.7.2-py3-none-musllinux_1_2_i686.whl", hash = "sha256:9fd67094e77efbea932e62b5d2483006154794040abb3a5072e659096415ae1e"}, - {file = "ruff-0.7.2-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:576305393998b7bd6c46018f8104ea3a9cb3fa7908c21d8580e3274a3b04b691"}, - {file = "ruff-0.7.2-py3-none-win32.whl", hash = "sha256:fa993cfc9f0ff11187e82de874dfc3611df80852540331bc85c75809c93253a8"}, - {file = "ruff-0.7.2-py3-none-win_amd64.whl", hash = "sha256:dd8800cbe0254e06b8fec585e97554047fb82c894973f7ff18558eee33d1cb88"}, - {file = "ruff-0.7.2-py3-none-win_arm64.whl", hash = "sha256:bb8368cd45bba3f57bb29cbb8d64b4a33f8415d0149d2655c5c8539452ce7760"}, - {file = "ruff-0.7.2.tar.gz", hash = "sha256:2b14e77293380e475b4e3a7a368e14549288ed2931fce259a6f99978669e844f"}, + {file = "ruff-0.7.4-py3-none-linux_armv6l.whl", hash = "sha256:a4919925e7684a3f18e18243cd6bea7cfb8e968a6eaa8437971f681b7ec51478"}, + {file = "ruff-0.7.4-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:cfb365c135b830778dda8c04fb7d4280ed0b984e1aec27f574445231e20d6c63"}, + {file = "ruff-0.7.4-py3-none-macosx_11_0_arm64.whl", hash = "sha256:63a569b36bc66fbadec5beaa539dd81e0527cb258b94e29e0531ce41bacc1f20"}, + {file = "ruff-0.7.4-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0d06218747d361d06fd2fdac734e7fa92df36df93035db3dc2ad7aa9852cb109"}, + {file = "ruff-0.7.4-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e0cea28d0944f74ebc33e9f934238f15c758841f9f5edd180b5315c203293452"}, + {file = "ruff-0.7.4-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:80094ecd4793c68b2571b128f91754d60f692d64bc0d7272ec9197fdd09bf9ea"}, + {file = "ruff-0.7.4-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:997512325c6620d1c4c2b15db49ef59543ef9cd0f4aa8065ec2ae5103cedc7e7"}, + {file = "ruff-0.7.4-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:00b4cf3a6b5fad6d1a66e7574d78956bbd09abfd6c8a997798f01f5da3d46a05"}, + {file = "ruff-0.7.4-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7dbdc7d8274e1422722933d1edddfdc65b4336abf0b16dfcb9dedd6e6a517d06"}, + {file = "ruff-0.7.4-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0e92dfb5f00eaedb1501b2f906ccabfd67b2355bdf117fea9719fc99ac2145bc"}, + {file = "ruff-0.7.4-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:3bd726099f277d735dc38900b6a8d6cf070f80828877941983a57bca1cd92172"}, + {file = "ruff-0.7.4-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:2e32829c429dd081ee5ba39aef436603e5b22335c3d3fff013cd585806a6486a"}, + {file = "ruff-0.7.4-py3-none-musllinux_1_2_i686.whl", hash = "sha256:662a63b4971807623f6f90c1fb664613f67cc182dc4d991471c23c541fee62dd"}, + {file = "ruff-0.7.4-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:876f5e09eaae3eb76814c1d3b68879891d6fde4824c015d48e7a7da4cf066a3a"}, + {file = "ruff-0.7.4-py3-none-win32.whl", hash = "sha256:75c53f54904be42dd52a548728a5b572344b50d9b2873d13a3f8c5e3b91f5cac"}, + {file = "ruff-0.7.4-py3-none-win_amd64.whl", hash = "sha256:745775c7b39f914238ed1f1b0bebed0b9155a17cd8bc0b08d3c87e4703b990d6"}, + {file = "ruff-0.7.4-py3-none-win_arm64.whl", hash = "sha256:11bff065102c3ae9d3ea4dc9ecdfe5a5171349cdd0787c1fc64761212fc9cf1f"}, + {file = "ruff-0.7.4.tar.gz", hash = "sha256:cd12e35031f5af6b9b93715d8c4f40360070b2041f81273d0527683d5708fce2"}, ] [[package]] @@ -1920,7 +1920,7 @@ files = [ [[package]] name = "tqdm" -version = "4.66.6" +version = "4.67.0" requires_python = ">=3.7" summary = "Fast, Extensible Progress Meter" groups = ["default"] @@ -1928,8 +1928,8 @@ dependencies = [ "colorama; platform_system == \"Windows\"", ] files = [ - {file = "tqdm-4.66.6-py3-none-any.whl", hash = "sha256:223e8b5359c2efc4b30555531f09e9f2f3589bcd7fdd389271191031b49b7a63"}, - {file = "tqdm-4.66.6.tar.gz", hash = "sha256:4bdd694238bef1485ce839d67967ab50af8f9272aab687c0d7702a01da0be090"}, + {file = "tqdm-4.67.0-py3-none-any.whl", hash = "sha256:0cd8af9d56911acab92182e88d763100d4788bdf421d251616040cc4d44863be"}, + {file = "tqdm-4.67.0.tar.gz", hash = "sha256:fe5a6f95e6fe0b9755e9469b77b9c3cf850048224ecaa8293d7d2d31f97d869a"}, ] [[package]]