From 5ac1293f98bd7f101acfe8d70883cf0fd01c41f9 Mon Sep 17 00:00:00 2001 From: t-sasatani <33111879+t-sasatani@users.noreply.github.com> Date: Thu, 5 Dec 2024 11:15:33 -0800 Subject: [PATCH] Side-by-side comparison setup --- miniscope_io/processing/video.py | 54 +++++++++++++++++++++++++++----- 1 file changed, 46 insertions(+), 8 deletions(-) diff --git a/miniscope_io/processing/video.py b/miniscope_io/processing/video.py index 3dc5fd99..6c23731b 100644 --- a/miniscope_io/processing/video.py +++ b/miniscope_io/processing/video.py @@ -13,6 +13,30 @@ logger = init_logger("video") +def plot_frames_side_by_side( + fig: plt.Figure, + frames: list[np.ndarray], + titles: str =None + ) -> None: + """ + Plot a list of frames side by side using matplotlib. + + :param frames: List of frames (images) to be plotted + :param titles: Optional list of titles for each subplot + """ + num_frames = len(frames) + plt.clf() # Clear current figure + + for i, frame in enumerate(frames): + plt.subplot(1, num_frames, i + 1) + plt.imshow(frame, cmap='gray') + if titles: + plt.title(titles[i]) + + plt.axis('off') # Turn off axis labels + + plt.tight_layout() + fig.canvas.draw() class AnnotatedFrameModel(BaseModel): """ A class to represent video data. @@ -85,6 +109,7 @@ def __del__(self): def gen_freq_mask( width: int = 200, height: int = 200, + center_radius: int = 6, show_mask: bool = True ) -> np.ndarray: """ @@ -107,7 +132,7 @@ def gen_freq_mask( mask[crow - horizontal_band_width:crow + horizontal_band_width, :] = 0 # Define the radius of the circular region to retain at the center - radius = 6 + radius = center_radius y, x = np.ogrid[:height, :width] center_mask = (x - ccol) ** 2 + (y - crow) ** 2 <= radius ** 2 @@ -222,6 +247,7 @@ def remove_stripes( frames = [] index = 0 + fig = plt.figure(figsize=(12, 4)) processor = FrameProcessor( height=200, @@ -231,7 +257,7 @@ def remove_stripes( freq_mask = gen_freq_mask( width=200, height=200, - show_mask=True + show_mask=False ) try: for frame in reader.read_frames(): @@ -253,12 +279,24 @@ def remove_stripes( ) frames.append(filtered_frame) - # Display the results for visualization - for frame in frames: - cv2.imshow('Video', frame) - if cv2.waitKey(100) & 0xFF == ord('q'): - break + frames_to_plot = [ + freq_mask, + gray_frame, + processed_frame, + filtered_frame, + ] + plot_frames_side_by_side( + fig, + frames_to_plot, + titles=[ + 'Frequency Mask', + 'Original Frame', + 'Processed Frame', + 'Filtered Frame', + ] + ) + plt.pause(0.01) finally: reader.release() - cv2.destroyAllWindows() \ No newline at end of file + plt.close(fig) \ No newline at end of file