Skip to content

Commit

Permalink
Push modules into associated classes
Browse files Browse the repository at this point in the history
  • Loading branch information
t-sasatani committed Dec 6, 2024
1 parent 499f109 commit c0fcb1a
Showing 1 changed file with 90 additions and 102 deletions.
192 changes: 90 additions & 102 deletions miniscope_io/process/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,47 +14,6 @@
from miniscope_io.plots.video import VideoPlotter

logger = init_logger("video")

def gen_freq_mask(
width: int,
height: int,
center_LPF: int,
vertical_BEF: int,
horizontal_BEF: int,
show_mask: bool = False,
) -> np.ndarray:
"""
Generate a mask to filter out horizontal and vertical frequencies.
A central circular region can be removed to allow low frequencies to pass.
"""
crow, ccol = height // 2, width // 2

# Create an initial mask filled with ones (pass all frequencies)
mask = np.ones((height, width), np.uint8)

# Zero out a vertical stripe at the frequency center
mask[:, ccol - vertical_BEF : ccol + vertical_BEF] = 0

# Zero out a horizontal stripe at the frequency center
mask[crow - horizontal_BEF : crow + horizontal_BEF, :] = 0

# Define spacial low pass filter
y, x = np.ogrid[:height, :width]
center_mask = (x - ccol) ** 2 + (y - crow) ** 2 <= center_LPF**2

# Restore the center circular area to allow low frequencies to pass
mask[center_mask] = 1

# Visualize the mask if needed. Might delete later.
if show_mask:
cv2.imshow("Mask", mask * np.iinfo(np.uint8).max)
while True:
if cv2.waitKey(1) == 27: # Press 'Esc' key to exit visualization
break
cv2.destroyAllWindows()
return mask


class FrameProcessor:
"""
A class to process video frames.
Expand Down Expand Up @@ -174,75 +133,51 @@ def remove_stripes(self, img: np.ndarray, mask: np.ndarray) -> np.ndarray:
img_back = np.abs(img_back)

return np.uint8(img_back), np.uint8(magnitude_spectrum)


class FrameListProcessor:
"""
A class to process a list of video frames.
"""

@staticmethod
def get_minimum_projection(image_list: list[np.ndarray]) -> np.ndarray:

def gen_freq_mask(
self,
center_LPF: int,
vertical_BEF: int,
horizontal_BEF: int,
show_mask: bool = False,
) -> np.ndarray:
"""
Get the minimum projection of a list of images.
Parameters:
image_list (list[np.ndarray]): A list of images to project.
Returns:
np.ndarray: The minimum projection of the images.
Generate a mask to filter out horizontal and vertical frequencies.
A central circular region can be removed to allow low frequencies to pass.
"""
stacked_images = np.stack(image_list, axis=0)
min_projection = np.min(stacked_images, axis=0)
return min_projection
crow, ccol = self.height // 2, self.width // 2

@staticmethod
def normalize_video_stack(image_list: list[np.ndarray]) -> list[np.ndarray]:
"""
Normalize a stack of images to 0-255 using max and minimum values of the entire stack.
Return a list of images.
# Create an initial mask filled with ones (pass all frequencies)
mask = np.ones((self.height, self.width), np.uint8)

Parameters:
image_list (list[np.ndarray]): A list of images to normalize.
# Zero out a vertical stripe at the frequency center
mask[:, ccol - vertical_BEF : ccol + vertical_BEF] = 0

Returns:
list[np.ndarray]: The normalized images as a list.
"""
# Zero out a horizontal stripe at the frequency center
mask[crow - horizontal_BEF : crow + horizontal_BEF, :] = 0

# Stack images along a new axis (axis=0)
stacked_images = np.stack(image_list, axis=0)
# Define spacial low pass filter
y, x = np.ogrid[:self.height, :self.width]
center_mask = (x - ccol) ** 2 + (y - crow) ** 2 <= center_LPF**2

# Find the global min and max across the entire stack
global_min = stacked_images.min()
global_max = stacked_images.max()

# Normalize each frame using the global min and max
normalized_images = []
for i in range(stacked_images.shape[0]):
normalized_image = cv2.normalize(
stacked_images[i],
None,
0,
np.iinfo(np.uint8).max,
cv2.NORM_MINMAX,
dtype=cv2.CV_32F,
)
# Apply global normalization
normalized_image = (
(stacked_images[i] - global_min)
/ (global_max - global_min)
* np.iinfo(np.uint8).max
)
normalized_images.append(normalized_image.astype(np.uint8))
# Restore the center circular area to allow low frequencies to pass
mask[center_mask] = 1

return normalized_images
# Visualize the mask if needed. Might delete later.
if show_mask:
cv2.imshow("Mask", mask * np.iinfo(np.uint8).max)
while True:
if cv2.waitKey(1) == 27: # Press 'Esc' key to exit visualization
break
cv2.destroyAllWindows()
return mask


class VideoProcessor:
"""
A class to process video files.
"""

@staticmethod
def denoise(
video_path: str,
slider_plot: bool = True,
Expand Down Expand Up @@ -275,9 +210,7 @@ def denoise(
buffer_split=buffer_split,
)

freq_mask = gen_freq_mask(
width=reader.width,
height=reader.width,
freq_mask = processor.gen_freq_mask(
center_LPF=spatial_LPF,
vertical_BEF=vertical_BEF,
horizontal_BEF=horizontal_BEF,
Expand Down Expand Up @@ -316,12 +249,12 @@ def denoise(
reader.release()
plt.close(fig)

normalized_frames = FrameListProcessor.normalize_video_stack(freq_filtered_frames)
minimum_projection = FrameListProcessor.get_minimum_projection(normalized_frames)
normalized_frames = VideoProcessor.normalize_video_stack(freq_filtered_frames)
minimum_projection = VideoProcessor.get_minimum_projection(normalized_frames)

subtract_minimum = [(frame - minimum_projection) for frame in normalized_frames]

subtract_minimum = FrameListProcessor.normalize_video_stack(subtract_minimum)
subtract_minimum = VideoProcessor.normalize_video_stack(subtract_minimum)

raw_video = NamedFrame(name="RAW", video_frame=raw_frames)
patched_video = NamedFrame(name="Patched", video_frame=patched_frames)
Expand Down Expand Up @@ -352,3 +285,58 @@ def denoise(
VideoPlotter.show_video_with_controls(
videos,
)
@staticmethod
def get_minimum_projection(image_list: list[np.ndarray]) -> np.ndarray:
"""
Get the minimum projection of a list of images.
Parameters:
image_list (list[np.ndarray]): A list of images to project.
Returns:
np.ndarray: The minimum projection of the images.
"""
stacked_images = np.stack(image_list, axis=0)
min_projection = np.min(stacked_images, axis=0)
return min_projection

@staticmethod
def normalize_video_stack(image_list: list[np.ndarray]) -> list[np.ndarray]:
"""
Normalize a stack of images to 0-255 using max and minimum values of the entire stack.
Return a list of images.
Parameters:
image_list (list[np.ndarray]): A list of images to normalize.
Returns:
list[np.ndarray]: The normalized images as a list.
"""

# Stack images along a new axis (axis=0)
stacked_images = np.stack(image_list, axis=0)

# Find the global min and max across the entire stack
global_min = stacked_images.min()
global_max = stacked_images.max()

# Normalize each frame using the global min and max
normalized_images = []
for i in range(stacked_images.shape[0]):
normalized_image = cv2.normalize(
stacked_images[i],
None,
0,
np.iinfo(np.uint8).max,
cv2.NORM_MINMAX,
dtype=cv2.CV_32F,
)
# Apply global normalization
normalized_image = (
(stacked_images[i] - global_min)
/ (global_max - global_min)
* np.iinfo(np.uint8).max
)
normalized_images.append(normalized_image.astype(np.uint8))

return normalized_images

0 comments on commit c0fcb1a

Please sign in to comment.