From c4b1200c14316f51a94126dcb1975caacea06ab8 Mon Sep 17 00:00:00 2001 From: Marius Isken Date: Fri, 5 Jan 2024 22:12:29 +0100 Subject: [PATCH] Refactor amplitude interpolation and add error handling --- .../magnitudes/moment_magnitude_store.py | 124 +++++++++++++++--- src/qseek/models/station.py | 3 +- src/qseek/utils.py | 23 +++- 3 files changed, 127 insertions(+), 23 deletions(-) diff --git a/src/qseek/magnitudes/moment_magnitude_store.py b/src/qseek/magnitudes/moment_magnitude_store.py index 3d6d485e..d148e2cf 100644 --- a/src/qseek/magnitudes/moment_magnitude_store.py +++ b/src/qseek/magnitudes/moment_magnitude_store.py @@ -164,6 +164,50 @@ class ModelledAmplitude(NamedTuple): amplitude_median: float std: float + def combine( + self, + amplitude: ModelledAmplitude, + weight: float = 1.0, + ) -> ModelledAmplitude: + """ + Combines with another ModelledAmplitude using a weighted average. + + Args: + amplitude (ModelledAmplitude): The ModelledAmplitude to be combined with. + weight (float, optional): The weight of the amplitude being combined. + Defaults to 1.0. + + Returns: + Self: A new instance of the ModelledAmplitude class with the combined values. + + Raises: + ValueError: If the weight is not between 0.0 and 1.0 (inclusive). + ValueError: If the distances of the amplitudes are different. + ValueError: If the peak amplitudes of the amplitudes are different. + """ + if not 0.0 <= weight <= 1.0: + raise ValueError(f"Invalid weight {weight}.") + if self.distance != amplitude.distance: + raise ValueError( + f"Cannot add amplitudes with different distances " + f"{self.distance} and {amplitude.distance}." + ) + if self.peak_amplitude != amplitude.peak_amplitude: + raise ValueError( + f"Cannot add amplitudes with different peak amplitudes " + f"{self.peak_amplitude} and {amplitude.peak_amplitude}." + ) + inv_weight = 1.0 - weight + return ModelledAmplitude( + distance=self.distance, + peak_amplitude=self.peak_amplitude, + amplitude_avg=self.amplitude_avg * inv_weight + + amplitude.amplitude_avg * weight, + amplitude_median=self.amplitude_median * inv_weight + + amplitude.amplitude_median * weight, + std=self.std * inv_weight + amplitude.std * weight, + ) + class SiteAmplitudesCollection(BaseModel): source_depth: float @@ -297,7 +341,7 @@ def plot( interp_std = np.array([amp.std for amp in interp_amplitudes]) site_amplitudes = getattr(self, f"_{peak_amplitude.replace('peak_', '')}") - dynamic = Range.from_array(site_amplitudes) + dynamic = Range.from_list(site_amplitudes) ax.scatter( self._distances, @@ -428,6 +472,10 @@ def from_selector(cls, selector: PeakAmplitudesBase) -> Self: kwargs["frequency_range"] = selector.frequency_range or store_frequency_range return cls(**kwargs) + @property + def source_depth_range(self) -> Range: + return Range.from_list([sa.source_depth for sa in self.site_amplitudes]) + def get_store(self) -> gf.Store: """ Load the GF store for the given store ID. @@ -614,7 +662,7 @@ async def get_modelled_waveforms() -> tuple[gf.Response, list[gf.Target]]: self.save() return collection - def get_collection(self, depth: float) -> SiteAmplitudesCollection: + def get_collection(self, source_depth: float) -> SiteAmplitudesCollection: """ Get the site amplitudes collection for the given source depth. @@ -625,9 +673,9 @@ def get_collection(self, depth: float) -> SiteAmplitudesCollection: SiteAmplitudesCollection: The site amplitudes collection. """ for site_amplitudes in self.site_amplitudes: - if site_amplitudes.source_depth == depth: + if site_amplitudes.source_depth == source_depth: return site_amplitudes - raise KeyError(f"No site amplitudes for depth {depth}.") + raise KeyError(f"No site amplitudes for depth {source_depth}.") def new_collection(self, depth: float) -> SiteAmplitudesCollection: """ @@ -658,6 +706,7 @@ async def get_amplitude( max_distance: float = 1.0 * KM, peak_amplitude: PeakAmplitude = "absolute", auto_fill: bool = True, + interpolation: Literal["nearest", "linear"] = "linear", ) -> ModelledAmplitude: """ Retrieves the amplitude for a given depth and distance. @@ -677,27 +726,62 @@ async def get_amplitude( Returns: ModelledAmplitude: The modelled amplitude for the given depth and distance. """ - collection = self.get_collection(source_depth) - try: - return collection.get_amplitude( - distance=distance, - n_amplitudes=n_amplitudes, - max_distance=max_distance, - peak_amplitude=peak_amplitude, - ) - except ValueError: - if auto_fill: - await self.fill_source_depth(source_depth) - logger.info("auto-filling site amplitudes for depth %f", source_depth) - return await self.get_amplitude( - source_depth=source_depth, + if not self.source_depth_range.inside(source_depth): + raise ValueError(f"Source depth {source_depth} outside range.") + + source_depths = np.array([sa.source_depth for sa in self.site_amplitudes]) + match interpolation: + case "nearest": + idx = [np.abs(source_depths - source_depth).argmin()] + case "linear": + idx = np.argsort(np.abs(source_depths - source_depth))[:2] + case _: + raise ValueError(f"Unknown interpolation method {interpolation}.") + + collections = [self.site_amplitudes[i] for i in idx] + + amplitudes: list[ModelledAmplitude] = [] + for collection in collections: + try: + amplitude = collection.get_amplitude( distance=distance, n_amplitudes=n_amplitudes, max_distance=max_distance, peak_amplitude=peak_amplitude, - auto_fill=True, ) - raise + amplitudes.append(amplitude) + except ValueError: + if auto_fill: + await self.fill_source_depth(source_depth) + logger.info("auto-filling amplitudes for depth %f", source_depth) + return await self.get_amplitude( + source_depth=source_depth, + distance=distance, + n_amplitudes=n_amplitudes, + max_distance=max_distance, + peak_amplitude=peak_amplitude, + interpolation=interpolation, + auto_fill=True, + ) + raise + + if not amplitudes: + raise ValueError(f"No site amplitudes for depth {source_depth}.") + + if interpolation == "nearest": + return amplitudes[0] + + if interpolation == "linear": + if len(amplitudes) != 2: + raise ValueError( + f"Cannot interpolate amplitudes with {len(amplitudes)} " + f"source depths." + ) + depths = source_depths[idx] + weight = abs((source_depth - depths[0]) / abs(depths[1] - depths[0])) + return amplitudes[0].combine(amplitudes[1], weight=weight) + + raise ValueError(f"Unknown interpolation method {interpolation}.") def hash(self) -> str: """ diff --git a/src/qseek/models/station.py b/src/qseek/models/station.py index 6dbfbc35..11466726 100644 --- a/src/qseek/models/station.py +++ b/src/qseek/models/station.py @@ -102,7 +102,8 @@ def weed_stations(self) -> None: seen_nsls = set() for sta in self.stations.copy(): if sta.lat == 0.0 or sta.lon == 0.0: - self.blacklist_station(sta, reason="bad geographical coordinates") + logger.warning("removing station %s: bad coordinates", sta.nsl_pretty) + self.stations.remove(sta) continue if sta.nsl_pretty in seen_nsls: diff --git a/src/qseek/utils.py b/src/qseek/utils.py index 3fcd9217..90ef3a74 100644 --- a/src/qseek/utils.py +++ b/src/qseek/utils.py @@ -90,11 +90,30 @@ class _Range(NamedTuple): max: float def inside(self, value: float) -> bool: + """ + Check if a value is inside the range. + + Args: + value (float): The value to check. + + Returns: + bool: True if the value is inside the range, False otherwise. + """ return self.min <= value <= self.max @classmethod - def from_array(cls, array: np.ndarray) -> _Range: - return cls(array.min(), array.max()) + def from_list(cls, array: np.ndarray | list[float]) -> _Range: + """ + Create a Range object from a numpy array. + + Parameters: + - array: numpy.ndarray + The array from which to create the Range object. + + Returns: + - _Range: The created Range object. + """ + return cls(min=np.min(array), max=np.max(array)) def _range_validator(v: _Range) -> _Range: