diff --git a/src/qseek/corrections/base.py b/src/qseek/corrections/base.py index d6784c6c..0f46295e 100644 --- a/src/qseek/corrections/base.py +++ b/src/qseek/corrections/base.py @@ -76,6 +76,7 @@ async def prepare( stations: Stations, octree: Octree, phases: Iterable[PhaseDescription], + rundir: Path | None = None, ) -> None: """Prepare the station for the corrections. @@ -83,6 +84,7 @@ async def prepare( station: The station to prepare. octree: The octree to use for the preparation. phases: The phases to prepare the station for. + rundir: The rundir to use for the delay. Defaults to None. """ ... diff --git a/src/qseek/images/base.py b/src/qseek/images/base.py index 924720a7..ffcc6623 100644 --- a/src/qseek/images/base.py +++ b/src/qseek/images/base.py @@ -114,6 +114,8 @@ def apply_exponent(self, exponent: float) -> None: Args: exponent (float): Exponent to apply. """ + if exponent == 1.0: + return for tr in self.traces: tr.ydata **= exponent diff --git a/src/qseek/models/semblance.py b/src/qseek/models/semblance.py index 1508d9cc..d91c2633 100644 --- a/src/qseek/models/semblance.py +++ b/src/qseek/models/semblance.py @@ -340,5 +340,7 @@ def normalize(self, factor: int | float) -> None: Args: factor (int | float): Normalization factor. """ + if factor == 1.0: + return self.semblance_unpadded /= factor self._clear_cache() diff --git a/src/qseek/search.py b/src/qseek/search.py index fb731165..434616d6 100644 --- a/src/qseek/search.py +++ b/src/qseek/search.py @@ -267,6 +267,7 @@ class Search(BaseModel): _travel_time_ranges: dict[ PhaseDescription, tuple[timedelta, timedelta] ] = PrivateAttr({}) + _last_detection_export: int = 0 _detections: EventCatalog = PrivateAttr() _config_stem: str = PrivateAttr("") @@ -394,6 +395,13 @@ async def prepare(self) -> None: """ logger.info("preparing search...") self.data_provider.prepare(self.stations) + if self.station_corrections: + await self.station_corrections.prepare( + self.stations, + self.octree, + self.image_functions.get_phases(), + self._rundir, + ) await self.ray_tracers.prepare( self.octree, self.stations, @@ -402,12 +410,6 @@ async def prepare(self) -> None: ) for magnitude in self.magnitudes: await magnitude.prepare(self.octree, self.stations) - if self.station_corrections: - await self.station_corrections.prepare( - self.stations, - self.octree, - self.image_functions.get_phases(), - ) self.init_boundaries() async def start(self, force_rundir: bool = False) -> None: @@ -480,10 +482,14 @@ async def new_detections(self, detections: list[EventDetection]) -> None: await self._detections.add(detection) await self._new_detection.emit(detection) - if self._detections.n_detections and self._detections.n_detections % 100 == 0: + if ( + self._detections.n_detections + and self._detections.n_detections - self._last_detection_export > 100 + ): await self._detections.export_detections( jitter_location=self.octree.smallest_node_size() ) + self._last_detection_export = self._detections.n_detections async def add_magnitude_and_features(self, event: EventDetection) -> EventDetection: """ @@ -610,11 +616,11 @@ async def calculate_semblance( shifts = np.round(-traveltimes / image.delta_t).astype(np.int32) weights = np.full_like(shifts, fill_value=image.weight, dtype=np.float32) + weights[traveltimes_bad] = 0.0 # Normalize by number of station contribution with np.errstate(divide="ignore", invalid="ignore"): weights /= station_contribution[:, np.newaxis] - weights[traveltimes_bad] = 0.0 if semblance_cache: cache_mask = semblance.get_cache_mask(semblance_cache) @@ -691,8 +697,9 @@ async def search( semblance_cache=semblance_cache, ) - semblance.apply_exponent(1.0 / parent.image_mean_p) + # Applying the generalized mean to the semblance semblance.normalize(images.cumulative_weight()) + semblance.apply_exponent(1.0 / parent.image_mean_p) semblance.apply_cache(semblance_cache or {}) # Apply after normalization diff --git a/src/qseek/utils.py b/src/qseek/utils.py index 0ae96610..1dd6c9ea 100644 --- a/src/qseek/utils.py +++ b/src/qseek/utils.py @@ -225,17 +225,57 @@ def weighted_median(data: np.ndarray, weights: np.ndarray | None = None) -> floa if weights is None: return float(np.median(data)) - data = np.atleast_1d(np.array(data).squeeze()) - weights = np.atleast_1d(np.array(weights).squeeze()) - try: - s_data, s_weights = map( - np.array, zip(*sorted(zip(data, weights, strict=True)), strict=True) - ) - except TypeError as exc: - raise exc - midpoint = 0.5 * sum(s_weights) - if any(weights > midpoint): - w_median = (data[weights == np.max(weights)])[0] + data = np.atleast_1d(data.squeeze()) + weights = np.atleast_1d(weights.squeeze()) + + sorted_indices = np.argsort(data) + s_data = data[sorted_indices] + s_weights = weights[sorted_indices] + + midpoint = 0.5 * s_weights.sum() + if np.any(weights > midpoint): + w_median = (data[weights == weights.max()])[0] + else: + cs_weights = np.cumsum(s_weights) + idx = np.where(cs_weights <= midpoint)[0][-1] + if cs_weights[idx] == midpoint: + w_median = np.mean(s_data[idx : idx + 2]) + else: + w_median = s_data[idx + 1] + return float(w_median) + + +async def async_weighted_median( + data: np.ndarray, weights: np.ndarray | None = None +) -> float: + """ + Asynchronously calculate the weighted median of an array/list using numpy. + + Parameters: + data (np.ndarray): The input array/list. + weights (np.ndarray | None): The weights corresponding to each + element in the data array/list. + If None, the function calculates the regular median. + + Returns: + float: The weighted median. + + Raises: + TypeError: If the data and weights arrays/lists cannot be sorted together. + """ + if weights is None: + return float(await asyncio.to_thread(np.median, data)) + + data = np.atleast_1d(data.squeeze()) + weights = np.atleast_1d(weights.squeeze()) + + sorted_indices = await asyncio.to_thread(np.argsort, data) + s_data = data[sorted_indices] + s_weights = weights[sorted_indices] + + midpoint = 0.5 * s_weights.sum() + if np.any(weights > midpoint): + w_median = (data[weights == weights.max()])[0] else: cs_weights = np.cumsum(s_weights) idx = np.where(cs_weights <= midpoint)[0][-1]