Skip to content

Commit

Permalink
Refactor code to improve performance and readability
Browse files Browse the repository at this point in the history
  • Loading branch information
Marius Isken committed Jan 24, 2024
1 parent 2f47cca commit b827efa
Show file tree
Hide file tree
Showing 12 changed files with 147 additions and 64 deletions.
8 changes: 4 additions & 4 deletions src/qseek/apps/qseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,23 +224,23 @@ async def extract() -> None:
iterator = asyncio.as_completed(
tuple(
search.add_magnitude_and_features(detection)
for detection in search._detections
for detection in search._catalog
)
)

for result in track(
iterator,
description="Extracting features",
total=search._detections.n_detections,
total=search._catalog.n_events,
):
event = await result
if event.magnitudes:
for mag in event.magnitudes:
print(f"{mag.magnitude} {mag.average:.2f}±{mag.error:.2f}")
print("--")

await search._detections.save()
await search._detections.export_detections(
await search._catalog.save()
await search._catalog.export_detections(
jitter_location=search.octree.smallest_node_size()
)

Expand Down
28 changes: 14 additions & 14 deletions src/qseek/models/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,16 @@ def _populate_table(self, table: Table) -> None:

class EventCatalog(BaseModel):
rundir: Path
detections: list[EventDetection] = []
events: list[EventDetection] = []
_stats: EventCatalogStats = PrivateAttr(default_factory=EventCatalogStats)

def model_post_init(self, __context: Any) -> None:
EventDetection.set_rundir(self.rundir)

@property
def n_detections(self) -> int:
def n_events(self) -> int:
"""Number of detections"""
return len(self.detections)
return len(self.events)

@property
def markers_dir(self) -> Path:
Expand All @@ -64,18 +64,18 @@ def csv_dir(self) -> Path:
return dir

async def add(self, detection: EventDetection) -> None:
detection.set_index(self.n_detections)
detection.set_index(self.n_events)

markers_file = self.markers_dir / f"{time_to_path(detection.time)}.list"
self.markers_dir.mkdir(exist_ok=True)
detection.export_pyrocko_markers(markers_file)

self.detections.append(detection)
self.events.append(detection)
logger.info(
"%s event detection %d %s: %.5f°, %.5f°, depth %.1f m, "
"border distance %.1f m, semblance %.3f, magnitude %.2f",
Symbols.Target,
self.n_detections,
self.n_events,
detection.time,
*detection.effective_lat_lon,
detection.depth,
Expand Down Expand Up @@ -116,24 +116,24 @@ def load_rundir(cls, rundir: Path) -> EventCatalog:
for idx, line in enumerate(f):
detection = EventDetection.model_validate_json(line)
detection.set_index(idx)
catalog.detections.append(detection)
catalog.events.append(detection)

logger.info("loaded %d detections", catalog.n_detections)
logger.info("loaded %d detections", catalog.n_events)

stats = catalog._stats
stats.n_detections = catalog.n_detections
stats.n_detections = catalog.n_events
if catalog:
stats.max_semblance = max(detection.semblance for detection in catalog)
return catalog

async def save(self) -> None:
"""Save catalog to current rundir."""
logger.debug("saving %d detections", self.n_detections)
logger.debug("saving %d detections", self.n_events)

lines_events = []
lines_recv = []
# Has to be the unsorted
for detection in self.detections:
for detection in self.events:
lines_events.append(f"{detection.model_dump_json(exclude={'receivers'})}\n")
lines_recv.append(f"{detection.receivers.model_dump_json()}\n")

Expand Down Expand Up @@ -179,7 +179,7 @@ async def export_csv(self, file: Path, jitter_location: float = 0.0) -> None:
if jitter_location:
detections = [det.jitter_location(jitter_location) for det in self]
else:
detections = self.detections
detections = self.events

csv_dicts: list[dict] = []
for detection in detections:
Expand Down Expand Up @@ -207,7 +207,7 @@ def export_pyrocko_events(
filename (Path): output filename
"""
logger.info("saving Pyrocko events to %s", filename)
detections = self.detections
detections = self.events
if jitter_location:
detections = [det.jitter_location(jitter_location) for det in detections]
dump_events(
Expand All @@ -229,4 +229,4 @@ def export_pyrocko_markers(self, filename: Path) -> None:
marker.save_markers(pyrocko_markers, str(filename))

def __iter__(self) -> Iterator[EventDetection]:
return iter(sorted(self.detections, key=lambda d: d.time))
return iter(sorted(self.events, key=lambda d: d.time))
63 changes: 58 additions & 5 deletions src/qseek/models/semblance.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,12 @@ def __init__(
start_time: datetime,
sampling_rate: float,
padding_samples: int = 0,
exponent: float = 1.0,
) -> None:
self.sampling_rate = sampling_rate
self.padding_samples = padding_samples
self.n_samples_unpadded = n_samples
self.exponent = exponent

self._start_time = start_time
self._node_hashes = [node.hash() for node in nodes]
Expand Down Expand Up @@ -149,6 +151,8 @@ def maximum_semblance(self) -> np.ndarray:
"""
if self._max_semblance is None:
self._max_semblance = self.semblance.max(axis=0)
if self.exponent != 1.0:
self._max_semblance **= self.exponent
return self._max_semblance

@property
Expand All @@ -164,10 +168,46 @@ def get_cache(self) -> dict[bytes, np.ndarray]:
for i, node_hash in enumerate(self._node_hashes)
}

def get_semblance(self, time_idx: int) -> np.ndarray:
"""
Get the semblance values at a specific time index.
Parameters:
time_idx (int): The index of the desired time.
Returns:
np.ndarray: The semblance values at the specified time index.
"""
semblance = self.semblance[:, time_idx]
if self.exponent != 1.0:
semblance **= self.exponent
return semblance

def get_cache_mask(self, cache: dict[bytes, np.ndarray]) -> np.ndarray:
"""
Returns a boolean mask indicating whether each node hash
in self._node_hashes is present in the cache.
Args:
cache (dict[bytes, np.ndarray]): The cache dictionary containing node
hashes as keys.
Returns:
np.ndarray: A boolean mask indicating whether each node hash is
present in the cache.
"""
return np.array([hash in cache for hash in self._node_hashes])

def apply_cache(self, cache: dict[bytes, np.ndarray]):
def apply_cache(self, cache: dict[bytes, np.ndarray]) -> None:
"""
Applies the cached data to the `semblance_unpadded` array.
Args:
cache (dict[bytes, np.ndarray]): The cache containing the cached data.
Returns:
None
"""
if not cache:
return
mask = self.get_cache_mask(cache)
Expand All @@ -176,7 +216,10 @@ def apply_cache(self, cache: dict[bytes, np.ndarray]):
self.semblance_unpadded[mask, :] = np.stack(data)

def maximum_node_semblance(self) -> np.ndarray:
return self.semblance.max(axis=1)
semblance = self.semblance.max(axis=1)
if self.exponent != 1.0:
semblance **= self.exponent
return semblance

async def maxima_node_idx(self, nparallel: int = 6) -> np.ndarray:
"""Indices of maximum semblance at any time step.
Expand Down Expand Up @@ -219,7 +262,7 @@ def apply_exponent(self, exponent: float) -> None:
"""
if exponent == 1.0:
return
self.semblance_unpadded **= exponent
np.power(self.semblance_unpadded, exponent, out=self.semblance_unpadded)
self._clear_cache()

def median_mask(self, level: float = 3.0) -> np.ndarray:
Expand Down Expand Up @@ -256,9 +299,16 @@ async def find_peaks(
Returns:
tuple[np.ndarray, np.ndarray]: Indices of peaks and peak values.
"""
max_semblance_unpadded = await asyncio.to_thread(
self.semblance_unpadded.max,
axis=0,
)
if self.exponent != 1.0:
max_semblance_unpadded **= self.exponent

detection_idx, _ = await asyncio.to_thread(
signal.find_peaks,
self.semblance_unpadded.max(axis=0),
max_semblance_unpadded,
height=height,
prominence=prominence,
distance=distance,
Expand All @@ -269,7 +319,10 @@ async def find_peaks(
detection_idx = detection_idx[detection_idx < self.maximum_semblance.size]
semblance = self.maximum_semblance[detection_idx]
else:
maximum_semblance = self.semblance_unpadded.max(axis=0)
maximum_semblance = await asyncio.to_thread(
self.semblance_unpadded.max,
axis=0,
)
semblance = maximum_semblance[detection_idx]

return detection_idx, semblance
Expand Down
16 changes: 15 additions & 1 deletion src/qseek/models/station.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,21 @@ def from_pyrocko_station(cls, station: PyrockoStation) -> Station:
)

def as_pyrocko_station(self) -> PyrockoStation:
return PyrockoStation(**self.model_dump(exclude={"effective_lat_lon"}))
return PyrockoStation(
**self.model_dump(
include={
"network",
"station",
"location",
"lat",
"lon",
"north_shift",
"east_shift",
"depth",
"elevation",
}
)
)

@property
def nsl(self) -> NSL:
Expand Down
2 changes: 1 addition & 1 deletion src/qseek/octree.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,8 +363,8 @@ def set_level(self, level: int):
raise ValueError(
f"invalid level {level}, expected level <= {self.n_levels()}"
)
logger.debug("setting tree to level %d", level)
self.reset()
logger.debug("setting tree to level %d", level)
for _ in range(level):
for node in self:
node.split()
Expand Down
Loading

0 comments on commit b827efa

Please sign in to comment.