From 4772156f038c24e0289afd33956e5ee3d7444f91 Mon Sep 17 00:00:00 2001 From: Marius Isken Date: Tue, 2 Jan 2024 20:18:45 +0100 Subject: [PATCH] Refactor code and fix bugs --- src/qseek/apps/qseek.py | 16 +- src/qseek/features/ground_motion.py | 3 +- src/qseek/magnitudes/local_magnitude.py | 10 +- src/qseek/magnitudes/local_magnitude_model.py | 4 +- src/qseek/models/detection.py | 200 +++++++++--------- src/qseek/search.py | 25 +-- src/qseek/tracers/cake.py | 36 ++-- src/qseek/utils.py | 50 ++++- src/qseek/waveforms/squirrel.py | 15 +- 9 files changed, 196 insertions(+), 163 deletions(-) diff --git a/src/qseek/apps/qseek.py b/src/qseek/apps/qseek.py index bbfebf99..2bd8eb36 100644 --- a/src/qseek/apps/qseek.py +++ b/src/qseek/apps/qseek.py @@ -229,7 +229,7 @@ async def extract() -> None: total=search._detections.n_detections, ): detection = await result - await detection.dump_detection(update=True) + await detection.save(update=True) await search._detections.export_detections( jitter_location=search.octree.smallest_node_size() @@ -238,10 +238,11 @@ async def extract() -> None: asyncio.run(extract()) case "corrections": - rundir = Path(args.rundir) + import json + from qseek.corrections.base import StationCorrections - search = Search.load_rundir(rundir) + rundir = Path(args.rundir) corrections_modules = StationCorrections.get_subclasses() @@ -257,15 +258,16 @@ async def extract() -> None: ) corrections_class = corrections_modules[int(module_choice)] corrections = asyncio.run(corrections_class.prepare(rundir, console)) - search.corrections = corrections + + search = json.loads((rundir / "search.json").read_text()) + search["corrections"] = corrections.model_dump(mode="json") new_config_file = rundir.parent / f"{rundir.name}-corrections.json" console.print("writing new config file") console.print( - "to use this config file, run [bold]`qseek search %s`", - new_config_file, + f"to use this config file, run [bold]qseek search {new_config_file}" ) - new_config_file.write_text(search.model_dump_json(by_alias=False, indent=2)) + new_config_file.write_text(json.dumps(search, indent=2)) case "serve": search = Search.load_rundir(args.rundir) diff --git a/src/qseek/features/ground_motion.py b/src/qseek/features/ground_motion.py index 89d54254..0bb6178a 100644 --- a/src/qseek/features/ground_motion.py +++ b/src/qseek/features/ground_motion.py @@ -80,7 +80,6 @@ async def add_features( except Exception: continue receiver_motions.append(ground_motion) - receiver.add_feature(ground_motion) event_ground_motions = EventGroundMotion( seconds_before=self.seconds_before, @@ -95,4 +94,4 @@ async def add_features( gm.peak_ground_velocity for gm in receiver_motions ), ) - event.features.add_feature(event_ground_motions) + event.add_feature(event_ground_motions) diff --git a/src/qseek/magnitudes/local_magnitude.py b/src/qseek/magnitudes/local_magnitude.py index 243fbfe6..15a0b4e6 100644 --- a/src/qseek/magnitudes/local_magnitude.py +++ b/src/qseek/magnitudes/local_magnitude.py @@ -102,8 +102,8 @@ def n_observations(self) -> int: def csv_row(self) -> dict[str, float]: return { - f"ML_{self.model}": self.average, - f"ML_error_{self.model}": self.error, + f"ML-{self.model}": self.average, + f"ML-error-{self.model}": self.error, } def plot(self) -> None: @@ -245,10 +245,4 @@ async def add_magnitude(self, squirrel: Squirrel, event: EventDetection) -> None logger.warning("Local magnitude is NaN, skipping event %s", event.time) return - logger.info( - "Ml %.1f (±%.2f) for event %s", - local_magnitude.average, - local_magnitude.error, - event.time, - ) event.add_magnitude(local_magnitude) diff --git a/src/qseek/magnitudes/local_magnitude_model.py b/src/qseek/magnitudes/local_magnitude_model.py index b1e92ee1..ab26f700 100644 --- a/src/qseek/magnitudes/local_magnitude_model.py +++ b/src/qseek/magnitudes/local_magnitude_model.py @@ -384,7 +384,7 @@ class IcelandBardabunga(WoodAnderson, LocalMagnitudeModel): @staticmethod def get_amp_attenuation(dist_hypo_km: float, dist_epi_km: float) -> float: - return 1.2534 * np.log10(dist_hypo_km / 17) + 0.0032 * (dist_hypo_km - 17) + 2 + return 1.2534 * np.log10(dist_hypo_km / 17) - 0.0032 * (dist_hypo_km - 17) + 2 class IcelandAskjaBardabungaCombined(WoodAnderson, LocalMagnitudeModel): @@ -395,7 +395,7 @@ class IcelandAskjaBardabungaCombined(WoodAnderson, LocalMagnitudeModel): @staticmethod def get_amp_attenuation(dist_hypo_km: float, dist_epi_km: float) -> float: - return 1.1999 * np.log10(dist_hypo_km / 17) + 0.0016 * (dist_hypo_km - 17) + 2 + return 1.1999 * np.log10(dist_hypo_km / 17) - 0.0016 * (dist_hypo_km - 17) + 2 class IcelandReykjanes(WoodAnderson, LocalMagnitudeModel): diff --git a/src/qseek/models/detection.py b/src/qseek/models/detection.py index 819b4f1a..1f899ba4 100644 --- a/src/qseek/models/detection.py +++ b/src/qseek/models/detection.py @@ -6,7 +6,7 @@ from itertools import chain from pathlib import Path from random import uniform -from typing import TYPE_CHECKING, Any, ClassVar, Iterator, Literal, Type, TypeVar +from typing import TYPE_CHECKING, Any, ClassVar, Iterator, Literal from uuid import UUID, uuid4 import aiofiles @@ -27,7 +27,7 @@ from typing_extensions import Self from qseek.console import console -from qseek.features import EventFeaturesType, ReceiverFeaturesType +from qseek.features import EventFeaturesType from qseek.images.images import ImageFunctionPick from qseek.magnitudes import EventMagnitudeType from qseek.models.detection_uncertainty import DetectionUncertainty @@ -38,31 +38,50 @@ from qseek.utils import PhaseDescription, Symbols, filter_clipped_traces, time_to_path if TYPE_CHECKING: - from pyrocko.squirrel import Response, Squirrel + from pyrocko.squirrel import Squirrel from pyrocko.trace import Trace - from qseek.features.base import EventFeature, ReceiverFeature + from qseek.features.base import EventFeature from qseek.magnitudes.base import EventMagnitude logger = logging.getLogger(__name__) -_ReceiverFeature = TypeVar("_ReceiverFeature", bound=ReceiverFeaturesType) - - MeasurementUnit = Literal[ "displacement", "velocity", "acceleration", ] - FILENAME_DETECTIONS = "detections.json" FILENAME_RECEIVERS = "detections_receivers.json" UPDATE_LOCK = asyncio.Lock() +class ReceiverCache: + file: Path + lines: list[str] = [] + mtime: float | None = None + + def __init__(self, file: Path) -> None: + self.file = file + self.load() + + def load(self) -> None: + if not self.file.exists(): + logger.debug("receiver cache %s does not exist", self.file) + return + logger.debug("loading receiver cache from %s", self.file) + self.lines = self.file.read_text().splitlines() + self.mtime = self.file.stat().st_mtime + + def get_row(self, row_index: int) -> str: + if self.mtime is None or self.mtime != self.file.stat().st_mtime: + self.load() + return self.lines[row_index] + + class PhaseDetection(BaseModel): phase: PhaseDescription model: RayTracerArrival @@ -131,24 +150,11 @@ def as_pyrocko_markers(self) -> list[marker.PhaseMarker]: class Receiver(Station): - features: list[ReceiverFeaturesType] = [] phase_arrivals: dict[PhaseDescription, PhaseDetection] = {} def add_phase_detection(self, arrival: PhaseDetection) -> None: self.phase_arrivals[arrival.phase] = arrival - def add_feature(self, feature: ReceiverFeature) -> None: - self.features = [ - f for f in self.features if not isinstance(feature, f.__class__) - ] - self.features.append(feature) - - def get_feature(self, feature_type: Type[_ReceiverFeature]) -> _ReceiverFeature: - for feature in self.features: - if isinstance(feature, feature_type): - return feature - raise TypeError(f"cannot find feature of type {feature_type.__class__}") - def as_pyrocko_markers(self) -> list[marker.PhaseMarker]: """ Convert the phase arrivals to Pyrocko markers. @@ -186,27 +192,6 @@ def get_arrivals_time_window( times = [arrival.get_arrival_time() for arrival in self.phase_arrivals.values()] return min(times), max(times) - def get_waveforms( - self, - squirrel: Squirrel, - seconds_after: float = 5.0, - seconds_before: float = 3.0, - phase: PhaseDescription | None = None, - load_data: bool = True, - ) -> list[Trace]: - start_time, end_time = self.get_arrivals_time_window(phase) - - traces = squirrel.get_waveforms( - codes=[(*self.nsl, "*")], - tmin=(start_time - timedelta(seconds=seconds_before)).timestamp(), - tmax=(end_time + timedelta(seconds=seconds_after)).timestamp(), - want_incomplete=False, - load_data=load_data, - ) - if not traces: - raise KeyError - return traces - @classmethod def from_station(cls, station: Station) -> Self: return cls.model_construct(**station.model_dump()) @@ -254,20 +239,21 @@ def get_waveforms( tmin = min(times).timestamp() - seconds_before tmax = max(times).timestamp() + seconds_after nslc_ids = [(*receiver.nsl, "*") for receiver in self] - traces: list[Trace] = squirrel.get_waveforms( + traces = squirrel.get_waveforms( codes=nslc_ids, tmin=tmin, tmax=tmax, accessor_id=accessor_id, want_incomplete=False, ) + squirrel.clear_accessor(accessor_id, cache_id="waveform") + for tr in traces: # Crop to receiver's phase arrival window receiver = self.get_receiver(tr.nslc_id[:3]) tmin, tmax = receiver.get_arrivals_time_window(phase) tr.chop(tmin.timestamp() - seconds_before, tmax.timestamp() + seconds_after) - squirrel.advance_accessor(accessor_id, cache_id="waveform") return traces async def get_waveforms_restituted( @@ -288,16 +274,16 @@ async def get_waveforms_restituted( Args: squirrel (Squirrel): The squirrel waveform organizer. - seconds_before (float, optional): Number of seconds before the event + seconds_before (float, optional): Number of seconds before phase arrival to retrieve. Defaults to 2.0. - seconds_after (float, optional): Number of seconds after the event + seconds_after (float, optional): Number of seconds after phase arrival to retrieve. Defaults to 5.0. seconds_fade (float, optional): Number of seconds for fade in/out. Defaults to 5.0. cut_off_fade (bool, optional): Whether to cut off the fade in/out. Defaults to True. - phase (PhaseDescription | None, optional): The phase description. - Defaults to None. + phase (PhaseDescription | None, optional): The phase description. If None, + the whole time window is retrieved. Defaults to None. quantity (MeasurementUnit, optional): The measurement unit. Defaults to "velocity". demean (bool, optional): Whether to demean the waveforms. Defaults to True. @@ -321,7 +307,7 @@ async def get_waveforms_restituted( restituted_traces = [] for tr in traces: try: - response: Response = squirrel.get_response( + response = squirrel.get_response( tmin=tr.tmin, tmax=tr.tmax, codes=[tr.nslc_id], @@ -461,6 +447,7 @@ class EventDetection(Location): _detection_idx: int | None = PrivateAttr(None) _rundir: ClassVar[Path | None] = None + _receiver_cache: ClassVar[ReceiverCache | None] = None @field_validator("features", mode="before") @classmethod @@ -479,6 +466,7 @@ def set_rundir(cls, rundir: Path) -> None: rundir (Path): The path to the rundir. """ cls._rundir = rundir + cls._receiver_cache = ReceiverCache(rundir / FILENAME_RECEIVERS) @property def magnitude(self) -> EventMagnitude | None: @@ -489,9 +477,7 @@ def magnitude(self) -> EventMagnitude | None: """ return self.magnitudes[0] if self.magnitudes else None - async def dump_detection( - self, file: Path | None = None, update: bool = False - ) -> None: + async def save(self, file: Path | None = None, update: bool = False) -> None: """ Dump the detection data to a file. @@ -592,12 +578,12 @@ def receivers(self) -> EventReceivers: elif self._detection_idx is None: self._receivers = EventReceivers(event_uid=self.uid) elif self._rundir and self._detection_idx is not None: - logger.debug("fetching receiver information from file") - receiver_file = self._rundir / FILENAME_RECEIVERS - with receiver_file.open() as f: - for _ in range(self._detection_idx): # Seek to line - next(f) - receivers = EventReceivers.model_validate_json(next(f)) + if self._receiver_cache is None: + raise ValueError("cannot fetch receivers without set rundir") + + logger.debug("fetching receiver information from cache") + row = self._receiver_cache.get_row(self._detection_idx) + receivers = EventReceivers.model_validate_json(row) if receivers.event_uid != self.uid: raise ValueError(f"uid mismatch: {receivers.event_uid} != {self.uid}") @@ -634,12 +620,12 @@ def get_csv_dict(self) -> dict[str, Any]: """ csv_line = { "time": self.time, - "lat": round(self.effective_lat, 5), - "lon": round(self.effective_lon, 5), - "depth": round(self.effective_depth, 5), - "east_shift": round(self.east_shift, 5), - "north_shift": round(self.north_shift, 5), - "distance_border": round(self.distance_border, 5), + "lat": round(self.effective_lat, 6), + "lon": round(self.effective_lon, 6), + "depth": round(self.effective_depth, 2), + "east_shift": round(self.east_shift, 2), + "north_shift": round(self.north_shift, 2), + "distance_border": round(self.distance_border, 2), "in_bounds": self.in_bounds, "semblance": self.semblance, } @@ -763,8 +749,8 @@ async def add(self, detection: EventDetection) -> None: self.detections.append(detection) logger.info( - "%s event detection #%d %s: %.5f°, %.5f°, depth %.1f m, " - "border distance %.1f m, semblance %.3f", + "%s event detection %d %s: %.5f°, %.5f°, depth %.1f m, " + "border distance %.1f m, semblance %.3f, magnitude %.2f", Symbols.Target, self.n_detections, detection.time, @@ -772,30 +758,13 @@ async def add(self, detection: EventDetection) -> None: detection.depth, detection.distance_border, detection.semblance, + detection.magnitude.average if detection.magnitude else 0.0, ) self._stats.new_detection(detection) # This has to happen after the markers are saved, cache is cleared - await detection.dump_detection() + await detection.save() - async def export_detections(self, jitter_location: float = 0.0) -> None: - """Dump all detections to files in the detection directory.""" - - logger.debug("dumping detections") - - await self.export_csv(self.csv_dir / "detections.csv") - self.export_pyrocko_events(self.rundir / "pyrocko_detections.list") - - if jitter_location: - await self.export_csv( - self.csv_dir / "detections_jittered.csv", - jitter_location=jitter_location, - ) - self.export_pyrocko_events( - self.rundir / "pyrocko_detections_jittered.list", - jitter_location=jitter_location, - ) - - def add_semblance_trace(self, trace: Trace) -> None: + def save_semblance_trace(self, trace: Trace) -> None: """Add semblance trace to detection and save to file. Args: @@ -827,21 +796,52 @@ def load_rundir(cls, rundir: Path) -> EventDetections: detections.detections.append(detection) logger.info("loaded %d detections", detections.n_detections) - detections._stats.n_detections = detections.n_detections - detections._stats.max_semblance = max( - detection.semblance for detection in detections - ) + + stats = detections._stats + stats.n_detections = detections.n_detections + if detections: + stats.max_semblance = max(detection.semblance for detection in detections) return detections + async def save(self) -> None: + """Save detections to current rundir.""" + logger.debug("saving %d detections", self.n_detections) + async with aiofiles.open(self.rundir / FILENAME_DETECTIONS, "w") as f: + for detection in self: + await f.write(f"{detection.model_dump_json(exclude={'receivers'})}\n") + + async def export_detections(self, jitter_location: float = 0.0) -> None: + """ + Export detections to CSV and Pyrocko event lists in the current rundir. + + Args: + jitter_location (float): The amount of jitter in [m] to apply + to the detection locations. Defaults to 0.0. + """ + logger.debug("dumping detections") + + await self.export_csv(self.csv_dir / "detections.csv") + self.export_pyrocko_events(self.rundir / "pyrocko_detections.list") + + if jitter_location: + await self.export_csv( + self.csv_dir / "detections_jittered.csv", + jitter_location=jitter_location, + ) + self.export_pyrocko_events( + self.rundir / "pyrocko_detections_jittered.list", + jitter_location=jitter_location, + ) + async def export_csv(self, file: Path, jitter_location: float = 0.0) -> None: - """Export detections to a CSV file + """Export detections to a CSV file. Args: - file (Path): output filename - randomize_meters (float, optional): randomize the location of each detection + file (Path): The output filename. + jitter_location (float, optional): Randomize the location of each detection by this many meters. Defaults to 0.0. """ - header = set() + header = [] if jitter_location: detections = [det.jitter_location(jitter_location) for det in self] @@ -851,16 +851,19 @@ async def export_csv(self, file: Path, jitter_location: float = 0.0) -> None: csv_dicts: list[dict] = [] for detection in detections: csv = detection.get_csv_dict() - header.update(csv.keys()) + for key in csv: + if key not in header: + header.append(key) csv_dicts.append(csv) - lines = [ + header_line = [",".join(header) + "\n"] + rows = [ ",".join(str(csv.get(key, "")) for key in header) + "\n" for csv in csv_dicts ] - async with aiofiles.open(file) as f: - await f.writelines(lines) + async with aiofiles.open(file, "w") as f: + await f.writelines(header_line + rows) def export_pyrocko_events( self, filename: Path, jitter_location: float = 0.0 @@ -877,6 +880,7 @@ def export_pyrocko_events( dump_events( [det.as_pyrocko_event() for det in detections], filename=str(filename), + format="yaml", ) def export_pyrocko_markers(self, filename: Path) -> None: diff --git a/src/qseek/search.py b/src/qseek/search.py index 391c7502..62fb2780 100644 --- a/src/qseek/search.py +++ b/src/qseek/search.py @@ -1,7 +1,6 @@ from __future__ import annotations import asyncio -import contextlib import cProfile import logging from collections import deque @@ -33,6 +32,7 @@ from qseek.stats import RuntimeStats, Stats from qseek.tracers.tracers import RayTracer, RayTracers from qseek.utils import ( + BackgroundTasks, PhaseDescription, alog_call, datetime_now, @@ -393,9 +393,9 @@ async def start(self, force_rundir: bool = False) -> None: detections, semblance_trace = await search_block.search() - self._detections.add_semblance_trace(semblance_trace) + self._detections.save_semblance_trace(semblance_trace) if detections: - await self.new_detections(detections) + BackgroundTasks.create_task(self.new_detections(detections)) stats.add_processed_batch( batch, @@ -405,6 +405,7 @@ async def start(self, force_rundir: bool = False) -> None: self.set_progress(batch.end_time) + await BackgroundTasks.wait_all() console.cancel() await self._detections.export_detections(jitter_location=self.octree.size_limit) logger.info("finished search in %s", datetime_now() - processing_start) @@ -490,15 +491,15 @@ def from_config( def has_rundir(self) -> bool: return hasattr(self, "_rundir") and self._rundir.exists() - def __del__(self) -> None: - # FIXME: Replace with signal overserver? - if hasattr(self, "_detections"): - with contextlib.suppress(Exception): - asyncio.ensure_future( # noqa: RUF006 - self._detections.export_detections( - jitter_location=self.octree.size_limit - ) - ) + # def __del__(self) -> None: + # FIXME: Replace with signal overserver? + # if hasattr(self, "_detections"): + # with contextlib.suppress(Exception): + # asyncio.run( + # self._detections.export_detections( + # jitter_location=self.octree.size_limit + # ) + # ) class SearchTraces: diff --git a/src/qseek/tracers/cake.py b/src/qseek/tracers/cake.py index 5256e39c..4099e2c8 100644 --- a/src/qseek/tracers/cake.py +++ b/src/qseek/tracers/cake.py @@ -412,15 +412,15 @@ def get_travel_times(self, octree: Octree, stations: Stations) -> np.ndarray: "was the LUT initialized with `TravelTimeTree.init_lut`?" ) from exc - stations_traveltimes = [] + stations_travel_times = [] fill_nodes = [] for node in octree: try: - node_traveltimes = self._node_lut[node.hash()][station_indices] + node_travel_times = self._node_lut[node.hash()][station_indices] except KeyError: fill_nodes.append(node) continue - stations_traveltimes.append(node_traveltimes) + stations_travel_times.append(node_travel_times) if fill_nodes: self.fill_lut(fill_nodes) @@ -434,7 +434,7 @@ def get_travel_times(self, octree: Octree, stations: Stations) -> np.ndarray: ) return self.get_travel_times(octree, stations) - return np.asarray(stations_traveltimes).astype(float, copy=False) + return np.asarray(stations_travel_times).astype(float, copy=False) def interpolate_travel_times( self, @@ -476,14 +476,14 @@ def _interpolate_travel_times( f"for {n_nodes} nodes", total=len(coordinates), ) - traveltimes = [] + travel_times = [] for coords in coordinates: - traveltimes.append(self._interpolate_traveltimes_sptree(coords)) + travel_times.append(self._interpolate_traveltimes_sptree(coords)) PROGRESS.update(status, advance=1) PROGRESS.remove_task(status) - return np.asarray(traveltimes).astype(float) + return np.asarray(travel_times).astype(float) def get_travel_time(self, source: Location, receiver: Location) -> float: coordinates = [ @@ -492,10 +492,10 @@ def get_travel_time(self, source: Location, receiver: Location) -> float: receiver.surface_distance_to(source), ] try: - traveltime = self._get_sptree().interpolate(coordinates) or np.nan + travel_time = self._get_sptree().interpolate(coordinates) or np.nan except spit.OutOfBounds: - traveltime = np.nan - return float(traveltime) + travel_time = np.nan + return float(travel_time) class CakeTracer(RayTracer): @@ -520,7 +520,7 @@ class CakeTracer(RayTracer): description="Size of the LUT cache. Default is `2G`.", ) - _traveltime_trees: dict[PhaseDescription, TravelTimeTree] = PrivateAttr({}) + _travel_time_trees: dict[PhaseDescription, TravelTimeTree] = PrivateAttr({}) @property def cache_dir(self) -> Path: @@ -530,7 +530,7 @@ def cache_dir(self) -> Path: def clear_cache(self) -> None: """Clear cached SPTreeModels from user's cache.""" - logging.info("clearing traveltime cached trees in %s", self.cache_dir) + logging.info("clearing cached travel time trees in %s", self.cache_dir) for file in self.cache_dir.glob("*.sptree"): file.unlink() @@ -590,18 +590,18 @@ async def prepare( for phase_descr, timing in self.phases.items(): for tree in cached_trees: if tree.is_suited(timing=timing, **traveltime_tree_args): - logger.info("using cached traveltime tree for %s", phase_descr) + logger.info("using cached travel time tree for %s", phase_descr) break else: - logger.info("pre-calculating traveltime tree for %s", phase_descr) + logger.info("pre-calculating travel time tree for %s", phase_descr) tree = TravelTimeTree.new(timing=timing, **traveltime_tree_args) tree.save(self.cache_dir) tree.init_lut(octree, stations) - self._traveltime_trees[phase_descr] = tree + self._travel_time_trees[phase_descr] = tree def _get_sptree_model(self, phase: str) -> TravelTimeTree: - return self._traveltime_trees[phase] + return self._travel_time_trees[phase] def _load_cached_trees(self) -> list[TravelTimeTree]: trees = [] @@ -609,7 +609,7 @@ def _load_cached_trees(self) -> list[TravelTimeTree]: try: tree = TravelTimeTree.load(file) except ValidationError: - logger.warning("deleting invalid cached tree %s", file) + logger.warning("deleting invalid cached travel time tree %s", file) file.unlink() continue trees.append(tree) @@ -635,7 +635,7 @@ def get_travel_times( stations: Stations, ) -> np.ndarray: if phase not in self.phases: - raise ValueError(f"Phase {phase} is not defined.") + raise ValueError(f"Phase {phase} is not defined.") return self._get_sptree_model(phase).get_travel_times(octree, stations) def get_arrivals( diff --git a/src/qseek/utils.py b/src/qseek/utils.py index 0bc88927..d1a25e2b 100644 --- a/src/qseek/utils.py +++ b/src/qseek/utils.py @@ -1,12 +1,23 @@ from __future__ import annotations +import asyncio import logging import os +import re import time from datetime import datetime, timedelta, timezone from functools import wraps from pathlib import Path -from typing import TYPE_CHECKING, Annotated, Awaitable, Callable, ParamSpec, TypeVar +from typing import ( + TYPE_CHECKING, + Annotated, + Awaitable, + Callable, + ClassVar, + Coroutine, + ParamSpec, + TypeVar, +) import numpy as np from pydantic import ByteSize, constr @@ -14,6 +25,8 @@ from rich.logging import RichHandler if TYPE_CHECKING: + from contextvars import Context + from pyrocko.trace import Trace logger = logging.getLogger(__name__) @@ -32,13 +45,7 @@ class Symbols: Target = "🞋" Check = "✓" CheckerBoard = "🙾" - - -class ANSI: - Bold = "\033[1m" - Italic = "\033[3m" - Underline = "\033[4m" - Reset = "\033[0m" + Cross = "✗" def setup_rich_logging(level: int) -> None: @@ -50,6 +57,31 @@ def setup_rich_logging(level: int) -> None: ) +class BackgroundTasks: + tasks: ClassVar[set[asyncio.Task]] = set() + + @classmethod + def create_task( + cls, + coro: Coroutine, + name: str | None = None, + context: Context | None = None, + ) -> asyncio.Task: + task = asyncio.create_task(coro, name=name, context=context) + cls.tasks.add(task) + task.add_done_callback(cls.tasks.remove) + return task + + @classmethod + def cancel_all(cls) -> None: + for task in cls.tasks: + task.cancel() + + @classmethod + async def wait_all(cls) -> None: + await asyncio.gather(*cls.tasks) + + def time_to_path(datetime: datetime) -> str: """ Converts a datetime object to a string representation of a file path. @@ -232,7 +264,7 @@ def camel_case_to_snake_case(name: str) -> str: >>> camel_case_to_snake_case("camelCaseString") 'camel_case_string' """ - return "".join(["_" + i.lower() if i.isupper() else i for i in name]).lstrip("_") + return re.sub(r"(? None: diff --git a/src/qseek/waveforms/squirrel.py b/src/qseek/waveforms/squirrel.py index d03f07ed..dff71265 100644 --- a/src/qseek/waveforms/squirrel.py +++ b/src/qseek/waveforms/squirrel.py @@ -112,12 +112,14 @@ async def post_process_batch(batch: Batch) -> None: await asyncio.to_thread(post_processing, batch) await self.queue.put(batch) - async with asyncio.TaskGroup() as group: - while not done.is_set(): - batch = await load_next() - if batch is None: - break - group.create_task(post_process_batch(batch)) + post_processing_task: asyncio.Task | None = None + while not done.is_set(): + batch = await load_next() + if batch is None: + break + if post_processing_task: + await post_processing_task + post_processing_task = asyncio.create_task(post_process_batch(batch)) await self.queue.put(None) @@ -278,7 +280,6 @@ async def iter_batches( tinc=window_increment.total_seconds(), tpad=window_padding.total_seconds(), want_incomplete=False, - accessor_id="qseek.squirrel", codes=[ (*nsl, self.channel_selector) for nsl in self._stations.get_all_nsl() ],