Skip to content

Commit

Permalink
Refactor phase arrival classes and update type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
Marius Isken committed Jan 15, 2024
1 parent 3dffe24 commit d962e22
Show file tree
Hide file tree
Showing 9 changed files with 31 additions and 62 deletions.
13 changes: 7 additions & 6 deletions src/qseek/images/base.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING, Literal
from typing import TYPE_CHECKING, Literal, NamedTuple

import numpy as np
from pydantic import BaseModel, Field

from qseek.models.phase_arrival import PhaseArrival
from qseek.models.station import Stations
from qseek.utils import PhaseDescription, downsample

Expand All @@ -16,8 +15,10 @@
from pyrocko.trace import Trace


class PickedArrival(PhaseArrival):
provider: Literal["PickedArrival"] = "PickedArrival"
class ObservedArrival(NamedTuple):
phase: str
time: datetime
detection_value: float


class ImageFunction(BaseModel):
Expand Down Expand Up @@ -120,7 +121,7 @@ def search_phase_arrival(
modelled_arrival: datetime,
search_length_seconds: float = 5,
threshold: float = 0.1,
) -> PickedArrival | None:
) -> ObservedArrival | None:
"""Search for a peak in all station's image functions.
Args:
Expand All @@ -140,7 +141,7 @@ def search_phase_arrivals(
modelled_arrivals: list[datetime | None],
search_length_seconds: float = 5,
threshold: float = 0.1,
) -> list[PickedArrival | None]:
) -> list[ObservedArrival | None]:
"""Search for a peak in all station's image functions.
Args:
Expand Down
10 changes: 2 additions & 8 deletions src/qseek/images/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@

from pydantic import Field, PositiveInt, PrivateAttr, RootModel, computed_field

from qseek.images.base import ImageFunction, PickedArrival
from qseek.images.phase_net import PhaseNet, PhaseNetPick
from qseek.images.base import ImageFunction
from qseek.images.phase_net import PhaseNet
from qseek.stats import Stats
from qseek.utils import PhaseDescription, datetime_now, human_readable_bytes

Expand All @@ -31,12 +31,6 @@
Field(..., discriminator="image"),
]

# Make this a Union when more picks are implemented
ImageFunctionPick = Annotated[
Union[PhaseNetPick, PickedArrival],
Field(..., discriminator="provider"),
]


class ImageFunctionsStats(Stats):
time_per_batch: timedelta = timedelta()
Expand Down
12 changes: 3 additions & 9 deletions src/qseek/images/phase_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from pyrocko import obspy_compat
from seisbench import logger as seisbench_logger

from qseek.images.base import ImageFunction, PickedArrival, WaveformImage
from qseek.images.base import ImageFunction, ObservedArrival, WaveformImage
from qseek.utils import alog_call, to_datetime

obspy_compat.plant()
Expand Down Expand Up @@ -40,20 +40,14 @@
StackMethod = Literal["avg", "max"]


class PhaseNetPick(PickedArrival):
provider: Literal["PhaseNetPick"] = "PhaseNetPick"
phase: str
detection_value: float


class PhaseNetImage(WaveformImage):
def search_phase_arrival(
self,
trace_idx: int,
modelled_arrival: datetime,
search_length_seconds: float = 5.0,
threshold: float = 0.1,
) -> PhaseNetPick | None:
) -> ObservedArrival | None:
"""Search for a peak in all station's image functions.
Args:
Expand All @@ -76,7 +70,7 @@ def search_phase_arrival(
time_seconds, value = search_trace.max()
if value < threshold:
return None
return PhaseNetPick(
return ObservedArrival(
time=to_datetime(time_seconds),
detection_value=float(value),
phase=self.phase,
Expand Down
8 changes: 4 additions & 4 deletions src/qseek/models/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@

from qseek.console import console
from qseek.features import EventFeaturesType
from qseek.images.images import ImageFunctionPick
from qseek.images.base import ObservedArrival
from qseek.magnitudes import EventMagnitudeType
from qseek.models.detection_uncertainty import DetectionUncertainty
from qseek.models.location import Location
from qseek.models.station import Station, Stations
from qseek.stats import Stats
from qseek.tracers.tracers import RayTracerArrival
from qseek.tracers.base import ModelledArrival
from qseek.utils import (
NSL,
MeasurementUnit,
Expand Down Expand Up @@ -84,8 +84,8 @@ def get_row(self, row_index: int) -> str:

class PhaseDetection(BaseModel):
phase: PhaseDescription
model: RayTracerArrival
observed: ImageFunctionPick | None = None
model: ModelledArrival
observed: ObservedArrival | None = None

@property
def traveltime_delay(self) -> timedelta | None:
Expand Down
7 changes: 0 additions & 7 deletions src/qseek/models/phase_arrival.py

This file was deleted.

12 changes: 4 additions & 8 deletions src/qseek/tracers/base.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Literal, Sequence, TypeVar
from typing import TYPE_CHECKING, Literal, NamedTuple, Sequence, TypeVar

import numpy as np
from pydantic import BaseModel

from qseek.models.location import Location
from qseek.models.phase_arrival import PhaseArrival

if TYPE_CHECKING:
from datetime import datetime
Expand All @@ -18,12 +17,9 @@
_LocationType = TypeVar("_LocationType", bound=Location)


class ModelledArrival(PhaseArrival):
tracer: Literal["ModelledArrival"] = "ModelledArrival"

@classmethod
def get_subclasses(cls) -> tuple[type[ModelledArrival], ...]:
return tuple(cls.__subclasses__())
class ModelledArrival(NamedTuple):
phase: str
time: datetime


class RayTracer(BaseModel):
Expand Down
12 changes: 5 additions & 7 deletions src/qseek/tracers/cake.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,6 @@
DEFAULT_VELOCITY_MODEL_FILE.write_text(DEFAULT_VELOCITY_MODEL)


class CakeArrival(ModelledArrival):
tracer: Literal["CakeArrival"] = "CakeArrival"
phase: str


class EarthModel(BaseModel):
filename: FilePath | None = Field(
default=DEFAULT_VELOCITY_MODEL_FILE,
Expand Down Expand Up @@ -644,7 +639,7 @@ def get_arrivals(
event_time: datetime,
source: Location,
receivers: Sequence[Location],
) -> list[CakeArrival | None]:
) -> list[ModelledArrival | None]:
traveltimes = self.get_travel_times_locations(
phase,
source=source,
Expand All @@ -657,6 +652,9 @@ def get_arrivals(
continue

arrivaltime = event_time + timedelta(seconds=traveltime)
arrival = CakeArrival(time=arrivaltime, phase=phase)
arrival = ModelledArrival(
time=arrivaltime,
phase=phase,
)
arrivals.append(arrival)
return arrivals
12 changes: 5 additions & 7 deletions src/qseek/tracers/constant_velocity.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,6 @@
from qseek.octree import Octree


class ConstantVelocityArrival(ModelledArrival):
tracer: Literal["ConstantVelocityArrival"] = "ConstantVelocityArrival"
phase: str


class ConstantVelocityTracer(RayTracer):
tracer: Literal["ConstantVelocityTracer"] = "ConstantVelocityTracer"
phase: PhaseDescription = Field(
Expand Down Expand Up @@ -66,7 +61,7 @@ def get_arrivals(
event_time: datetime,
source: Location,
receivers: Sequence[Location],
) -> list[ConstantVelocityArrival]:
) -> list[ModelledArrival]:
self._check_phase(phase)

traveltimes = self.get_travel_times_locations(
Expand All @@ -77,6 +72,9 @@ def get_arrivals(
arrivals = []
for traveltime in traveltimes:
arrivaltime = event_time + timedelta(seconds=traveltime)
arrival = ConstantVelocityArrival(time=arrivaltime, phase=phase)
arrival = ModelledArrival(
phase=phase,
time=arrivaltime,
)
arrivals.append(arrival)
return arrivals
7 changes: 1 addition & 6 deletions src/qseek/tracers/tracers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
cake, # noqa: F401
constant_velocity, # noqa: F401
)
from qseek.tracers.base import ModelledArrival, RayTracer
from qseek.tracers.base import RayTracer

if TYPE_CHECKING:
from qseek.models.station import Stations
Expand All @@ -24,11 +24,6 @@
Field(..., discriminator="tracer"),
]

RayTracerArrival = Annotated[
Union[(ModelledArrival, *ModelledArrival.get_subclasses())],
Field(..., discriminator="tracer"),
]


class RayTracers(RootModel):
root: list[RayTracerType] = []
Expand Down

0 comments on commit d962e22

Please sign in to comment.