Skip to content

Commit

Permalink
Refactor amplitude interpolation and add error handling
Browse files Browse the repository at this point in the history
  • Loading branch information
Marius Isken committed Jan 5, 2024
1 parent 09fa4d9 commit c4b1200
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 23 deletions.
124 changes: 104 additions & 20 deletions src/qseek/magnitudes/moment_magnitude_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,50 @@ class ModelledAmplitude(NamedTuple):
amplitude_median: float
std: float

def combine(
self,
amplitude: ModelledAmplitude,
weight: float = 1.0,
) -> ModelledAmplitude:
"""
Combines with another ModelledAmplitude using a weighted average.
Args:
amplitude (ModelledAmplitude): The ModelledAmplitude to be combined with.
weight (float, optional): The weight of the amplitude being combined.
Defaults to 1.0.
Returns:
Self: A new instance of the ModelledAmplitude class with the combined values.
Raises:
ValueError: If the weight is not between 0.0 and 1.0 (inclusive).
ValueError: If the distances of the amplitudes are different.
ValueError: If the peak amplitudes of the amplitudes are different.
"""
if not 0.0 <= weight <= 1.0:
raise ValueError(f"Invalid weight {weight}.")
if self.distance != amplitude.distance:
raise ValueError(
f"Cannot add amplitudes with different distances "
f"{self.distance} and {amplitude.distance}."
)
if self.peak_amplitude != amplitude.peak_amplitude:
raise ValueError(
f"Cannot add amplitudes with different peak amplitudes "
f"{self.peak_amplitude} and {amplitude.peak_amplitude}."
)
inv_weight = 1.0 - weight
return ModelledAmplitude(
distance=self.distance,
peak_amplitude=self.peak_amplitude,
amplitude_avg=self.amplitude_avg * inv_weight
+ amplitude.amplitude_avg * weight,
amplitude_median=self.amplitude_median * inv_weight
+ amplitude.amplitude_median * weight,
std=self.std * inv_weight + amplitude.std * weight,
)


class SiteAmplitudesCollection(BaseModel):
source_depth: float
Expand Down Expand Up @@ -297,7 +341,7 @@ def plot(
interp_std = np.array([amp.std for amp in interp_amplitudes])

site_amplitudes = getattr(self, f"_{peak_amplitude.replace('peak_', '')}")
dynamic = Range.from_array(site_amplitudes)
dynamic = Range.from_list(site_amplitudes)

ax.scatter(
self._distances,
Expand Down Expand Up @@ -428,6 +472,10 @@ def from_selector(cls, selector: PeakAmplitudesBase) -> Self:
kwargs["frequency_range"] = selector.frequency_range or store_frequency_range
return cls(**kwargs)

@property
def source_depth_range(self) -> Range:
return Range.from_list([sa.source_depth for sa in self.site_amplitudes])

def get_store(self) -> gf.Store:
"""
Load the GF store for the given store ID.
Expand Down Expand Up @@ -614,7 +662,7 @@ async def get_modelled_waveforms() -> tuple[gf.Response, list[gf.Target]]:
self.save()
return collection

def get_collection(self, depth: float) -> SiteAmplitudesCollection:
def get_collection(self, source_depth: float) -> SiteAmplitudesCollection:
"""
Get the site amplitudes collection for the given source depth.
Expand All @@ -625,9 +673,9 @@ def get_collection(self, depth: float) -> SiteAmplitudesCollection:
SiteAmplitudesCollection: The site amplitudes collection.
"""
for site_amplitudes in self.site_amplitudes:
if site_amplitudes.source_depth == depth:
if site_amplitudes.source_depth == source_depth:
return site_amplitudes
raise KeyError(f"No site amplitudes for depth {depth}.")
raise KeyError(f"No site amplitudes for depth {source_depth}.")

def new_collection(self, depth: float) -> SiteAmplitudesCollection:
"""
Expand Down Expand Up @@ -658,6 +706,7 @@ async def get_amplitude(
max_distance: float = 1.0 * KM,
peak_amplitude: PeakAmplitude = "absolute",
auto_fill: bool = True,
interpolation: Literal["nearest", "linear"] = "linear",
) -> ModelledAmplitude:
"""
Retrieves the amplitude for a given depth and distance.
Expand All @@ -677,27 +726,62 @@ async def get_amplitude(
Returns:
ModelledAmplitude: The modelled amplitude for the given depth and distance.
"""
collection = self.get_collection(source_depth)
try:
return collection.get_amplitude(
distance=distance,
n_amplitudes=n_amplitudes,
max_distance=max_distance,
peak_amplitude=peak_amplitude,
)
except ValueError:
if auto_fill:
await self.fill_source_depth(source_depth)
logger.info("auto-filling site amplitudes for depth %f", source_depth)
return await self.get_amplitude(
source_depth=source_depth,
if not self.source_depth_range.inside(source_depth):
raise ValueError(f"Source depth {source_depth} outside range.")

source_depths = np.array([sa.source_depth for sa in self.site_amplitudes])
match interpolation:
case "nearest":
idx = [np.abs(source_depths - source_depth).argmin()]
case "linear":
idx = np.argsort(np.abs(source_depths - source_depth))[:2]
case _:
raise ValueError(f"Unknown interpolation method {interpolation}.")

collections = [self.site_amplitudes[i] for i in idx]

amplitudes: list[ModelledAmplitude] = []
for collection in collections:
try:
amplitude = collection.get_amplitude(
distance=distance,
n_amplitudes=n_amplitudes,
max_distance=max_distance,
peak_amplitude=peak_amplitude,
auto_fill=True,
)
raise
amplitudes.append(amplitude)
except ValueError:
if auto_fill:
await self.fill_source_depth(source_depth)
logger.info("auto-filling amplitudes for depth %f", source_depth)
return await self.get_amplitude(
source_depth=source_depth,
distance=distance,
n_amplitudes=n_amplitudes,
max_distance=max_distance,
peak_amplitude=peak_amplitude,
interpolation=interpolation,
auto_fill=True,
)
raise

if not amplitudes:
raise ValueError(f"No site amplitudes for depth {source_depth}.")

if interpolation == "nearest":
return amplitudes[0]

if interpolation == "linear":
if len(amplitudes) != 2:
raise ValueError(
f"Cannot interpolate amplitudes with {len(amplitudes)} "
f"source depths."
)
depths = source_depths[idx]
weight = abs((source_depth - depths[0]) / abs(depths[1] - depths[0]))
return amplitudes[0].combine(amplitudes[1], weight=weight)

raise ValueError(f"Unknown interpolation method {interpolation}.")

def hash(self) -> str:
"""
Expand Down
3 changes: 2 additions & 1 deletion src/qseek/models/station.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,8 @@ def weed_stations(self) -> None:
seen_nsls = set()
for sta in self.stations.copy():
if sta.lat == 0.0 or sta.lon == 0.0:
self.blacklist_station(sta, reason="bad geographical coordinates")
logger.warning("removing station %s: bad coordinates", sta.nsl_pretty)
self.stations.remove(sta)
continue

if sta.nsl_pretty in seen_nsls:
Expand Down
23 changes: 21 additions & 2 deletions src/qseek/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,30 @@ class _Range(NamedTuple):
max: float

def inside(self, value: float) -> bool:
"""
Check if a value is inside the range.
Args:
value (float): The value to check.
Returns:
bool: True if the value is inside the range, False otherwise.
"""
return self.min <= value <= self.max

@classmethod
def from_array(cls, array: np.ndarray) -> _Range:
return cls(array.min(), array.max())
def from_list(cls, array: np.ndarray | list[float]) -> _Range:
"""
Create a Range object from a numpy array.
Parameters:
- array: numpy.ndarray
The array from which to create the Range object.
Returns:
- _Range: The created Range object.
"""
return cls(min=np.min(array), max=np.max(array))


def _range_validator(v: _Range) -> _Range:
Expand Down

0 comments on commit c4b1200

Please sign in to comment.