From 2f47ccaef0e42855d75d503faf5811c68eb734a8 Mon Sep 17 00:00:00 2001 From: Marius Isken Date: Mon, 22 Jan 2024 22:26:59 +0100 Subject: [PATCH] Add optional rundir parameter to station preparation Apply exponent only if it's not equal to 1.0 Apply normalization factor only if it's not equal to 1.0 Export detections every 100 detections Fix weights calculation in SearchTraces Fix weighted_median function in utils.py --- src/qseek/corrections/base.py | 2 ++ src/qseek/images/base.py | 2 ++ src/qseek/models/semblance.py | 2 ++ src/qseek/search.py | 25 +++++++++----- src/qseek/utils.py | 62 ++++++++++++++++++++++++++++------- 5 files changed, 73 insertions(+), 20 deletions(-) diff --git a/src/qseek/corrections/base.py b/src/qseek/corrections/base.py index d6784c6c..0f46295e 100644 --- a/src/qseek/corrections/base.py +++ b/src/qseek/corrections/base.py @@ -76,6 +76,7 @@ async def prepare( stations: Stations, octree: Octree, phases: Iterable[PhaseDescription], + rundir: Path | None = None, ) -> None: """Prepare the station for the corrections. @@ -83,6 +84,7 @@ async def prepare( station: The station to prepare. octree: The octree to use for the preparation. phases: The phases to prepare the station for. + rundir: The rundir to use for the delay. Defaults to None. """ ... diff --git a/src/qseek/images/base.py b/src/qseek/images/base.py index 924720a7..ffcc6623 100644 --- a/src/qseek/images/base.py +++ b/src/qseek/images/base.py @@ -114,6 +114,8 @@ def apply_exponent(self, exponent: float) -> None: Args: exponent (float): Exponent to apply. """ + if exponent == 1.0: + return for tr in self.traces: tr.ydata **= exponent diff --git a/src/qseek/models/semblance.py b/src/qseek/models/semblance.py index 1508d9cc..d91c2633 100644 --- a/src/qseek/models/semblance.py +++ b/src/qseek/models/semblance.py @@ -340,5 +340,7 @@ def normalize(self, factor: int | float) -> None: Args: factor (int | float): Normalization factor. """ + if factor == 1.0: + return self.semblance_unpadded /= factor self._clear_cache() diff --git a/src/qseek/search.py b/src/qseek/search.py index fb731165..434616d6 100644 --- a/src/qseek/search.py +++ b/src/qseek/search.py @@ -267,6 +267,7 @@ class Search(BaseModel): _travel_time_ranges: dict[ PhaseDescription, tuple[timedelta, timedelta] ] = PrivateAttr({}) + _last_detection_export: int = 0 _detections: EventCatalog = PrivateAttr() _config_stem: str = PrivateAttr("") @@ -394,6 +395,13 @@ async def prepare(self) -> None: """ logger.info("preparing search...") self.data_provider.prepare(self.stations) + if self.station_corrections: + await self.station_corrections.prepare( + self.stations, + self.octree, + self.image_functions.get_phases(), + self._rundir, + ) await self.ray_tracers.prepare( self.octree, self.stations, @@ -402,12 +410,6 @@ async def prepare(self) -> None: ) for magnitude in self.magnitudes: await magnitude.prepare(self.octree, self.stations) - if self.station_corrections: - await self.station_corrections.prepare( - self.stations, - self.octree, - self.image_functions.get_phases(), - ) self.init_boundaries() async def start(self, force_rundir: bool = False) -> None: @@ -480,10 +482,14 @@ async def new_detections(self, detections: list[EventDetection]) -> None: await self._detections.add(detection) await self._new_detection.emit(detection) - if self._detections.n_detections and self._detections.n_detections % 100 == 0: + if ( + self._detections.n_detections + and self._detections.n_detections - self._last_detection_export > 100 + ): await self._detections.export_detections( jitter_location=self.octree.smallest_node_size() ) + self._last_detection_export = self._detections.n_detections async def add_magnitude_and_features(self, event: EventDetection) -> EventDetection: """ @@ -610,11 +616,11 @@ async def calculate_semblance( shifts = np.round(-traveltimes / image.delta_t).astype(np.int32) weights = np.full_like(shifts, fill_value=image.weight, dtype=np.float32) + weights[traveltimes_bad] = 0.0 # Normalize by number of station contribution with np.errstate(divide="ignore", invalid="ignore"): weights /= station_contribution[:, np.newaxis] - weights[traveltimes_bad] = 0.0 if semblance_cache: cache_mask = semblance.get_cache_mask(semblance_cache) @@ -691,8 +697,9 @@ async def search( semblance_cache=semblance_cache, ) - semblance.apply_exponent(1.0 / parent.image_mean_p) + # Applying the generalized mean to the semblance semblance.normalize(images.cumulative_weight()) + semblance.apply_exponent(1.0 / parent.image_mean_p) semblance.apply_cache(semblance_cache or {}) # Apply after normalization diff --git a/src/qseek/utils.py b/src/qseek/utils.py index 0ae96610..1dd6c9ea 100644 --- a/src/qseek/utils.py +++ b/src/qseek/utils.py @@ -225,17 +225,57 @@ def weighted_median(data: np.ndarray, weights: np.ndarray | None = None) -> floa if weights is None: return float(np.median(data)) - data = np.atleast_1d(np.array(data).squeeze()) - weights = np.atleast_1d(np.array(weights).squeeze()) - try: - s_data, s_weights = map( - np.array, zip(*sorted(zip(data, weights, strict=True)), strict=True) - ) - except TypeError as exc: - raise exc - midpoint = 0.5 * sum(s_weights) - if any(weights > midpoint): - w_median = (data[weights == np.max(weights)])[0] + data = np.atleast_1d(data.squeeze()) + weights = np.atleast_1d(weights.squeeze()) + + sorted_indices = np.argsort(data) + s_data = data[sorted_indices] + s_weights = weights[sorted_indices] + + midpoint = 0.5 * s_weights.sum() + if np.any(weights > midpoint): + w_median = (data[weights == weights.max()])[0] + else: + cs_weights = np.cumsum(s_weights) + idx = np.where(cs_weights <= midpoint)[0][-1] + if cs_weights[idx] == midpoint: + w_median = np.mean(s_data[idx : idx + 2]) + else: + w_median = s_data[idx + 1] + return float(w_median) + + +async def async_weighted_median( + data: np.ndarray, weights: np.ndarray | None = None +) -> float: + """ + Asynchronously calculate the weighted median of an array/list using numpy. + + Parameters: + data (np.ndarray): The input array/list. + weights (np.ndarray | None): The weights corresponding to each + element in the data array/list. + If None, the function calculates the regular median. + + Returns: + float: The weighted median. + + Raises: + TypeError: If the data and weights arrays/lists cannot be sorted together. + """ + if weights is None: + return float(await asyncio.to_thread(np.median, data)) + + data = np.atleast_1d(data.squeeze()) + weights = np.atleast_1d(weights.squeeze()) + + sorted_indices = await asyncio.to_thread(np.argsort, data) + s_data = data[sorted_indices] + s_weights = weights[sorted_indices] + + midpoint = 0.5 * s_weights.sum() + if np.any(weights > midpoint): + w_median = (data[weights == weights.max()])[0] else: cs_weights = np.cumsum(s_weights) idx = np.where(cs_weights <= midpoint)[0][-1]