diff --git a/miniscope_io/process/video.py b/miniscope_io/process/video.py index 5e9ee9c4..00a7ff8a 100644 --- a/miniscope_io/process/video.py +++ b/miniscope_io/process/video.py @@ -14,47 +14,6 @@ from miniscope_io.plots.video import VideoPlotter logger = init_logger("video") - -def gen_freq_mask( - width: int, - height: int, - center_LPF: int, - vertical_BEF: int, - horizontal_BEF: int, - show_mask: bool = False, -) -> np.ndarray: - """ - Generate a mask to filter out horizontal and vertical frequencies. - A central circular region can be removed to allow low frequencies to pass. - """ - crow, ccol = height // 2, width // 2 - - # Create an initial mask filled with ones (pass all frequencies) - mask = np.ones((height, width), np.uint8) - - # Zero out a vertical stripe at the frequency center - mask[:, ccol - vertical_BEF : ccol + vertical_BEF] = 0 - - # Zero out a horizontal stripe at the frequency center - mask[crow - horizontal_BEF : crow + horizontal_BEF, :] = 0 - - # Define spacial low pass filter - y, x = np.ogrid[:height, :width] - center_mask = (x - ccol) ** 2 + (y - crow) ** 2 <= center_LPF**2 - - # Restore the center circular area to allow low frequencies to pass - mask[center_mask] = 1 - - # Visualize the mask if needed. Might delete later. - if show_mask: - cv2.imshow("Mask", mask * np.iinfo(np.uint8).max) - while True: - if cv2.waitKey(1) == 27: # Press 'Esc' key to exit visualization - break - cv2.destroyAllWindows() - return mask - - class FrameProcessor: """ A class to process video frames. @@ -174,75 +133,51 @@ def remove_stripes(self, img: np.ndarray, mask: np.ndarray) -> np.ndarray: img_back = np.abs(img_back) return np.uint8(img_back), np.uint8(magnitude_spectrum) - - -class FrameListProcessor: - """ - A class to process a list of video frames. - """ - - @staticmethod - def get_minimum_projection(image_list: list[np.ndarray]) -> np.ndarray: + + def gen_freq_mask( + self, + center_LPF: int, + vertical_BEF: int, + horizontal_BEF: int, + show_mask: bool = False, + ) -> np.ndarray: """ - Get the minimum projection of a list of images. - - Parameters: - image_list (list[np.ndarray]): A list of images to project. - - Returns: - np.ndarray: The minimum projection of the images. + Generate a mask to filter out horizontal and vertical frequencies. + A central circular region can be removed to allow low frequencies to pass. """ - stacked_images = np.stack(image_list, axis=0) - min_projection = np.min(stacked_images, axis=0) - return min_projection + crow, ccol = self.height // 2, self.width // 2 - @staticmethod - def normalize_video_stack(image_list: list[np.ndarray]) -> list[np.ndarray]: - """ - Normalize a stack of images to 0-255 using max and minimum values of the entire stack. - Return a list of images. + # Create an initial mask filled with ones (pass all frequencies) + mask = np.ones((self.height, self.width), np.uint8) - Parameters: - image_list (list[np.ndarray]): A list of images to normalize. + # Zero out a vertical stripe at the frequency center + mask[:, ccol - vertical_BEF : ccol + vertical_BEF] = 0 - Returns: - list[np.ndarray]: The normalized images as a list. - """ + # Zero out a horizontal stripe at the frequency center + mask[crow - horizontal_BEF : crow + horizontal_BEF, :] = 0 - # Stack images along a new axis (axis=0) - stacked_images = np.stack(image_list, axis=0) + # Define spacial low pass filter + y, x = np.ogrid[:self.height, :self.width] + center_mask = (x - ccol) ** 2 + (y - crow) ** 2 <= center_LPF**2 - # Find the global min and max across the entire stack - global_min = stacked_images.min() - global_max = stacked_images.max() - - # Normalize each frame using the global min and max - normalized_images = [] - for i in range(stacked_images.shape[0]): - normalized_image = cv2.normalize( - stacked_images[i], - None, - 0, - np.iinfo(np.uint8).max, - cv2.NORM_MINMAX, - dtype=cv2.CV_32F, - ) - # Apply global normalization - normalized_image = ( - (stacked_images[i] - global_min) - / (global_max - global_min) - * np.iinfo(np.uint8).max - ) - normalized_images.append(normalized_image.astype(np.uint8)) + # Restore the center circular area to allow low frequencies to pass + mask[center_mask] = 1 - return normalized_images + # Visualize the mask if needed. Might delete later. + if show_mask: + cv2.imshow("Mask", mask * np.iinfo(np.uint8).max) + while True: + if cv2.waitKey(1) == 27: # Press 'Esc' key to exit visualization + break + cv2.destroyAllWindows() + return mask class VideoProcessor: """ A class to process video files. """ - + @staticmethod def denoise( video_path: str, slider_plot: bool = True, @@ -275,9 +210,7 @@ def denoise( buffer_split=buffer_split, ) - freq_mask = gen_freq_mask( - width=reader.width, - height=reader.width, + freq_mask = processor.gen_freq_mask( center_LPF=spatial_LPF, vertical_BEF=vertical_BEF, horizontal_BEF=horizontal_BEF, @@ -316,12 +249,12 @@ def denoise( reader.release() plt.close(fig) - normalized_frames = FrameListProcessor.normalize_video_stack(freq_filtered_frames) - minimum_projection = FrameListProcessor.get_minimum_projection(normalized_frames) + normalized_frames = VideoProcessor.normalize_video_stack(freq_filtered_frames) + minimum_projection = VideoProcessor.get_minimum_projection(normalized_frames) subtract_minimum = [(frame - minimum_projection) for frame in normalized_frames] - subtract_minimum = FrameListProcessor.normalize_video_stack(subtract_minimum) + subtract_minimum = VideoProcessor.normalize_video_stack(subtract_minimum) raw_video = NamedFrame(name="RAW", video_frame=raw_frames) patched_video = NamedFrame(name="Patched", video_frame=patched_frames) @@ -352,3 +285,58 @@ def denoise( VideoPlotter.show_video_with_controls( videos, ) + @staticmethod + def get_minimum_projection(image_list: list[np.ndarray]) -> np.ndarray: + """ + Get the minimum projection of a list of images. + + Parameters: + image_list (list[np.ndarray]): A list of images to project. + + Returns: + np.ndarray: The minimum projection of the images. + """ + stacked_images = np.stack(image_list, axis=0) + min_projection = np.min(stacked_images, axis=0) + return min_projection + + @staticmethod + def normalize_video_stack(image_list: list[np.ndarray]) -> list[np.ndarray]: + """ + Normalize a stack of images to 0-255 using max and minimum values of the entire stack. + Return a list of images. + + Parameters: + image_list (list[np.ndarray]): A list of images to normalize. + + Returns: + list[np.ndarray]: The normalized images as a list. + """ + + # Stack images along a new axis (axis=0) + stacked_images = np.stack(image_list, axis=0) + + # Find the global min and max across the entire stack + global_min = stacked_images.min() + global_max = stacked_images.max() + + # Normalize each frame using the global min and max + normalized_images = [] + for i in range(stacked_images.shape[0]): + normalized_image = cv2.normalize( + stacked_images[i], + None, + 0, + np.iinfo(np.uint8).max, + cv2.NORM_MINMAX, + dtype=cv2.CV_32F, + ) + # Apply global normalization + normalized_image = ( + (stacked_images[i] - global_min) + / (global_max - global_min) + * np.iinfo(np.uint8).max + ) + normalized_images.append(normalized_image.astype(np.uint8)) + + return normalized_images