Skip to content

Commit

Permalink
Refactor local magnitude
Browse files Browse the repository at this point in the history
  • Loading branch information
Marius Isken committed Dec 31, 2023
1 parent 25a428b commit 39595ac
Show file tree
Hide file tree
Showing 5 changed files with 328 additions and 306 deletions.
8 changes: 4 additions & 4 deletions src/qseek/features/ground_motion.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np

from qseek.features.base import EventFeature, FeatureExtractor, ReceiverFeature
from qseek.features.utils import TraceSelectors
from qseek.features.utils import ChannelSelectors

if TYPE_CHECKING:
from pyrocko.squirrel import Squirrel
Expand Down Expand Up @@ -66,9 +66,9 @@ async def add_features(
seconds_before=self.seconds_before,
quantity="velocity",
)
pga = _get_maximum(TraceSelectors.All(traces_acc))
pha = _get_maximum(TraceSelectors.Horizontal(traces_acc))
pgv = _get_maximum(TraceSelectors.All(traces_vel))
pga = _get_maximum(ChannelSelectors.All(traces_acc))
pha = _get_maximum(ChannelSelectors.Horizontal(traces_acc))
pgv = _get_maximum(ChannelSelectors.All(traces_vel))

ground_motion = ReceiverGroundMotion(
seconds_before=self.seconds_before,
Expand Down
14 changes: 12 additions & 2 deletions src/qseek/features/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING

import numpy as np

if TYPE_CHECKING:
from pyrocko.trace import Trace

Expand All @@ -11,6 +13,7 @@
class ChannelSelector:
channels: str
number_channels: int
normalize: bool = False

def get_traces(self, traces: list[Trace]) -> list[Trace]:
traces = [tr for tr in traces if tr.channel[-1] in self.channels]
Expand All @@ -20,12 +23,19 @@ def get_traces(self, traces: list[Trace]) -> list[Trace]:
f" for selector {self.channels}"
f" available: {', '.join('.'.join(tr.nslc_id) for tr in traces)}"
)
if self.normalize:
traces_norm = traces[0].copy()
traces_norm.ydata = np.linalg.norm(
np.array([tr.ydata for tr in traces]), axis=0
)
return [traces_norm]
return traces

__call__ = get_traces


class TraceSelectors:
class ChannelSelectors:
All = ChannelSelector("ENZ0123", 3)
Horizontal = ChannelSelector("EN23", 2)
HorizontalAbs = ChannelSelector("EN23", 2, normalize=True)
Vertical = ChannelSelector("Z0", 1)
NorthEast = ChannelSelector("NE", 2)
19 changes: 10 additions & 9 deletions src/qseek/magnitudes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import TYPE_CHECKING, Literal

from pydantic import BaseModel
from pydantic import BaseModel, Field

if TYPE_CHECKING:
from pyrocko.squirrel import Squirrel
Expand All @@ -13,6 +13,15 @@
class EventMagnitude(BaseModel):
magnitude: Literal["EventMagnitude"] = "EventMagnitude"

average: float = Field(
default=0.0,
description="Average local magnitude.",
)
error: float = Field(
default=0.0,
description="Average error of local magnitude.",
)

@classmethod
def get_subclasses(cls) -> tuple[type[EventMagnitude], ...]:
"""Get the subclasses of this class.
Expand All @@ -26,14 +35,6 @@ def get_subclasses(cls) -> tuple[type[EventMagnitude], ...]:
def name(self) -> str:
return self.__class__.__name__

@property
def average(self) -> float:
raise NotImplementedError

@property
def error(self) -> float:
raise NotImplementedError

def csv_row(self) -> dict[str, float]:
return {
"magnitude": self.average,
Expand Down
Loading

0 comments on commit 39595ac

Please sign in to comment.