Skip to content

Commit

Permalink
Adding PreProcessing and DeepDenoiser
Browse files Browse the repository at this point in the history
  • Loading branch information
Marius Isken committed Feb 5, 2024
1 parent 4fbeeaf commit 6d1ffe2
Show file tree
Hide file tree
Showing 17 changed files with 483 additions and 109 deletions.
8 changes: 7 additions & 1 deletion src/qseek/apps/qseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,7 @@ async def extract() -> None:
from qseek.corrections.base import TravelTimeCorrections
from qseek.features.base import FeatureExtractor
from qseek.magnitudes.base import EventMagnitudeCalculator
from qseek.pre_processing.base import BatchPreProcessing
from qseek.tracers.base import RayTracer
from qseek.waveforms.base import WaveformProvider

Expand All @@ -303,10 +304,11 @@ async def extract() -> None:
table.add_column("Description")

module_classes = (
WaveformProvider,
BatchPreProcessing,
RayTracer,
FeatureExtractor,
EventMagnitudeCalculator,
WaveformProvider,
TravelTimeCorrections,
)

Expand All @@ -332,6 +334,10 @@ def is_insight(module: type) -> bool:
table.add_section()

console.print(table)
console.print("🔑 indicates an insight module\n")
console.print(
"Use `qseek modules --json <module_name>` to print the JSON schema"
)

case "dump-schemas":
import json
Expand Down
11 changes: 8 additions & 3 deletions src/qseek/images/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,14 @@ class ImageFunctionsStats(Stats):
time_per_batch: timedelta = timedelta()
bytes_per_second: float = 0.0

_queue: asyncio.Queue[WaveformImages | None] | None = PrivateAttr(None)
_queue: asyncio.Queue[
Tuple[WaveformImages | WaveformBatch] | None
] | None = PrivateAttr(None)

def set_queue(self, queue: asyncio.Queue[WaveformImages | None]) -> None:
def set_queue(
self,
queue: asyncio.Queue[Tuple[WaveformImages | WaveformBatch] | None],
) -> None:
self._queue = queue

@computed_field
Expand Down Expand Up @@ -69,7 +74,7 @@ class ImageFunctions(RootModel):

_queue: asyncio.Queue[Tuple[WaveformImages, WaveformBatch] | None] = PrivateAttr()
_processed_images: int = PrivateAttr(0)
_stats = PrivateAttr(default_factory=ImageFunctionsStats)
_stats: ImageFunctionsStats = PrivateAttr(default_factory=ImageFunctionsStats)

def model_post_init(self, __context: Any) -> None:
# Check if phases are provided twice
Expand Down
16 changes: 10 additions & 6 deletions src/qseek/images/phase_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,10 @@ class PhaseNet(ImageFunction):
le=3000,
description="Window overlap in samples.",
)
torch_use_cuda: bool = Field(
torch_use_cuda: bool | int = Field(
default=False,
description="Use CUDA for inference.",
description="Use CUDA for inference. If `True` use default device, if `int` use"
" the specified device.",
)
torch_cpu_threads: PositiveInt = Field(
default=4,
Expand Down Expand Up @@ -147,17 +148,20 @@ def _prepare(self) -> None:

torch.set_num_threads(self.torch_cpu_threads)
self._phase_net = PhaseNetSeisBench.from_pretrained(self.model)
if self.torch_use_cuda:
if isinstance(self.torch_use_cuda, bool):
self._phase_net.cuda()
else:
self._phase_net.cuda(self.torch_use_cuda)
self._phase_net.eval()
try:
logger.info("compiling PhaseNet model...")
self._phase_net = torch.compile(self._phase_net, mode="reduce-overhead")
self._phase_net = torch.compile(self._phase_net, mode="max-autotune")
except RuntimeError as exc:
logger.warning(
"failed to compile PhaseNet model, using uncompiled model.",
exc_info=exc,
)
if self.torch_use_cuda:
self._phase_net.cuda()
self._phase_net.eval()

@property
def blinding(self) -> timedelta:
Expand Down
2 changes: 2 additions & 0 deletions src/qseek/magnitudes/local_magnitude.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@ def plot(self) -> None:


class LocalMagnitudeExtractor(EventMagnitudeCalculator):
"""Local magnitude calculator for different regional models."""

magnitude: Literal["LocalMagnitude"] = "LocalMagnitude"

seconds_before: PositiveFloat = Field(
Expand Down
2 changes: 2 additions & 0 deletions src/qseek/magnitudes/moment_magnitude.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,8 @@ async def add_traces(


class MomentMagnitudeExtractor(EventMagnitudeCalculator):
"""Moment magnitude calculator from peak amplitudes."""

magnitude: Literal["MomentMagnitude"] = "MomentMagnitude"

seconds_before: PositiveFloat = Field(
Expand Down
4 changes: 1 addition & 3 deletions src/qseek/models/station.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,14 @@
from pyrocko.model import Station as PyrockoStation
from pyrocko.model import dump_stations_yaml, load_stations

from qseek.utils import NSL
from qseek.utils import NSL, NSL_RE

if TYPE_CHECKING:
from pyrocko.squirrel import Squirrel
from pyrocko.trace import Trace

from qseek.models.location import CoordSystem, Location

NSL_RE = r"^[a-zA-Z0-9]{0,2}\.[a-zA-Z0-9]{0,5}\.[a-zA-Z0-9]{0,3}$"

logger = logging.getLogger(__name__)


Expand Down
Empty file.
74 changes: 74 additions & 0 deletions src/qseek/pre_processing/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Literal

from pydantic import BaseModel, Field, field_validator

from qseek.utils import NSL

if TYPE_CHECKING:
from pyrocko.trace import Trace

from qseek.waveforms.base import WaveformBatch


class BatchPreProcessing(BaseModel):
process: Literal["BasePreProcessing"] = "BasePreProcessing"

stations: set[NSL] = Field(
default=set(),
description="List of station codes to process. E.g. ['6E.BFO', '6E.BHZ']. "
"If empty, all stations are processed.",
)

@field_validator("stations")
@classmethod
def validate_stations(cls, v) -> set[NSL]:
stations = set()
for station in v:
stations.add(NSL.parse(station))
return stations

@classmethod
def get_subclasses(cls) -> tuple[type[BatchPreProcessing], ...]:
"""
Returns a tuple of all the subclasses of BasePreProcessing.
"""
return tuple(cls.__subclasses__())

def select_traces(self, batch: WaveformBatch) -> list[Trace]:
"""
Selects traces from the given list based on the stations specified.
Args:
traces (list[Trace]): The list of traces to select from.
Returns:
list[Trace]: The selected traces.
"""
if not self.stations:
return batch.traces
return [
trace
for trace in batch.traces
if NSL.parse(trace.nslc_id).station in self.stations
]

async def prepare(self) -> None:
"""
Prepare the pre-processing module.
"""
pass

async def process_batch(self, batch: WaveformBatch) -> WaveformBatch:
"""
Process a list of traces.
Args:
traces (list[Trace]): The list of traces to be processed.
Returns:
list[Trace]: The processed list of traces.
"""
raise NotImplementedError
73 changes: 73 additions & 0 deletions src/qseek/pre_processing/deep_denoiser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from __future__ import annotations

import asyncio
import logging
from typing import TYPE_CHECKING, Literal

from obspy import Stream
from pydantic import Field, PrivateAttr
from pyrocko import obspy_compat

from qseek.pre_processing.base import BatchPreProcessing

if TYPE_CHECKING:
from seisbench.models import DeepDenoiser as SeisBenchDeepDenoiser

from qseek.waveforms.base import WaveformBatch


logger = logging.getLogger(__name__)
DenoiserModels = Literal["original", "instance"]
obspy_compat.plant()


class DeepDenoiser(BatchPreProcessing):
"""De-noise the traces using the DeepDenoiser neural network (Zhu et al., 2019)."""

process: Literal["deep-denoiser"] = "deep-denoiser"

model: DenoiserModels = Field(
"original",
description="The model to use for denoising.",
)
torch_use_cuda: bool | str = Field(
False,
description="Whether to use CUDA for the PyTorch model."
"A string can be used to specify the device.",
)

_denoiser: SeisBenchDeepDenoiser = PrivateAttr()

async def prepare(self) -> None:
import torch
from seisbench.models import DeepDenoiser as SeisBenchDeepDenoiser

self._denoiser = SeisBenchDeepDenoiser.from_pretrained(self.model)
if self.torch_use_cuda:
if isinstance(self.torch_use_cuda, bool):
self._denoiser.cuda()
else:
self._denoiser.cuda(self.torch_use_cuda)

self._denoiser.eval()
try:
logger.info("compiling DeepDenoiser model...")
self._denoiser = torch.compile(self._denoiser, mode="max-autotune")
except RuntimeError as exc:
logger.warning(
"failed to compile PhaseNet model, using uncompiled model.",
exc_info=exc,
)

async def process_batch(self, batch: WaveformBatch) -> WaveformBatch:
if self._denoiser is None:
raise RuntimeError("DeepDenoiser model not initialized.")

stream = Stream(tr.to_obspy_trace() for tr in self.select_traces(batch))
stream = await asyncio.to_thread(self._denoiser.annotate, stream)

denoised_traces = [tr.to_pyrocko_trace() for tr in stream]
for tr in denoised_traces:
tr.channel = tr.channel.replace("DeepDenoiser_", "")

return batch
32 changes: 32 additions & 0 deletions src/qseek/pre_processing/downsample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from __future__ import annotations

import asyncio
from typing import TYPE_CHECKING, Literal

from pydantic import Field, PositiveFloat

from qseek.pre_processing.base import BatchPreProcessing

if TYPE_CHECKING:
from qseek.waveforms.base import WaveformBatch


class Downsample(BatchPreProcessing):
"""Downsample the traces to a new sampling frequency."""

process: Literal["downsample"] = "downsample"
sampling_frequency: PositiveFloat = Field(
100.0,
description="The new sampling frequency in Hz.",
)

async def process_batch(self, batch: WaveformBatch) -> WaveformBatch:
desired_deltat = 1 / self.sampling_frequency

def worker() -> None:
for trace in self.select_traces(batch):
if trace.deltat < desired_deltat:
trace.downsample_to(deltat=desired_deltat, allow_upsample_max=3)

await asyncio.to_thread(worker)
return batch
Loading

0 comments on commit 6d1ffe2

Please sign in to comment.