From 372c9bd417769c2e01a059996bb601f2c7c57b53 Mon Sep 17 00:00:00 2001 From: t-sasatani <33111879+t-sasatani@users.noreply.github.com> Date: Wed, 4 Dec 2024 19:04:18 -0800 Subject: [PATCH 01/20] Add: initial video preprocessing --- miniscope_io/processing/__init__.py | 3 + miniscope_io/processing/video.py | 367 ++++++++++++++++++++++++++++ pdm.lock | 44 ++-- 3 files changed, 392 insertions(+), 22 deletions(-) create mode 100644 miniscope_io/processing/__init__.py create mode 100644 miniscope_io/processing/video.py diff --git a/miniscope_io/processing/__init__.py b/miniscope_io/processing/__init__.py new file mode 100644 index 00000000..032d6d9e --- /dev/null +++ b/miniscope_io/processing/__init__.py @@ -0,0 +1,3 @@ +""" +Pre-processing module for miniscope data. +""" diff --git a/miniscope_io/processing/video.py b/miniscope_io/processing/video.py new file mode 100644 index 00000000..9adeefe0 --- /dev/null +++ b/miniscope_io/processing/video.py @@ -0,0 +1,367 @@ +""" +This module contains functions for pre-processing video data. +""" + +import copy +from typing import Iterator, Optional + +import cv2 +import matplotlib.pyplot as plt +import numpy as np +from pydantic import BaseModel, Field + +from miniscope_io import init_logger + +logger = init_logger("video") + +def serialize_image(image: np.ndarray) -> np.ndarray: + """ + Serializes a 2D image into a 1D array. + """ + return image.flatten() + +def deserialize_image( + serialized_image: np.ndarray, + height: int, + width: int)-> np.ndarray: + """ + Deserializes a 1D array back into a 2D image. + """ + return serialized_image.reshape((height, width)) + +def detect_noisy_parts( + current_frame: np.ndarray, + previous_frame: np.ndarray, + noise_threshold: float, + buffer_size: int, + block_size: int = 32 + ) -> np.ndarray: + """ + Detect noisy parts in the current frame by comparing it with the previous frame. + """ + current_frame_serialized = serialize_image(current_frame) + previous_frame_serialized = serialize_image(previous_frame) + + + + noisy_parts = np.zeros_like(current_frame_serialized) + noize_mask = deserialize_image(noisy_parts, current_frame.shape[0], current_frame.shape[1]) + +def plot_frames_side_by_side( + fig: plt.Figure, + frames: list[np.ndarray], + titles: str =None + ) -> None: + """ + Plot a list of frames side by side using matplotlib. + + :param frames: List of frames (images) to be plotted + :param titles: Optional list of titles for each subplot + """ + num_frames = len(frames) + plt.clf() # Clear current figure + + for i, frame in enumerate(frames): + plt.subplot(1, num_frames, i + 1) + plt.imshow(frame, cmap='gray') + if titles: + plt.title(titles[i]) + + plt.axis('off') # Turn off axis labels + + plt.tight_layout() + fig.canvas.draw() +class AnnotatedFrameModel(BaseModel): + """ + A class to represent video data. + """ + data: np.ndarray = Field( + ..., + description="The video data as a NumPy array." + ) + status_tag: Optional[str] = Field( + None, + description="A tag indicating the status of the video data." + ) + index: Optional[int] = Field( + None, + description="The index of the video data." + ) + fps: Optional[float] = Field( + None, + description="The frames per second of the video." + ) + + # Might be a numpydantic situation? Need to look later but will skip. + class Config: + arbitrary_types_allowed = True +class AnnotatedFrameListModel(BaseModel): + """ + A class to represent a list of annotated video frames. + Not used yet + """ + frames: list[AnnotatedFrameModel] = Field( + ..., + description="A list of annotated video frames." + ) +class VideoReader: + """ + A class to read video files. + """ + def __init__(self, video_path: str): + """ + Initialize the VideoReader object. + """ + self.video_path = video_path + self.cap = cv2.VideoCapture(str(video_path)) + + if not self.cap.isOpened(): + raise ValueError(f"Could not open video at {video_path}") + + logger.info(f"Opened video at {video_path}") + + def read_frames(self) -> Iterator[np.ndarray]: + """ + Read frames from the video file. + """ + while self.cap.isOpened(): + ret, frame = self.cap.read() + logger.debug(f"Reading frame {self.cap.get(cv2.CAP_PROP_POS_FRAMES)}") + if not ret: + break + yield frame + + def release(self)-> None: + """ + Release the video capture object. + """ + self.cap.release() + + def __del__(self): + self.release() + +def gen_freq_mask( + width: int = 200, + height: int = 200, + center_radius: int = 6, + show_mask: bool = True + ) -> np.ndarray: + """ + Generate a mask to filter out horizontal and vertical frequencies. + A central circular region can be removed to allow low frequencies to pass. + """ + crow, ccol = height // 2, width // 2 + + # Create an initial mask filled with ones (pass all frequencies) + mask = np.ones((height, width), np.uint8) + + # Define band widths for vertical and horizontal suppression + vertical_band_width = 2 + horizontal_band_width = 0 + + # Zero out a vertical stripe at the frequency center + mask[:, ccol - vertical_band_width:ccol + vertical_band_width] = 0 + + # Zero out a horizontal stripe at the frequency center + mask[crow - horizontal_band_width:crow + horizontal_band_width, :] = 0 + + # Define the radius of the circular region to retain at the center + radius = center_radius + y, x = np.ogrid[:height, :width] + center_mask = (x - ccol) ** 2 + (y - crow) ** 2 <= radius ** 2 + + # Restore the center circular area to allow low frequencies to pass + mask[center_mask] = 1 + + # Visualize the mask if needed + if show_mask: + cv2.imshow('Mask', mask * 255) + while True: + if cv2.waitKey(1) == 27: # Press 'Esc' key to exit visualization + break + cv2.destroyAllWindows() + + return mask + +class FrameProcessor: + """ + A class to process video frames. + """ + def __init__(self, + height: int, + width: int, + buffer_size: int=5032, + block_size: int=32 + ): + """ + Initialize the FrameProcessor object. + Block size/buffer size will be set by dev config later. + """ + self.height = height + self.width = width + self.buffer_size = buffer_size + self.block_size = block_size + + def split_by_length( + self, + array: np.ndarray, + segment_length: int + ) -> list[np.ndarray]: + """ + Split an array into sub-arrays of a specified length. + """ + num_segments = len(array) // segment_length + + # Create sub-arrays of the specified segment length + split_arrays = [ + array[i * segment_length: (i + 1) * segment_length] for i in range(num_segments) + ] + + # Add the remaining elements as a final shorter segment, if any + if len(array) % segment_length != 0: + split_arrays.append(array[num_segments * segment_length:]) + + return split_arrays + + def patch_noisy_buffer( + self, + current_frame: np.ndarray, + previous_frame: np.ndarray, + noise_threshold: float + ) -> np.ndarray: + """ + Process the frame, replacing noisy blocks with those from the previous frame. + """ + + serialized_current = current_frame.flatten() + serialized_previous = previous_frame.flatten() + + split_current = self.split_by_length(serialized_current, self.buffer_size) + split_previous = self.split_by_length(serialized_previous, self.buffer_size) + + # Not best to deepcopy this. Just doing for now to take care of + # inhomogeneous array sizes. + split_output = copy.deepcopy(split_current) + noisy_parts = copy.deepcopy(split_current) + + for i in range(len(split_current)): + mean_error = abs(split_current[i] - split_previous[i]).mean() + if mean_error > noise_threshold: + logger.debug(f"Replacing buffer {i} with mean error {mean_error}") + split_output[i] = split_previous[i] + noisy_parts[i] = np.ones_like(split_current[i]) * 255 + else: + split_output[i] = split_current[i] + noisy_parts[i] = np.zeros_like(split_current[i]) + + serialized_output = np.concatenate(split_output)[:self.height * self.width] + noise_output = np.concatenate(noisy_parts)[:self.height * self.width] + + # Deserialize processed frame + processed_frame = serialized_output.reshape( + self.width, + self.height) + noise_patch = noise_output.reshape( + self.width, + self.height) + + return processed_frame, noise_patch + + def remove_stripes( + self, + img: np.ndarray, + mask: np.ndarray + )-> np.ndarray: + """Perform FFT/IFFT to remove horizontal stripes from a single frame.""" + f = np.fft.fft2(img) + fshift = np.fft.fftshift(f) + magnitude_spectrum = np.log(np.abs(fshift) + 1) # Use log for better visualization + + # Normalize the magnitude spectrum for visualization + magnitude_spectrum = cv2.normalize(magnitude_spectrum, None, 0, 255, cv2.NORM_MINMAX) + + if False: + # Display the magnitude spectrum + cv2.imshow('Magnitude Spectrum', np.uint8(magnitude_spectrum)) + while True: + if cv2.waitKey(1) == 27: # Press 'Esc' key to exit visualization + break + cv2.destroyAllWindows() + logger.debug(f"FFT shape: {fshift.shape}") + + # Apply mask and inverse FFT + fshift *= mask + f_ishift = np.fft.ifftshift(fshift) + img_back = np.fft.ifft2(f_ishift) + img_back = np.abs(img_back) + + # Normalize the result: + img_back = cv2.normalize(img_back, None, 0, 255, cv2.NORM_MINMAX) + + return np.uint8(img_back) # Convert to 8-bit image for display and storage + +if __name__ == "__main__": + """ + For inital debugging. + Will be removed later. + """ + video_path = 'output_001_test.avi' + reader = VideoReader(video_path) + + frames = [] + index = 0 + fig = plt.figure(figsize=(12, 4)) + + processor = FrameProcessor( + height=200, + width=200, + ) + + freq_mask = gen_freq_mask( + width=200, + height=200, + show_mask=False + ) + try: + for frame in reader.read_frames(): + gray_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) + index += 1 + if index > 100: + break + logger.info(f"Processing frame {index}") + if index == 1: + previous_frame = gray_frame + processed_frame, noise_patch = processor.patch_noisy_buffer( + gray_frame, + previous_frame, + noise_threshold=250 + ) + filtered_frame = processor.remove_stripes( + img=processed_frame, + mask=freq_mask + ) + frames.append(filtered_frame) + + frames_to_plot = [ + freq_mask, + gray_frame, + processed_frame, + noise_patch, + filtered_frame, + ] + plot_frames_side_by_side( + fig, + frames_to_plot, + titles=[ + 'Frequency Mask', + 'Original Frame', + 'Processed Frame', + 'Noisy Patch', + 'Filtered Frame', + ] + ) + plt.pause(0.01) + + finally: + reader.release() + plt.close(fig) \ No newline at end of file diff --git a/pdm.lock b/pdm.lock index 010af761..c8222c64 100644 --- a/pdm.lock +++ b/pdm.lock @@ -1722,29 +1722,29 @@ files = [ [[package]] name = "ruff" -version = "0.7.2" +version = "0.7.4" requires_python = ">=3.7" summary = "An extremely fast Python linter and code formatter, written in Rust." groups = ["all", "dev"] files = [ - {file = "ruff-0.7.2-py3-none-linux_armv6l.whl", hash = "sha256:b73f873b5f52092e63ed540adefc3c36f1f803790ecf2590e1df8bf0a9f72cb8"}, - {file = "ruff-0.7.2-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:5b813ef26db1015953daf476202585512afd6a6862a02cde63f3bafb53d0b2d4"}, - {file = "ruff-0.7.2-py3-none-macosx_11_0_arm64.whl", hash = "sha256:853277dbd9675810c6826dad7a428d52a11760744508340e66bf46f8be9701d9"}, - {file = "ruff-0.7.2-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:21aae53ab1490a52bf4e3bf520c10ce120987b047c494cacf4edad0ba0888da2"}, - {file = "ruff-0.7.2-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ccc7e0fc6e0cb3168443eeadb6445285abaae75142ee22b2b72c27d790ab60ba"}, - {file = "ruff-0.7.2-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fd77877a4e43b3a98e5ef4715ba3862105e299af0c48942cc6d51ba3d97dc859"}, - {file = "ruff-0.7.2-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:e00163fb897d35523c70d71a46fbaa43bf7bf9af0f4534c53ea5b96b2e03397b"}, - {file = "ruff-0.7.2-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f3c54b538633482dc342e9b634d91168fe8cc56b30a4b4f99287f4e339103e88"}, - {file = "ruff-0.7.2-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7b792468e9804a204be221b14257566669d1db5c00d6bb335996e5cd7004ba80"}, - {file = "ruff-0.7.2-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dba53ed84ac19ae4bfb4ea4bf0172550a2285fa27fbb13e3746f04c80f7fa088"}, - {file = "ruff-0.7.2-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:b19fafe261bf741bca2764c14cbb4ee1819b67adb63ebc2db6401dcd652e3748"}, - {file = "ruff-0.7.2-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:28bd8220f4d8f79d590db9e2f6a0674f75ddbc3847277dd44ac1f8d30684b828"}, - {file = "ruff-0.7.2-py3-none-musllinux_1_2_i686.whl", hash = "sha256:9fd67094e77efbea932e62b5d2483006154794040abb3a5072e659096415ae1e"}, - {file = "ruff-0.7.2-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:576305393998b7bd6c46018f8104ea3a9cb3fa7908c21d8580e3274a3b04b691"}, - {file = "ruff-0.7.2-py3-none-win32.whl", hash = "sha256:fa993cfc9f0ff11187e82de874dfc3611df80852540331bc85c75809c93253a8"}, - {file = "ruff-0.7.2-py3-none-win_amd64.whl", hash = "sha256:dd8800cbe0254e06b8fec585e97554047fb82c894973f7ff18558eee33d1cb88"}, - {file = "ruff-0.7.2-py3-none-win_arm64.whl", hash = "sha256:bb8368cd45bba3f57bb29cbb8d64b4a33f8415d0149d2655c5c8539452ce7760"}, - {file = "ruff-0.7.2.tar.gz", hash = "sha256:2b14e77293380e475b4e3a7a368e14549288ed2931fce259a6f99978669e844f"}, + {file = "ruff-0.7.4-py3-none-linux_armv6l.whl", hash = "sha256:a4919925e7684a3f18e18243cd6bea7cfb8e968a6eaa8437971f681b7ec51478"}, + {file = "ruff-0.7.4-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:cfb365c135b830778dda8c04fb7d4280ed0b984e1aec27f574445231e20d6c63"}, + {file = "ruff-0.7.4-py3-none-macosx_11_0_arm64.whl", hash = "sha256:63a569b36bc66fbadec5beaa539dd81e0527cb258b94e29e0531ce41bacc1f20"}, + {file = "ruff-0.7.4-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0d06218747d361d06fd2fdac734e7fa92df36df93035db3dc2ad7aa9852cb109"}, + {file = "ruff-0.7.4-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e0cea28d0944f74ebc33e9f934238f15c758841f9f5edd180b5315c203293452"}, + {file = "ruff-0.7.4-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:80094ecd4793c68b2571b128f91754d60f692d64bc0d7272ec9197fdd09bf9ea"}, + {file = "ruff-0.7.4-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:997512325c6620d1c4c2b15db49ef59543ef9cd0f4aa8065ec2ae5103cedc7e7"}, + {file = "ruff-0.7.4-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:00b4cf3a6b5fad6d1a66e7574d78956bbd09abfd6c8a997798f01f5da3d46a05"}, + {file = "ruff-0.7.4-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7dbdc7d8274e1422722933d1edddfdc65b4336abf0b16dfcb9dedd6e6a517d06"}, + {file = "ruff-0.7.4-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0e92dfb5f00eaedb1501b2f906ccabfd67b2355bdf117fea9719fc99ac2145bc"}, + {file = "ruff-0.7.4-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:3bd726099f277d735dc38900b6a8d6cf070f80828877941983a57bca1cd92172"}, + {file = "ruff-0.7.4-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:2e32829c429dd081ee5ba39aef436603e5b22335c3d3fff013cd585806a6486a"}, + {file = "ruff-0.7.4-py3-none-musllinux_1_2_i686.whl", hash = "sha256:662a63b4971807623f6f90c1fb664613f67cc182dc4d991471c23c541fee62dd"}, + {file = "ruff-0.7.4-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:876f5e09eaae3eb76814c1d3b68879891d6fde4824c015d48e7a7da4cf066a3a"}, + {file = "ruff-0.7.4-py3-none-win32.whl", hash = "sha256:75c53f54904be42dd52a548728a5b572344b50d9b2873d13a3f8c5e3b91f5cac"}, + {file = "ruff-0.7.4-py3-none-win_amd64.whl", hash = "sha256:745775c7b39f914238ed1f1b0bebed0b9155a17cd8bc0b08d3c87e4703b990d6"}, + {file = "ruff-0.7.4-py3-none-win_arm64.whl", hash = "sha256:11bff065102c3ae9d3ea4dc9ecdfe5a5171349cdd0787c1fc64761212fc9cf1f"}, + {file = "ruff-0.7.4.tar.gz", hash = "sha256:cd12e35031f5af6b9b93715d8c4f40360070b2041f81273d0527683d5708fce2"}, ] [[package]] @@ -1959,7 +1959,7 @@ files = [ [[package]] name = "tqdm" -version = "4.66.6" +version = "4.67.0" requires_python = ">=3.7" summary = "Fast, Extensible Progress Meter" groups = ["default"] @@ -1967,8 +1967,8 @@ dependencies = [ "colorama; platform_system == \"Windows\"", ] files = [ - {file = "tqdm-4.66.6-py3-none-any.whl", hash = "sha256:223e8b5359c2efc4b30555531f09e9f2f3589bcd7fdd389271191031b49b7a63"}, - {file = "tqdm-4.66.6.tar.gz", hash = "sha256:4bdd694238bef1485ce839d67967ab50af8f9272aab687c0d7702a01da0be090"}, + {file = "tqdm-4.67.0-py3-none-any.whl", hash = "sha256:0cd8af9d56911acab92182e88d763100d4788bdf421d251616040cc4d44863be"}, + {file = "tqdm-4.67.0.tar.gz", hash = "sha256:fe5a6f95e6fe0b9755e9469b77b9c3cf850048224ecaa8293d7d2d31f97d869a"}, ] [[package]] From df49809e0a10d1dd75199494eaf020606450ae7f Mon Sep 17 00:00:00 2001 From: t-sasatani <33111879+t-sasatani@users.noreply.github.com> Date: Thu, 5 Dec 2024 19:27:28 -0800 Subject: [PATCH 02/20] Add: noise detection, freq filter, min projection --- miniscope_io/processing/video.py | 280 ++++++++++++++++++------------- 1 file changed, 162 insertions(+), 118 deletions(-) diff --git a/miniscope_io/processing/video.py b/miniscope_io/processing/video.py index 9adeefe0..1a5743fb 100644 --- a/miniscope_io/processing/video.py +++ b/miniscope_io/processing/video.py @@ -3,49 +3,98 @@ """ import copy -from typing import Iterator, Optional +from typing import Iterator, Optional, Tuple import cv2 +import matplotlib.animation as animation import matplotlib.pyplot as plt import numpy as np +from matplotlib.widgets import Button, Slider from pydantic import BaseModel, Field from miniscope_io import init_logger logger = init_logger("video") -def serialize_image(image: np.ndarray) -> np.ndarray: - """ - Serializes a 2D image into a 1D array. - """ - return image.flatten() +import cv2 +import numpy as np +import matplotlib.pyplot as plt +from matplotlib.widgets import Slider, Button +from matplotlib import animation -def deserialize_image( - serialized_image: np.ndarray, - height: int, - width: int)-> np.ndarray: - """ - Deserializes a 1D array back into a 2D image. - """ - return serialized_image.reshape((height, width)) - -def detect_noisy_parts( - current_frame: np.ndarray, - previous_frame: np.ndarray, - noise_threshold: float, - buffer_size: int, - block_size: int = 32 - ) -> np.ndarray: +def plot_video_streams_with_controls( + video_frames: list[list[np.ndarray] or np.ndarray], + titles: list[str] = None, + fps: int = 20 + ) -> None: """ - Detect noisy parts in the current frame by comparing it with the previous frame. + Plot multiple video streams or static images side-by-side with controls to play/pause and navigate frames. """ - current_frame_serialized = serialize_image(current_frame) - previous_frame_serialized = serialize_image(previous_frame) + # Wrap static images in lists to handle them uniformly + video_frames = [frame if isinstance(frame, list) else [frame] for frame in video_frames] + + num_streams = len(video_frames) + num_frames = max(len(stream) for stream in video_frames) # Use max to account for static images with 1 frame + + # Initialize plots + fig, axes = plt.subplots(1, num_streams, figsize=(20, 5)) + + # Initial display of the first frame from each stream + frame_displays = [] + for idx, ax in enumerate(axes): + # Adjust static images to display them consistently + initial_frame = video_frames[idx][0] + frame_display = ax.imshow(initial_frame, cmap='gray', vmin=0, vmax=255) + frame_displays.append(frame_display) + if titles: + ax.set_title(titles[idx]) + ax.axis('off') + # Define the slider + ax_slider = plt.axes([0.1, 0.1, 0.65, 0.05], facecolor='lightgoldenrodyellow') + slider = Slider(ax=ax_slider, label='Frame', valmin=0, valmax=num_frames - 1, valinit=0, valstep=1) + + # Define the play/pause button + playing = [False] # Use a mutable object to track play state + ax_button = plt.axes([0.8, 0.1, 0.1, 0.05]) + button = Button(ax_button, 'Play/Pause') + + # Callback to toggle play/pause + def toggle_play(event): + playing[0] = not playing[0] + + button.on_clicked(toggle_play) + + # Update function for the slider and frame displays + def update_frame(index): + for idx, frame_display in enumerate(frame_displays): + # Repeat last frame for static images or when the index is larger than stream length + if index < len(video_frames[idx]): + frame = video_frames[idx][index] + else: + frame = video_frames[idx][-1] # Keep showing last frame for shorter streams + frame_display.set_data(frame) + fig.canvas.draw_idle() + # Slider update callback + def on_slider_change(val): + index = int(slider.val) + update_frame(index) - noisy_parts = np.zeros_like(current_frame_serialized) - noize_mask = deserialize_image(noisy_parts, current_frame.shape[0], current_frame.shape[1]) + # Connect the slider update function + slider.on_changed(on_slider_change) + + # Animation function + def animate(i): + if playing[0]: + current_frame = int(slider.val) + next_frame = (current_frame + 1) % num_frames + slider.set_val(next_frame) # This will also trigger on_slider_change + + # Use FuncAnimation to update the figure at the specified FPS + ani = animation.FuncAnimation(fig, animate, frames=num_frames, interval=1000//fps, blit=False) + + plt.show() def plot_frames_side_by_side( fig: plt.Figure, @@ -54,9 +103,6 @@ def plot_frames_side_by_side( ) -> None: """ Plot a list of frames side by side using matplotlib. - - :param frames: List of frames (images) to be plotted - :param titles: Optional list of titles for each subplot """ num_frames = len(frames) plt.clf() # Clear current figure @@ -71,39 +117,7 @@ def plot_frames_side_by_side( plt.tight_layout() fig.canvas.draw() -class AnnotatedFrameModel(BaseModel): - """ - A class to represent video data. - """ - data: np.ndarray = Field( - ..., - description="The video data as a NumPy array." - ) - status_tag: Optional[str] = Field( - None, - description="A tag indicating the status of the video data." - ) - index: Optional[int] = Field( - None, - description="The index of the video data." - ) - fps: Optional[float] = Field( - None, - description="The frames per second of the video." - ) - # Might be a numpydantic situation? Need to look later but will skip. - class Config: - arbitrary_types_allowed = True -class AnnotatedFrameListModel(BaseModel): - """ - A class to represent a list of annotated video frames. - Not used yet - """ - frames: list[AnnotatedFrameModel] = Field( - ..., - description="A list of annotated video frames." - ) class VideoReader: """ A class to read video files. @@ -140,10 +154,21 @@ def release(self)-> None: def __del__(self): self.release() +def show_frame(frame: np.ndarray) -> None: + """ + Display a single frame using OpenCV. + """ + cv2.imshow('Mask', frame * 255) + while True: + if cv2.waitKey(1) == 27: # Press 'Esc' key to exit visualization + break + + cv2.destroyAllWindows() + def gen_freq_mask( width: int = 200, height: int = 200, - center_radius: int = 6, + center_radius: int = 15, show_mask: bool = True ) -> np.ndarray: """ @@ -180,7 +205,6 @@ def gen_freq_mask( if cv2.waitKey(1) == 27: # Press 'Esc' key to exit visualization break cv2.destroyAllWindows() - return mask class FrameProcessor: @@ -228,31 +252,28 @@ def patch_noisy_buffer( current_frame: np.ndarray, previous_frame: np.ndarray, noise_threshold: float - ) -> np.ndarray: + ) -> Tuple[np.ndarray, np.ndarray]: """ Process the frame, replacing noisy blocks with those from the previous frame. """ - - serialized_current = current_frame.flatten() - serialized_previous = previous_frame.flatten() + serialized_current = current_frame.flatten().astype(np.int16) + serialized_previous = previous_frame.flatten().astype(np.int16) split_current = self.split_by_length(serialized_current, self.buffer_size) split_previous = self.split_by_length(serialized_previous, self.buffer_size) - # Not best to deepcopy this. Just doing for now to take care of - # inhomogeneous array sizes. - split_output = copy.deepcopy(split_current) - noisy_parts = copy.deepcopy(split_current) + split_output = split_current.copy() + noisy_parts = split_current.copy() for i in range(len(split_current)): mean_error = abs(split_current[i] - split_previous[i]).mean() if mean_error > noise_threshold: - logger.debug(f"Replacing buffer {i} with mean error {mean_error}") + logger.info(f"Replacing buffer {i} with mean error {mean_error}") split_output[i] = split_previous[i] - noisy_parts[i] = np.ones_like(split_current[i]) * 255 + noisy_parts[i] = np.ones_like(split_current[i], np.uint8) else: split_output[i] = split_current[i] - noisy_parts[i] = np.zeros_like(split_current[i]) + noisy_parts[i] = np.zeros_like(split_current[i], np.uint8) serialized_output = np.concatenate(split_output)[:self.height * self.width] noise_output = np.concatenate(noisy_parts)[:self.height * self.width] @@ -265,14 +286,16 @@ def patch_noisy_buffer( self.width, self.height) - return processed_frame, noise_patch + return np.uint8(processed_frame), np.uint8(noise_patch) def remove_stripes( self, img: np.ndarray, mask: np.ndarray )-> np.ndarray: - """Perform FFT/IFFT to remove horizontal stripes from a single frame.""" + """ + Perform FFT/IFFT to remove horizontal stripes from a single frame. + """ f = np.fft.fft2(img) fshift = np.fft.fftshift(f) magnitude_spectrum = np.log(np.abs(fshift) + 1) # Use log for better visualization @@ -280,25 +303,18 @@ def remove_stripes( # Normalize the magnitude spectrum for visualization magnitude_spectrum = cv2.normalize(magnitude_spectrum, None, 0, 255, cv2.NORM_MINMAX) - if False: - # Display the magnitude spectrum - cv2.imshow('Magnitude Spectrum', np.uint8(magnitude_spectrum)) - while True: - if cv2.waitKey(1) == 27: # Press 'Esc' key to exit visualization - break - cv2.destroyAllWindows() - logger.debug(f"FFT shape: {fshift.shape}") - # Apply mask and inverse FFT fshift *= mask f_ishift = np.fft.ifftshift(fshift) img_back = np.fft.ifft2(f_ishift) img_back = np.abs(img_back) - # Normalize the result: - img_back = cv2.normalize(img_back, None, 0, 255, cv2.NORM_MINMAX) - - return np.uint8(img_back) # Convert to 8-bit image for display and storage + return np.uint8(img_back), np.uint8(magnitude_spectrum) + +def get_minimum_projection(image_list): + stacked_images = np.stack(image_list, axis=0) + min_projection = np.min(stacked_images, axis=0) + return min_projection if __name__ == "__main__": """ @@ -307,10 +323,19 @@ def remove_stripes( """ video_path = 'output_001_test.avi' reader = VideoReader(video_path) + streaming_plot = False + slider_plot = True + freq_masks = [] + gray_frames = [] + processed_frames = [] + freq_domain_frames = [] + noise_patchs = [] + filtered_frames = [] + diff_frames = [] - frames = [] index = 0 - fig = plt.figure(figsize=(12, 4)) + + fig = plt.figure(figsize=(16, 4)) processor = FrameProcessor( height=200, @@ -322,46 +347,65 @@ def remove_stripes( height=200, show_mask=False ) + try: for frame in reader.read_frames(): gray_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) index += 1 if index > 100: break - logger.info(f"Processing frame {index}") + logger.debug(f"Processing frame {index}") if index == 1: previous_frame = gray_frame processed_frame, noise_patch = processor.patch_noisy_buffer( gray_frame, previous_frame, - noise_threshold=250 + noise_threshold=20 ) - filtered_frame = processor.remove_stripes( + filtered_frame, freq_domain_frame = processor.remove_stripes( img=processed_frame, mask=freq_mask ) - frames.append(filtered_frame) - - frames_to_plot = [ - freq_mask, - gray_frame, - processed_frame, - noise_patch, - filtered_frame, - ] - plot_frames_side_by_side( - fig, - frames_to_plot, - titles=[ - 'Frequency Mask', - 'Original Frame', - 'Processed Frame', - 'Noisy Patch', - 'Filtered Frame', - ] - ) - plt.pause(0.01) - + + diff_frame = cv2.absdiff(gray_frame, filtered_frame) + gray_frames.append(gray_frame) + processed_frames.append(processed_frame) + freq_domain_frames.append(freq_domain_frame) + noise_patchs.append(noise_patch*255) + filtered_frames.append(filtered_frame) + diff_frames.append(diff_frame*10) finally: reader.release() - plt.close(fig) \ No newline at end of file + plt.close(fig) + + minimum_projection = get_minimum_projection(filtered_frames) + show_frame(minimum_projection) + + subtract_minimum = [(frame - minimum_projection) for frame in filtered_frames] + + if slider_plot: + video_frames = [ + gray_frames, + processed_frames, + diff_frames, + noise_patchs, + freq_mask * 255, + freq_domain_frames, + filtered_frames, + minimum_projection, + subtract_minimum, + ] + plot_video_streams_with_controls( + video_frames, + titles=[ + 'Original', + 'Patched', + 'Diff', + 'Noisy area', + 'Freq mask', + 'Freq domain', + 'Freq filtered', + 'Min Proj', + 'Subtracted', + ] + ) \ No newline at end of file From 46da23fe4cd8c52da6258e084ae5c29eee88351a Mon Sep 17 00:00:00 2001 From: t-sasatani <33111879+t-sasatani@users.noreply.github.com> Date: Thu, 5 Dec 2024 20:11:38 -0800 Subject: [PATCH 03/20] Add: stack editing class, chunked error detection units --- miniscope_io/processing/video.py | 133 ++++++++++++++++++++++++------- 1 file changed, 105 insertions(+), 28 deletions(-) diff --git a/miniscope_io/processing/video.py b/miniscope_io/processing/video.py index 1a5743fb..5bccd66d 100644 --- a/miniscope_io/processing/video.py +++ b/miniscope_io/processing/video.py @@ -2,44 +2,35 @@ This module contains functions for pre-processing video data. """ -import copy -from typing import Iterator, Optional, Tuple +from typing import Iterator, Tuple import cv2 import matplotlib.animation as animation import matplotlib.pyplot as plt import numpy as np from matplotlib.widgets import Button, Slider -from pydantic import BaseModel, Field from miniscope_io import init_logger logger = init_logger("video") -import cv2 -import numpy as np -import matplotlib.pyplot as plt -from matplotlib.widgets import Slider, Button -from matplotlib import animation - def plot_video_streams_with_controls( video_frames: list[list[np.ndarray] or np.ndarray], titles: list[str] = None, fps: int = 20 ) -> None: """ - Plot multiple video streams or static images side-by-side with controls to play/pause and navigate frames. + Plot multiple video streams or static images side-by-side. + Can play/pause and navigate frames. """ # Wrap static images in lists to handle them uniformly video_frames = [frame if isinstance(frame, list) else [frame] for frame in video_frames] num_streams = len(video_frames) - num_frames = max(len(stream) for stream in video_frames) # Use max to account for static images with 1 frame + num_frames = max(len(stream) for stream in video_frames) - # Initialize plots fig, axes = plt.subplots(1, num_streams, figsize=(20, 5)) - # Initial display of the first frame from each stream frame_displays = [] for idx, ax in enumerate(axes): # Adjust static images to display them consistently @@ -54,7 +45,6 @@ def plot_video_streams_with_controls( ax_slider = plt.axes([0.1, 0.1, 0.65, 0.05], facecolor='lightgoldenrodyellow') slider = Slider(ax=ax_slider, label='Frame', valmin=0, valmax=num_frames - 1, valinit=0, valstep=1) - # Define the play/pause button playing = [False] # Use a mutable object to track play state ax_button = plt.axes([0.8, 0.1, 0.1, 0.05]) button = Button(ax_button, 'Play/Pause') @@ -168,7 +158,7 @@ def show_frame(frame: np.ndarray) -> None: def gen_freq_mask( width: int = 200, height: int = 200, - center_radius: int = 15, + center_radius: int = 5, show_mask: bool = True ) -> np.ndarray: """ @@ -220,11 +210,18 @@ def __init__(self, """ Initialize the FrameProcessor object. Block size/buffer size will be set by dev config later. + + Parameters: + height (int): Height of the video frame. + width (int): Width of the video frame. + buffer_size (int): Size of the buffer to process. + block_size (int): Size of the blocks to process. Not used now. + """ self.height = height self.width = width self.buffer_size = buffer_size - self.block_size = block_size + self.buffer_split = 1 def split_by_length( self, @@ -233,6 +230,13 @@ def split_by_length( ) -> list[np.ndarray]: """ Split an array into sub-arrays of a specified length. + + Parameters: + array (np.ndarray): The array to split. + segment_length (int): The length of each sub-array. + + Returns: + list[np.ndarray]: A list of sub-arrays. """ num_segments = len(array) // segment_length @@ -255,12 +259,20 @@ def patch_noisy_buffer( ) -> Tuple[np.ndarray, np.ndarray]: """ Process the frame, replacing noisy blocks with those from the previous frame. + + Parameters: + current_frame (np.ndarray): The current frame to process. + previous_frame (np.ndarray): The previous frame to compare against. + noise_threshold (float): The threshold for mean error to consider a block noisy. + + Returns: + Tuple[np.ndarray, np.ndarray]: The processed frame and the noise patch """ serialized_current = current_frame.flatten().astype(np.int16) serialized_previous = previous_frame.flatten().astype(np.int16) - split_current = self.split_by_length(serialized_current, self.buffer_size) - split_previous = self.split_by_length(serialized_previous, self.buffer_size) + split_current = self.split_by_length(serialized_current, self.buffer_size//5) + split_previous = self.split_by_length(serialized_previous, self.buffer_size//5) split_output = split_current.copy() noisy_parts = split_current.copy() @@ -295,6 +307,13 @@ def remove_stripes( )-> np.ndarray: """ Perform FFT/IFFT to remove horizontal stripes from a single frame. + + Parameters: + img (np.ndarray): The image to process. + mask (np.ndarray): The frequency mask to apply. + + Returns: + np.ndarray: The filtered image """ f = np.fft.fft2(img) fshift = np.fft.fftshift(f) @@ -311,10 +330,65 @@ def remove_stripes( return np.uint8(img_back), np.uint8(magnitude_spectrum) -def get_minimum_projection(image_list): - stacked_images = np.stack(image_list, axis=0) - min_projection = np.min(stacked_images, axis=0) - return min_projection +class FrameListProcessor: + """ + A class to process a list of video frames. + """ + @staticmethod + def get_minimum_projection( + image_list: list[np.ndarray] + )-> np.ndarray: + """ + Get the minimum projection of a list of images. + + Parameters: + image_list (list[np.ndarray]): A list of images to project. + + Returns: + np.ndarray: The minimum projection of the images. + """ + stacked_images = np.stack(image_list, axis=0) + min_projection = np.min(stacked_images, axis=0) + return min_projection + + @staticmethod + def normalize_video_stack( + image_list: list[np.ndarray] + ) -> list[np.ndarray]: + """ + Normalize a stack of images to 0-255 using max and minimum values throughout the entire stack. + Return a list of images. + + Parameters: + image_list (list[np.ndarray]): A list of images to normalize. + + Returns: + list[np.ndarray]: The normalized images as a list. + """ + + # Stack images along a new axis (axis=0) + stacked_images = np.stack(image_list, axis=0) + + # Find the global min and max across the entire stack + global_min = stacked_images.min() + global_max = stacked_images.max() + + # Normalize each frame using the global min and max + normalized_images = [] + for i in range(stacked_images.shape[0]): + normalized_image = cv2.normalize( + stacked_images[i], + None, + 0, + 255, + cv2.NORM_MINMAX, + dtype=cv2.CV_32F + ) + # Apply global normalization + normalized_image = (stacked_images[i] - global_min) / (global_max - global_min) * 255 + normalized_images.append(normalized_image.astype(np.uint8)) + + return normalized_images if __name__ == "__main__": """ @@ -360,13 +434,12 @@ def get_minimum_projection(image_list): processed_frame, noise_patch = processor.patch_noisy_buffer( gray_frame, previous_frame, - noise_threshold=20 + noise_threshold=15 ) filtered_frame, freq_domain_frame = processor.remove_stripes( img=processed_frame, mask=freq_mask - ) - + ) diff_frame = cv2.absdiff(gray_frame, filtered_frame) gray_frames.append(gray_frame) processed_frames.append(processed_frame) @@ -378,10 +451,12 @@ def get_minimum_projection(image_list): reader.release() plt.close(fig) - minimum_projection = get_minimum_projection(filtered_frames) - show_frame(minimum_projection) + normalized_frames = FrameListProcessor.normalize_video_stack(filtered_frames) + minimum_projection = FrameListProcessor.get_minimum_projection(normalized_frames) + + subtract_minimum = [(frame - minimum_projection) for frame in normalized_frames] - subtract_minimum = [(frame - minimum_projection) for frame in filtered_frames] + subtract_minimum = FrameListProcessor.normalize_video_stack(subtract_minimum) if slider_plot: video_frames = [ @@ -392,6 +467,7 @@ def get_minimum_projection(image_list): freq_mask * 255, freq_domain_frames, filtered_frames, + normalized_frames, minimum_projection, subtract_minimum, ] @@ -405,6 +481,7 @@ def get_minimum_projection(image_list): 'Freq mask', 'Freq domain', 'Freq filtered', + 'Normalized', 'Min Proj', 'Subtracted', ] From 808400310686dc60738e68b314589bf22316d7e8 Mon Sep 17 00:00:00 2001 From: t-sasatani <33111879+t-sasatani@users.noreply.github.com> Date: Thu, 5 Dec 2024 23:37:04 -0800 Subject: [PATCH 04/20] Restructure, format, cli --- miniscope_io/processing/video.py | 444 +++++++++++++------------------ mio/cli/main.py | 1 + mio/cli/process.py | 24 ++ mio/plots/video.py | 137 ++++++++++ 4 files changed, 342 insertions(+), 264 deletions(-) create mode 100644 mio/cli/process.py create mode 100644 mio/plots/video.py diff --git a/miniscope_io/processing/video.py b/miniscope_io/processing/video.py index 5bccd66d..a8c6bfb0 100644 --- a/miniscope_io/processing/video.py +++ b/miniscope_io/processing/video.py @@ -5,125 +5,40 @@ from typing import Iterator, Tuple import cv2 -import matplotlib.animation as animation import matplotlib.pyplot as plt import numpy as np -from matplotlib.widgets import Button, Slider from miniscope_io import init_logger +from miniscope_io.plots.video import VideoPlotter logger = init_logger("video") -def plot_video_streams_with_controls( - video_frames: list[list[np.ndarray] or np.ndarray], - titles: list[str] = None, - fps: int = 20 - ) -> None: - """ - Plot multiple video streams or static images side-by-side. - Can play/pause and navigate frames. - """ - # Wrap static images in lists to handle them uniformly - video_frames = [frame if isinstance(frame, list) else [frame] for frame in video_frames] - - num_streams = len(video_frames) - num_frames = max(len(stream) for stream in video_frames) - - fig, axes = plt.subplots(1, num_streams, figsize=(20, 5)) - - frame_displays = [] - for idx, ax in enumerate(axes): - # Adjust static images to display them consistently - initial_frame = video_frames[idx][0] - frame_display = ax.imshow(initial_frame, cmap='gray', vmin=0, vmax=255) - frame_displays.append(frame_display) - if titles: - ax.set_title(titles[idx]) - ax.axis('off') - - # Define the slider - ax_slider = plt.axes([0.1, 0.1, 0.65, 0.05], facecolor='lightgoldenrodyellow') - slider = Slider(ax=ax_slider, label='Frame', valmin=0, valmax=num_frames - 1, valinit=0, valstep=1) - - playing = [False] # Use a mutable object to track play state - ax_button = plt.axes([0.8, 0.1, 0.1, 0.05]) - button = Button(ax_button, 'Play/Pause') - - # Callback to toggle play/pause - def toggle_play(event): - playing[0] = not playing[0] - - button.on_clicked(toggle_play) - - # Update function for the slider and frame displays - def update_frame(index): - for idx, frame_display in enumerate(frame_displays): - # Repeat last frame for static images or when the index is larger than stream length - if index < len(video_frames[idx]): - frame = video_frames[idx][index] - else: - frame = video_frames[idx][-1] # Keep showing last frame for shorter streams - frame_display.set_data(frame) - fig.canvas.draw_idle() - - # Slider update callback - def on_slider_change(val): - index = int(slider.val) - update_frame(index) - - # Connect the slider update function - slider.on_changed(on_slider_change) - - # Animation function - def animate(i): - if playing[0]: - current_frame = int(slider.val) - next_frame = (current_frame + 1) % num_frames - slider.set_val(next_frame) # This will also trigger on_slider_change - - # Use FuncAnimation to update the figure at the specified FPS - ani = animation.FuncAnimation(fig, animate, frames=num_frames, interval=1000//fps, blit=False) - - plt.show() - -def plot_frames_side_by_side( - fig: plt.Figure, - frames: list[np.ndarray], - titles: str =None - ) -> None: - """ - Plot a list of frames side by side using matplotlib. - """ - num_frames = len(frames) - plt.clf() # Clear current figure - - for i, frame in enumerate(frames): - plt.subplot(1, num_frames, i + 1) - plt.imshow(frame, cmap='gray') - if titles: - plt.title(titles[i]) - - plt.axis('off') # Turn off axis labels - - plt.tight_layout() - fig.canvas.draw() class VideoReader: """ A class to read video files. """ + def __init__(self, video_path: str): """ Initialize the VideoReader object. + + Parameters: + video_path (str): The path to the video file. + + Raises: + ValueError: If the video file cannot be opened. """ self.video_path = video_path self.cap = cv2.VideoCapture(str(video_path)) + self.height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + self.width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)) if not self.cap.isOpened(): raise ValueError(f"Could not open video at {video_path}") - + logger.info(f"Opened video at {video_path}") - + def read_frames(self) -> Iterator[np.ndarray]: """ Read frames from the video file. @@ -134,8 +49,8 @@ def read_frames(self) -> Iterator[np.ndarray]: if not ret: break yield frame - - def release(self)-> None: + + def release(self) -> None: """ Release the video capture object. """ @@ -144,69 +59,66 @@ def release(self)-> None: def __del__(self): self.release() + def show_frame(frame: np.ndarray) -> None: """ Display a single frame using OpenCV. """ - cv2.imshow('Mask', frame * 255) + cv2.imshow("Mask", frame * np.iinfo(np.uint8).max) while True: if cv2.waitKey(1) == 27: # Press 'Esc' key to exit visualization break cv2.destroyAllWindows() + def gen_freq_mask( - width: int = 200, - height: int = 200, - center_radius: int = 5, - show_mask: bool = True - ) -> np.ndarray: + width: int, + height: int, + center_LPF: int, + vertical_BEF: int, + horizontal_BEF: int, + show_mask: bool = False, +) -> np.ndarray: """ Generate a mask to filter out horizontal and vertical frequencies. A central circular region can be removed to allow low frequencies to pass. """ crow, ccol = height // 2, width // 2 - + # Create an initial mask filled with ones (pass all frequencies) mask = np.ones((height, width), np.uint8) - - # Define band widths for vertical and horizontal suppression - vertical_band_width = 2 - horizontal_band_width = 0 - + # Zero out a vertical stripe at the frequency center - mask[:, ccol - vertical_band_width:ccol + vertical_band_width] = 0 - + mask[:, ccol - vertical_BEF : ccol + vertical_BEF] = 0 + # Zero out a horizontal stripe at the frequency center - mask[crow - horizontal_band_width:crow + horizontal_band_width, :] = 0 - - # Define the radius of the circular region to retain at the center - radius = center_radius + mask[crow - horizontal_BEF : crow + horizontal_BEF, :] = 0 + + # Define spacial low pass filter + radius = center_LPF y, x = np.ogrid[:height, :width] - center_mask = (x - ccol) ** 2 + (y - crow) ** 2 <= radius ** 2 - + center_mask = (x - ccol) ** 2 + (y - crow) ** 2 <= radius**2 + # Restore the center circular area to allow low frequencies to pass mask[center_mask] = 1 # Visualize the mask if needed if show_mask: - cv2.imshow('Mask', mask * 255) + cv2.imshow("Mask", mask * np.iinfo(np.uint8).max) while True: if cv2.waitKey(1) == 27: # Press 'Esc' key to exit visualization break cv2.destroyAllWindows() return mask - + + class FrameProcessor: """ A class to process video frames. """ - def __init__(self, - height: int, - width: int, - buffer_size: int=5032, - block_size: int=32 - ): + + def __init__(self, height: int, width: int, buffer_size: int = 5032, block_size: int = 32): """ Initialize the FrameProcessor object. Block size/buffer size will be set by dev config later. @@ -222,12 +134,8 @@ def __init__(self, self.width = width self.buffer_size = buffer_size self.buffer_split = 1 - - def split_by_length( - self, - array: np.ndarray, - segment_length: int - ) -> list[np.ndarray]: + + def split_by_length(self, array: np.ndarray, segment_length: int) -> list[np.ndarray]: """ Split an array into sub-arrays of a specified length. @@ -242,21 +150,18 @@ def split_by_length( # Create sub-arrays of the specified segment length split_arrays = [ - array[i * segment_length: (i + 1) * segment_length] for i in range(num_segments) - ] + array[i * segment_length : (i + 1) * segment_length] for i in range(num_segments) + ] # Add the remaining elements as a final shorter segment, if any if len(array) % segment_length != 0: - split_arrays.append(array[num_segments * segment_length:]) + split_arrays.append(array[num_segments * segment_length :]) return split_arrays - + def patch_noisy_buffer( - self, - current_frame: np.ndarray, - previous_frame: np.ndarray, - noise_threshold: float - ) -> Tuple[np.ndarray, np.ndarray]: + self, current_frame: np.ndarray, previous_frame: np.ndarray, noise_threshold: float + ) -> Tuple[np.ndarray, np.ndarray]: """ Process the frame, replacing noisy blocks with those from the previous frame. @@ -271,8 +176,8 @@ def patch_noisy_buffer( serialized_current = current_frame.flatten().astype(np.int16) serialized_previous = previous_frame.flatten().astype(np.int16) - split_current = self.split_by_length(serialized_current, self.buffer_size//5) - split_previous = self.split_by_length(serialized_previous, self.buffer_size//5) + split_current = self.split_by_length(serialized_current, self.buffer_size // 5) + split_previous = self.split_by_length(serialized_previous, self.buffer_size // 5) split_output = split_current.copy() noisy_parts = split_current.copy() @@ -287,24 +192,16 @@ def patch_noisy_buffer( split_output[i] = split_current[i] noisy_parts[i] = np.zeros_like(split_current[i], np.uint8) - serialized_output = np.concatenate(split_output)[:self.height * self.width] - noise_output = np.concatenate(noisy_parts)[:self.height * self.width] - + serialized_output = np.concatenate(split_output)[: self.height * self.width] + noise_output = np.concatenate(noisy_parts)[: self.height * self.width] + # Deserialize processed frame - processed_frame = serialized_output.reshape( - self.width, - self.height) - noise_patch = noise_output.reshape( - self.width, - self.height) + processed_frame = serialized_output.reshape(self.width, self.height) + noise_patch = noise_output.reshape(self.width, self.height) return np.uint8(processed_frame), np.uint8(noise_patch) - - def remove_stripes( - self, - img: np.ndarray, - mask: np.ndarray - )-> np.ndarray: + + def remove_stripes(self, img: np.ndarray, mask: np.ndarray) -> np.ndarray: """ Perform FFT/IFFT to remove horizontal stripes from a single frame. @@ -320,7 +217,9 @@ def remove_stripes( magnitude_spectrum = np.log(np.abs(fshift) + 1) # Use log for better visualization # Normalize the magnitude spectrum for visualization - magnitude_spectrum = cv2.normalize(magnitude_spectrum, None, 0, 255, cv2.NORM_MINMAX) + magnitude_spectrum = cv2.normalize( + magnitude_spectrum, None, 0, np.iinfo(np.uint8).max, cv2.NORM_MINMAX + ) # Apply mask and inverse FFT fshift *= mask @@ -329,15 +228,15 @@ def remove_stripes( img_back = np.abs(img_back) return np.uint8(img_back), np.uint8(magnitude_spectrum) - + + class FrameListProcessor: """ A class to process a list of video frames. """ + @staticmethod - def get_minimum_projection( - image_list: list[np.ndarray] - )-> np.ndarray: + def get_minimum_projection(image_list: list[np.ndarray]) -> np.ndarray: """ Get the minimum projection of a list of images. @@ -350,13 +249,11 @@ def get_minimum_projection( stacked_images = np.stack(image_list, axis=0) min_projection = np.min(stacked_images, axis=0) return min_projection - + @staticmethod - def normalize_video_stack( - image_list: list[np.ndarray] - ) -> list[np.ndarray]: + def normalize_video_stack(image_list: list[np.ndarray]) -> list[np.ndarray]: """ - Normalize a stack of images to 0-255 using max and minimum values throughout the entire stack. + Normalize a stack of images to 0-255 using max and minimum values of the entire stack. Return a list of images. Parameters: @@ -380,109 +277,128 @@ def normalize_video_stack( stacked_images[i], None, 0, - 255, + np.iinfo(np.uint8).max, cv2.NORM_MINMAX, - dtype=cv2.CV_32F - ) + dtype=cv2.CV_32F, + ) # Apply global normalization - normalized_image = (stacked_images[i] - global_min) / (global_max - global_min) * 255 + normalized_image = ( + (stacked_images[i] - global_min) + / (global_max - global_min) + * np.iinfo(np.uint8).max + ) normalized_images.append(normalized_image.astype(np.uint8)) return normalized_images -if __name__ == "__main__": + +class VideoProcessor: """ - For inital debugging. - Will be removed later. + A class to process video files. """ - video_path = 'output_001_test.avi' - reader = VideoReader(video_path) - streaming_plot = False - slider_plot = True - freq_masks = [] - gray_frames = [] - processed_frames = [] - freq_domain_frames = [] - noise_patchs = [] - filtered_frames = [] - diff_frames = [] - - index = 0 - - fig = plt.figure(figsize=(16, 4)) - - processor = FrameProcessor( - height=200, - width=200, - ) - - freq_mask = gen_freq_mask( - width=200, - height=200, - show_mask=False + + def denoise( + video_path: str, + slider_plot: bool = True, + end_frame: int = 100, + noise_threshold: float = 10, + spatial_LPF: int = 5, + vertical_BEF: int = 2, + horizontal_BEF: int = 0, + ) -> None: + """ + Process a video file and display the results. + Might be useful to define some using environment variables. + """ + reader = VideoReader(video_path) + raw_frames = [] + patched_frames = [] + freq_domain_frames = [] + noise_patchs = [] + freq_filtered_frames = [] + diff_frames = [] + + index = 0 + fig = plt.figure() + + processor = FrameProcessor( + height=reader.height, + width=reader.width, ) - - try: - for frame in reader.read_frames(): - gray_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) - index += 1 - if index > 100: - break - logger.debug(f"Processing frame {index}") - if index == 1: - previous_frame = gray_frame - processed_frame, noise_patch = processor.patch_noisy_buffer( - gray_frame, - previous_frame, - noise_threshold=15 + + freq_mask = gen_freq_mask( + width=reader.width, + height=reader.width, + center_LPF=spatial_LPF, + vertical_BEF=vertical_BEF, + horizontal_BEF=horizontal_BEF, + show_mask=False, + ) + + try: + for frame in reader.read_frames(): + raw_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) + + if index > end_frame: + break + + logger.debug(f"Processing frame {index}") + + if index == 0: + previous_frame = raw_frame + + patched_frame, noise_patch = processor.patch_noisy_buffer( + raw_frame, previous_frame, noise_threshold=noise_threshold + ) + freq_filtered_frame, frame_freq_domain = processor.remove_stripes( + img=patched_frame, mask=freq_mask ) - filtered_frame, freq_domain_frame = processor.remove_stripes( - img=processed_frame, - mask=freq_mask - ) - diff_frame = cv2.absdiff(gray_frame, filtered_frame) - gray_frames.append(gray_frame) - processed_frames.append(processed_frame) - freq_domain_frames.append(freq_domain_frame) - noise_patchs.append(noise_patch*255) - filtered_frames.append(filtered_frame) - diff_frames.append(diff_frame*10) - finally: - reader.release() - plt.close(fig) - - normalized_frames = FrameListProcessor.normalize_video_stack(filtered_frames) - minimum_projection = FrameListProcessor.get_minimum_projection(normalized_frames) - - subtract_minimum = [(frame - minimum_projection) for frame in normalized_frames] - - subtract_minimum = FrameListProcessor.normalize_video_stack(subtract_minimum) - - if slider_plot: - video_frames = [ - gray_frames, - processed_frames, - diff_frames, - noise_patchs, - freq_mask * 255, - freq_domain_frames, - filtered_frames, - normalized_frames, - minimum_projection, - subtract_minimum, - ] - plot_video_streams_with_controls( - video_frames, - titles=[ - 'Original', - 'Patched', - 'Diff', - 'Noisy area', - 'Freq mask', - 'Freq domain', - 'Freq filtered', - 'Normalized', - 'Min Proj', - 'Subtracted', + diff_frame = cv2.absdiff(raw_frame, freq_filtered_frame) + + raw_frames.append(raw_frame) + patched_frames.append(patched_frame) + freq_domain_frames.append(frame_freq_domain) + noise_patchs.append(noise_patch * np.iinfo(np.uint8).max) + freq_filtered_frames.append(freq_filtered_frame) + diff_frames.append(diff_frame * 10) + + index += 1 + finally: + reader.release() + plt.close(fig) + + normalized_frames = FrameListProcessor.normalize_video_stack(freq_filtered_frames) + minimum_projection = FrameListProcessor.get_minimum_projection(normalized_frames) + + subtract_minimum = [(frame - minimum_projection) for frame in normalized_frames] + + subtract_minimum = FrameListProcessor.normalize_video_stack(subtract_minimum) + + if slider_plot: + video_frames = [ + raw_frames, + patched_frames, + diff_frames, + noise_patchs, + freq_mask * np.iinfo(np.uint8).max, + freq_domain_frames, + freq_filtered_frames, + normalized_frames, + minimum_projection, + subtract_minimum, ] - ) \ No newline at end of file + VideoPlotter.show_video_with_controls( + video_frames, + titles=[ + "RAW", + "Patched", + "Diff", + "Noisy area", + "Freq mask", + "Freq domain", + "Freq filtered", + "Normalized", + "Min Proj", + "Subtracted", + ], + ) diff --git a/mio/cli/main.py b/mio/cli/main.py index f9c85413..3272b695 100644 --- a/mio/cli/main.py +++ b/mio/cli/main.py @@ -23,3 +23,4 @@ def cli(ctx: click.Context) -> None: cli.add_command(update) cli.add_command(device) cli.add_command(config) +cli.add_command(denoise) diff --git a/mio/cli/process.py b/mio/cli/process.py new file mode 100644 index 00000000..ffa5d3f4 --- /dev/null +++ b/mio/cli/process.py @@ -0,0 +1,24 @@ +""" +Command line interface for +""" + +import click + +from miniscope_io.processing.video import VideoProcessor + + +@click.command() +@click.option( + "-i", + "--input", + required=True, + type=click.Path(exists=True, dir_okay=False), + help="Path to the video file to process.", +) +def denoise( + input: str, +) -> None: + """ + Denoise a video file. + """ + VideoProcessor.denoise(input) diff --git a/mio/plots/video.py b/mio/plots/video.py new file mode 100644 index 00000000..60ccd03b --- /dev/null +++ b/mio/plots/video.py @@ -0,0 +1,137 @@ +""" +Plotting functions for video streams and frames. +""" + +from typing import Union + +import matplotlib.pyplot as plt +import numpy as np +from matplotlib import animation +from matplotlib.backend_bases import KeyEvent +from matplotlib.widgets import Button, Slider + + +class VideoPlotter: + """ + Class to display video streams and static images. + """ + + @staticmethod + def show_video_with_controls( + video_frames: Union[list[np.ndarray], np.ndarray], titles: list[str] = None, fps: int = 20 + ) -> None: + """ + Plot multiple video streams or static images side-by-side. + Can play/pause and navigate frames. + + Parameters + ---------- + video_frames : list[np.ndarray] or np.ndarray + List of video streams or static images to display. + Each element of the list should be a list of frames or a single frame. + titles : list[str], optional + List of titles for each stream, by default None + fps : int, optional + Frames per second for playback, by default 20 + """ + # Wrap static images in lists to handle them uniformly + video_frames = [frame if isinstance(frame, list) else [frame] for frame in video_frames] + + num_streams = len(video_frames) + num_frames = max(len(stream) for stream in video_frames) + + fig, axes = plt.subplots(1, num_streams, figsize=(20, 5)) + + frame_displays = [] + for idx, ax in enumerate(axes): + initial_frame = video_frames[idx][0] + frame_display = ax.imshow(initial_frame, cmap="gray", vmin=0, vmax=255) + frame_displays.append(frame_display) + if titles: + ax.set_title(titles[idx]) + ax.axis("off") + + # Slider + ax_slider = plt.axes([0.1, 0.1, 0.65, 0.05], facecolor="lightgoldenrodyellow") + slider = Slider( + ax=ax_slider, label="Frame", valmin=0, valmax=num_frames - 1, valinit=0, valstep=1 + ) + + playing = [False] # Use a mutable object to track play state + ax_button = plt.axes([0.8, 0.1, 0.1, 0.05]) + button = Button(ax_button, "Play/Pause") + + # Callback to toggle play/pause + def toggle_play(event: KeyEvent) -> None: + playing[0] = not playing[0] + + button.on_clicked(toggle_play) + + # Update function for the slider and frame displays + def update_frame(index: int) -> None: + for idx, frame_display in enumerate(frame_displays): + # Repeat last frame for static images or when the index is larger than stream length + if index < len(video_frames[idx]): + frame = video_frames[idx][index] + else: + frame = video_frames[idx][-1] # Keep showing last frame for shorter streams + frame_display.set_data(frame) + fig.canvas.draw_idle() + + # Slider update callback + def on_slider_change(val: float) -> None: + index = int(slider.val) + update_frame(index) + + # Connect the slider update function + slider.on_changed(on_slider_change) + + # Animation function + def animate(i: int) -> None: + if playing[0]: + current_frame = int(slider.val) + next_frame = (current_frame + 1) % num_frames + slider.set_val(next_frame) # This will also trigger on_slider_change + + # Use FuncAnimation to update the figure at the specified FPS + # This needs to be stored in a variable to prevent animation getting deleted + ani = animation.FuncAnimation( # noqa: F841 + fig, animate, frames=num_frames, interval=1000 // fps, blit=False + ) + + plt.show() + + @staticmethod + def show_video_side_by_side( + fig: plt.Figure, frames: list[np.ndarray], titles: str = None + ) -> None: + """ + Plot a list of frames side by side using matplotlib. + + Parameters + ---------- + fig : plt.Figure + Figure to plot on + frames : list[np.ndarray] + List of frames to plot + titles : str, optional + List of titles for each frame, by default None + + Raises + ------ + ValueError + If the number of frames and titles do not match + """ + num_frames = len(frames) + plt.clf() # Clear current figure + + for i, frame in enumerate(frames): + plt.subplot(1, num_frames, i + 1) + plt.imshow(frame, cmap="gray") + if titles: + plt.title(titles[i]) + + plt.axis("off") # Turn off axis labels + + plt.tight_layout() + fig.canvas.draw() From b50c4525e7c06178e747c7a13ce38c28f6d37660 Mon Sep 17 00:00:00 2001 From: t-sasatani <33111879+t-sasatani@users.noreply.github.com> Date: Fri, 6 Dec 2024 11:43:00 -0800 Subject: [PATCH 05/20] Model for frame data handling --- .../{processing => process}/__init__.py | 0 miniscope_io/{processing => process}/video.py | 91 +++++++++---------- mio/models/frames.py | 85 +++++++++++++++++ mio/plots/video.py | 27 +++--- 4 files changed, 146 insertions(+), 57 deletions(-) rename miniscope_io/{processing => process}/__init__.py (100%) rename miniscope_io/{processing => process}/video.py (84%) create mode 100644 mio/models/frames.py diff --git a/miniscope_io/processing/__init__.py b/miniscope_io/process/__init__.py similarity index 100% rename from miniscope_io/processing/__init__.py rename to miniscope_io/process/__init__.py diff --git a/miniscope_io/processing/video.py b/miniscope_io/process/video.py similarity index 84% rename from miniscope_io/processing/video.py rename to miniscope_io/process/video.py index a8c6bfb0..4b013f17 100644 --- a/miniscope_io/processing/video.py +++ b/miniscope_io/process/video.py @@ -9,6 +9,7 @@ import numpy as np from miniscope_io import init_logger +from miniscope_io.models.frames import NamedFrame from miniscope_io.plots.video import VideoPlotter logger = init_logger("video") @@ -42,6 +43,9 @@ def __init__(self, video_path: str): def read_frames(self) -> Iterator[np.ndarray]: """ Read frames from the video file. + + Yields: + np.ndarray: The next frame in the video. """ while self.cap.isOpened(): ret, frame = self.cap.read() @@ -60,18 +64,6 @@ def __del__(self): self.release() -def show_frame(frame: np.ndarray) -> None: - """ - Display a single frame using OpenCV. - """ - cv2.imshow("Mask", frame * np.iinfo(np.uint8).max) - while True: - if cv2.waitKey(1) == 27: # Press 'Esc' key to exit visualization - break - - cv2.destroyAllWindows() - - def gen_freq_mask( width: int, height: int, @@ -96,14 +88,13 @@ def gen_freq_mask( mask[crow - horizontal_BEF : crow + horizontal_BEF, :] = 0 # Define spacial low pass filter - radius = center_LPF y, x = np.ogrid[:height, :width] - center_mask = (x - ccol) ** 2 + (y - crow) ** 2 <= radius**2 + center_mask = (x - ccol) ** 2 + (y - crow) ** 2 <= center_LPF**2 # Restore the center circular area to allow low frequencies to pass mask[center_mask] = 1 - # Visualize the mask if needed + # Visualize the mask if needed. Might delete later. if show_mask: cv2.imshow("Mask", mask * np.iinfo(np.uint8).max) while True: @@ -118,7 +109,7 @@ class FrameProcessor: A class to process video frames. """ - def __init__(self, height: int, width: int, buffer_size: int = 5032, block_size: int = 32): + def __init__(self, height: int, width: int, buffer_size: int = 5032, buffer_split: int = 1): """ Initialize the FrameProcessor object. Block size/buffer size will be set by dev config later. @@ -133,7 +124,7 @@ def __init__(self, height: int, width: int, buffer_size: int = 5032, block_size: self.height = height self.width = width self.buffer_size = buffer_size - self.buffer_split = 1 + self.buffer_split = buffer_split def split_by_length(self, array: np.ndarray, segment_length: int) -> list[np.ndarray]: """ @@ -176,8 +167,12 @@ def patch_noisy_buffer( serialized_current = current_frame.flatten().astype(np.int16) serialized_previous = previous_frame.flatten().astype(np.int16) - split_current = self.split_by_length(serialized_current, self.buffer_size // 5) - split_previous = self.split_by_length(serialized_previous, self.buffer_size // 5) + split_current = self.split_by_length( + serialized_current, self.buffer_size // self.buffer_split + ) + split_previous = self.split_by_length( + serialized_previous, self.buffer_size // self.buffer_split + ) split_output = split_current.copy() noisy_parts = split_current.copy() @@ -301,10 +296,12 @@ def denoise( video_path: str, slider_plot: bool = True, end_frame: int = 100, - noise_threshold: float = 10, - spatial_LPF: int = 5, + noise_threshold: float = 20, + spatial_LPF: int = 10, vertical_BEF: int = 2, horizontal_BEF: int = 0, + diff_mag: int = 10, + buffer_split: int = 1, ) -> None: """ Process a video file and display the results. @@ -324,6 +321,7 @@ def denoise( processor = FrameProcessor( height=reader.height, width=reader.width, + buffer_split=buffer_split, ) freq_mask = gen_freq_mask( @@ -360,7 +358,7 @@ def denoise( freq_domain_frames.append(frame_freq_domain) noise_patchs.append(noise_patch * np.iinfo(np.uint8).max) freq_filtered_frames.append(freq_filtered_frame) - diff_frames.append(diff_frame * 10) + diff_frames.append(diff_frame * diff_mag) index += 1 finally: @@ -374,31 +372,32 @@ def denoise( subtract_minimum = FrameListProcessor.normalize_video_stack(subtract_minimum) + raw_video = NamedFrame(name="RAW", video_frame=raw_frames) + patched_video = NamedFrame(name="Patched", video_frame=patched_frames) + diff_video = NamedFrame(name=f"Diff {diff_mag}x", video_frame=diff_frames) + noise_patch = NamedFrame(name="Noisy area", video_frame=noise_patchs) + freq_mask_frame = NamedFrame( + name="Freq mask", static_frame=freq_mask * np.iinfo(np.uint8).max + ) + freq_domain_video = NamedFrame(name="Freq domain", video_frame=freq_domain_frames) + freq_filtered_video = NamedFrame(name="Freq filtered", video_frame=freq_filtered_frames) + normalized_video = NamedFrame(name="Normalized", video_frame=normalized_frames) + min_proj_frame = NamedFrame(name="Min Proj", static_frame=minimum_projection) + subtract_video = NamedFrame(name="Subtracted", video_frame=subtract_minimum) + if slider_plot: - video_frames = [ - raw_frames, - patched_frames, - diff_frames, - noise_patchs, - freq_mask * np.iinfo(np.uint8).max, - freq_domain_frames, - freq_filtered_frames, - normalized_frames, - minimum_projection, - subtract_minimum, + videos = [ + raw_video, + patched_video, + diff_video, + noise_patch, + freq_mask_frame, + freq_domain_video, + freq_filtered_video, + normalized_video, + min_proj_frame, + subtract_video, ] VideoPlotter.show_video_with_controls( - video_frames, - titles=[ - "RAW", - "Patched", - "Diff", - "Noisy area", - "Freq mask", - "Freq domain", - "Freq filtered", - "Normalized", - "Min Proj", - "Subtracted", - ], + videos, ) diff --git a/mio/models/frames.py b/mio/models/frames.py new file mode 100644 index 00000000..990e8436 --- /dev/null +++ b/mio/models/frames.py @@ -0,0 +1,85 @@ +""" +Pydantic models for storing frames and videos. +""" + +from typing import List, Optional, TypeVar + +import numpy as np +from pydantic import BaseModel, Field, model_validator + +T = TypeVar("T", np.ndarray, List[np.ndarray], List[List[np.ndarray]]) + + +class NamedFrame(BaseModel): + """ + Pydantic model to store an array (frame/video/video list) together with a name. + """ + + name: str = Field( + ..., + description="Name of the video.", + ) + static_frame: Optional[np.ndarray] = Field( + None, + description="Frame data, if provided.", + ) + video_frame: Optional[List[np.ndarray]] = Field( + None, + description="Video data, if provided.", + ) + video_list_frame: Optional[List[List[np.ndarray]]] = Field( + None, + description="List of video data, if provided.", + ) + frame_type: Optional[str] = Field( + None, + description="Type of frame data.", + ) + + @model_validator(mode="before") + def check_frame_type(cls, values: dict) -> dict: + """ + Ensure that exactly one of static_frame, video_frame, or video_list_frame is provided. + """ + static = values.get("static_frame") + video = values.get("video_frame") + video_list = values.get("video_list_frame") + + # Identify which fields are present + present_fields = [ + (field_name, field_value) + for field_name, field_value in zip( + ("static_frame", "video_frame", "video_list_frame"), (static, video, video_list) + ) + if field_value is not None + ] + + if len(present_fields) != 1: + raise ValueError( + "Exactly one of static_frame, video_frame, or video_list_frame must be provided." + ) + + # Record which frame type is present + values["frame_type"] = present_fields[0][0] + + return values + + @property + def data(self) -> T: + """Return the content of the populated field.""" + if self.frame_type == "static_frame": + return self.static_frame + elif self.frame_type == "video_frame": + return self.video_frame + elif self.frame_type == "video_list_frame": + return self.video_list_frame + else: + raise ValueError("Unknown frame type or no frame data provided.") + + class Config: + """ + Pydantic config for allowing np.ndarray types. + Could be an Numpydantic situation so will look into it later. + """ + + arbitrary_types_allowed = True diff --git a/mio/plots/video.py b/mio/plots/video.py index 60ccd03b..541f35b6 100644 --- a/mio/plots/video.py +++ b/mio/plots/video.py @@ -2,7 +2,7 @@ Plotting functions for video streams and frames. """ -from typing import Union +from typing import List import matplotlib.pyplot as plt import numpy as np @@ -10,6 +10,8 @@ from matplotlib.backend_bases import KeyEvent from matplotlib.widgets import Button, Slider +from miniscope_io.models.frames import NamedFrame + class VideoPlotter: """ @@ -17,25 +19,28 @@ class VideoPlotter: """ @staticmethod - def show_video_with_controls( - video_frames: Union[list[np.ndarray], np.ndarray], titles: list[str] = None, fps: int = 20 - ) -> None: + def show_video_with_controls(videos: List[NamedFrame], fps: int = 20) -> None: """ Plot multiple video streams or static images side-by-side. Can play/pause and navigate frames. Parameters ---------- - video_frames : list[np.ndarray] or np.ndarray - List of video streams or static images to display. - Each element of the list should be a list of frames or a single frame. - titles : list[str], optional - List of titles for each stream, by default None + videos : NamedFrame + NamedFrame object containing video data and names. fps : int, optional - Frames per second for playback, by default 20 + Frames per second for the video, by default 20 """ + + if any(frame.frame_type == "video_list_frame" for frame in videos): + raise NotImplementedError("Only single videos or frames are supported for now.") + # Wrap static images in lists to handle them uniformly - video_frames = [frame if isinstance(frame, list) else [frame] for frame in video_frames] + video_frames = [ + frame.data if frame.frame_type == "video_frame" else [frame.data] for frame in videos + ] + + titles = [video.name for video in videos] num_streams = len(video_frames) num_frames = max(len(stream) for stream in video_frames) From 9eef110a08e7b7f1f917b6ce97955269586f4208 Mon Sep 17 00:00:00 2001 From: t-sasatani <33111879+t-sasatani@users.noreply.github.com> Date: Fri, 6 Dec 2024 11:43:11 -0800 Subject: [PATCH 06/20] update CLI --- mio/cli/process.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/mio/cli/process.py b/mio/cli/process.py index ffa5d3f4..a607472c 100644 --- a/mio/cli/process.py +++ b/mio/cli/process.py @@ -1,13 +1,21 @@ """ -Command line interface for +Command line interface for offline video pre-processing. """ import click -from miniscope_io.processing.video import VideoProcessor +from miniscope_io.process.video import VideoProcessor -@click.command() +@click.group() +def process() -> None: + """ + Command group for video processing. + """ + pass + + +@process.command() @click.option( "-i", "--input", From d2df326606d88baded00ecc91060b12ade91c1ef Mon Sep 17 00:00:00 2001 From: t-sasatani <33111879+t-sasatani@users.noreply.github.com> Date: Fri, 6 Dec 2024 11:43:24 -0800 Subject: [PATCH 07/20] Add user directory, ignore logs --- .gitignore | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 12f8a6ad..e8953907 100644 --- a/.gitignore +++ b/.gitignore @@ -159,4 +159,7 @@ cython_debug/ #.idea/ wirefree_example.mp4 wirefree_example.avi -.pdm-python \ No newline at end of file +.pdm-python +user_dir/ + +~/.config/miniscope_io/logs/ From a49180328506c8c984b81563fab05659a133e249 Mon Sep 17 00:00:00 2001 From: t-sasatani <33111879+t-sasatani@users.noreply.github.com> Date: Fri, 6 Dec 2024 12:06:17 -0800 Subject: [PATCH 08/20] Move out videowriter from stream class/add videoreader class --- miniscope_io/process/video.py | 53 +------------------ mio/io.py | 96 ++++++++++++++++++++++++++++++++++- mio/stream_daq.py | 37 +++----------- 3 files changed, 103 insertions(+), 83 deletions(-) diff --git a/miniscope_io/process/video.py b/miniscope_io/process/video.py index 4b013f17..5e9ee9c4 100644 --- a/miniscope_io/process/video.py +++ b/miniscope_io/process/video.py @@ -2,68 +2,19 @@ This module contains functions for pre-processing video data. """ -from typing import Iterator, Tuple +from typing import Tuple import cv2 import matplotlib.pyplot as plt import numpy as np from miniscope_io import init_logger +from miniscope_io.io import VideoReader from miniscope_io.models.frames import NamedFrame from miniscope_io.plots.video import VideoPlotter logger = init_logger("video") - -class VideoReader: - """ - A class to read video files. - """ - - def __init__(self, video_path: str): - """ - Initialize the VideoReader object. - - Parameters: - video_path (str): The path to the video file. - - Raises: - ValueError: If the video file cannot be opened. - """ - self.video_path = video_path - self.cap = cv2.VideoCapture(str(video_path)) - self.height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) - self.width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)) - - if not self.cap.isOpened(): - raise ValueError(f"Could not open video at {video_path}") - - logger.info(f"Opened video at {video_path}") - - def read_frames(self) -> Iterator[np.ndarray]: - """ - Read frames from the video file. - - Yields: - np.ndarray: The next frame in the video. - """ - while self.cap.isOpened(): - ret, frame = self.cap.read() - logger.debug(f"Reading frame {self.cap.get(cv2.CAP_PROP_POS_FRAMES)}") - if not ret: - break - yield frame - - def release(self) -> None: - """ - Release the video capture object. - """ - self.cap.release() - - def __del__(self): - self.release() - - def gen_freq_mask( width: int, height: int, diff --git a/mio/io.py b/mio/io.py index e6227775..357289e5 100644 --- a/mio/io.py +++ b/mio/io.py @@ -6,7 +6,7 @@ import contextlib import csv from pathlib import Path -from typing import Any, BinaryIO, List, Literal, Optional, Union, overload +from typing import Any, BinaryIO, Iterator, List, Literal, Optional, Union, overload import cv2 import numpy as np @@ -19,6 +19,100 @@ from mio.types import ConfigSource +class VideoWriter: + """ + Write data to a video file using OpenCV. + """ + @staticmethod + def init_video( + path: Union[Path, str], + width: int, + height: int, + fps: int, + fourcc: str = "Y800", + **kwargs: dict + ) -> cv2.VideoWriter: + """ + Create a parameterized video writer + + Parameters + ---------- + frame_buffer_queue : multiprocessing.Queue[list[bytes]] + Input buffer queue. + path : Union[Path, str] + Video file to write to + width : int + Width of video + height : int + Height of video + frame_rate : int + Frame rate of video + fourcc : str + Fourcc code to use + kwargs : dict + passed to :class:`cv2.VideoWriter` + + Returns: + --------- + :class:`cv2.VideoWriter` + """ + if isinstance(path, str): + path = Path(path) + + fourcc = cv2.VideoWriter_fourcc(*fourcc) + frame_rate = fps + frame_size = (width, height) + out = cv2.VideoWriter(str(path), fourcc, frame_rate, frame_size, **kwargs) + return out + +class VideoReader: + """ + A class to read video files. + """ + + def __init__(self, video_path: str): + """ + Initialize the VideoReader object. + + Parameters: + video_path (str): The path to the video file. + + Raises: + ValueError: If the video file cannot be opened. + """ + self.video_path = video_path + self.cap = cv2.VideoCapture(str(video_path)) + self.height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + self.width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + self.logger = init_logger("VideoReader") + + if not self.cap.isOpened(): + raise ValueError(f"Could not open video at {video_path}") + + self.logger.info(f"Opened video at {video_path}") + + def read_frames(self) -> Iterator[np.ndarray]: + """ + Read frames from the video file. + + Yields: + np.ndarray: The next frame in the video. + """ + while self.cap.isOpened(): + ret, frame = self.cap.read() + self.logger.debug(f"Reading frame {self.cap.get(cv2.CAP_PROP_POS_FRAMES)}") + if not ret: + break + yield frame + + def release(self) -> None: + """ + Release the video capture object. + """ + self.cap.release() + + def __del__(self): + self.release() class BufferedCSVWriter: """ Write data to a CSV file in buffered mode. diff --git a/mio/stream_daq.py b/mio/stream_daq.py index cd46a742..17202da4 100644 --- a/mio/stream_daq.py +++ b/mio/stream_daq.py @@ -565,36 +565,6 @@ def _format_frame( except queue.Full: locallogs.error("Image array queue full, Could not put sentinel.") - def init_video( - self, path: Union[Path, str], fourcc: str = "Y800", **kwargs: dict - ) -> cv2.VideoWriter: - """ - Create a parameterized video writer - - Parameters - ---------- - frame_buffer_queue : multiprocessing.Queue[list[bytes]] - Input buffer queue. - path : Union[Path, str] - Video file to write to - fourcc : str - Fourcc code to use - kwargs : dict - passed to :class:`cv2.VideoWriter` - - Returns: - --------- - :class:`cv2.VideoWriter` - """ - if isinstance(path, str): - path = Path(path) - - fourcc = cv2.VideoWriter_fourcc(*fourcc) - frame_rate = self.config.fs - frame_size = (self.config.frame_width, self.config.frame_height) - out = cv2.VideoWriter(str(path), fourcc, frame_rate, frame_size, **kwargs) - return out - def alive_processes(self) -> List[multiprocessing.Process]: """ Return a list of alive processes. @@ -684,7 +654,12 @@ def capture( if video: if video_kwargs is None: video_kwargs = {} - writer = self.init_video(video, **video_kwargs) + writer = VideoWriter.init_video( + path=video, + width=self.config.frame_width, + height=self.config.frame_height, + fps=self.config.fs, + **video_kwargs) p_buffer_to_frame = multiprocessing.Process( target=self._buffer_to_frame, From dadef1bf9135f1a12b7d890df15b3526996a7ca8 Mon Sep 17 00:00:00 2001 From: t-sasatani <33111879+t-sasatani@users.noreply.github.com> Date: Fri, 6 Dec 2024 12:15:37 -0800 Subject: [PATCH 09/20] Push modules into associated classes --- miniscope_io/process/video.py | 192 ++++++++++++++++------------------ 1 file changed, 90 insertions(+), 102 deletions(-) diff --git a/miniscope_io/process/video.py b/miniscope_io/process/video.py index 5e9ee9c4..00a7ff8a 100644 --- a/miniscope_io/process/video.py +++ b/miniscope_io/process/video.py @@ -14,47 +14,6 @@ from miniscope_io.plots.video import VideoPlotter logger = init_logger("video") - -def gen_freq_mask( - width: int, - height: int, - center_LPF: int, - vertical_BEF: int, - horizontal_BEF: int, - show_mask: bool = False, -) -> np.ndarray: - """ - Generate a mask to filter out horizontal and vertical frequencies. - A central circular region can be removed to allow low frequencies to pass. - """ - crow, ccol = height // 2, width // 2 - - # Create an initial mask filled with ones (pass all frequencies) - mask = np.ones((height, width), np.uint8) - - # Zero out a vertical stripe at the frequency center - mask[:, ccol - vertical_BEF : ccol + vertical_BEF] = 0 - - # Zero out a horizontal stripe at the frequency center - mask[crow - horizontal_BEF : crow + horizontal_BEF, :] = 0 - - # Define spacial low pass filter - y, x = np.ogrid[:height, :width] - center_mask = (x - ccol) ** 2 + (y - crow) ** 2 <= center_LPF**2 - - # Restore the center circular area to allow low frequencies to pass - mask[center_mask] = 1 - - # Visualize the mask if needed. Might delete later. - if show_mask: - cv2.imshow("Mask", mask * np.iinfo(np.uint8).max) - while True: - if cv2.waitKey(1) == 27: # Press 'Esc' key to exit visualization - break - cv2.destroyAllWindows() - return mask - - class FrameProcessor: """ A class to process video frames. @@ -174,75 +133,51 @@ def remove_stripes(self, img: np.ndarray, mask: np.ndarray) -> np.ndarray: img_back = np.abs(img_back) return np.uint8(img_back), np.uint8(magnitude_spectrum) - - -class FrameListProcessor: - """ - A class to process a list of video frames. - """ - - @staticmethod - def get_minimum_projection(image_list: list[np.ndarray]) -> np.ndarray: + + def gen_freq_mask( + self, + center_LPF: int, + vertical_BEF: int, + horizontal_BEF: int, + show_mask: bool = False, + ) -> np.ndarray: """ - Get the minimum projection of a list of images. - - Parameters: - image_list (list[np.ndarray]): A list of images to project. - - Returns: - np.ndarray: The minimum projection of the images. + Generate a mask to filter out horizontal and vertical frequencies. + A central circular region can be removed to allow low frequencies to pass. """ - stacked_images = np.stack(image_list, axis=0) - min_projection = np.min(stacked_images, axis=0) - return min_projection + crow, ccol = self.height // 2, self.width // 2 - @staticmethod - def normalize_video_stack(image_list: list[np.ndarray]) -> list[np.ndarray]: - """ - Normalize a stack of images to 0-255 using max and minimum values of the entire stack. - Return a list of images. + # Create an initial mask filled with ones (pass all frequencies) + mask = np.ones((self.height, self.width), np.uint8) - Parameters: - image_list (list[np.ndarray]): A list of images to normalize. + # Zero out a vertical stripe at the frequency center + mask[:, ccol - vertical_BEF : ccol + vertical_BEF] = 0 - Returns: - list[np.ndarray]: The normalized images as a list. - """ + # Zero out a horizontal stripe at the frequency center + mask[crow - horizontal_BEF : crow + horizontal_BEF, :] = 0 - # Stack images along a new axis (axis=0) - stacked_images = np.stack(image_list, axis=0) + # Define spacial low pass filter + y, x = np.ogrid[:self.height, :self.width] + center_mask = (x - ccol) ** 2 + (y - crow) ** 2 <= center_LPF**2 - # Find the global min and max across the entire stack - global_min = stacked_images.min() - global_max = stacked_images.max() - - # Normalize each frame using the global min and max - normalized_images = [] - for i in range(stacked_images.shape[0]): - normalized_image = cv2.normalize( - stacked_images[i], - None, - 0, - np.iinfo(np.uint8).max, - cv2.NORM_MINMAX, - dtype=cv2.CV_32F, - ) - # Apply global normalization - normalized_image = ( - (stacked_images[i] - global_min) - / (global_max - global_min) - * np.iinfo(np.uint8).max - ) - normalized_images.append(normalized_image.astype(np.uint8)) + # Restore the center circular area to allow low frequencies to pass + mask[center_mask] = 1 - return normalized_images + # Visualize the mask if needed. Might delete later. + if show_mask: + cv2.imshow("Mask", mask * np.iinfo(np.uint8).max) + while True: + if cv2.waitKey(1) == 27: # Press 'Esc' key to exit visualization + break + cv2.destroyAllWindows() + return mask class VideoProcessor: """ A class to process video files. """ - + @staticmethod def denoise( video_path: str, slider_plot: bool = True, @@ -275,9 +210,7 @@ def denoise( buffer_split=buffer_split, ) - freq_mask = gen_freq_mask( - width=reader.width, - height=reader.width, + freq_mask = processor.gen_freq_mask( center_LPF=spatial_LPF, vertical_BEF=vertical_BEF, horizontal_BEF=horizontal_BEF, @@ -316,12 +249,12 @@ def denoise( reader.release() plt.close(fig) - normalized_frames = FrameListProcessor.normalize_video_stack(freq_filtered_frames) - minimum_projection = FrameListProcessor.get_minimum_projection(normalized_frames) + normalized_frames = VideoProcessor.normalize_video_stack(freq_filtered_frames) + minimum_projection = VideoProcessor.get_minimum_projection(normalized_frames) subtract_minimum = [(frame - minimum_projection) for frame in normalized_frames] - subtract_minimum = FrameListProcessor.normalize_video_stack(subtract_minimum) + subtract_minimum = VideoProcessor.normalize_video_stack(subtract_minimum) raw_video = NamedFrame(name="RAW", video_frame=raw_frames) patched_video = NamedFrame(name="Patched", video_frame=patched_frames) @@ -352,3 +285,58 @@ def denoise( VideoPlotter.show_video_with_controls( videos, ) + @staticmethod + def get_minimum_projection(image_list: list[np.ndarray]) -> np.ndarray: + """ + Get the minimum projection of a list of images. + + Parameters: + image_list (list[np.ndarray]): A list of images to project. + + Returns: + np.ndarray: The minimum projection of the images. + """ + stacked_images = np.stack(image_list, axis=0) + min_projection = np.min(stacked_images, axis=0) + return min_projection + + @staticmethod + def normalize_video_stack(image_list: list[np.ndarray]) -> list[np.ndarray]: + """ + Normalize a stack of images to 0-255 using max and minimum values of the entire stack. + Return a list of images. + + Parameters: + image_list (list[np.ndarray]): A list of images to normalize. + + Returns: + list[np.ndarray]: The normalized images as a list. + """ + + # Stack images along a new axis (axis=0) + stacked_images = np.stack(image_list, axis=0) + + # Find the global min and max across the entire stack + global_min = stacked_images.min() + global_max = stacked_images.max() + + # Normalize each frame using the global min and max + normalized_images = [] + for i in range(stacked_images.shape[0]): + normalized_image = cv2.normalize( + stacked_images[i], + None, + 0, + np.iinfo(np.uint8).max, + cv2.NORM_MINMAX, + dtype=cv2.CV_32F, + ) + # Apply global normalization + normalized_image = ( + (stacked_images[i] - global_min) + / (global_max - global_min) + * np.iinfo(np.uint8).max + ) + normalized_images.append(normalized_image.astype(np.uint8)) + + return normalized_images From 846fd5c1f34b1d56a860a41f22bda14d685ca694 Mon Sep 17 00:00:00 2001 From: t-sasatani <33111879+t-sasatani@users.noreply.github.com> Date: Fri, 6 Dec 2024 13:13:52 -0800 Subject: [PATCH 10/20] configure denoise with yaml file --- .../data/config/device/WLMS_v02_200px.yml | 56 +++++++++++ .../data/config/process/denoise_example.yml | 17 ++++ miniscope_io/process/video.py | 53 +++++----- mio/cli/process.py | 13 ++- mio/models/process.py | 99 +++++++++++++++++++ 5 files changed, 212 insertions(+), 26 deletions(-) create mode 100644 miniscope_io/data/config/device/WLMS_v02_200px.yml create mode 100644 miniscope_io/data/config/process/denoise_example.yml create mode 100644 mio/models/process.py diff --git a/miniscope_io/data/config/device/WLMS_v02_200px.yml b/miniscope_io/data/config/device/WLMS_v02_200px.yml new file mode 100644 index 00000000..75aa4ad3 --- /dev/null +++ b/miniscope_io/data/config/device/WLMS_v02_200px.yml @@ -0,0 +1,56 @@ +id: wireless-200px +mio_model: mio.models.stream.StreamDevConfig +mio_version: "v5.0.0" + +# capture device. "OK" (Opal Kelly) or "UART" +device: "OK" + +# bitstream file to upload to Opal Kelly board +bitstream: "XEM7310-A75/USBInterface-8_33mhz-J2_2-3v3-IEEE.bit" + +# COM port and baud rate is only required for UART mode +port: null +baudrate: null + +# Preamble for each data buffer. +preamble: 0x12345678 + +# Image format. StreamDaq will calculate buffer size, etc. based on these parameters +frame_width: 200 +frame_height: 200 +pix_depth: 8 + +# Buffer data format. These have to match the firmware value +header_len: 384 # 12 * 32 (in bits) +buffer_block_length: 10 +block_size: 512 +num_buffers: 32 +dummy_words: 10 + +# Flags to flip bit/byte order when recovering headers and data. See model document for details. +reverse_header_bits: True +reverse_header_bytes: True +reverse_payload_bits: True +reverse_payload_bytes: True + +adc_scale: + ref_voltage: 1.1 + bitdepth: 8 + battery_div_factor: 5 + vin_div_factor: 11.3 + +runtime: + serial_buffer_queue_size: 10 + frame_buffer_queue_size: 5 + image_buffer_queue_size: 5 + csv: + buffer: 100 + plot: + keys: + - timestamp + - buffer_count + - frame_buffer_count + - battery_voltage + - input_voltage + update_ms: 1000 + history: 500 diff --git a/miniscope_io/data/config/process/denoise_example.yml b/miniscope_io/data/config/process/denoise_example.yml new file mode 100644 index 00000000..93f4ad94 --- /dev/null +++ b/miniscope_io/data/config/process/denoise_example.yml @@ -0,0 +1,17 @@ +interactive_display: + enable: True + end_frame: 1000 +noise_patch: + enable: True + method: "mean_error" + threshold: 20 + buffer_size: 5032 + buffer_split: 1 + diff_multiply: 10 +frequency_masking: + enable: True + spacial_LPF_cutoff_radius: 10 + vertical_BEF_cutoff: 5 + horizontal_BEF_cutoff: 0 + display_mask: False +end_frame: 1000 \ No newline at end of file diff --git a/miniscope_io/process/video.py b/miniscope_io/process/video.py index 00a7ff8a..f68a6734 100644 --- a/miniscope_io/process/video.py +++ b/miniscope_io/process/video.py @@ -11,6 +11,7 @@ from miniscope_io import init_logger from miniscope_io.io import VideoReader from miniscope_io.models.frames import NamedFrame +from miniscope_io.models.process import DenoiseConfig from miniscope_io.plots.video import VideoPlotter logger = init_logger("video") @@ -33,8 +34,6 @@ def __init__(self, height: int, width: int, buffer_size: int = 5032, buffer_spli """ self.height = height self.width = width - self.buffer_size = buffer_size - self.buffer_split = buffer_split def split_by_length(self, array: np.ndarray, segment_length: int) -> list[np.ndarray]: """ @@ -61,7 +60,12 @@ def split_by_length(self, array: np.ndarray, segment_length: int) -> list[np.nda return split_arrays def patch_noisy_buffer( - self, current_frame: np.ndarray, previous_frame: np.ndarray, noise_threshold: float + self, + current_frame: np.ndarray, + previous_frame: np.ndarray, + buffer_size: int, + buffer_split: int, + noise_threshold: float ) -> Tuple[np.ndarray, np.ndarray]: """ Process the frame, replacing noisy blocks with those from the previous frame. @@ -78,10 +82,10 @@ def patch_noisy_buffer( serialized_previous = previous_frame.flatten().astype(np.int16) split_current = self.split_by_length( - serialized_current, self.buffer_size // self.buffer_split + serialized_current, buffer_size // buffer_split ) split_previous = self.split_by_length( - serialized_previous, self.buffer_size // self.buffer_split + serialized_previous, buffer_size // buffer_split ) split_output = split_current.copy() @@ -180,14 +184,7 @@ class VideoProcessor: @staticmethod def denoise( video_path: str, - slider_plot: bool = True, - end_frame: int = 100, - noise_threshold: float = 20, - spatial_LPF: int = 10, - vertical_BEF: int = 2, - horizontal_BEF: int = 0, - diff_mag: int = 10, - buffer_split: int = 1, + config: DenoiseConfig, ) -> None: """ Process a video file and display the results. @@ -207,21 +204,21 @@ def denoise( processor = FrameProcessor( height=reader.height, width=reader.width, - buffer_split=buffer_split, ) - freq_mask = processor.gen_freq_mask( - center_LPF=spatial_LPF, - vertical_BEF=vertical_BEF, - horizontal_BEF=horizontal_BEF, - show_mask=False, - ) + if config.noise_patch.enable: + freq_mask = processor.gen_freq_mask( + center_LPF=config.frequency_masking.spatial_LPF_cutoff_radius, + vertical_BEF=config.frequency_masking.vertical_BEF_cutoff, + horizontal_BEF=config.frequency_masking.horizontal_BEF_cutoff, + show_mask=config.frequency_masking.display_mask, + ) try: for frame in reader.read_frames(): raw_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) - if index > end_frame: + if config.end_frame and index > config.end_frame: break logger.debug(f"Processing frame {index}") @@ -230,7 +227,11 @@ def denoise( previous_frame = raw_frame patched_frame, noise_patch = processor.patch_noisy_buffer( - raw_frame, previous_frame, noise_threshold=noise_threshold + raw_frame, + previous_frame, + buffer_size=config.noise_patch.buffer_size, + buffer_split=config.noise_patch.buffer_split, + noise_threshold=config.noise_patch.threshold ) freq_filtered_frame, frame_freq_domain = processor.remove_stripes( img=patched_frame, mask=freq_mask @@ -242,7 +243,7 @@ def denoise( freq_domain_frames.append(frame_freq_domain) noise_patchs.append(noise_patch * np.iinfo(np.uint8).max) freq_filtered_frames.append(freq_filtered_frame) - diff_frames.append(diff_frame * diff_mag) + diff_frames.append(diff_frame * config.noise_patch.diff_multiply) index += 1 finally: @@ -258,7 +259,9 @@ def denoise( raw_video = NamedFrame(name="RAW", video_frame=raw_frames) patched_video = NamedFrame(name="Patched", video_frame=patched_frames) - diff_video = NamedFrame(name=f"Diff {diff_mag}x", video_frame=diff_frames) + diff_video = NamedFrame( + name=f"Diff {config.noise_patch.diff_multiply}x", + video_frame=diff_frames) noise_patch = NamedFrame(name="Noisy area", video_frame=noise_patchs) freq_mask_frame = NamedFrame( name="Freq mask", static_frame=freq_mask * np.iinfo(np.uint8).max @@ -269,7 +272,7 @@ def denoise( min_proj_frame = NamedFrame(name="Min Proj", static_frame=minimum_projection) subtract_video = NamedFrame(name="Subtracted", video_frame=subtract_minimum) - if slider_plot: + if config.interactive_display.enable: videos = [ raw_video, patched_video, diff --git a/mio/cli/process.py b/mio/cli/process.py index a607472c..72b06cec 100644 --- a/mio/cli/process.py +++ b/mio/cli/process.py @@ -4,6 +4,7 @@ import click +from miniscope_io.models.process import DenoiseConfig from miniscope_io.process.video import VideoProcessor @@ -23,10 +24,20 @@ def process() -> None: type=click.Path(exists=True, dir_okay=False), help="Path to the video file to process.", ) +@click.option( + "-c", + "--denoise_config", + required=True, + type=click.Path(exists=True, dir_okay=False), + help="Path to the YAML processing configuration file.", +) def denoise( input: str, + denoise_config: str, ) -> None: """ Denoise a video file. """ - VideoProcessor.denoise(input) + denoise_config_parsed = DenoiseConfig.from_yaml(denoise_config) + VideoProcessor.denoise(input, denoise_config_parsed) + diff --git a/mio/models/process.py b/mio/models/process.py new file mode 100644 index 00000000..d14c80c7 --- /dev/null +++ b/mio/models/process.py @@ -0,0 +1,99 @@ +""" +Module for preprocessing data. +""" + +from typing import Optional + +from pydantic import BaseModel, Field + +from miniscope_io.models.mixins import YAMLMixin + + +class InteractiveDisplayConfig(BaseModel): + """ + Configuration for displaying a video. + """ + enable: bool = Field( + default=False, + description="Whether to plot the output .", + ) + end_frame: Optional[int] = Field( + default=100, + description="Frame to end processing at.", + ) + +class NoisePatchConfig(BaseModel): + """ + Configuration for patch based noise handling. + """ + enable: bool = Field( + default=True, + description="Whether to use patch based noise handling.", + ) + method: str = Field( + default="mean_error", + description="Method for handling noise.", + ) + threshold: float = Field( + default=20, + description="Threshold for detecting noise.", + ) + buffer_size: int = Field( + default=5032, + description="Size of the buffers composing the image." + "This premises that the noisy area will appear in units of buffer_size.", + ) + buffer_split: int = Field( + default=1, + description="Number of splits to make in the buffer when detecting noisy areas.", + ) + diff_multiply: int = Field( + default=1, + description="Multiplier for the difference between the mean and the pixel value.", + ) + +class FreqencyMaskingConfig(BaseModel): + """ + Configuration for frequency filtering. + """ + enable: bool = Field( + default=True, + description="Whether to use frequency filtering.", + ) + spatial_LPF_cutoff_radius: int = Field( + default=5, + description="Radius for the spatial cutoff.", + ) + vertical_BEF_cutoff: int = Field( + default=5, + description="Cutoff for the vertical band elimination filter.", + ) + horizontal_BEF_cutoff: int = Field( + default=0, + description="Cutoff for the horizontal band elimination filter.", + ) + display_mask: bool = Field( + default=False, + description="Whether to display the mask.", + ) + +class DenoiseConfig(BaseModel, YAMLMixin): + """ + Configuration for denoising a video. + """ + interactive_display: Optional[InteractiveDisplayConfig] = Field( + default=None, + description="Configuration for displaying the video.", + ) + noise_patch: Optional[NoisePatchConfig] = Field( + default=None, + description="Configuration for patch based noise handling.", + ) + frequency_masking: Optional[FreqencyMaskingConfig] = Field( + default=None, + description="Configuration for frequency filtering.", + ) + end_frame: Optional[int] = Field( + default=None, + description="Frame to end processing at.", + ) From 2becfa65e21e68d80f3adef33d3d0bd9a4027953 Mon Sep 17 00:00:00 2001 From: t-sasatani <33111879+t-sasatani@users.noreply.github.com> Date: Fri, 6 Dec 2024 15:55:15 -0800 Subject: [PATCH 11/20] add .gitkeep to gitignore --- .gitignore | 3 ++- user_dir/.gitkeep | 0 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 user_dir/.gitkeep diff --git a/.gitignore b/.gitignore index e8953907..0404861f 100644 --- a/.gitignore +++ b/.gitignore @@ -160,6 +160,7 @@ cython_debug/ wirefree_example.mp4 wirefree_example.avi .pdm-python -user_dir/ +user_dir/* +!user_dir/.gitkeep ~/.config/miniscope_io/logs/ diff --git a/user_dir/.gitkeep b/user_dir/.gitkeep new file mode 100644 index 00000000..e69de29b From f5c8eddcc91090600c87be61fba25ca8955ababb Mon Sep 17 00:00:00 2001 From: t-sasatani <33111879+t-sasatani@users.noreply.github.com> Date: Fri, 6 Dec 2024 16:00:05 -0800 Subject: [PATCH 12/20] add video export method to NamedFrame model --- .../data/config/process/denoise_example.yml | 14 +- miniscope_io/process/video.py | 120 +++++++++++------- mio/cli/process.py | 1 - mio/io.py | 6 +- mio/models/frames.py | 51 ++++++++ mio/models/process.py | 49 ++++++- mio/stream_daq.py | 3 +- 7 files changed, 188 insertions(+), 56 deletions(-) diff --git a/miniscope_io/data/config/process/denoise_example.yml b/miniscope_io/data/config/process/denoise_example.yml index 93f4ad94..3a0e4bea 100644 --- a/miniscope_io/data/config/process/denoise_example.yml +++ b/miniscope_io/data/config/process/denoise_example.yml @@ -7,11 +7,19 @@ noise_patch: threshold: 20 buffer_size: 5032 buffer_split: 1 - diff_multiply: 10 + diff_multiply: 1 + output_result: True + output_noise_patch: True + output_diff: True frequency_masking: enable: True spacial_LPF_cutoff_radius: 10 - vertical_BEF_cutoff: 5 + vertical_BEF_cutoff: 1 horizontal_BEF_cutoff: 0 display_mask: False -end_frame: 1000 \ No newline at end of file + output_mask: True + output_result: True + output_freq_domain: True +end_frame: 100 +output_result: True +output_dir: 'user_dir/output' \ No newline at end of file diff --git a/miniscope_io/process/video.py b/miniscope_io/process/video.py index f68a6734..c12e39a0 100644 --- a/miniscope_io/process/video.py +++ b/miniscope_io/process/video.py @@ -2,6 +2,7 @@ This module contains functions for pre-processing video data. """ +from pathlib import Path from typing import Tuple import cv2 @@ -15,12 +16,14 @@ from miniscope_io.plots.video import VideoPlotter logger = init_logger("video") + + class FrameProcessor: """ A class to process video frames. """ - def __init__(self, height: int, width: int, buffer_size: int = 5032, buffer_split: int = 1): + def __init__(self, height: int, width: int): """ Initialize the FrameProcessor object. Block size/buffer size will be set by dev config later. @@ -65,7 +68,7 @@ def patch_noisy_buffer( previous_frame: np.ndarray, buffer_size: int, buffer_split: int, - noise_threshold: float + noise_threshold: float, ) -> Tuple[np.ndarray, np.ndarray]: """ Process the frame, replacing noisy blocks with those from the previous frame. @@ -81,12 +84,8 @@ def patch_noisy_buffer( serialized_current = current_frame.flatten().astype(np.int16) serialized_previous = previous_frame.flatten().astype(np.int16) - split_current = self.split_by_length( - serialized_current, buffer_size // buffer_split - ) - split_previous = self.split_by_length( - serialized_previous, buffer_size // buffer_split - ) + split_current = self.split_by_length(serialized_current, buffer_size // buffer_split) + split_previous = self.split_by_length(serialized_previous, buffer_size // buffer_split) split_output = split_current.copy() noisy_parts = split_current.copy() @@ -137,7 +136,7 @@ def remove_stripes(self, img: np.ndarray, mask: np.ndarray) -> np.ndarray: img_back = np.abs(img_back) return np.uint8(img_back), np.uint8(magnitude_spectrum) - + def gen_freq_mask( self, center_LPF: int, @@ -161,7 +160,7 @@ def gen_freq_mask( mask[crow - horizontal_BEF : crow + horizontal_BEF, :] = 0 # Define spacial low pass filter - y, x = np.ogrid[:self.height, :self.width] + y, x = np.ogrid[: self.height, : self.width] center_mask = (x - ccol) ** 2 + (y - crow) ** 2 <= center_LPF**2 # Restore the center circular area to allow low frequencies to pass @@ -181,6 +180,7 @@ class VideoProcessor: """ A class to process video files. """ + @staticmethod def denoise( video_path: str, @@ -191,12 +191,20 @@ def denoise( Might be useful to define some using environment variables. """ reader = VideoReader(video_path) + pathstem = Path(video_path).stem + output_dir = Path.cwd() / config.output_dir + if not output_dir.exists(): + output_dir.mkdir(parents=True) raw_frames = [] - patched_frames = [] - freq_domain_frames = [] - noise_patchs = [] - freq_filtered_frames = [] - diff_frames = [] + output_frames = [] + + if config.noise_patch.enable: + patched_frames = [] + noise_patchs = [] + diff_frames = [] + if config.frequency_masking.enable: + freq_domain_frames = [] + freq_filtered_frames = [] index = 0 fig = plt.figure() @@ -216,59 +224,80 @@ def denoise( try: for frame in reader.read_frames(): - raw_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) - if config.end_frame and index > config.end_frame: break - logger.debug(f"Processing frame {index}") + raw_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) + raw_frames.append(raw_frame) + if index == 0: previous_frame = raw_frame - patched_frame, noise_patch = processor.patch_noisy_buffer( - raw_frame, - previous_frame, - buffer_size=config.noise_patch.buffer_size, - buffer_split=config.noise_patch.buffer_split, - noise_threshold=config.noise_patch.threshold - ) - freq_filtered_frame, frame_freq_domain = processor.remove_stripes( - img=patched_frame, mask=freq_mask - ) - diff_frame = cv2.absdiff(raw_frame, freq_filtered_frame) - - raw_frames.append(raw_frame) - patched_frames.append(patched_frame) - freq_domain_frames.append(frame_freq_domain) - noise_patchs.append(noise_patch * np.iinfo(np.uint8).max) - freq_filtered_frames.append(freq_filtered_frame) - diff_frames.append(diff_frame * config.noise_patch.diff_multiply) - + output_frame = raw_frame.copy() + + if config.noise_patch.enable: + patched_frame, noise_patch = processor.patch_noisy_buffer( + output_frame, + previous_frame, + buffer_size=config.noise_patch.buffer_size, + buffer_split=config.noise_patch.buffer_split, + noise_threshold=config.noise_patch.threshold, + ) + diff_frame = cv2.absdiff(raw_frame, previous_frame) + patched_frames.append(patched_frame) + noise_patchs.append(noise_patch * np.iinfo(np.uint8).max) + diff_frames.append(diff_frame * config.noise_patch.diff_multiply) + output_frame = patched_frame + + if config.frequency_masking.enable: + freq_filtered_frame, frame_freq_domain = processor.remove_stripes( + img=patched_frame, mask=freq_mask + ) + freq_domain_frames.append(frame_freq_domain) + freq_filtered_frames.append(freq_filtered_frame) + output_frame = freq_filtered_frame + output_frames.append(output_frame) index += 1 finally: reader.release() plt.close(fig) - normalized_frames = VideoProcessor.normalize_video_stack(freq_filtered_frames) - minimum_projection = VideoProcessor.get_minimum_projection(normalized_frames) + minimum_projection = VideoProcessor.get_minimum_projection(output_frames) - subtract_minimum = [(frame - minimum_projection) for frame in normalized_frames] + subtract_minimum = [(frame - minimum_projection) for frame in output_frames] subtract_minimum = VideoProcessor.normalize_video_stack(subtract_minimum) raw_video = NamedFrame(name="RAW", video_frame=raw_frames) patched_video = NamedFrame(name="Patched", video_frame=patched_frames) diff_video = NamedFrame( - name=f"Diff {config.noise_patch.diff_multiply}x", - video_frame=diff_frames) + name=f"Diff {config.noise_patch.diff_multiply}x", video_frame=diff_frames + ) noise_patch = NamedFrame(name="Noisy area", video_frame=noise_patchs) freq_mask_frame = NamedFrame( name="Freq mask", static_frame=freq_mask * np.iinfo(np.uint8).max ) - freq_domain_video = NamedFrame(name="Freq domain", video_frame=freq_domain_frames) - freq_filtered_video = NamedFrame(name="Freq filtered", video_frame=freq_filtered_frames) - normalized_video = NamedFrame(name="Normalized", video_frame=normalized_frames) + + if config.frequency_masking.enable: + freq_domain_video = NamedFrame(name="freq_domain", video_frame=freq_domain_frames) + freq_filtered_video = NamedFrame( + name="freq_filtered", video_frame=freq_filtered_frames + ) + if config.frequency_masking.output_freq_domain: + freq_domain_video.export( + output_dir / f"{pathstem}", + suffix=True, + fps=20, + ) + if config.frequency_masking.output_result: + freq_filtered_video.export( + (output_dir / f"{pathstem}"), + suffix=True, + fps=20, + ) + + normalized_video = NamedFrame(name="Normalized", video_frame=output_frames) min_proj_frame = NamedFrame(name="Min Proj", static_frame=minimum_projection) subtract_video = NamedFrame(name="Subtracted", video_frame=subtract_minimum) @@ -288,6 +317,7 @@ def denoise( VideoPlotter.show_video_with_controls( videos, ) + @staticmethod def get_minimum_projection(image_list: list[np.ndarray]) -> np.ndarray: """ diff --git a/mio/cli/process.py b/mio/cli/process.py index 72b06cec..026a984f 100644 --- a/mio/cli/process.py +++ b/mio/cli/process.py @@ -40,4 +40,3 @@ def denoise( """ denoise_config_parsed = DenoiseConfig.from_yaml(denoise_config) VideoProcessor.denoise(input, denoise_config_parsed) - diff --git a/mio/io.py b/mio/io.py index 357289e5..e95968f3 100644 --- a/mio/io.py +++ b/mio/io.py @@ -23,6 +23,7 @@ class VideoWriter: """ Write data to a video file using OpenCV. """ + @staticmethod def init_video( path: Union[Path, str], @@ -30,7 +31,7 @@ def init_video( height: int, fps: int, fourcc: str = "Y800", - **kwargs: dict + **kwargs: dict, ) -> cv2.VideoWriter: """ Create a parameterized video writer @@ -65,6 +66,7 @@ def init_video( out = cv2.VideoWriter(str(path), fourcc, frame_rate, frame_size, **kwargs) return out + class VideoReader: """ A class to read video files. @@ -113,6 +115,8 @@ def release(self) -> None: def __del__(self): self.release() + + class BufferedCSVWriter: """ Write data to a CSV file in buffered mode. diff --git a/mio/models/frames.py b/mio/models/frames.py index 990e8436..ffe41897 100644 --- a/mio/models/frames.py +++ b/mio/models/frames.py @@ -2,13 +2,20 @@ Pydantic models for storing frames and videos. """ +from pathlib import Path from typing import List, Optional, TypeVar +import cv2 import numpy as np from pydantic import BaseModel, Field, model_validator +from miniscope_io.io import VideoWriter +from miniscope_io.logging import init_logger + T = TypeVar("T", np.ndarray, List[np.ndarray], List[List[np.ndarray]]) +logger = init_logger("model.frames") + class NamedFrame(BaseModel): """ @@ -76,6 +83,50 @@ def data(self) -> T: else: raise ValueError("Unknown frame type or no frame data provided.") + def export(self, output_path: Path, fps: int, suffix: bool) -> None: + """ + Export the frame data to a file. + + Parameters + ---------- + output_path : str + Path to the output file. + fps : int + Frames per second for the + + Raises + ------ + NotImplementedError + If the frame type is video_list_frame. + """ + if suffix: + output_path = output_path.with_name(output_path.stem + f"_{self.name}") + if self.frame_type == "static_frame": + # write PNG out + cv2.imwrite(str(output_path.with_suffix(".png")), self.static_frame) + elif self.frame_type == "video_frame": + writer = VideoWriter.init_video( + path=output_path.with_suffix(".avi"), + width=self.video_frame[0].shape[1], + height=self.video_frame[0].shape[0], + fps=20, + ) + logger.info( + f"Writing video to {output_path}.avi:" + f"{self.video_frame[0].shape[1]}x{self.video_frame[0].shape[0]}" + ) + try: + for frame in self.video_frame: + picture = cv2.cvtColor(frame, cv2.COLOR_GRAY2BGR) + writer.write(picture) + finally: + writer.release() + + elif self.frame_type == "video_list_frame": + raise NotImplementedError("Exporting video list frames is not yet supported.") + else: + raise ValueError("Unknown frame type or no frame data provided.") + class Config: """ Pydantic config for allowing np.ndarray types. diff --git a/mio/models/process.py b/mio/models/process.py index d14c80c7..ef285ab5 100644 --- a/mio/models/process.py +++ b/mio/models/process.py @@ -1,6 +1,6 @@ """ Module for preprocessing data. -""" +""" from typing import Optional @@ -10,9 +10,10 @@ class InteractiveDisplayConfig(BaseModel): - """ - Configuration for displaying a video. """ + Configuration for displaying a video. + """ + enable: bool = Field( default=False, description="Whether to plot the output .", @@ -22,10 +23,12 @@ class InteractiveDisplayConfig(BaseModel): description="Frame to end processing at.", ) + class NoisePatchConfig(BaseModel): """ Configuration for patch based noise handling. """ + enable: bool = Field( default=True, description="Whether to use patch based noise handling.", @@ -51,11 +54,25 @@ class NoisePatchConfig(BaseModel): default=1, description="Multiplier for the difference between the mean and the pixel value.", ) + output_result: bool = Field( + default=False, + description="Whether to output the result.", + ) + output_noise_patch: bool = Field( + default=False, + description="Whether to output the noise patch.", + ) + output_diff: bool = Field( + default=False, + description="Whether to output the difference.", + ) + class FreqencyMaskingConfig(BaseModel): """ Configuration for frequency filtering. """ + enable: bool = Field( default=True, description="Whether to use frequency filtering.", @@ -76,11 +93,25 @@ class FreqencyMaskingConfig(BaseModel): default=False, description="Whether to display the mask.", ) + output_result: bool = Field( + default=False, + description="Whether to output the result.", + ) + output_mask: bool = Field( + default=False, + description="Whether to output the mask.", + ) + output_freq_domain: bool = Field( + default=False, + description="Whether to output the frequency domain.", + ) + class DenoiseConfig(BaseModel, YAMLMixin): - """ - Configuration for denoising a video. """ + Configuration for denoising a video. + """ + interactive_display: Optional[InteractiveDisplayConfig] = Field( default=None, description="Configuration for displaying the video.", @@ -97,3 +128,11 @@ class DenoiseConfig(BaseModel, YAMLMixin): default=None, description="Frame to end processing at.", ) + output_result: bool = Field( + default=True, + description="Whether to output the result.", + ) + output_dir: Optional[str] = Field( + default=None, + description="Directory to save the output in.", + ) diff --git a/mio/stream_daq.py b/mio/stream_daq.py index 17202da4..b2e71e12 100644 --- a/mio/stream_daq.py +++ b/mio/stream_daq.py @@ -659,7 +659,8 @@ def capture( width=self.config.frame_width, height=self.config.frame_height, fps=self.config.fs, - **video_kwargs) + **video_kwargs, + ) p_buffer_to_frame = multiprocessing.Process( target=self._buffer_to_frame, From c937b8037b2596e4ba80a8fa14a406e30f72373d Mon Sep 17 00:00:00 2001 From: t-sasatani <33111879+t-sasatani@users.noreply.github.com> Date: Fri, 6 Dec 2024 17:02:10 -0800 Subject: [PATCH 13/20] docs: error handling for plt import, check plt inside modules --- miniscope_io/process/video.py | 14 +++++++-- mio/plots/video.py | 58 ++++++++++------------------------- 2 files changed, 29 insertions(+), 43 deletions(-) diff --git a/miniscope_io/process/video.py b/miniscope_io/process/video.py index c12e39a0..759d43e8 100644 --- a/miniscope_io/process/video.py +++ b/miniscope_io/process/video.py @@ -6,7 +6,6 @@ from typing import Tuple import cv2 -import matplotlib.pyplot as plt import numpy as np from miniscope_io import init_logger @@ -17,6 +16,11 @@ logger = init_logger("video") +try: + import matplotlib.pyplot as plt +except ImportError: + plt = None + class FrameProcessor: """ @@ -190,6 +194,13 @@ def denoise( Process a video file and display the results. Might be useful to define some using environment variables. """ + if plt is None: + raise ModuleNotFoundError( + "matplotlib is not a required dependency of miniscope-io, to use it, " + "install it manually or install miniscope-io with `pip install miniscope-io[plot]`" + ) + fig = plt.figure() + reader = VideoReader(video_path) pathstem = Path(video_path).stem output_dir = Path.cwd() / config.output_dir @@ -207,7 +218,6 @@ def denoise( freq_filtered_frames = [] index = 0 - fig = plt.figure() processor = FrameProcessor( height=reader.height, diff --git a/mio/plots/video.py b/mio/plots/video.py index 541f35b6..042c4296 100644 --- a/mio/plots/video.py +++ b/mio/plots/video.py @@ -4,14 +4,20 @@ from typing import List -import matplotlib.pyplot as plt -import numpy as np -from matplotlib import animation -from matplotlib.backend_bases import KeyEvent -from matplotlib.widgets import Button, Slider - from miniscope_io.models.frames import NamedFrame +try: + import matplotlib.pyplot as plt + from matplotlib import animation + from matplotlib.backend_bases import KeyEvent + from matplotlib.widgets import Button, Slider +except ImportError: + plt = None + animation = None + Button = None + Slider = None + KeyEvent = None + class VideoPlotter: """ @@ -31,6 +37,11 @@ def show_video_with_controls(videos: List[NamedFrame], fps: int = 20) -> None: fps : int, optional Frames per second for the video, by default 20 """ + if plt is None: + raise ModuleNotFoundError( + "matplotlib is not a required dependency of miniscope-io, to use it, " + "install it manually or install miniscope-io with `pip install miniscope-io[plot]`" + ) if any(frame.frame_type == "video_list_frame" for frame in videos): raise NotImplementedError("Only single videos or frames are supported for now.") @@ -105,38 +116,3 @@ def animate(i: int) -> None: ) plt.show() - - @staticmethod - def show_video_side_by_side( - fig: plt.Figure, frames: list[np.ndarray], titles: str = None - ) -> None: - """ - Plot a list of frames side by side using matplotlib. - - Parameters - ---------- - fig : plt.Figure - Figure to plot on - frames : list[np.ndarray] - List of frames to plot - titles : str, optional - List of titles for each frame, by default None - - Raises - ------ - ValueError - If the number of frames and titles do not match - """ - num_frames = len(frames) - plt.clf() # Clear current figure - - for i, frame in enumerate(frames): - plt.subplot(1, num_frames, i + 1) - plt.imshow(frame, cmap="gray") - if titles: - plt.title(titles[i]) - - plt.axis("off") # Turn off axis labels - - plt.tight_layout() - fig.canvas.draw() From 44d06138bb9ad4709f60f968bb200ffbd4c45bb8 Mon Sep 17 00:00:00 2001 From: t-sasatani <33111879+t-sasatani@users.noreply.github.com> Date: Mon, 9 Dec 2024 19:29:42 -0800 Subject: [PATCH 14/20] Fix config, add start/end of display --- .../data/config/process/denoise_example.yml | 11 ++--- miniscope_io/process/video.py | 44 +++++++++++++++---- mio/models/process.py | 8 +++- mio/plots/video.py | 25 ++++++++++- 4 files changed, 71 insertions(+), 17 deletions(-) diff --git a/miniscope_io/data/config/process/denoise_example.yml b/miniscope_io/data/config/process/denoise_example.yml index 3a0e4bea..f5488e95 100644 --- a/miniscope_io/data/config/process/denoise_example.yml +++ b/miniscope_io/data/config/process/denoise_example.yml @@ -1,20 +1,21 @@ interactive_display: enable: True - end_frame: 1000 + start_frame: 40 + end_frame: 90 noise_patch: enable: True method: "mean_error" - threshold: 20 + threshold: 30 buffer_size: 5032 - buffer_split: 1 + buffer_split: 10 diff_multiply: 1 output_result: True output_noise_patch: True output_diff: True frequency_masking: enable: True - spacial_LPF_cutoff_radius: 10 - vertical_BEF_cutoff: 1 + spatial_LPF_cutoff_radius: 15 + vertical_BEF_cutoff: 2 horizontal_BEF_cutoff: 0 display_mask: False output_mask: True diff --git a/miniscope_io/process/video.py b/miniscope_io/process/video.py index 759d43e8..29515a99 100644 --- a/miniscope_io/process/video.py +++ b/miniscope_io/process/video.py @@ -88,12 +88,19 @@ def patch_noisy_buffer( serialized_current = current_frame.flatten().astype(np.int16) serialized_previous = previous_frame.flatten().astype(np.int16) - split_current = self.split_by_length(serialized_current, buffer_size // buffer_split) - split_previous = self.split_by_length(serialized_previous, buffer_size // buffer_split) + buffer_per_frame = len(serialized_current) // buffer_size + 1 + + split_current = self.split_by_length( + serialized_current, + buffer_size // buffer_split) + split_previous = self.split_by_length( + serialized_previous, + buffer_size // buffer_split) split_output = split_current.copy() noisy_parts = split_current.copy() + ''' for i in range(len(split_current)): mean_error = abs(split_current[i] - split_previous[i]).mean() if mean_error > noise_threshold: @@ -103,6 +110,25 @@ def patch_noisy_buffer( else: split_output[i] = split_current[i] noisy_parts[i] = np.zeros_like(split_current[i], np.uint8) + ''' + buffer_has_noise = False + for buffer_index in range(buffer_per_frame): + for split_index in range(buffer_split): + i = buffer_index * buffer_split + split_index + mean_error = abs(split_current[i] - split_previous[i]).mean() + if mean_error > noise_threshold: + logger.info(f"Replacing buffer {i} with mean error {mean_error}") + buffer_has_noise = True + break + else: + split_output[i] = split_current[i] + noisy_parts[i] = np.zeros_like(split_current[i], np.uint8) + if buffer_has_noise: + for split_index in range(buffer_split): + i = buffer_index * buffer_split + split_index + split_output[i] = split_previous[i] + noisy_parts[i] = np.ones_like(split_current[i], np.uint8) + buffer_has_noise = False serialized_output = np.concatenate(split_output)[: self.height * self.width] noise_output = np.concatenate(noisy_parts)[: self.height * self.width] @@ -314,18 +340,20 @@ def denoise( if config.interactive_display.enable: videos = [ raw_video, - patched_video, - diff_video, noise_patch, - freq_mask_frame, - freq_domain_video, + patched_video, freq_filtered_video, - normalized_video, + freq_domain_video, min_proj_frame, - subtract_video, + freq_mask_frame, + #diff_video, + #normalized_video, + #subtract_video, ] VideoPlotter.show_video_with_controls( videos, + start_frame=config.interactive_display.start_frame, + end_frame=config.interactive_display.end_frame, ) @staticmethod diff --git a/mio/models/process.py b/mio/models/process.py index ef285ab5..63872867 100644 --- a/mio/models/process.py +++ b/mio/models/process.py @@ -18,8 +18,12 @@ class InteractiveDisplayConfig(BaseModel): default=False, description="Whether to plot the output .", ) + start_frame: Optional[int] = Field( + default=..., + description="Frame to start processing at.", + ) end_frame: Optional[int] = Field( - default=100, + default=..., description="Frame to end processing at.", ) @@ -78,7 +82,7 @@ class FreqencyMaskingConfig(BaseModel): description="Whether to use frequency filtering.", ) spatial_LPF_cutoff_radius: int = Field( - default=5, + default=..., description="Radius for the spatial cutoff.", ) vertical_BEF_cutoff: int = Field( diff --git a/mio/plots/video.py b/mio/plots/video.py index 042c4296..adc6d18c 100644 --- a/mio/plots/video.py +++ b/mio/plots/video.py @@ -5,6 +5,7 @@ from typing import List from miniscope_io.models.frames import NamedFrame +from miniscope_io import init_logger try: import matplotlib.pyplot as plt @@ -18,6 +19,7 @@ Slider = None KeyEvent = None +logger = init_logger("videoplot") class VideoPlotter: """ @@ -25,7 +27,12 @@ class VideoPlotter: """ @staticmethod - def show_video_with_controls(videos: List[NamedFrame], fps: int = 20) -> None: + def show_video_with_controls( + videos: List[NamedFrame], + start_frame: int, + end_frame: int, + fps: int = 20, + ) -> None: """ Plot multiple video streams or static images side-by-side. Can play/pause and navigate frames. @@ -34,6 +41,10 @@ def show_video_with_controls(videos: List[NamedFrame], fps: int = 20) -> None: ---------- videos : NamedFrame NamedFrame object containing video data and names. + start_frame : int + Starting frame index for the video display. + end_frame : int + Ending frame index for the video display. fps : int, optional Frames per second for the video, by default 20 """ @@ -45,7 +56,6 @@ def show_video_with_controls(videos: List[NamedFrame], fps: int = 20) -> None: if any(frame.frame_type == "video_list_frame" for frame in videos): raise NotImplementedError("Only single videos or frames are supported for now.") - # Wrap static images in lists to handle them uniformly video_frames = [ frame.data if frame.frame_type == "video_frame" else [frame.data] for frame in videos @@ -54,7 +64,18 @@ def show_video_with_controls(videos: List[NamedFrame], fps: int = 20) -> None: titles = [video.name for video in videos] num_streams = len(video_frames) + + logger.info(f"Displaying {num_streams} video streams.") + if end_frame > start_frame: + logger.info(f"Displaying frames {start_frame} to {end_frame}.") + for stream_index in range(len(video_frames)): + logger.info(f"Stream length: {len(video_frames[stream_index])}") + if len(video_frames[stream_index]) > 1: + video_frames[stream_index] = video_frames[stream_index][start_frame:end_frame] + logger.info(f"Trimmed stream length: {len(video_frames[stream_index])}") + num_frames = max(len(stream) for stream in video_frames) + logger.info(f"Max stream length: {num_frames}") fig, axes = plt.subplots(1, num_streams, figsize=(20, 5)) From 95eeb1a24d0883e0583b0078e9606dcdf1e3d8f1 Mon Sep 17 00:00:00 2001 From: t-sasatani <33111879+t-sasatani@users.noreply.github.com> Date: Tue, 10 Dec 2024 15:43:01 -0800 Subject: [PATCH 15/20] Interface each processing method with pydantic models --- .../data/config/process/denoise_example.yml | 6 +- miniscope_io/process/video.py | 138 +++++++++--------- mio/io.py | 15 +- mio/plots/video.py | 7 +- 4 files changed, 85 insertions(+), 81 deletions(-) diff --git a/miniscope_io/data/config/process/denoise_example.yml b/miniscope_io/data/config/process/denoise_example.yml index f5488e95..5691e751 100644 --- a/miniscope_io/data/config/process/denoise_example.yml +++ b/miniscope_io/data/config/process/denoise_example.yml @@ -1,11 +1,11 @@ interactive_display: enable: True start_frame: 40 - end_frame: 90 + end_frame: 140 noise_patch: enable: True method: "mean_error" - threshold: 30 + threshold: 10 buffer_size: 5032 buffer_split: 10 diff_multiply: 1 @@ -14,7 +14,7 @@ noise_patch: output_diff: True frequency_masking: enable: True - spatial_LPF_cutoff_radius: 15 + spatial_LPF_cutoff_radius: 20 vertical_BEF_cutoff: 2 horizontal_BEF_cutoff: 0 display_mask: False diff --git a/miniscope_io/process/video.py b/miniscope_io/process/video.py index 29515a99..e7e69fbf 100644 --- a/miniscope_io/process/video.py +++ b/miniscope_io/process/video.py @@ -11,7 +11,7 @@ from miniscope_io import init_logger from miniscope_io.io import VideoReader from miniscope_io.models.frames import NamedFrame -from miniscope_io.models.process import DenoiseConfig +from miniscope_io.models.process import DenoiseConfig, FreqencyMaskingConfig, NoisePatchConfig from miniscope_io.plots.video import VideoPlotter logger = init_logger("video") @@ -35,9 +35,9 @@ def __init__(self, height: int, width: int): Parameters: height (int): Height of the video frame. width (int): Width of the video frame. - buffer_size (int): Size of the buffer to process. - block_size (int): Size of the blocks to process. Not used now. + Returns: + FrameProcessor: A FrameProcessor object. """ self.height = height self.width = width @@ -45,6 +45,7 @@ def __init__(self, height: int, width: int): def split_by_length(self, array: np.ndarray, segment_length: int) -> list[np.ndarray]: """ Split an array into sub-arrays of a specified length. + Last sub-array may be shorter if the array length is not a multiple of the segment length. Parameters: array (np.ndarray): The array to split. @@ -55,27 +56,28 @@ def split_by_length(self, array: np.ndarray, segment_length: int) -> list[np.nda """ num_segments = len(array) // segment_length - # Create sub-arrays of the specified segment length - split_arrays = [ + # Split the array into segments of the specified length + sub_arrays = [ array[i * segment_length : (i + 1) * segment_length] for i in range(num_segments) ] # Add the remaining elements as a final shorter segment, if any if len(array) % segment_length != 0: - split_arrays.append(array[num_segments * segment_length :]) + sub_arrays.append(array[num_segments * segment_length :]) - return split_arrays + return sub_arrays def patch_noisy_buffer( self, current_frame: np.ndarray, previous_frame: np.ndarray, - buffer_size: int, - buffer_split: int, - noise_threshold: float, + noise_patch_config: NoisePatchConfig, ) -> Tuple[np.ndarray, np.ndarray]: """ - Process the frame, replacing noisy blocks with those from the previous frame. + Compare current frame with the previous frame to find noisy frames. + Replace noisy blocks with those from the previous frame. + The comparison is done in blocks of a specified size, + defined by the buffer_size divided by buffer_split. Parameters: current_frame (np.ndarray): The current frame to process. @@ -83,40 +85,31 @@ def patch_noisy_buffer( noise_threshold (float): The threshold for mean error to consider a block noisy. Returns: - Tuple[np.ndarray, np.ndarray]: The processed frame and the noise patch + Tuple[np.ndarray, np.ndarray]: The processed frame and the noise patch. """ serialized_current = current_frame.flatten().astype(np.int16) serialized_previous = previous_frame.flatten().astype(np.int16) - buffer_per_frame = len(serialized_current) // buffer_size + 1 + buffer_per_frame = len(serialized_current) // noise_patch_config.buffer_size + 1 split_current = self.split_by_length( serialized_current, - buffer_size // buffer_split) + noise_patch_config.buffer_size // noise_patch_config.buffer_split + 1, + ) split_previous = self.split_by_length( serialized_previous, - buffer_size // buffer_split) + noise_patch_config.buffer_size // noise_patch_config.buffer_split + 1, + ) split_output = split_current.copy() noisy_parts = split_current.copy() - ''' - for i in range(len(split_current)): - mean_error = abs(split_current[i] - split_previous[i]).mean() - if mean_error > noise_threshold: - logger.info(f"Replacing buffer {i} with mean error {mean_error}") - split_output[i] = split_previous[i] - noisy_parts[i] = np.ones_like(split_current[i], np.uint8) - else: - split_output[i] = split_current[i] - noisy_parts[i] = np.zeros_like(split_current[i], np.uint8) - ''' buffer_has_noise = False for buffer_index in range(buffer_per_frame): - for split_index in range(buffer_split): - i = buffer_index * buffer_split + split_index + for split_index in range(noise_patch_config.buffer_split): + i = buffer_index * noise_patch_config.buffer_split + split_index mean_error = abs(split_current[i] - split_previous[i]).mean() - if mean_error > noise_threshold: + if mean_error > noise_patch_config.threshold: logger.info(f"Replacing buffer {i} with mean error {mean_error}") buffer_has_noise = True break @@ -124,8 +117,8 @@ def patch_noisy_buffer( split_output[i] = split_current[i] noisy_parts[i] = np.zeros_like(split_current[i], np.uint8) if buffer_has_noise: - for split_index in range(buffer_split): - i = buffer_index * buffer_split + split_index + for split_index in range(noise_patch_config.buffer_split): + i = buffer_index * noise_patch_config.buffer_split + split_index split_output[i] = split_previous[i] noisy_parts[i] = np.ones_like(split_current[i], np.uint8) buffer_has_noise = False @@ -139,7 +132,7 @@ def patch_noisy_buffer( return np.uint8(processed_frame), np.uint8(noise_patch) - def remove_stripes(self, img: np.ndarray, mask: np.ndarray) -> np.ndarray: + def apply_freq_mask(self, img: np.ndarray, mask: np.ndarray) -> np.ndarray: """ Perform FFT/IFFT to remove horizontal stripes from a single frame. @@ -169,10 +162,7 @@ def remove_stripes(self, img: np.ndarray, mask: np.ndarray) -> np.ndarray: def gen_freq_mask( self, - center_LPF: int, - vertical_BEF: int, - horizontal_BEF: int, - show_mask: bool = False, + freq_mask_config: FreqencyMaskingConfig, ) -> np.ndarray: """ Generate a mask to filter out horizontal and vertical frequencies. @@ -184,20 +174,32 @@ def gen_freq_mask( mask = np.ones((self.height, self.width), np.uint8) # Zero out a vertical stripe at the frequency center - mask[:, ccol - vertical_BEF : ccol + vertical_BEF] = 0 + mask[ + :, + ccol + - freq_mask_config.vertical_BEF_cutoff : ccol + + freq_mask_config.vertical_BEF_cutoff, + ] = 0 # Zero out a horizontal stripe at the frequency center - mask[crow - horizontal_BEF : crow + horizontal_BEF, :] = 0 + mask[ + crow + - freq_mask_config.horizontal_BEF_cutoff : crow + + freq_mask_config.horizontal_BEF_cutoff, + :, + ] = 0 # Define spacial low pass filter y, x = np.ogrid[: self.height, : self.width] - center_mask = (x - ccol) ** 2 + (y - crow) ** 2 <= center_LPF**2 + center_mask = (x - ccol) ** 2 + ( + y - crow + ) ** 2 <= freq_mask_config.spatial_LPF_cutoff_radius**2 # Restore the center circular area to allow low frequencies to pass mask[center_mask] = 1 # Visualize the mask if needed. Might delete later. - if show_mask: + if freq_mask_config.display_mask: cv2.imshow("Mask", mask * np.iinfo(np.uint8).max) while True: if cv2.waitKey(1) == 27: # Press 'Esc' key to exit visualization @@ -217,24 +219,24 @@ def denoise( config: DenoiseConfig, ) -> None: """ - Process a video file and display the results. - Might be useful to define some using environment variables. + Preprocess a video file and display the results. """ if plt is None: raise ModuleNotFoundError( "matplotlib is not a required dependency of miniscope-io, to use it, " "install it manually or install miniscope-io with `pip install miniscope-io[plot]`" ) - fig = plt.figure() reader = VideoReader(video_path) + pathstem = Path(video_path).stem output_dir = Path.cwd() / config.output_dir if not output_dir.exists(): output_dir.mkdir(parents=True) + + # Initialize lists to store frames raw_frames = [] output_frames = [] - if config.noise_patch.enable: patched_frames = [] noise_patchs = [] @@ -243,8 +245,7 @@ def denoise( freq_domain_frames = [] freq_filtered_frames = [] - index = 0 - + # Initiate the frame processor processor = FrameProcessor( height=reader.height, width=reader.width, @@ -252,53 +253,52 @@ def denoise( if config.noise_patch.enable: freq_mask = processor.gen_freq_mask( - center_LPF=config.frequency_masking.spatial_LPF_cutoff_radius, - vertical_BEF=config.frequency_masking.vertical_BEF_cutoff, - horizontal_BEF=config.frequency_masking.horizontal_BEF_cutoff, - show_mask=config.frequency_masking.display_mask, + freq_mask_config=config.frequency_masking, ) + # index for frame number in original video try: - for frame in reader.read_frames(): + for index, frame in reader.read_frames(): if config.end_frame and index > config.end_frame: break - logger.debug(f"Processing frame {index}") + logger.info(f"Processing frame {index}") raw_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) raw_frames.append(raw_frame) - if index == 0: - previous_frame = raw_frame - + previous_frame = raw_frame.copy() output_frame = raw_frame.copy() if config.noise_patch.enable: + if index == 1: + previous_frame = raw_frame + patched_frame, noise_patch = processor.patch_noisy_buffer( - output_frame, + raw_frame, previous_frame, - buffer_size=config.noise_patch.buffer_size, - buffer_split=config.noise_patch.buffer_split, - noise_threshold=config.noise_patch.threshold, + config.noise_patch, ) - diff_frame = cv2.absdiff(raw_frame, previous_frame) patched_frames.append(patched_frame) noise_patchs.append(noise_patch * np.iinfo(np.uint8).max) - diff_frames.append(diff_frame * config.noise_patch.diff_multiply) + + if config.noise_patch.output_diff: + diff_frame = cv2.absdiff(raw_frame, previous_frame) + diff_frames.append(diff_frame * config.noise_patch.diff_multiply) + output_frame = patched_frame if config.frequency_masking.enable: - freq_filtered_frame, frame_freq_domain = processor.remove_stripes( - img=patched_frame, mask=freq_mask + freq_filtered_frame, frame_freq_domain = processor.apply_freq_mask( + img=patched_frame, + mask=freq_mask, ) freq_domain_frames.append(frame_freq_domain) freq_filtered_frames.append(freq_filtered_frame) output_frame = freq_filtered_frame output_frames.append(output_frame) - index += 1 finally: reader.release() - plt.close(fig) - + logger.info(f"shape of output_frames: {output_frames[0].shape}") minimum_projection = VideoProcessor.get_minimum_projection(output_frames) subtract_minimum = [(frame - minimum_projection) for frame in output_frames] @@ -346,9 +346,9 @@ def denoise( freq_domain_video, min_proj_frame, freq_mask_frame, - #diff_video, - #normalized_video, - #subtract_video, + # diff_video, + # normalized_video, + # subtract_video, ] VideoPlotter.show_video_with_controls( videos, diff --git a/mio/io.py b/mio/io.py index e95968f3..fe0b7a22 100644 --- a/mio/io.py +++ b/mio/io.py @@ -6,7 +6,7 @@ import contextlib import csv from pathlib import Path -from typing import Any, BinaryIO, Iterator, List, Literal, Optional, Union, overload +from typing import Any, BinaryIO, Iterator, List, Literal, Optional, Tuple, Union, overload import cv2 import numpy as np @@ -93,19 +93,22 @@ def __init__(self, video_path: str): self.logger.info(f"Opened video at {video_path}") - def read_frames(self) -> Iterator[np.ndarray]: + def read_frames(self) -> Iterator[Tuple[int, np.ndarray]]: """ - Read frames from the video file. + Read frames from the video file along with their index. Yields: - np.ndarray: The next frame in the video. + Tuple[int, np.ndarray]: The index and the next frame in the video. """ while self.cap.isOpened(): ret, frame = self.cap.read() - self.logger.debug(f"Reading frame {self.cap.get(cv2.CAP_PROP_POS_FRAMES)}") + index = int(self.cap.get(cv2.CAP_PROP_POS_FRAMES)) + self.logger.debug(f"Reading frame {index}") + if not ret: break - yield frame + + yield index, frame def release(self) -> None: """ diff --git a/mio/plots/video.py b/mio/plots/video.py index adc6d18c..f13f376f 100644 --- a/mio/plots/video.py +++ b/mio/plots/video.py @@ -21,6 +21,7 @@ logger = init_logger("videoplot") + class VideoPlotter: """ Class to display video streams and static images. @@ -32,7 +33,7 @@ def show_video_with_controls( start_frame: int, end_frame: int, fps: int = 20, - ) -> None: + ) -> None: """ Plot multiple video streams or static images side-by-side. Can play/pause and navigate frames. @@ -64,7 +65,7 @@ def show_video_with_controls( titles = [video.name for video in videos] num_streams = len(video_frames) - + logger.info(f"Displaying {num_streams} video streams.") if end_frame > start_frame: logger.info(f"Displaying frames {start_frame} to {end_frame}.") @@ -73,7 +74,7 @@ def show_video_with_controls( if len(video_frames[stream_index]) > 1: video_frames[stream_index] = video_frames[stream_index][start_frame:end_frame] logger.info(f"Trimmed stream length: {len(video_frames[stream_index])}") - + num_frames = max(len(stream) for stream in video_frames) logger.info(f"Max stream length: {num_frames}") From 64366333167bfb68ab867e2c6fa624bf9dd7201e Mon Sep 17 00:00:00 2001 From: t-sasatani <33111879+t-sasatani@users.noreply.github.com> Date: Tue, 10 Dec 2024 16:51:59 -0800 Subject: [PATCH 16/20] Fix optional stuff --- .../data/config/process/denoise_example.yml | 10 ++--- miniscope_io/process/video.py | 43 +++++++++++-------- mio/plots/video.py | 2 +- 3 files changed, 30 insertions(+), 25 deletions(-) diff --git a/miniscope_io/data/config/process/denoise_example.yml b/miniscope_io/data/config/process/denoise_example.yml index 5691e751..901f65c1 100644 --- a/miniscope_io/data/config/process/denoise_example.yml +++ b/miniscope_io/data/config/process/denoise_example.yml @@ -1,17 +1,17 @@ interactive_display: - enable: True + enable: False start_frame: 40 end_frame: 140 noise_patch: enable: True method: "mean_error" - threshold: 10 + threshold: 30 buffer_size: 5032 buffer_split: 10 diff_multiply: 1 output_result: True output_noise_patch: True - output_diff: True + output_diff: False frequency_masking: enable: True spatial_LPF_cutoff_radius: 20 @@ -20,7 +20,7 @@ frequency_masking: display_mask: False output_mask: True output_result: True - output_freq_domain: True -end_frame: 100 + output_freq_domain: False +end_frame: 140 output_result: True output_dir: 'user_dir/output' \ No newline at end of file diff --git a/miniscope_io/process/video.py b/miniscope_io/process/video.py index e7e69fbf..e333cd92 100644 --- a/miniscope_io/process/video.py +++ b/miniscope_io/process/video.py @@ -109,6 +109,7 @@ def patch_noisy_buffer( for split_index in range(noise_patch_config.buffer_split): i = buffer_index * noise_patch_config.buffer_split + split_index mean_error = abs(split_current[i] - split_previous[i]).mean() + logger.debug(f"Mean error for buffer {i}: {mean_error}") if mean_error > noise_patch_config.threshold: logger.info(f"Replacing buffer {i} with mean error {mean_error}") buffer_has_noise = True @@ -212,7 +213,6 @@ class VideoProcessor: """ A class to process video files. """ - @staticmethod def denoise( video_path: str, @@ -261,23 +261,23 @@ def denoise( for index, frame in reader.read_frames(): if config.end_frame and index > config.end_frame: break - logger.info(f"Processing frame {index}") raw_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) raw_frames.append(raw_frame) - previous_frame = raw_frame.copy() output_frame = raw_frame.copy() if config.noise_patch.enable: if index == 1: - previous_frame = raw_frame + previous_frame = raw_frame.copy() + logger.debug(f"Processing frame {index}") patched_frame, noise_patch = processor.patch_noisy_buffer( raw_frame, previous_frame, config.noise_patch, ) + previous_frame = patched_frame patched_frames.append(patched_frame) noise_patchs.append(noise_patch * np.iinfo(np.uint8).max) @@ -298,7 +298,6 @@ def denoise( output_frames.append(output_frame) finally: reader.release() - logger.info(f"shape of output_frames: {output_frames[0].shape}") minimum_projection = VideoProcessor.get_minimum_projection(output_frames) subtract_minimum = [(frame - minimum_projection) for frame in output_frames] @@ -306,14 +305,25 @@ def denoise( subtract_minimum = VideoProcessor.normalize_video_stack(subtract_minimum) raw_video = NamedFrame(name="RAW", video_frame=raw_frames) - patched_video = NamedFrame(name="Patched", video_frame=patched_frames) - diff_video = NamedFrame( - name=f"Diff {config.noise_patch.diff_multiply}x", video_frame=diff_frames - ) - noise_patch = NamedFrame(name="Noisy area", video_frame=noise_patchs) - freq_mask_frame = NamedFrame( - name="Freq mask", static_frame=freq_mask * np.iinfo(np.uint8).max - ) + + if config.noise_patch.enable: + patched_video = NamedFrame(name="patched", video_frame=patched_frames) + if config.noise_patch.output_result: + patched_video.export( + output_dir / f"{pathstem}", + suffix=True, + fps=20, + ) + if config.noise_patch.output_diff: + diff_video = NamedFrame( + name=f"diff_{config.noise_patch.diff_multiply}x", video_frame=diff_frames + ) + if config.noise_patch.output_noise_patch: + noise_patch = NamedFrame(name="noise_patch", video_frame=noise_patchs) + if config.frequency_masking.output_mask: + freq_mask_frame = NamedFrame( + name="freq_mask", static_frame=freq_mask * np.iinfo(np.uint8).max + ) if config.frequency_masking.enable: freq_domain_video = NamedFrame(name="freq_domain", video_frame=freq_domain_frames) @@ -333,9 +343,7 @@ def denoise( fps=20, ) - normalized_video = NamedFrame(name="Normalized", video_frame=output_frames) - min_proj_frame = NamedFrame(name="Min Proj", static_frame=minimum_projection) - subtract_video = NamedFrame(name="Subtracted", video_frame=subtract_minimum) + min_proj_frame = NamedFrame(name="min_proj", static_frame=minimum_projection) if config.interactive_display.enable: videos = [ @@ -346,9 +354,6 @@ def denoise( freq_domain_video, min_proj_frame, freq_mask_frame, - # diff_video, - # normalized_video, - # subtract_video, ] VideoPlotter.show_video_with_controls( videos, diff --git a/mio/plots/video.py b/mio/plots/video.py index f13f376f..740a1620 100644 --- a/mio/plots/video.py +++ b/mio/plots/video.py @@ -4,8 +4,8 @@ from typing import List -from miniscope_io.models.frames import NamedFrame from miniscope_io import init_logger +from miniscope_io.models.frames import NamedFrame try: import matplotlib.pyplot as plt From 36b0167771b430159224df3d17c139a88f7d63d1 Mon Sep 17 00:00:00 2001 From: t-sasatani <33111879+t-sasatani@users.noreply.github.com> Date: Tue, 10 Dec 2024 17:04:47 -0800 Subject: [PATCH 17/20] Move stuff in correct place for rebasing --- .../data/config/device/WLMS_v02_200px.yml | 56 ------------------- .../data/config/process/denoise_example.yml | 0 {miniscope_io => mio}/process/__init__.py | 0 {miniscope_io => mio}/process/video.py | 10 ++-- 4 files changed, 5 insertions(+), 61 deletions(-) delete mode 100644 miniscope_io/data/config/device/WLMS_v02_200px.yml rename {miniscope_io => mio}/data/config/process/denoise_example.yml (100%) rename {miniscope_io => mio}/process/__init__.py (100%) rename {miniscope_io => mio}/process/video.py (98%) diff --git a/miniscope_io/data/config/device/WLMS_v02_200px.yml b/miniscope_io/data/config/device/WLMS_v02_200px.yml deleted file mode 100644 index 75aa4ad3..00000000 --- a/miniscope_io/data/config/device/WLMS_v02_200px.yml +++ /dev/null @@ -1,56 +0,0 @@ -id: wireless-200px -mio_model: mio.models.stream.StreamDevConfig -mio_version: "v5.0.0" - -# capture device. "OK" (Opal Kelly) or "UART" -device: "OK" - -# bitstream file to upload to Opal Kelly board -bitstream: "XEM7310-A75/USBInterface-8_33mhz-J2_2-3v3-IEEE.bit" - -# COM port and baud rate is only required for UART mode -port: null -baudrate: null - -# Preamble for each data buffer. -preamble: 0x12345678 - -# Image format. StreamDaq will calculate buffer size, etc. based on these parameters -frame_width: 200 -frame_height: 200 -pix_depth: 8 - -# Buffer data format. These have to match the firmware value -header_len: 384 # 12 * 32 (in bits) -buffer_block_length: 10 -block_size: 512 -num_buffers: 32 -dummy_words: 10 - -# Flags to flip bit/byte order when recovering headers and data. See model document for details. -reverse_header_bits: True -reverse_header_bytes: True -reverse_payload_bits: True -reverse_payload_bytes: True - -adc_scale: - ref_voltage: 1.1 - bitdepth: 8 - battery_div_factor: 5 - vin_div_factor: 11.3 - -runtime: - serial_buffer_queue_size: 10 - frame_buffer_queue_size: 5 - image_buffer_queue_size: 5 - csv: - buffer: 100 - plot: - keys: - - timestamp - - buffer_count - - frame_buffer_count - - battery_voltage - - input_voltage - update_ms: 1000 - history: 500 diff --git a/miniscope_io/data/config/process/denoise_example.yml b/mio/data/config/process/denoise_example.yml similarity index 100% rename from miniscope_io/data/config/process/denoise_example.yml rename to mio/data/config/process/denoise_example.yml diff --git a/miniscope_io/process/__init__.py b/mio/process/__init__.py similarity index 100% rename from miniscope_io/process/__init__.py rename to mio/process/__init__.py diff --git a/miniscope_io/process/video.py b/mio/process/video.py similarity index 98% rename from miniscope_io/process/video.py rename to mio/process/video.py index e333cd92..75443173 100644 --- a/miniscope_io/process/video.py +++ b/mio/process/video.py @@ -8,11 +8,11 @@ import cv2 import numpy as np -from miniscope_io import init_logger -from miniscope_io.io import VideoReader -from miniscope_io.models.frames import NamedFrame -from miniscope_io.models.process import DenoiseConfig, FreqencyMaskingConfig, NoisePatchConfig -from miniscope_io.plots.video import VideoPlotter +from mio import init_logger +from mio.io import VideoReader +from mio.models.frames import NamedFrame +from mio.models.process import DenoiseConfig, FreqencyMaskingConfig, NoisePatchConfig +from mio.plots.video import VideoPlotter logger = init_logger("video") From dc0cea31d68f6bdfd85b7cd98fcb19a67c79052c Mon Sep 17 00:00:00 2001 From: t-sasatani <33111879+t-sasatani@users.noreply.github.com> Date: Tue, 10 Dec 2024 17:12:58 -0800 Subject: [PATCH 18/20] Fix imports for rebasing --- mio/cli/main.py | 3 ++- mio/cli/process.py | 4 ++-- mio/models/frames.py | 4 ++-- mio/models/process.py | 2 +- mio/plots/video.py | 4 ++-- 5 files changed, 9 insertions(+), 8 deletions(-) diff --git a/mio/cli/main.py b/mio/cli/main.py index 3272b695..c7c69129 100644 --- a/mio/cli/main.py +++ b/mio/cli/main.py @@ -5,6 +5,7 @@ import click from mio.cli.config import config +from mio.cli.process import process from mio.cli.stream import stream from mio.cli.update import device, update @@ -23,4 +24,4 @@ def cli(ctx: click.Context) -> None: cli.add_command(update) cli.add_command(device) cli.add_command(config) -cli.add_command(denoise) +cli.add_command(process) diff --git a/mio/cli/process.py b/mio/cli/process.py index 026a984f..2fc21001 100644 --- a/mio/cli/process.py +++ b/mio/cli/process.py @@ -4,8 +4,8 @@ import click -from miniscope_io.models.process import DenoiseConfig -from miniscope_io.process.video import VideoProcessor +from mio.models.process import DenoiseConfig +from mio.process.video import VideoProcessor @click.group() diff --git a/mio/models/frames.py b/mio/models/frames.py index ffe41897..1b6ba5d4 100644 --- a/mio/models/frames.py +++ b/mio/models/frames.py @@ -9,8 +9,8 @@ import numpy as np from pydantic import BaseModel, Field, model_validator -from miniscope_io.io import VideoWriter -from miniscope_io.logging import init_logger +from mio.io import VideoWriter +from mio.logging import init_logger T = TypeVar("T", np.ndarray, List[np.ndarray], List[List[np.ndarray]]) diff --git a/mio/models/process.py b/mio/models/process.py index 63872867..ffcf8df7 100644 --- a/mio/models/process.py +++ b/mio/models/process.py @@ -6,7 +6,7 @@ from pydantic import BaseModel, Field -from miniscope_io.models.mixins import YAMLMixin +from mio.models.mixins import YAMLMixin class InteractiveDisplayConfig(BaseModel): diff --git a/mio/plots/video.py b/mio/plots/video.py index 740a1620..807b8dd5 100644 --- a/mio/plots/video.py +++ b/mio/plots/video.py @@ -4,8 +4,8 @@ from typing import List -from miniscope_io import init_logger -from miniscope_io.models.frames import NamedFrame +from mio import init_logger +from mio.models.frames import NamedFrame try: import matplotlib.pyplot as plt From 87e8a44ae0e999e9ed4529f520276c9a4dbab640 Mon Sep 17 00:00:00 2001 From: t-sasatani <33111879+t-sasatani@users.noreply.github.com> Date: Tue, 10 Dec 2024 17:18:28 -0800 Subject: [PATCH 19/20] Add id to denoise config, formatting --- mio/data/config/process/denoise_example.yml | 29 ++++++++++++--------- mio/process/video.py | 3 +++ mio/stream_daq.py | 2 +- 3 files changed, 20 insertions(+), 14 deletions(-) diff --git a/mio/data/config/process/denoise_example.yml b/mio/data/config/process/denoise_example.yml index 901f65c1..1973627e 100644 --- a/mio/data/config/process/denoise_example.yml +++ b/mio/data/config/process/denoise_example.yml @@ -1,26 +1,29 @@ +id: denoise_example +mio_model: tests.test_mixins.LoaderModel +mio_version: 0.6.1.dev16+g6436633.d20241211 interactive_display: - enable: False + enable: false start_frame: 40 end_frame: 140 noise_patch: - enable: True - method: "mean_error" + enable: true + method: mean_error threshold: 30 buffer_size: 5032 buffer_split: 10 diff_multiply: 1 - output_result: True - output_noise_patch: True - output_diff: False + output_result: true + output_noise_patch: true + output_diff: false frequency_masking: - enable: True + enable: true spatial_LPF_cutoff_radius: 20 vertical_BEF_cutoff: 2 horizontal_BEF_cutoff: 0 - display_mask: False - output_mask: True - output_result: True - output_freq_domain: False + display_mask: false + output_mask: true + output_result: true + output_freq_domain: false end_frame: 140 -output_result: True -output_dir: 'user_dir/output' \ No newline at end of file +output_result: true +output_dir: user_dir/output diff --git a/mio/process/video.py b/mio/process/video.py index 75443173..87a77f13 100644 --- a/mio/process/video.py +++ b/mio/process/video.py @@ -213,6 +213,7 @@ class VideoProcessor: """ A class to process video files. """ + @staticmethod def denoise( video_path: str, @@ -315,9 +316,11 @@ def denoise( fps=20, ) if config.noise_patch.output_diff: + """ diff_video = NamedFrame( name=f"diff_{config.noise_patch.diff_multiply}x", video_frame=diff_frames ) + """ if config.noise_patch.output_noise_patch: noise_patch = NamedFrame(name="noise_patch", video_frame=noise_patchs) if config.frequency_masking.output_mask: diff --git a/mio/stream_daq.py b/mio/stream_daq.py index b2e71e12..615ca3a7 100644 --- a/mio/stream_daq.py +++ b/mio/stream_daq.py @@ -20,7 +20,7 @@ from mio.bit_operation import BufferFormatter from mio.devices.mocks import okDevMock from mio.exceptions import EndOfRecordingException, StreamReadError -from mio.io import BufferedCSVWriter +from mio.io import BufferedCSVWriter, VideoWriter from mio.models.stream import ( StreamBufferHeader, StreamBufferHeaderFormat, From 9cd2bc3955cd36d37771ad8ea5cc0dda9c02befb Mon Sep 17 00:00:00 2001 From: t-sasatani <33111879+t-sasatani@users.noreply.github.com> Date: Wed, 11 Dec 2024 15:04:09 -0800 Subject: [PATCH 20/20] allow endless file input --- mio/data/config/process/denoise_example.yml | 4 ++-- mio/process/video.py | 14 +++++++++++--- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/mio/data/config/process/denoise_example.yml b/mio/data/config/process/denoise_example.yml index 1973627e..40c1d52f 100644 --- a/mio/data/config/process/denoise_example.yml +++ b/mio/data/config/process/denoise_example.yml @@ -17,13 +17,13 @@ noise_patch: output_diff: false frequency_masking: enable: true - spatial_LPF_cutoff_radius: 20 + spatial_LPF_cutoff_radius: 15 vertical_BEF_cutoff: 2 horizontal_BEF_cutoff: 0 display_mask: false output_mask: true output_result: true output_freq_domain: false -end_frame: 140 +end_frame: -1 #-1 means all frames output_result: true output_dir: user_dir/output diff --git a/mio/process/video.py b/mio/process/video.py index 87a77f13..7914be86 100644 --- a/mio/process/video.py +++ b/mio/process/video.py @@ -260,7 +260,7 @@ def denoise( # index for frame number in original video try: for index, frame in reader.read_frames(): - if config.end_frame and index > config.end_frame: + if config.end_frame and config.end_frame != -1 and index > config.end_frame: break raw_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) @@ -316,17 +316,25 @@ def denoise( fps=20, ) if config.noise_patch.output_diff: - """ diff_video = NamedFrame( name=f"diff_{config.noise_patch.diff_multiply}x", video_frame=diff_frames ) - """ + diff_video.export( + output_dir / f"{pathstem}", + suffix=True, + fps=20, + ) if config.noise_patch.output_noise_patch: noise_patch = NamedFrame(name="noise_patch", video_frame=noise_patchs) if config.frequency_masking.output_mask: freq_mask_frame = NamedFrame( name="freq_mask", static_frame=freq_mask * np.iinfo(np.uint8).max ) + freq_mask_frame.export( + output_dir / f"{pathstem}", + suffix=True, + fps=20, + ) if config.frequency_masking.enable: freq_domain_video = NamedFrame(name="freq_domain", video_frame=freq_domain_frames)