diff --git a/src/qseek/images/images.py b/src/qseek/images/images.py index e2396734..9c3a85cc 100644 --- a/src/qseek/images/images.py +++ b/src/qseek/images/images.py @@ -127,6 +127,7 @@ async def worker() -> None: break yield ret + logger.debug("waiting for image function to finish") await task def get_phases(self) -> tuple[PhaseDescription, ...]: diff --git a/src/qseek/models/detection_uncertainty.py b/src/qseek/models/detection_uncertainty.py index 769c3e3c..bdd0146f 100644 --- a/src/qseek/models/detection_uncertainty.py +++ b/src/qseek/models/detection_uncertainty.py @@ -11,7 +11,7 @@ # Equivalent to one standard deviation -THRESHOLD = 1.0 / np.sqrt(np.e) +PERCENTILE = 0.02 class DetectionUncertainty(BaseModel): @@ -30,7 +30,7 @@ class DetectionUncertainty(BaseModel): @classmethod def from_event( - cls, source_node: Node, octree: Octree, width: float = THRESHOLD + cls, source_node: Node, octree: Octree, percentile: float = PERCENTILE ) -> Self: """ Calculate the uncertainty of an event detection. @@ -38,6 +38,8 @@ def from_event( Args: event: The event detection to calculate the uncertainty for. octree: The octree to use for the calculation. + percentile: The percentile to use for the calculation. + Defaults to 0.02 (2%). Returns: The calculated uncertainty. @@ -45,7 +47,9 @@ def from_event( if not source_node.semblance: raise ValueError("Source node must have semblance value.") - nodes = octree.get_nodes(semblance_threshold=source_node.semblance * width) + nodes = octree.get_nodes( + semblance_threshold=source_node.semblance * (1.0 - percentile) + ) vicinity_coords = np.array( [(node.east, node.north, node.depth) for node in nodes] ) diff --git a/src/qseek/octree.py b/src/qseek/octree.py index 74162f01..d134aede 100644 --- a/src/qseek/octree.py +++ b/src/qseek/octree.py @@ -8,6 +8,7 @@ from functools import cached_property, reduce from hashlib import sha1 from operator import mul +from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Iterator, Sequence import numpy as np @@ -519,6 +520,18 @@ def copy(self, deep=False) -> Self: node.set_tree(tree) return tree + def save_pickle(self, filename: Path) -> None: + """Save the octree to a pickle file. + + Args: + filename (Path): Filename to save to. + """ + import pickle + + logger.info("saving octree pickle to %s", filename) + with filename.open("wb") as f: + pickle.dump(self, f) + def __hash__(self) -> int: return hash( ( diff --git a/src/qseek/plot/octree.py b/src/qseek/plot/octree.py index 5f8cb8cf..a89a556d 100644 --- a/src/qseek/plot/octree.py +++ b/src/qseek/plot/octree.py @@ -51,15 +51,17 @@ def octree_to_rectangles( height=size, ) rectangles.append(rect) + if normalize: semblances /= semblances.max() colors = cmap(semblances) + edge_colors = cm.get_cmap("binary")(semblances**2, alpha=0.8) return PatchCollection( patches=rectangles, facecolors=colors, - edgecolors=(0, 0, 0, 0.3), - linewidths=0.5, + edgecolors=edge_colors, + linewidths=0.1, ) diff --git a/src/qseek/search.py b/src/qseek/search.py index eb0b0387..3ebacb25 100644 --- a/src/qseek/search.py +++ b/src/qseek/search.py @@ -1,7 +1,6 @@ from __future__ import annotations import asyncio -import cProfile import logging from collections import deque from datetime import datetime, timedelta, timezone @@ -54,7 +53,6 @@ logger = logging.getLogger(__name__) SamplingRate = Literal[10, 20, 25, 50, 100] -p = cProfile.Profile() class SearchStats(Stats): @@ -445,7 +443,6 @@ async def start(self, force_rundir: bool = False) -> None: ) console = asyncio.create_task(RuntimeStats.live_view()) - p.enable() async for images, batch in self.image_functions.iter_images(waveform_iterator): batch_processing_start = datetime_now() @@ -471,12 +468,11 @@ async def start(self, force_rundir: bool = False) -> None: ) self.set_progress(batch.end_time) - p.dump_stats("qseek.prof") - p.disable() - # await BackgroundTasks.wait_all() - console.cancel() + await BackgroundTasks.wait_all() + await self._catalog.save() await self._catalog.export_detections(jitter_location=self.octree.size_limit) + console.cancel() logger.info("finished search in %s", datetime_now() - processing_start) logger.info("found %d detections", self._catalog.n_events) diff --git a/src/qseek/stats.py b/src/qseek/stats.py index adf1e07e..419491e7 100644 --- a/src/qseek/stats.py +++ b/src/qseek/stats.py @@ -66,7 +66,10 @@ def generate_grid() -> Table: ) as live: while True: live.update(generate_grid()) - await asyncio.sleep(0.4) + try: + await asyncio.sleep(0.2) + except asyncio.CancelledError: + break class Stats(BaseModel): diff --git a/src/qseek/tracers/cake.py b/src/qseek/tracers/cake.py index b631536c..839171fc 100644 --- a/src/qseek/tracers/cake.py +++ b/src/qseek/tracers/cake.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import logging import re import struct @@ -476,6 +477,7 @@ async def _interpolate_travel_times( for coords in coordinates: travel_times.append(self._interpolate_traveltimes_sptree(coords)) PROGRESS.update(status, advance=1) + await asyncio.sleep(0.0) PROGRESS.remove_task(status) diff --git a/src/qseek/utils.py b/src/qseek/utils.py index 1dd6c9ea..7cf5f196 100644 --- a/src/qseek/utils.py +++ b/src/qseek/utils.py @@ -83,6 +83,9 @@ def cancel_all(cls) -> None: @classmethod async def wait_all(cls) -> None: + if not cls.tasks: + return + logger.debug("waiting for %d tasks to finish", len(cls.tasks)) await asyncio.gather(*cls.tasks) @@ -472,8 +475,8 @@ def load_insights() -> None: import qseek.insights # noqa: F401 logger.info("loaded qseek.insights package") - except ImportError as exc: - logger.warning("package qseek.insights not installed", exc_info=exc) + except ImportError: + logger.debug("package qseek.insights not installed") MeasurementUnit = Literal[ @@ -512,6 +515,14 @@ def get_traces(self, traces_flt: list[Trace]) -> list[Trace]: traces_flt = [tr for tr in traces_flt if tr.channel[-1] in self.channels] + tmins = {tr.tmin for tr in traces_flt} + tmaxs = {tr.tmax for tr in traces_flt} + if len(tmins) != 1 or len(tmaxs) != 1: + raise KeyError( + f"unhealthy timing on channels {self.channels}", + f" for: {', '.join('.'.join(tr.nslc_id) for tr in traces_flt)}", + ) + if len(traces_flt) != self.number_channels: raise KeyError( f"cannot get {self.number_channels} channels" @@ -520,10 +531,9 @@ def get_traces(self, traces_flt: list[Trace]) -> list[Trace]: ) if self.normalize: traces_norm = traces_flt[0].copy() - traces_norm.ydata = np.linalg.norm( - np.atleast_2d(np.array([tr.ydata for tr in traces_flt])), - axis=0, - ) + data = np.atleast_2d(np.array([tr.ydata for tr in traces_flt])) + + traces_norm.ydata = np.linalg.norm(data, axis=0) return [traces_norm] return traces_flt diff --git a/src/qseek/waveforms/squirrel.py b/src/qseek/waveforms/squirrel.py index 212819c5..2004a707 100644 --- a/src/qseek/waveforms/squirrel.py +++ b/src/qseek/waveforms/squirrel.py @@ -104,6 +104,7 @@ async def load_data() -> None | Batch: start = datetime_now() batch = await asyncio.to_thread(next, self.iterator, None) if batch is None: + await self._load_queue.put(None) return logger.debug("read waveform batch in %s", datetime_now() - start) self._fetched_batches += 1 @@ -123,6 +124,7 @@ async def post_process() -> None: await load_task await post_process_task + logger.debug("waiting for waveform batches to finish") await self.queue.put(None)