Skip to content

Commit

Permalink
configure denoise with yaml file
Browse files Browse the repository at this point in the history
  • Loading branch information
t-sasatani committed Dec 6, 2024
1 parent c0fcb1a commit f01fb88
Show file tree
Hide file tree
Showing 5 changed files with 156 additions and 26 deletions.
13 changes: 12 additions & 1 deletion miniscope_io/cli/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import click

from miniscope_io.models.process import DenoiseConfig
from miniscope_io.process.video import VideoProcessor


Expand All @@ -23,10 +24,20 @@ def process() -> None:
type=click.Path(exists=True, dir_okay=False),
help="Path to the video file to process.",
)
@click.option(
"-c",
"--denoise_config",
required=True,
type=click.Path(exists=True, dir_okay=False),
help="Path to the YAML processing configuration file.",
)
def denoise(
input: str,
denoise_config: str,
) -> None:
"""
Denoise a video file.
"""
VideoProcessor.denoise(input)
denoise_config_parsed = DenoiseConfig.from_yaml(denoise_config)
VideoProcessor.denoise(input, denoise_config_parsed)

File renamed without changes.
17 changes: 17 additions & 0 deletions miniscope_io/data/config/process/denoise_example.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
interactive_display:
enable: True
end_frame: 1000
noise_patch:
enable: True
method: "mean_error"
threshold: 20
buffer_size: 5032
buffer_split: 1
diff_multiply: 10
frequency_masking:
enable: True
spacial_LPF_cutoff_radius: 10
vertical_BEF_cutoff: 5
horizontal_BEF_cutoff: 0
display_mask: False
end_frame: 1000
99 changes: 99 additions & 0 deletions miniscope_io/models/process.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
"""
Module for preprocessing data.
"""

from typing import Optional

from pydantic import BaseModel, Field

from miniscope_io.models.mixins import YAMLMixin


class InteractiveDisplayConfig(BaseModel):
"""
Configuration for displaying a video.
"""
enable: bool = Field(
default=False,
description="Whether to plot the output .",
)
end_frame: Optional[int] = Field(
default=100,
description="Frame to end processing at.",
)

class NoisePatchConfig(BaseModel):
"""
Configuration for patch based noise handling.
"""
enable: bool = Field(
default=True,
description="Whether to use patch based noise handling.",
)
method: str = Field(
default="mean_error",
description="Method for handling noise.",
)
threshold: float = Field(
default=20,
description="Threshold for detecting noise.",
)
buffer_size: int = Field(
default=5032,
description="Size of the buffers composing the image."
"This premises that the noisy area will appear in units of buffer_size.",
)
buffer_split: int = Field(
default=1,
description="Number of splits to make in the buffer when detecting noisy areas.",
)
diff_multiply: int = Field(
default=1,
description="Multiplier for the difference between the mean and the pixel value.",
)

class FreqencyMaskingConfig(BaseModel):
"""
Configuration for frequency filtering.
"""
enable: bool = Field(
default=True,
description="Whether to use frequency filtering.",
)
spatial_LPF_cutoff_radius: int = Field(
default=5,
description="Radius for the spatial cutoff.",
)
vertical_BEF_cutoff: int = Field(
default=5,
description="Cutoff for the vertical band elimination filter.",
)
horizontal_BEF_cutoff: int = Field(
default=0,
description="Cutoff for the horizontal band elimination filter.",
)
display_mask: bool = Field(
default=False,
description="Whether to display the mask.",
)

class DenoiseConfig(BaseModel, YAMLMixin):
"""
Configuration for denoising a video.
"""
interactive_display: Optional[InteractiveDisplayConfig] = Field(
default=None,
description="Configuration for displaying the video.",
)
noise_patch: Optional[NoisePatchConfig] = Field(
default=None,
description="Configuration for patch based noise handling.",
)
frequency_masking: Optional[FreqencyMaskingConfig] = Field(
default=None,
description="Configuration for frequency filtering.",
)
end_frame: Optional[int] = Field(
default=None,
description="Frame to end processing at.",
)
53 changes: 28 additions & 25 deletions miniscope_io/process/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from miniscope_io import init_logger
from miniscope_io.io import VideoReader
from miniscope_io.models.frames import NamedFrame
from miniscope_io.models.process import DenoiseConfig
from miniscope_io.plots.video import VideoPlotter

logger = init_logger("video")
Expand All @@ -33,8 +34,6 @@ def __init__(self, height: int, width: int, buffer_size: int = 5032, buffer_spli
"""
self.height = height
self.width = width
self.buffer_size = buffer_size
self.buffer_split = buffer_split

def split_by_length(self, array: np.ndarray, segment_length: int) -> list[np.ndarray]:
"""
Expand All @@ -61,7 +60,12 @@ def split_by_length(self, array: np.ndarray, segment_length: int) -> list[np.nda
return split_arrays

def patch_noisy_buffer(
self, current_frame: np.ndarray, previous_frame: np.ndarray, noise_threshold: float
self,
current_frame: np.ndarray,
previous_frame: np.ndarray,
buffer_size: int,
buffer_split: int,
noise_threshold: float
) -> Tuple[np.ndarray, np.ndarray]:
"""
Process the frame, replacing noisy blocks with those from the previous frame.
Expand All @@ -78,10 +82,10 @@ def patch_noisy_buffer(
serialized_previous = previous_frame.flatten().astype(np.int16)

split_current = self.split_by_length(
serialized_current, self.buffer_size // self.buffer_split
serialized_current, buffer_size // buffer_split
)
split_previous = self.split_by_length(
serialized_previous, self.buffer_size // self.buffer_split
serialized_previous, buffer_size // buffer_split
)

split_output = split_current.copy()
Expand Down Expand Up @@ -180,14 +184,7 @@ class VideoProcessor:
@staticmethod
def denoise(
video_path: str,
slider_plot: bool = True,
end_frame: int = 100,
noise_threshold: float = 20,
spatial_LPF: int = 10,
vertical_BEF: int = 2,
horizontal_BEF: int = 0,
diff_mag: int = 10,
buffer_split: int = 1,
config: DenoiseConfig,
) -> None:
"""
Process a video file and display the results.
Expand All @@ -207,21 +204,21 @@ def denoise(
processor = FrameProcessor(
height=reader.height,
width=reader.width,
buffer_split=buffer_split,
)

freq_mask = processor.gen_freq_mask(
center_LPF=spatial_LPF,
vertical_BEF=vertical_BEF,
horizontal_BEF=horizontal_BEF,
show_mask=False,
)
if config.noise_patch.enable:
freq_mask = processor.gen_freq_mask(
center_LPF=config.frequency_masking.spatial_LPF_cutoff_radius,
vertical_BEF=config.frequency_masking.vertical_BEF_cutoff,
horizontal_BEF=config.frequency_masking.horizontal_BEF_cutoff,
show_mask=config.frequency_masking.display_mask,
)

try:
for frame in reader.read_frames():
raw_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)

if index > end_frame:
if config.end_frame and index > config.end_frame:
break

logger.debug(f"Processing frame {index}")
Expand All @@ -230,7 +227,11 @@ def denoise(
previous_frame = raw_frame

patched_frame, noise_patch = processor.patch_noisy_buffer(
raw_frame, previous_frame, noise_threshold=noise_threshold
raw_frame,
previous_frame,
buffer_size=config.noise_patch.buffer_size,
buffer_split=config.noise_patch.buffer_split,
noise_threshold=config.noise_patch.threshold
)
freq_filtered_frame, frame_freq_domain = processor.remove_stripes(
img=patched_frame, mask=freq_mask
Expand All @@ -242,7 +243,7 @@ def denoise(
freq_domain_frames.append(frame_freq_domain)
noise_patchs.append(noise_patch * np.iinfo(np.uint8).max)
freq_filtered_frames.append(freq_filtered_frame)
diff_frames.append(diff_frame * diff_mag)
diff_frames.append(diff_frame * config.noise_patch.diff_multiply)

index += 1
finally:
Expand All @@ -258,7 +259,9 @@ def denoise(

raw_video = NamedFrame(name="RAW", video_frame=raw_frames)
patched_video = NamedFrame(name="Patched", video_frame=patched_frames)
diff_video = NamedFrame(name=f"Diff {diff_mag}x", video_frame=diff_frames)
diff_video = NamedFrame(
name=f"Diff {config.noise_patch.diff_multiply}x",
video_frame=diff_frames)
noise_patch = NamedFrame(name="Noisy area", video_frame=noise_patchs)
freq_mask_frame = NamedFrame(
name="Freq mask", static_frame=freq_mask * np.iinfo(np.uint8).max
Expand All @@ -269,7 +272,7 @@ def denoise(
min_proj_frame = NamedFrame(name="Min Proj", static_frame=minimum_projection)
subtract_video = NamedFrame(name="Subtracted", video_frame=subtract_minimum)

if slider_plot:
if config.interactive_display.enable:
videos = [
raw_video,
patched_video,
Expand Down

0 comments on commit f01fb88

Please sign in to comment.