Skip to content

Commit

Permalink
Fix config, add start/end of display
Browse files Browse the repository at this point in the history
  • Loading branch information
t-sasatani committed Dec 10, 2024
1 parent bc9a268 commit fa9a036
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 17 deletions.
11 changes: 6 additions & 5 deletions miniscope_io/data/config/process/denoise_example.yml
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
interactive_display:
enable: True
end_frame: 1000
start_frame: 40
end_frame: 90
noise_patch:
enable: True
method: "mean_error"
threshold: 20
threshold: 30
buffer_size: 5032
buffer_split: 1
buffer_split: 10
diff_multiply: 1
output_result: True
output_noise_patch: True
output_diff: True
frequency_masking:
enable: True
spacial_LPF_cutoff_radius: 10
vertical_BEF_cutoff: 1
spatial_LPF_cutoff_radius: 15
vertical_BEF_cutoff: 2
horizontal_BEF_cutoff: 0
display_mask: False
output_mask: True
Expand Down
8 changes: 6 additions & 2 deletions miniscope_io/models/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,12 @@ class InteractiveDisplayConfig(BaseModel):
default=False,
description="Whether to plot the output .",
)
start_frame: Optional[int] = Field(
default=...,
description="Frame to start processing at.",
)
end_frame: Optional[int] = Field(
default=100,
default=...,
description="Frame to end processing at.",
)

Expand Down Expand Up @@ -78,7 +82,7 @@ class FreqencyMaskingConfig(BaseModel):
description="Whether to use frequency filtering.",
)
spatial_LPF_cutoff_radius: int = Field(
default=5,
default=...,
description="Radius for the spatial cutoff.",
)
vertical_BEF_cutoff: int = Field(
Expand Down
25 changes: 23 additions & 2 deletions miniscope_io/plots/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import List

from miniscope_io.models.frames import NamedFrame
from miniscope_io import init_logger

try:
import matplotlib.pyplot as plt
Expand All @@ -18,14 +19,20 @@
Slider = None
KeyEvent = None

logger = init_logger("videoplot")

class VideoPlotter:
"""
Class to display video streams and static images.
"""

@staticmethod
def show_video_with_controls(videos: List[NamedFrame], fps: int = 20) -> None:
def show_video_with_controls(
videos: List[NamedFrame],
start_frame: int,
end_frame: int,
fps: int = 20,
) -> None:
"""
Plot multiple video streams or static images side-by-side.
Can play/pause and navigate frames.
Expand All @@ -34,6 +41,10 @@ def show_video_with_controls(videos: List[NamedFrame], fps: int = 20) -> None:
----------
videos : NamedFrame
NamedFrame object containing video data and names.
start_frame : int
Starting frame index for the video display.
end_frame : int
Ending frame index for the video display.
fps : int, optional
Frames per second for the video, by default 20
"""
Expand All @@ -45,7 +56,6 @@ def show_video_with_controls(videos: List[NamedFrame], fps: int = 20) -> None:

if any(frame.frame_type == "video_list_frame" for frame in videos):
raise NotImplementedError("Only single videos or frames are supported for now.")

# Wrap static images in lists to handle them uniformly
video_frames = [
frame.data if frame.frame_type == "video_frame" else [frame.data] for frame in videos
Expand All @@ -54,7 +64,18 @@ def show_video_with_controls(videos: List[NamedFrame], fps: int = 20) -> None:
titles = [video.name for video in videos]

num_streams = len(video_frames)

logger.info(f"Displaying {num_streams} video streams.")
if end_frame > start_frame:
logger.info(f"Displaying frames {start_frame} to {end_frame}.")
for stream_index in range(len(video_frames)):
logger.info(f"Stream length: {len(video_frames[stream_index])}")
if len(video_frames[stream_index]) > 1:
video_frames[stream_index] = video_frames[stream_index][start_frame:end_frame]
logger.info(f"Trimmed stream length: {len(video_frames[stream_index])}")

num_frames = max(len(stream) for stream in video_frames)
logger.info(f"Max stream length: {num_frames}")

fig, axes = plt.subplots(1, num_streams, figsize=(20, 5))

Expand Down
44 changes: 36 additions & 8 deletions miniscope_io/process/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,19 @@ def patch_noisy_buffer(
serialized_current = current_frame.flatten().astype(np.int16)
serialized_previous = previous_frame.flatten().astype(np.int16)

split_current = self.split_by_length(serialized_current, buffer_size // buffer_split)
split_previous = self.split_by_length(serialized_previous, buffer_size // buffer_split)
buffer_per_frame = len(serialized_current) // buffer_size + 1

split_current = self.split_by_length(
serialized_current,
buffer_size // buffer_split)
split_previous = self.split_by_length(
serialized_previous,
buffer_size // buffer_split)

split_output = split_current.copy()
noisy_parts = split_current.copy()

'''
for i in range(len(split_current)):
mean_error = abs(split_current[i] - split_previous[i]).mean()
if mean_error > noise_threshold:
Expand All @@ -103,6 +110,25 @@ def patch_noisy_buffer(
else:
split_output[i] = split_current[i]
noisy_parts[i] = np.zeros_like(split_current[i], np.uint8)
'''
buffer_has_noise = False
for buffer_index in range(buffer_per_frame):
for split_index in range(buffer_split):
i = buffer_index * buffer_split + split_index
mean_error = abs(split_current[i] - split_previous[i]).mean()
if mean_error > noise_threshold:
logger.info(f"Replacing buffer {i} with mean error {mean_error}")
buffer_has_noise = True
break
else:
split_output[i] = split_current[i]
noisy_parts[i] = np.zeros_like(split_current[i], np.uint8)
if buffer_has_noise:
for split_index in range(buffer_split):
i = buffer_index * buffer_split + split_index
split_output[i] = split_previous[i]
noisy_parts[i] = np.ones_like(split_current[i], np.uint8)
buffer_has_noise = False

serialized_output = np.concatenate(split_output)[: self.height * self.width]
noise_output = np.concatenate(noisy_parts)[: self.height * self.width]
Expand Down Expand Up @@ -314,18 +340,20 @@ def denoise(
if config.interactive_display.enable:
videos = [
raw_video,
patched_video,
diff_video,
noise_patch,
freq_mask_frame,
freq_domain_video,
patched_video,
freq_filtered_video,
normalized_video,
freq_domain_video,
min_proj_frame,
subtract_video,
freq_mask_frame,
#diff_video,
#normalized_video,
#subtract_video,
]
VideoPlotter.show_video_with_controls(
videos,
start_frame=config.interactive_display.start_frame,
end_frame=config.interactive_display.end_frame,
)

@staticmethod
Expand Down

0 comments on commit fa9a036

Please sign in to comment.