Skip to content

Commit

Permalink
Add optional rundir parameter to station preparation
Browse files Browse the repository at this point in the history
Apply exponent only if it's not equal to 1.0

Apply normalization factor only if it's not equal to 1.0

Export detections every 100 detections

Fix weights calculation in SearchTraces

Fix weighted_median function in utils.py
  • Loading branch information
Marius Isken committed Jan 22, 2024
1 parent 476ef11 commit 2f47cca
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 20 deletions.
2 changes: 2 additions & 0 deletions src/qseek/corrections/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,15 @@ async def prepare(
stations: Stations,
octree: Octree,
phases: Iterable[PhaseDescription],
rundir: Path | None = None,
) -> None:
"""Prepare the station for the corrections.
Args:
station: The station to prepare.
octree: The octree to use for the preparation.
phases: The phases to prepare the station for.
rundir: The rundir to use for the delay. Defaults to None.
"""
...

Expand Down
2 changes: 2 additions & 0 deletions src/qseek/images/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ def apply_exponent(self, exponent: float) -> None:
Args:
exponent (float): Exponent to apply.
"""
if exponent == 1.0:
return
for tr in self.traces:
tr.ydata **= exponent

Expand Down
2 changes: 2 additions & 0 deletions src/qseek/models/semblance.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,5 +340,7 @@ def normalize(self, factor: int | float) -> None:
Args:
factor (int | float): Normalization factor.
"""
if factor == 1.0:
return
self.semblance_unpadded /= factor
self._clear_cache()
25 changes: 16 additions & 9 deletions src/qseek/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,7 @@ class Search(BaseModel):
_travel_time_ranges: dict[
PhaseDescription, tuple[timedelta, timedelta]
] = PrivateAttr({})
_last_detection_export: int = 0

_detections: EventCatalog = PrivateAttr()
_config_stem: str = PrivateAttr("")
Expand Down Expand Up @@ -394,6 +395,13 @@ async def prepare(self) -> None:
"""
logger.info("preparing search...")
self.data_provider.prepare(self.stations)
if self.station_corrections:
await self.station_corrections.prepare(
self.stations,
self.octree,
self.image_functions.get_phases(),
self._rundir,
)
await self.ray_tracers.prepare(
self.octree,
self.stations,
Expand All @@ -402,12 +410,6 @@ async def prepare(self) -> None:
)
for magnitude in self.magnitudes:
await magnitude.prepare(self.octree, self.stations)
if self.station_corrections:
await self.station_corrections.prepare(
self.stations,
self.octree,
self.image_functions.get_phases(),
)
self.init_boundaries()

async def start(self, force_rundir: bool = False) -> None:
Expand Down Expand Up @@ -480,10 +482,14 @@ async def new_detections(self, detections: list[EventDetection]) -> None:
await self._detections.add(detection)
await self._new_detection.emit(detection)

if self._detections.n_detections and self._detections.n_detections % 100 == 0:
if (
self._detections.n_detections
and self._detections.n_detections - self._last_detection_export > 100
):
await self._detections.export_detections(
jitter_location=self.octree.smallest_node_size()
)
self._last_detection_export = self._detections.n_detections

async def add_magnitude_and_features(self, event: EventDetection) -> EventDetection:
"""
Expand Down Expand Up @@ -610,11 +616,11 @@ async def calculate_semblance(

shifts = np.round(-traveltimes / image.delta_t).astype(np.int32)
weights = np.full_like(shifts, fill_value=image.weight, dtype=np.float32)
weights[traveltimes_bad] = 0.0

# Normalize by number of station contribution
with np.errstate(divide="ignore", invalid="ignore"):
weights /= station_contribution[:, np.newaxis]
weights[traveltimes_bad] = 0.0

if semblance_cache:
cache_mask = semblance.get_cache_mask(semblance_cache)
Expand Down Expand Up @@ -691,8 +697,9 @@ async def search(
semblance_cache=semblance_cache,
)

semblance.apply_exponent(1.0 / parent.image_mean_p)
# Applying the generalized mean to the semblance
semblance.normalize(images.cumulative_weight())
semblance.apply_exponent(1.0 / parent.image_mean_p)

semblance.apply_cache(semblance_cache or {}) # Apply after normalization

Expand Down
62 changes: 51 additions & 11 deletions src/qseek/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,17 +225,57 @@ def weighted_median(data: np.ndarray, weights: np.ndarray | None = None) -> floa
if weights is None:
return float(np.median(data))

data = np.atleast_1d(np.array(data).squeeze())
weights = np.atleast_1d(np.array(weights).squeeze())
try:
s_data, s_weights = map(
np.array, zip(*sorted(zip(data, weights, strict=True)), strict=True)
)
except TypeError as exc:
raise exc
midpoint = 0.5 * sum(s_weights)
if any(weights > midpoint):
w_median = (data[weights == np.max(weights)])[0]
data = np.atleast_1d(data.squeeze())
weights = np.atleast_1d(weights.squeeze())

sorted_indices = np.argsort(data)
s_data = data[sorted_indices]
s_weights = weights[sorted_indices]

midpoint = 0.5 * s_weights.sum()
if np.any(weights > midpoint):
w_median = (data[weights == weights.max()])[0]
else:
cs_weights = np.cumsum(s_weights)
idx = np.where(cs_weights <= midpoint)[0][-1]
if cs_weights[idx] == midpoint:
w_median = np.mean(s_data[idx : idx + 2])
else:
w_median = s_data[idx + 1]
return float(w_median)


async def async_weighted_median(
data: np.ndarray, weights: np.ndarray | None = None
) -> float:
"""
Asynchronously calculate the weighted median of an array/list using numpy.
Parameters:
data (np.ndarray): The input array/list.
weights (np.ndarray | None): The weights corresponding to each
element in the data array/list.
If None, the function calculates the regular median.
Returns:
float: The weighted median.
Raises:
TypeError: If the data and weights arrays/lists cannot be sorted together.
"""
if weights is None:
return float(await asyncio.to_thread(np.median, data))

data = np.atleast_1d(data.squeeze())
weights = np.atleast_1d(weights.squeeze())

sorted_indices = await asyncio.to_thread(np.argsort, data)
s_data = data[sorted_indices]
s_weights = weights[sorted_indices]

midpoint = 0.5 * s_weights.sum()
if np.any(weights > midpoint):
w_median = (data[weights == weights.max()])[0]
else:
cs_weights = np.cumsum(s_weights)
idx = np.where(cs_weights <= midpoint)[0][-1]
Expand Down

0 comments on commit 2f47cca

Please sign in to comment.