From e26d5d2afe9232b8bfb7c69fc7be860031b53530 Mon Sep 17 00:00:00 2001 From: Mi! Date: Wed, 13 Sep 2023 11:50:39 +0200 Subject: [PATCH] Development (#2) * fixes * adding plots * adding plots * eq rate plot * minor changes * adding fast-marching * fmm: implementation * wip: fast-marching implementation * cake: loading .nd files // adding depth constrain * wip: fast-marching implementation * cake: fixing new structure * fast-marching: adding NonLinLoc support * wip: fast-marching implementation * wip: fast-marching implementation * wip: fast-marching implementation * wip: fast-marching implementation * wip: fast-marching implementation * wip: fast-marching implementation * wip: fixing tests * wip: finishing up fast-marching * wip: finishing up fast-marching * fix typing * fixing tests * refactoring waveform provider * finishing up * update README --------- Co-authored-by: Marius Isken --- .github/workflows/build.yaml | 2 +- .pre-commit-config.yaml | 2 +- README.md | 32 +- docs/index.md | 17 + lassie/apps/lassie.py | 80 ++- lassie/features/local_magnitude.py | 2 +- lassie/images/__init__.py | 19 +- lassie/images/base.py | 6 +- lassie/images/phase_net.py | 10 +- lassie/models/detection.py | 21 +- lassie/models/location.py | 99 +++- lassie/models/semblance.py | 3 + lassie/models/station.py | 59 +- lassie/octree.py | 76 ++- lassie/plot/detections.py | 66 +++ lassie/plot/octree.py | 95 +++- lassie/{ => plot}/plot.py | 9 +- lassie/plot/utils.py | 40 ++ lassie/{search/base.py => search.py} | 309 ++++++++--- lassie/search/__init__.py | 1 - lassie/search/squirrel.py | 209 ------- lassie/station_corrections.py | 9 +- lassie/tracers/__init__.py | 18 +- lassie/tracers/base.py | 12 +- lassie/tracers/cake.py | 229 +++++--- lassie/tracers/constant_velocity.py | 33 +- lassie/tracers/fast_marching/__init__.py | 1 + lassie/tracers/fast_marching/fast_marching.py | 524 ++++++++++++++++++ .../tracers/fast_marching/velocity_models.py | 467 ++++++++++++++++ lassie/utils.py | 11 +- lassie/waveforms/__init__.py | 11 + lassie/waveforms/base.py | 67 +++ lassie/waveforms/squirrel.py | 142 +++++ mkdocs.yml | 14 + pyproject.toml | 10 +- test/conftest.py | 110 +++- test/test_cake.py | 30 +- test/test_fast_marching.py | 235 ++++++++ test/test_location.py | 67 +++ test/test_octree.py | 23 +- test/test_plot.py | 45 ++ test/test_search.py | 4 +- test/upload_data.sh | 3 + 43 files changed, 2623 insertions(+), 599 deletions(-) create mode 100644 docs/index.md create mode 100644 lassie/plot/detections.py rename lassie/{ => plot}/plot.py (80%) create mode 100644 lassie/plot/utils.py rename lassie/{search/base.py => search.py} (60%) delete mode 100644 lassie/search/__init__.py delete mode 100644 lassie/search/squirrel.py create mode 100644 lassie/tracers/fast_marching/__init__.py create mode 100644 lassie/tracers/fast_marching/fast_marching.py create mode 100644 lassie/tracers/fast_marching/velocity_models.py create mode 100644 lassie/waveforms/__init__.py create mode 100644 lassie/waveforms/base.py create mode 100644 lassie/waveforms/squirrel.py create mode 100644 mkdocs.yml create mode 100644 test/test_fast_marching.py create mode 100644 test/test_plot.py create mode 100755 test/upload_data.sh diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index a786957c..bc9f1a25 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -23,4 +23,4 @@ jobs: pip install .[dev] - name: Test with pytest run: | - pytest + pytest -m "not plot" diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9c65a47b..df667eb2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -19,6 +19,6 @@ repos: # language_version: python3.9 - repo: https://github.com/charliermarsh/ruff-pre-commit # Ruff version. - rev: "v0.0.275" + rev: "v0.0.287" hooks: - id: ruff diff --git a/README.md b/README.md index 224820b1..d81621f1 100644 --- a/README.md +++ b/README.md @@ -8,21 +8,28 @@ [![pre-commit](https://img.shields.io/badge/pre--commit-enabled-brightgreen?logo=pre-commit&logoColor=white)](https://pre-commit.com/) -Lassie is an earthquake detector based on stacking and migration method. It combines neural network phase picks with an iterative octree localisation approach. +Lassie is an earthquake detection and localisation framework based on stacking and migration method. It combines neural network phase picks with an iterative octree localisation approach for accurate localisation of seismic events. -Key features are of the tools are: +Key features are of the earthquake detection and localisation framework are: -* Phase detection using SeisBench -* Efficient and accurate Octree localisation approach -* Extraction of event features +* Earthquake phase detection using machine-learning pickers from [SeisBench](https://github.com/seisbench/seisbench) +* Octree localisation approach for efficient and accurate search +* Different velocity models: + * Constant velocity + * 1D Layered velocity model + * 3D fast-marching velocity model (NonLinLoc compatible) +* Extraction of earthquake event features: * Local magnitudes * Ground motion attributes -* Determination of station corrections +* Automatic extraction of modelled and picked travel times +* Calculation and application of station corrections / station delay times + +Lassie is built on top of [Pyrocko](https://pyrocko.org). ## Installation ```sh -git clone https://github.com/miili/lassie-v2 +git clone https://github.com/pyrocko/lassie-v2 cd lassie-v2 pip3 install . ``` @@ -32,12 +39,12 @@ pip3 install . Initialize a new project in a fresh directory. ```sh -lassie new project-dir/ +lassie init my-project/ ``` -Edit the `search.json` +Edit the `my-project.json` -Start the detection +Start the earthquake detection with ```sh lassie run search.json @@ -52,14 +59,13 @@ The simplest and recommended way of installing from source: Local development through pip. ```sh -cd lightguide +cd lassie-v2 pip3 install .[dev] ``` The project utilizes pre-commit for clean commits, install the hooks via: ```sh -pip install pre-commit pre-commit install ``` @@ -73,4 +79,4 @@ Please cite lassie as: Contribution and merge requests by the community are welcome! -Lassie-v@ was written by Marius Paul Isken and is licensed under the GNU GENERAL PUBLIC LICENSE v3. +Lassie-v2 was written by Marius Paul Isken and is licensed under the GNU GENERAL PUBLIC LICENSE v3. diff --git a/docs/index.md b/docs/index.md new file mode 100644 index 00000000..000ea345 --- /dev/null +++ b/docs/index.md @@ -0,0 +1,17 @@ +# Welcome to MkDocs + +For full documentation visit [mkdocs.org](https://www.mkdocs.org). + +## Commands + +* `mkdocs new [dir-name]` - Create a new project. +* `mkdocs serve` - Start the live-reloading docs server. +* `mkdocs build` - Build the documentation site. +* `mkdocs -h` - Print help message and exit. + +## Project layout + + mkdocs.yml # The configuration file. + docs/ + index.md # The documentation homepage. + ... # Other markdown pages, images and other files. diff --git a/lassie/apps/lassie.py b/lassie/apps/lassie.py index 8e80a3b8..ae5328fb 100644 --- a/lassie/apps/lassie.py +++ b/lassie/apps/lassie.py @@ -3,22 +3,18 @@ import argparse import asyncio import logging -from datetime import datetime +import shutil from pathlib import Path import nest_asyncio from pkg_resources import get_distribution from lassie.console import console -from lassie.images import ImageFunctions -from lassie.images.phase_net import PhaseNet from lassie.models import Stations -from lassie.search import SquirrelSearch +from lassie.search import Search from lassie.server import WebServer from lassie.station_corrections import StationCorrections -from lassie.tracers import RayTracers -from lassie.tracers.cake import CakeTracer -from lassie.utils import ANSI, setup_rich_logging +from lassie.utils import CACHE_DIR, setup_rich_logging nest_asyncio.apply() @@ -66,6 +62,14 @@ def main() -> None: ) continue_run.add_argument("rundir", type=Path, help="existing runding to continue") + init_project = subparsers.add_parser( + "init", + help="initialize a new Lassie project", + ) + init_project.add_argument( + "folder", type=Path, help="folder to initialize project in" + ) + features = subparsers.add_parser( "feature-extraction", help="extract features from an existing run", @@ -74,7 +78,7 @@ def main() -> None: features.add_argument("rundir", type=Path, help="path of existing run") station_corrections = subparsers.add_parser( - "station-corrections", + "corrections", help="analyse station corrections from existing run", description="analyze and plot station corrections from a finished run", ) @@ -88,27 +92,26 @@ def main() -> None: serve = subparsers.add_parser( "serve", - help="serve results from an existing run", + help="start webserver and serve results from an existing run", description="start a webserver and serve detections and results from a run", ) serve.add_argument("rundir", type=Path, help="rundir to serve") - new = subparsers.add_parser( - "new", - help="initialize a new project", + subparsers.add_parser( + "clear-cache", + help="clear the cach directory", ) - new.add_argument("folder", type=Path, help="folder to initialize project in") dump_schemas = subparsers.add_parser( "dump-schemas", - help="dump models to json-schema (development)", + help="dump data models to json-schema (development)", ) dump_schemas.add_argument("folder", type=Path, help="folder to dump schemas to") args = parser.parse_args() setup_rich_logging(level=logging.INFO - args.verbose * 10) - if args.command == "new": + if args.command == "init": folder: Path = args.folder if folder.exists(): raise FileExistsError(f"Folder {folder} already exists") @@ -117,43 +120,34 @@ def main() -> None: pyrocko_stations = folder / "pyrocko-stations.yaml" pyrocko_stations.touch() - config = SquirrelSearch( - ray_tracers=RayTracers(root=[CakeTracer()]), - image_functions=ImageFunctions( - root=[PhaseNet(phase_map={"P": "cake:P", "S": "cake:S"})] - ), - stations=Stations(pyrocko_station_yamls=[pyrocko_stations]), - waveform_data=[Path("/data/")], - time_span=( - datetime.fromisoformat("2023-04-11T00:00:00+00:00"), - datetime.fromisoformat("2023-04-18T00:00:00+00:00"), - ), + config = Search( + stations=Stations( + pyrocko_station_yamls=[pyrocko_stations.relative_to(folder)] + ) ) - config_file = folder / "config.json" + config_file = folder / f"{folder.name}.json" config_file.write_text(config.model_dump_json(by_alias=False, indent=2)) + logger.info("initialized new project in folder %s", folder) - logger.info( - "start detecting with:\n\t%slassie run config.json%s", ANSI.Bold, ANSI.Reset - ) + logger.info("start detection with: lassie run %s", config_file.name) elif args.command == "run": - search = SquirrelSearch.from_config(args.config) - search.init_rundir(force=args.force) + search = Search.from_config(args.config) webserver = WebServer(search) async def _run() -> None: http = asyncio.create_task(webserver.start()) - await search.scan_squirrel() + await search.start(force_rundir=args.force) await http asyncio.run(_run()) elif args.command == "continue": - search = SquirrelSearch.load_rundir(args.rundir) - if search.progress.time_progress: - console.rule(f"Continuing search from {search.progress.time_progress}") + search = Search.load_rundir(args.rundir) + if search._progress.time_progress: + console.rule(f"Continuing search from {search._progress.time_progress}") else: console.rule("Starting search from scratch") @@ -161,13 +155,13 @@ async def _run() -> None: async def _run() -> None: http = asyncio.create_task(webserver.start()) - await search.scan_squirrel() + await search.start() await http asyncio.run(_run()) elif args.command == "feature-extraction": - search = SquirrelSearch.load_rundir(args.rundir) + search = Search.load_rundir(args.rundir) async def extract() -> None: for detection in search._detections.detections: @@ -175,7 +169,7 @@ async def extract() -> None: asyncio.run(extract()) - elif args.command == "station-corrections": + elif args.command == "corrections": rundir = Path(args.rundir) station_corrections = StationCorrections(rundir=rundir) if args.plot: @@ -183,13 +177,17 @@ async def extract() -> None: station_corrections.save_csv(filename=rundir / "station_corrections_stats.csv") elif args.command == "serve": - search = SquirrelSearch.load_rundir(args.rundir) + search = Search.load_rundir(args.rundir) webserver = WebServer(search) loop = asyncio.get_event_loop() loop.create_task(webserver.start()) loop.run_forever() + elif args.command == "clear-cache": + logger.info("clearing cache directory %s", CACHE_DIR) + shutil.rmtree(CACHE_DIR) + elif args.command == "dump-schemas": from lassie.models.detection import EventDetections @@ -198,7 +196,7 @@ async def extract() -> None: file = args.folder / "search.schema.json" print(f"writing JSON schemas to {args.folder}") - file.write_text(SquirrelSearch.model_json_schema(indent=2)) + file.write_text(Search.model_json_schema(indent=2)) file = args.folder / "detections.schema.json" file.write_text(EventDetections.model_json_schema(indent=2)) diff --git a/lassie/features/local_magnitude.py b/lassie/features/local_magnitude.py index 0c78b021..91087667 100644 --- a/lassie/features/local_magnitude.py +++ b/lassie/features/local_magnitude.py @@ -254,4 +254,4 @@ async def add_features(self, squirrel: Squirrel, event: EventDetection) -> None: print(event.time, local_magnitude) event.magnitude = local_magnitude.median event.magnitude_type = "local" - event.feature.add_feature(local_magnitude) + event.features.add_feature(local_magnitude) diff --git a/lassie/images/__init__.py b/lassie/images/__init__.py index 6a4c1451..5773511f 100644 --- a/lassie/images/__init__.py +++ b/lassie/images/__init__.py @@ -2,7 +2,8 @@ import logging from dataclasses import dataclass -from typing import TYPE_CHECKING, Annotated, Iterator, Union +from itertools import chain +from typing import TYPE_CHECKING, Annotated, Any, Iterator, Union from pydantic import Field, RootModel @@ -34,7 +35,13 @@ class ImageFunctions(RootModel): - root: list[ImageFunctionType] = [] + root: list[ImageFunctionType] = [PhaseNet()] + + def model_post_init(self, __context: Any) -> None: + # Check if phases are provided twice + phases = self.get_phases() + if len(set(phases)) != len(phases): + raise ValueError("A phase was provided twice") async def process_traces(self, traces: list[Trace]) -> WaveformImages: images = [] @@ -44,6 +51,14 @@ async def process_traces(self, traces: list[Trace]) -> WaveformImages: return WaveformImages(root=images) + def get_phases(self) -> tuple[str, ...]: + """Get all phases that are available in the image functions. + + Returns: + tuple[str, ...]: All available phases. + """ + return tuple(chain.from_iterable(image.get_provided_phases() for image in self)) + def get_blinding(self) -> timedelta: return max(image.blinding for image in self) diff --git a/lassie/images/base.py b/lassie/images/base.py index a28887dc..cc083eb4 100644 --- a/lassie/images/base.py +++ b/lassie/images/base.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Literal import numpy as np -from pydantic import BaseModel +from pydantic import BaseModel, Field from lassie.models.phase_arrival import PhaseArrival from lassie.models.station import Stations @@ -35,7 +35,7 @@ def blinding(self) -> timedelta: """Blinding duration for the image function. Added to padded waveforms.""" raise NotImplementedError("must be implemented by subclass") - def get_available_phases(self) -> tuple[str]: + def get_provided_phases(self) -> tuple[str, ...]: ... @@ -45,7 +45,7 @@ class WaveformImage: phase: PhaseDescription weight: float traces: list[Trace] - stations: Stations = Stations.construct() + stations: Stations = Field(default_factory=lambda: Stations.model_construct()) @property def sampling_rate(self) -> float: diff --git a/lassie/images/phase_net.py b/lassie/images/phase_net.py index 6c3df32a..670d8bfe 100644 --- a/lassie/images/phase_net.py +++ b/lassie/images/phase_net.py @@ -4,12 +4,10 @@ from datetime import datetime, timedelta from typing import TYPE_CHECKING, Any, Literal -import torch from obspy import Stream from pydantic import PositiveFloat, PositiveInt, PrivateAttr, conint from pyrocko import obspy_compat from seisbench import logger -from seisbench.models import PhaseNet as PhaseNetSeisBench from lassie.images.base import ImageFunction, PickedArrival, WaveformImage from lassie.utils import alog_call, to_datetime @@ -20,6 +18,7 @@ if TYPE_CHECKING: from pyrocko.trace import Trace + from seisbench.models import PhaseNet as PhaseNetSeisBench ModelName = Literal[ "diting", @@ -102,6 +101,9 @@ class PhaseNet(ImageFunction): _phase_net: PhaseNetSeisBench = PrivateAttr(None) def model_post_init(self, __context: Any) -> None: + import torch + from seisbench.models import PhaseNet as PhaseNetSeisBench + torch.set_num_threads(self.torch_cpu_threads) self._phase_net = PhaseNetSeisBench.from_pretrained(self.model) if self.torch_use_cuda: @@ -144,5 +146,5 @@ async def process_traces(self, traces: list[Trace]) -> list[PhaseNetImage]: return [annotation_s, annotation_p] - def get_available_phases(self) -> tuple[str]: - return tuple(self.phase_map.keys()) + def get_provided_phases(self) -> tuple[str, ...]: + return tuple(self.phase_map.values()) diff --git a/lassie/models/detection.py b/lassie/models/detection.py index 417e33fd..f4027182 100644 --- a/lassie/models/detection.py +++ b/lassie/models/detection.py @@ -20,7 +20,7 @@ from lassie.models.location import Location from lassie.models.station import Station, Stations from lassie.tracers import RayTracerArrival -from lassie.utils import PhaseDescription, time_to_path +from lassie.utils import PhaseDescription, Symbols, time_to_path if TYPE_CHECKING: from pyrocko.squirrel import Response, Squirrel @@ -368,7 +368,7 @@ def as_pyrocko_event(self) -> Event: north_shift=self.north_shift, depth=self.depth, elevation=self.elevation, - magnitude=self.magnitude, + magnitude=self.magnitude or self.semblance, magnitude_type=self.magnitude_type, ) @@ -407,7 +407,7 @@ def jitter_location(self, meters: float) -> EventDetection: detection.east_shift += uniform(-half_meters, half_meters) detection.north_shift += uniform(-half_meters, half_meters) detection.depth += uniform(-half_meters, half_meters) - del detection.effective_lat_lon + detection._cached_lat_lon = None return detection def snuffle(self, squirrel: Squirrel, restituted: bool = False) -> None: @@ -449,6 +449,17 @@ def add(self, detection: EventDetection) -> None: marker.save_markers(detection.get_pyrocko_markers(), str(markers_file)) self.detections.append(detection) + logger.info( + "%s event detection #%d %s: %.5f°, %.5f°, depth %.1f m, " + "border distance %.1f m, semblance %.3f", + Symbols.Target, + self.n_detections, + detection.time, + *detection.effective_lat_lon, + detection.depth, + detection.distance_border, + detection.semblance, + ) # This has to happen after the markers are saved detection.dump_append(self.rundir, self.n_detections - 1) @@ -493,7 +504,7 @@ def load_rundir(cls, rundir: Path) -> EventDetections: detections = cls(rundir=rundir) - with console.status(f"Loading detections from {rundir}..."), open( + with console.status(f"loading detections from {rundir}..."), open( detection_file ) as f: for i_detection, line in enumerate(f): @@ -553,4 +564,4 @@ def save_pyrocko_markers(self, filename: Path) -> None: marker.save_markers(pyrocko_markers, str(filename)) def __iter__(self) -> Iterator[EventDetection]: - return iter(self.detections) + return iter(sorted(self.detections, key=lambda d: d.time)) diff --git a/lassie/models/location.py b/lassie/models/location.py index 7a96f40e..9347730a 100644 --- a/lassie/models/location.py +++ b/lassie/models/location.py @@ -1,11 +1,13 @@ from __future__ import annotations +import hashlib import math -from functools import cached_property +import struct from typing import TYPE_CHECKING, Iterable, Literal, TypeVar -from pydantic import BaseModel, computed_field +from pydantic import BaseModel, PrivateAttr from pyrocko import orthodrome as od +from typing_extensions import Self if TYPE_CHECKING: from pathlib import Path @@ -21,6 +23,8 @@ class Location(BaseModel): elevation: float = 0.0 depth: float = 0.0 + _cached_lat_lon: tuple[float, float] | None = PrivateAttr(None) + @property def effective_lat(self) -> float: return self.effective_lat_lon[0] @@ -29,19 +33,21 @@ def effective_lat(self) -> float: def effective_lon(self) -> float: return self.effective_lat_lon[1] - @computed_field - @cached_property + @property def effective_lat_lon(self) -> tuple[float, float]: """Shift-corrected lat/lon pair of the location.""" - if self.north_shift == 0.0 and self.east_shift == 0.0: - return self.lat, self.lon - lat, lon = od.ne_to_latlon( - self.lat, - self.lon, - self.north_shift, - self.east_shift, - ) - return float(lat), float(lon) + if self._cached_lat_lon is None: + if self.north_shift == 0.0 and self.east_shift == 0.0: + self._cached_lat_lon = self.lat, self.lon + else: + lat, lon = od.ne_to_latlon( + self.lat, + self.lon, + self.north_shift, + self.east_shift, + ) + self._cached_lat_lon = float(lat), float(lon) + return self._cached_lat_lon @property def effective_elevation(self) -> float: @@ -49,13 +55,20 @@ def effective_elevation(self) -> float: @property def effective_depth(self) -> float: - return self.depth + self.elevation + return self.depth - self.elevation def _same_origin(self, other: Location) -> bool: return bool(self.lat == other.lat and self.lon == other.lon) def surface_distance_to(self, other: Location) -> float: - """Compute surface distance [m] to other location object.""" + """Compute surface distance [m] to other location object. + + Args: + other (Location): The other location. + + Returns: + float: The surface distance in [m]. + """ if self._same_origin(other): return math.sqrt( @@ -69,6 +82,14 @@ def surface_distance_to(self, other: Location) -> float: ) def distance_to(self, other: Location) -> float: + """Compute 3-dimensional distance [m] to other location object. + + Args: + other (Location): The other location. + + Returns: + float: The distance in [m]. + """ if self._same_origin(other): return math.sqrt( (self.north_shift - other.north_shift) ** 2 @@ -85,9 +106,52 @@ def distance_to(self, other: Location) -> float: return math.sqrt((sx - ox) ** 2 + (sy - oy) ** 2 + (sz - oz) ** 2) + def offset_from(self, other: Location) -> tuple[float, float, float]: + """Return offset vector (east, north, depth) to other location in [m] + + Args: + other (Location): The other location. + + Returns: + tuple[float, float, float]: The offset vector. + """ + if self._same_origin(other): + return ( + self.east_shift - other.east_shift, + self.north_shift - other.north_shift, + -(self.effective_elevation - other.effective_elevation), + ) + + shift_north, shift_east = od.latlon_to_ne_numpy( + self.lat, self.lon, other.lat, other.lon + ) + + return ( + self.east_shift - other.east_shift - shift_east[0], + self.north_shift - other.north_shift - shift_north[0], + -(self.effective_elevation - other.effective_elevation), + ) + + def shifted_origin(self) -> Self: + """Shift the origin of the location to the effective lat/lon. + + Returns: + Self: The shifted location. + """ + shifted = self.model_copy() + shifted.lat = self.effective_lat + shifted.lon = self.effective_lon + shifted.east_shift = 0.0 + shifted.north_shift = 0.0 + return shifted + def __hash__(self) -> int: - return hash( - ( + return hash(self.location_hash()) + + def location_hash(self) -> str: + sha1 = hashlib.sha1( + struct.pack( + "dddddd", self.lat, self.lon, self.east_shift, @@ -96,6 +160,7 @@ def __hash__(self) -> int: self.depth, ) ) + return sha1.hexdigest() def locations_to_csv(locations: Iterable[Location], filename: Path) -> Path: diff --git a/lassie/models/semblance.py b/lassie/models/semblance.py index 5874edf1..eb74a9ae 100644 --- a/lassie/models/semblance.py +++ b/lassie/models/semblance.py @@ -96,6 +96,9 @@ def maximum_semblance(self) -> np.ndarray: self._max_semblance = self.semblance.max(axis=0) return self._max_semblance + def maximum_node_semblance(self) -> np.ndarray: + return self.semblance.max(axis=1) + async def maxima_node_idx(self, nparallel: int = 6) -> np.ndarray: """Indices of maximum semblance at any time step. diff --git a/lassie/models/station.py b/lassie/models/station.py index a637a4d5..6e6054a2 100644 --- a/lassie/models/station.py +++ b/lassie/models/station.py @@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Any, Iterable, Iterator import numpy as np -from pydantic import BaseModel, PrivateAttr, constr +from pydantic import BaseModel, constr from pyrocko.io.stationxml import load_xml from pyrocko.model import Station as PyrockoStation from pyrocko.model import dump_stations_yaml, load_stations @@ -41,7 +41,7 @@ def from_pyrocko_station(cls, station: PyrockoStation) -> Station: ) def to_pyrocko_station(self) -> PyrockoStation: - return PyrockoStation(**self.dict(exclude={"effective_lat_lon"})) + return PyrockoStation(**self.model_dump(exclude={"effective_lat_lon"})) @property def pretty_nsl(self) -> str: @@ -56,14 +56,11 @@ def __hash__(self) -> int: class Stations(BaseModel): - stations: list[Station] = [] - blacklist: set[constr(pattern=NSL_RE)] = set() - station_xmls: list[Path] = [] pyrocko_station_yamls: list[Path] = [] - _cached_coordinates: np.ndarray | None = PrivateAttr(None) - _cached_iter: list[Station] | None = PrivateAttr(None) + stations: list[Station] = [] + blacklist: set[constr(pattern=NSL_RE)] = set() def model_post_init(self, __context: Any) -> None: loaded_stations = [] @@ -88,11 +85,7 @@ def weed_stations(self) -> None: seen_nsls = set() for sta in self.stations.copy(): if sta.lat == 0.0 or sta.lon == 0.0: - logger.warning( - "blacklisting station %s: bad geographical coordinates", - sta.pretty_nsl, - ) - self.blacklist.add(sta.pretty_nsl) + self.blacklist_station(sta, reason="bad geographical coordinates") continue if sta.pretty_nsl in seen_nsls: @@ -104,6 +97,12 @@ def weed_stations(self) -> None: # if not self.stations: # logger.warning("no stations available, add stations to start detection") + def blacklist_station(self, station: Station, reason: str) -> None: + logger.warning("blacklisting station %s: %s", station.pretty_nsl, reason) + self.blacklist.add(station.pretty_nsl) + if self.n_stations == 0: + raise ValueError("no stations available, all stations blacklisted") + def weed_from_squirrel_waveforms(self, squirrel: Squirrel) -> None: """Remove stations without waveforms from squirrel instances. @@ -118,30 +117,24 @@ def weed_from_squirrel_waveforms(self, squirrel: Squirrel) -> None: n_removed_stations = 0 for sta in self.stations.copy(): if sta.pretty_nsl not in available_squirrel_nsls: - logger.debug( - "removing station %s: waveforms not available in squirrel", + logger.warning( + "removing station %s: no waveforms available in squirrel", sta.pretty_nsl, ) self.stations.remove(sta) n_removed_stations += 1 if n_removed_stations: - logger.info("removed %d stations without waveforms", n_removed_stations) + logger.warning("removed %d stations without waveforms", n_removed_stations) if not self.stations: raise ValueError("no stations available, add waveforms to start detection") def __iter__(self) -> Iterator[Station]: - if self._cached_iter is None: - self._cached_iter = [ - sta for sta in self.stations if sta.pretty_nsl not in self.blacklist - ] - return iter(self._cached_iter) + return (sta for sta in self.stations if sta.pretty_nsl not in self.blacklist) @property def n_stations(self) -> int: """Number of stations in the stations object.""" - if self._cached_iter: - return len(self._cached_iter) return sum(1 for _ in self) def get_all_nsl(self) -> list[tuple[str, str, str]]: @@ -186,11 +179,9 @@ def get_centroid(self) -> Location: ) def get_coordinates(self, system: CoordSystem = "geographic") -> np.ndarray: - if self._cached_coordinates is None: - self._cached_coordinates = np.array( - [(*sta.effective_lat_lon, sta.effective_elevation) for sta in self] - ) - return self._cached_coordinates + return np.array( + [(*sta.effective_lat_lon, sta.effective_elevation) for sta in self] + ) def dump_pyrocko_stations(self, filename: Path) -> None: """Dump stations to pyrocko station yaml file. @@ -203,5 +194,19 @@ def dump_pyrocko_stations(self, filename: Path) -> None: filename=str(filename.expanduser()), ) + def dump_csv(self, filename: Path) -> None: + """Dump stations to CSV file. + + Args: + filename (Path): Path to CSV file. + """ + with filename.open("w") as f: + f.write("network,station,location,latitude,longitude,elevation,depth\n") + for sta in self: + f.write( + f"{sta.network},{sta.station},{sta.location}," + f"{sta.lat},{sta.lon},{sta.elevation},{sta.depth}\n" + ) + def __hash__(self) -> int: return hash(sta for sta in self) diff --git a/lassie/octree.py b/lassie/octree.py index 65062f02..51c2fa84 100644 --- a/lassie/octree.py +++ b/lassie/octree.py @@ -10,11 +10,12 @@ import numpy as np from pydantic import ( - ConfigDict, BaseModel, + ConfigDict, Field, PositiveFloat, PrivateAttr, + confloat, field_validator, model_validator, ) @@ -65,13 +66,13 @@ class Node(BaseModel): semblance: float = 0.0 tree: Octree | None = Field(None, exclude=True) - children: tuple[Node] = Field((), exclude=True) + children: tuple[Node, ...] = Field((), exclude=True) _hash: bytes | None = PrivateAttr(None) - _children_cached: tuple[Node] = PrivateAttr(()) + _children_cached: tuple[Node, ...] = PrivateAttr(()) _location: Location | None = PrivateAttr(None) - def split(self) -> tuple[Node]: + def split(self) -> tuple[Node, ...]: if not self.tree: raise EnvironmentError("Parent tree is not set.") @@ -82,7 +83,7 @@ def split(self) -> tuple[Node]: half_size = self.size / 2 self._children_cached = tuple( - Node.construct( + Node.model_construct( east=self.east + east * half_size / 2, north=self.north + north * half_size / 2, depth=self.depth + depth * half_size / 2, @@ -118,6 +119,8 @@ def distance_border(self) -> float: ) def can_split(self) -> bool: + if self.tree is None: + raise AttributeError("parent tree not set") half_size = self.size / 2 return half_size >= self.tree.size_limit @@ -135,14 +138,17 @@ def distance_to_location(self, location: Location) -> float: return location.distance_to(self.as_location()) def as_location(self) -> Location: + if not self.tree: + raise AttributeError("parent tree not set") if not self._location: - self._location = Location.construct( - lat=self.tree.center_lat, - lon=self.tree.center_lon, - elevation=self.tree.surface_elevation, - east_shift=float(self.east), - north_shift=float(self.north), - depth=float(self.depth), + reference = self.tree.reference + self._location = Location.model_construct( + lat=reference.lat, + lon=reference.lon, + elevation=reference.elevation, + east_shift=reference.east_shift + float(self.east), + north_shift=reference.north_shift + float(self.north), + depth=reference.depth + float(self.depth), ) return self._location @@ -158,8 +164,8 @@ def hash(self) -> bytes: self._hash = sha1( struct.pack( "dddddd", - self.tree.center_lat, - self.tree.center_lon, + self.tree.reference.lat, + self.tree.reference.lon, self.east, self.north, self.depth, @@ -173,15 +179,13 @@ def __hash__(self) -> int: class Octree(BaseModel): - center_lat: float = 0.0 - center_lon: float = 0.0 - surface_elevation: float = 0.0 + reference: Location = Location(lat=0.0, lon=0) size_initial: PositiveFloat = 2 * KM size_limit: PositiveFloat = 500 east_bounds: tuple[float, float] = (-10 * KM, 10 * KM) north_bounds: tuple[float, float] = (-10 * KM, 10 * KM) depth_bounds: tuple[float, float] = (0 * KM, 20 * KM) - absorbing_boundary: PositiveFloat = 1 * KM + absorbing_boundary: confloat(ge=0.0) = 1 * KM _root_nodes: list[Node] = PrivateAttr([]) _cached_coordinates: dict[CoordSystem, np.ndarray] = PrivateAttr({}) @@ -190,19 +194,22 @@ class Octree(BaseModel): @field_validator("east_bounds", "north_bounds", "depth_bounds") def check_bounds( - cls, bounds: tuple[float, float] # noqa: N805 + cls, # noqa: N805 + bounds: tuple[float, float], ) -> tuple[float, float]: if bounds[0] >= bounds[1]: raise ValueError(f"invalid bounds {bounds}, expected (min, max)") return bounds @model_validator(mode="after") - def _check_limits(cls, m: Octree) -> Octree: # noqa: N805 - if m.size_limit >= m.size_initial: + def check_limits(self) -> Octree: + """Check that the size limits are valid.""" + if self.size_limit > self.size_initial: raise ValueError( - "invalid octree size limits, expected size_limit < size_initial" + f"invalid octree size limits ({self.size_initial}, {self.size_limit})," + " expected size_limit <= size_initial" ) - return m + return self def model_post_init(self, __context: Any) -> None: """Initialize octree. This method is called by the pydantic model""" @@ -352,6 +359,27 @@ def smallest_node_size(self) -> float: size /= 2 return size + def n_levels(self) -> int: + """Returns the number of levels in the octree. + + Returns: + int: Number of levels. + """ + levels = 0 + size = self.size_initial + while size >= self.size_limit * 2: + levels += 1 + size /= 2 + return levels + + def total_number_nodes(self) -> int: + """Returns the total number of nodes of all levels. + + Returns: + int: Total number of nodes. + """ + return len(self._root_nodes) * (8 ** self.n_levels()) + def maximum_number_nodes(self) -> int: """Returns the maximum number of nodes. @@ -362,7 +390,7 @@ def maximum_number_nodes(self) -> int: (self.east_bounds[1] - self.east_bounds[0]) * (self.north_bounds[1] - self.north_bounds[0]) * (self.depth_bounds[1] - self.depth_bounds[0]) - / self.smallest_node_size() ** 3 + / (self.smallest_node_size() ** 3) ) def copy(self, deep=False) -> Self: diff --git a/lassie/plot/detections.py b/lassie/plot/detections.py new file mode 100644 index 00000000..5c7bad69 --- /dev/null +++ b/lassie/plot/detections.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +from datetime import datetime, timezone +from pathlib import Path +from typing import TYPE_CHECKING, cast + +import matplotlib.pyplot as plt +import numpy as np + +from .utils import with_default_axes + +if TYPE_CHECKING: + from lassie.models.detection import EventDetections + +HOUR = 3600 +DAY = 24 * HOUR + + +@with_default_axes +def plot_detections( + detections: EventDetections, + axes: plt.Axes | None = None, + filename: Path | None = None, +) -> None: + axes = cast(plt.Axes, axes) # injected by wrapper + + semblances = [detection.semblance for detection in detections] + times = [ + detection.time.replace(tzinfo=None) # Stupid fix for matplotlib bug + for detection in detections + ] + + axes.scatter(times, semblances, cmap="viridis_r", c=semblances, s=3, alpha=0.5) + axes.set_ylabel("Detection Semblance") + axes.grid(axis="x", alpha=0.3) + # axes.figure.autofmt_xdate() + + cum_axes = axes.twinx() + + cummulative_detections = np.cumsum(np.ones(detections.n_detections)) + cum_axes.plot( + times, + cummulative_detections, + color="black", + alpha=0.8, + label="Cumulative Detections", + ) + cum_axes.set_ylabel("# Detections") + + to_timestamps = np.vectorize(lambda d: d.timestamp()) + from_timestamps = np.vectorize(lambda t: datetime.fromtimestamp(t, tz=timezone.utc)) + detection_time_span = times[-1] - times[0] + daily_rate, edges = np.histogram( + to_timestamps(times), + bins=detection_time_span.days, + ) + + cum_axes.stairs( + daily_rate, + from_timestamps(edges), + color="gray", + fill=True, + alpha=0.5, + label="Daily Detections", + ) + cum_axes.legend(loc="upper left", fontsize="small") diff --git a/lassie/plot/octree.py b/lassie/plot/octree.py index ec3ad4d2..83847d76 100644 --- a/lassie/plot/octree.py +++ b/lassie/plot/octree.py @@ -5,17 +5,64 @@ import matplotlib.pyplot as plt import numpy as np +from matplotlib import cm from matplotlib.animation import FFMpegFileWriter, FuncAnimation from matplotlib.cm import get_cmap +from matplotlib.collections import PatchCollection +from matplotlib.patches import Rectangle + +from lassie.models.detection import EventDetection if TYPE_CHECKING: from pathlib import Path + + from matplotlib.colors import Colormap + from lassie.octree import Octree logger = logging.getLogger(__name__) -def plot_octree(octree: Octree, cmap: str = "Oranges") -> None: +def octree_to_rectangles( + octree: Octree, + cmap: str | Colormap = "Oranges", + normalize: bool = False, +) -> PatchCollection: + if isinstance(cmap, str): + cmap = cm.get_cmap(cmap) + + coords = octree.reduce_surface() + coords = coords[np.argsort(coords[:, 2])[::-1]] + size_order = np.argsort(coords[:, 2])[::-1] + coords = coords[size_order] + + sizes = coords[:, 2] + semblances = coords[:, 3] + sizes = sorted(set(sizes), reverse=True) + # zorders = {size: 1.0 + float(order) for order, size in enumerate(sizes)} + + rectangles = [] + for node in coords: + east, north, size, semblance = node + half_size = size / 2 + rect = Rectangle( + xy=(east - half_size, north - half_size), + width=size, + height=size, + ) + rectangles.append(rect) + if normalize: + semblances /= semblances.max() + colors = cmap(semblances) + return PatchCollection( + patches=rectangles, + facecolors=colors, + edgecolors=(0, 0, 0, 0.3), + linewidths=0.5, + ) + + +def plot_octree_3d(octree: Octree, cmap: str = "Oranges") -> None: ax = plt.figure().add_subplot(projection="3d") colormap = get_cmap(cmap) @@ -29,7 +76,7 @@ def plot_octree(octree: Octree, cmap: str = "Oranges") -> None: plt.show() -def plot_octree_surface( +def plot_octree_scatter( octree: Octree, accumulator: Callable = np.max, cmap: str = "Oranges", @@ -47,7 +94,49 @@ def plot_octree_surface( plt.show() -def plot_octree_movie( +def plot_octree_surface_tiles( + octree: Octree, + axes: plt.Axes | None = None, + normalize: bool = False, + filename: Path | None = None, + detections: list[EventDetection] | None = None, +) -> None: + if axes is None: + fig = plt.figure() + ax = fig.gca() + else: + fig = axes.figure + ax = axes + + for spine in ax.spines.values(): + spine.set_visible(False) + ax.set_xticklabels([]) + ax.set_xticks([]) + ax.set_yticklabels([]) + ax.set_yticks([]) + ax.set_xlabel("East [m]") + ax.set_ylabel("North [m]") + ax.add_collection(octree_to_rectangles(octree, normalize=normalize)) + + ax.set_title(f"Octree surface tiles (nodes: {octree.n_nodes})") + + ax.autoscale() + for detection in detections or []: + ax.scatter( + detection.east_shift, + detection.north_shift, + marker="*", + s=50, + color="yellow", + ) + if filename is not None: + fig.savefig(str(filename), bbox_inches="tight", dpi=300) + plt.close() + elif axes is None: + plt.show() + + +def plot_octree_semblance_movie( octree: Octree, semblance: np.ndarray, file: Path, diff --git a/lassie/plot.py b/lassie/plot/plot.py similarity index 80% rename from lassie/plot.py rename to lassie/plot/plot.py index 49f83c6e..69a3ef88 100644 --- a/lassie/plot.py +++ b/lassie/plot/plot.py @@ -10,7 +10,7 @@ if TYPE_CHECKING: from matplotlib.colors import Colormap - from lassie.models import EventDetection + from lassie.octree import Octree @@ -45,15 +45,14 @@ def octree_to_rectangles( return PatchCollection(patches=rectangles, facecolors=colors, edgecolors="k") -def plot_detection(detection: EventDetection, axes: plt.Axes | None = None) -> None: +def plot_octree(octree: Octree, axes: plt.Axes | None = None) -> None: if axes is None: fig = plt.figure() ax = fig.gca() else: ax = axes - ax.add_collection(octree_to_rectangles(detection.octree)) - eq_location = sorted(detection.octree, key=lambda node: node.semblance)[-1] - ax.scatter(eq_location.east, eq_location.north, marker="*", s=50, c="red") + ax.add_collection(octree_to_rectangles(octree)) + ax.autoscale() if axes is None: plt.show() diff --git a/lassie/plot/utils.py b/lassie/plot/utils.py new file mode 100644 index 00000000..3a2c3e09 --- /dev/null +++ b/lassie/plot/utils.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +import inspect +from functools import wraps +from typing import Callable, ParamSpec + +from matplotlib import pyplot as plt + +P = ParamSpec("P") + + +def with_default_axes(func: Callable[P, None]) -> Callable[P, None]: + signature = inspect.signature(func) + if "axes" not in signature.parameters: + raise AttributeError("Function must have an 'axes' parameter") + if "filename" not in signature.parameters: + raise AttributeError("Function must have an 'filename' parameter") + + @wraps(func) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> None: + axes_provided = kwargs.get("axes") is not None + if kwargs.get("axes") is None: + fig = plt.figure() + ax = fig.gca() + kwargs["axes"] = ax + else: + ax: plt.Axes = kwargs["axes"] + fig = ax.figure + + ret = func(*args, **kwargs) + + if kwargs.get("filename") is not None: + fig.savefig(str(kwargs.get("filename")), bbox_inches="tight", dpi=300) + plt.close() + + if not axes_provided: + plt.show() + return ret + + return wrapper diff --git a/lassie/search/base.py b/lassie/search.py similarity index 60% rename from lassie/search/base.py rename to lassie/search.py index de82f154..abc8c372 100644 --- a/lassie/search/base.py +++ b/lassie/search.py @@ -1,34 +1,40 @@ from __future__ import annotations import asyncio +import contextlib import logging +from collections import deque from copy import deepcopy from datetime import datetime, timedelta, timezone from itertools import chain from pathlib import Path -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Deque, Literal import numpy as np -from pydantic import ( - BaseModel, - Field, - PositiveFloat, - PositiveInt, - PrivateAttr, - confloat, - conint, -) +from pydantic import BaseModel, Field, PositiveFloat, PositiveInt, PrivateAttr from pyrocko import parstack +from lassie.features import ( + FeatureExtractors, + GroundMotionExtractor, + LocalMagnitudeExtractor, +) from lassie.images import ImageFunctions, WaveformImages from lassie.models import Stations from lassie.models.detection import EventDetection, EventDetections, PhaseDetection from lassie.models.semblance import Semblance, SemblanceStats from lassie.octree import NodeSplitError, Octree +from lassie.plot.octree import plot_octree_surface_tiles from lassie.signals import Signal from lassie.station_corrections import StationCorrections -from lassie.tracers import RayTracers -from lassie.utils import PhaseDescription, Symbols, alog_call, time_to_path +from lassie.tracers import ( + CakeTracer, + ConstantVelocityTracer, + FastMarchingTracer, + RayTracers, +) +from lassie.utils import PhaseDescription, alog_call, datetime_now, time_to_path +from lassie.waveforms import PyrockoSquirrel, WaveformProviderType if TYPE_CHECKING: from pyrocko.trace import Trace @@ -40,6 +46,8 @@ logger = logging.getLogger(__name__) +SamplingRate = Literal[10, 20, 25, 50, 100] + class SearchProgress(BaseModel): time_progress: datetime | None = None @@ -47,38 +55,51 @@ class SearchProgress(BaseModel): class Search(BaseModel): - sampling_rate: confloat(ge=10.0, le=20.0) = 10.0 - detection_threshold: PositiveFloat = 0.05 - detection_blinding: timedelta = timedelta(seconds=2.0) - project_dir: Path = Path(".") - - octree: Octree = Octree() stations: Stations = Stations() - ray_tracers: RayTracers - image_functions: ImageFunctions + data_provider: WaveformProviderType = PyrockoSquirrel() - station_corrections: StationCorrections | None = None + octree: Octree = Octree() + image_functions: ImageFunctions = ImageFunctions() + ray_tracers: RayTracers = RayTracers( + root=[ConstantVelocityTracer(), CakeTracer(), FastMarchingTracer()] + ) + station_corrections: StationCorrections = StationCorrections() + event_features: list[FeatureExtractors] = [ + GroundMotionExtractor(), + LocalMagnitudeExtractor(), + ] + + sampling_rate: SamplingRate = 50 + detection_threshold: PositiveFloat = 0.05 + node_split_threshold: float = Field(default=0.9, gt=0.0, lt=1.0) + detection_blinding: timedelta = timedelta(seconds=2.0) + window_length: timedelta = timedelta(minutes=5) - n_threads_parstack: conint(ge=0) = 0 + n_threads_parstack: int = Field(default=0, ge=0) n_threads_argmax: PositiveInt = 4 - split_fraction: confloat(gt=0.0, lt=1.0) = 0.9 + plot_octree_surface: bool = False + created: datetime = Field(default_factory=lambda: datetime.now(tz=timezone.utc)) - # Overwritten at initialisation - shift_range: timedelta = timedelta(seconds=0.0) - window_padding: timedelta = timedelta(seconds=0.0) - distance_range: tuple[float, float] = (0.0, 0.0) - travel_time_ranges: dict[PhaseDescription, tuple[timedelta, timedelta]] = {} - progress: SearchProgress = SearchProgress() + _progress: SearchProgress = PrivateAttr(SearchProgress()) - created: datetime = Field(default_factory=lambda: datetime.now(tz=timezone.utc)) + _shift_range: timedelta = PrivateAttr(timedelta(seconds=0.0)) + _window_padding: timedelta = PrivateAttr(timedelta(seconds=0.0)) + _distance_range: tuple[float, float] = PrivateAttr((0.0, 0.0)) + _travel_time_ranges: dict[ + PhaseDescription, tuple[timedelta, timedelta] + ] = PrivateAttr({}) _detections: EventDetections = PrivateAttr() _config_stem: str = PrivateAttr("") _rundir: Path = PrivateAttr() + # Signals _new_detection: Signal[EventDetection] = PrivateAttr(Signal()) + _batch_processing_durations: Deque[timedelta] = PrivateAttr( + default_factory=lambda: deque(maxlen=25) + ) def init_rundir(self, force=False) -> None: rundir = ( @@ -100,65 +121,185 @@ def init_rundir(self, force=False) -> None: if not rundir.exists(): rundir.mkdir() + file_logger = logging.FileHandler(self._rundir / "lassie.log") + logging.root.addHandler(file_logger) self.write_config() - self.stations.dump_pyrocko_stations(rundir / "pyrocko-stations.yaml") logger.info("created new rundir %s", rundir) - self._detections = EventDetections(rundir=rundir) def write_config(self, path: Path | None = None) -> None: - path = path or self._rundir / "search.json" + rundir = self._rundir + path = path or rundir / "search.json" + logger.debug("writing search config to %s", path) path.write_text(self.model_dump_json(indent=2, exclude_unset=True)) + logger.debug("dumping stations...") + self.stations.dump_pyrocko_stations(rundir / "pyrocko-stations.yaml") + self.stations.dump_csv(rundir / "stations.csv") + @property def semblance_stats(self) -> SemblanceStats: - return self.progress.semblance_stats + return self._progress.semblance_stats def set_progress(self, time: datetime) -> None: - self.progress.time_progress = time + self._progress.time_progress = time progress_file = self._rundir / "progress.json" - progress_file.write_text(self.progress.model_dump_json()) + progress_file.write_text(self._progress.model_dump_json()) - def init_search(self) -> None: + def init_boundaries(self) -> None: """Initialise search.""" - file_logger = logging.FileHandler(self._rundir / "lassie.log") - logging.root.addHandler(file_logger) - # Grid/receiver distances distances = self.octree.distances_stations(self.stations) - self.distance_range = (distances.min(), distances.max()) + self._distance_range = (distances.min(), distances.max()) # Timing ranges for phase, tracer in self.ray_tracers.iter_phase_tracer(): - traveltimes = tracer.get_traveltimes(phase, self.octree, self.stations) - self.travel_time_ranges[phase] = ( + traveltimes = tracer.get_travel_times(phase, self.octree, self.stations) + self._travel_time_ranges[phase] = ( timedelta(seconds=np.nanmin(traveltimes)), timedelta(seconds=np.nanmax(traveltimes)), ) logger.info( - "shift ranges: %s / %s - %s", phase, *self.travel_time_ranges[phase] + "time shift ranges: %s / %s - %s", + phase, + *self._travel_time_ranges[phase], ) # TODO: minimum shift is calculated on the coarse octree grid, which is # not necessarily the same as the fine grid used for semblance calculation - shift_min = min(chain.from_iterable(self.travel_time_ranges.values())) - shift_max = max(chain.from_iterable(self.travel_time_ranges.values())) - self.shift_range = shift_max - shift_min + shift_min = min(chain.from_iterable(self._travel_time_ranges.values())) + shift_max = max(chain.from_iterable(self._travel_time_ranges.values())) + self._shift_range = shift_max - shift_min - self.window_padding = ( - self.shift_range + self._window_padding = ( + self._shift_range + self.detection_blinding + self.image_functions.get_blinding() ) + if self.window_length < 2 * self._window_padding + self._shift_range: + raise ValueError( + f"window length {self.window_length} is too short for the " + f"theoretical shift range {self._shift_range} and " + f"cummulative window padding of {self._window_padding}." + " Increase the window_length time." + ) + logger.info("using trace window padding: %s", self._window_padding) + logger.info("time shift range %s", self._shift_range) logger.info( - "source-station distances range: %.1f - %.1f m", *self.distance_range + "source-station distance range: %.1f - %.1f m", + *self._distance_range, ) - logger.info("shift range %s", self.shift_range) - logger.info("using trace window padding: %s", self.window_padding) - self.write_config() + + def _plot_octree_surface( + self, + octree: Octree, + time: datetime, + detections: list[EventDetection] | None = None, + ) -> None: + logger.info("plotting octree surface...") + filename = ( + self._rundir + / "figures" + / "octree_surface" + / f"{time_to_path(time)}-nodes-{octree.n_nodes}.png" + ) + filename.parent.mkdir(parents=True, exist_ok=True) + plot_octree_surface_tiles(octree, filename=filename, detections=detections) + + async def prepare(self) -> None: + logger.info("preparing search...") + self.data_provider.prepare(self.stations) + await self.ray_tracers.prepare( + self.octree, + self.stations, + phases=self.image_functions.get_phases(), + ) + self.init_boundaries() + + async def start(self, force_rundir: bool = False) -> None: + await self.prepare() + self.init_rundir(force_rundir) + logger.info("starting search...") + processing_start = datetime_now() + + if self._progress.time_progress: + logger.info("continuing search from %s", self._progress.time_progress) + + async for batch in self.data_provider.iter_batches( + window_increment=self.window_length, + window_padding=self._window_padding, + start_time=self._progress.time_progress, + ): + batch.clean_traces() + + if batch.is_empty(): + logger.warning("batch is empty") + continue + + if batch.duration < 2 * self._window_padding: + logger.warning("batch duration is too short") + continue + + search_block = SearchTraces( + parent=self, + traces=batch.traces, + start_time=batch.start_time, + end_time=batch.end_time, + ) + detections, semblance_trace = await search_block.search() + + self._detections.add_semblance(semblance_trace) + for detection in detections: + if detection.in_bounds: + await self.add_features(detection) + + self._detections.add(detection) + await self._new_detection.emit(detection) + + if batch.i_batch % 50 == 0: + self._detections.dump_detections(jitter_location=self.octree.size_limit) + + processing_time = datetime_now() - processing_start + self._batch_processing_durations.append(processing_time) + if batch.n_batches: + percent_processed = ((batch.i_batch + 1) / batch.n_batches) * 100 + else: + percent_processed = 0.0 + logger.info( + "%s%% processed - batch %d/%s - %s in %s", + f"{percent_processed:.1f}" if percent_processed else "??", + batch.i_batch + 1, + str(batch.n_batches or "?"), + batch.start_time, + processing_time, + ) + if batch.n_batches: + remaining_time = ( + sum(self._batch_processing_durations, timedelta()) + / len(self._batch_processing_durations) + * (batch.n_batches - batch.i_batch - 1) + ) + logger.info( + "%s remaining - estimated finish at %s", + remaining_time, + datetime.now() + remaining_time, # noqa: DTZ005 + ) + + processing_start = datetime_now() + self.set_progress(batch.end_time) + + async def add_features(self, event: EventDetection) -> None: + try: + squirrel = self.data_provider.get_squirrel() + except NotImplementedError: + return + + for extractor in self.event_features: + logger.info("adding features from %s", extractor.feature) + await extractor.add_features(squirrel, event) @classmethod def load_rundir(cls, rundir: Path) -> Self: @@ -169,7 +310,7 @@ def load_rundir(cls, rundir: Path) -> Self: progress_file = rundir / "progress.json" if progress_file.exists(): - search.progress = SearchProgress.model_validate_json( + search._progress = SearchProgress.model_validate_json( progress_file.read_text() ) return search @@ -190,6 +331,11 @@ def from_config( model._config_stem = filename.stem return model + def __del__(self) -> None: + if hasattr(self, "_detections"): + with contextlib.suppress(Exception): + self._detections.dump_detections(jitter_location=self.octree.size_limit) + class SearchTraces: _images: dict[float | None, WaveformImages] @@ -202,27 +348,16 @@ def __init__( end_time: datetime, ) -> None: self.parent = parent - self.traces = self.clean_traces(traces) - + self.traces = traces self.start_time = start_time self.end_time = end_time self._images = {} - @staticmethod - def clean_traces(traces: list[Trace]) -> list[Trace]: - """Remove empty or bad traces.""" - for tr in traces.copy(): - if not tr.ydata.size or not np.all(np.isfinite(tr.ydata)): - logger.warning("skipping empty or bad trace: %s", ".".join(tr.nslc_id)) - traces.remove(tr) - - return traces - def _n_samples_semblance(self) -> int: """Number of samples to use for semblance calculation, includes padding.""" parent = self.parent - window_padding = parent.window_padding + window_padding = parent._window_padding time_span = (self.end_time + window_padding) - ( self.start_time - window_padding ) @@ -240,7 +375,7 @@ async def calculate_semblance( logger.debug("stacking image %s", image.image_function.name) parent = self.parent - traveltimes = ray_tracer.get_traveltimes(image.phase, octree, image.stations) + traveltimes = ray_tracer.get_travel_times(image.phase, octree, image.stations) if parent.station_corrections: station_delays = parent.station_corrections.get_delays( @@ -263,7 +398,7 @@ async def calculate_semblance( semblance_data, offsets = await asyncio.to_thread( parstack.parstack, arrays=image.get_trace_data(), - offsets=image.get_offsets(self.start_time - parent.window_padding), + offsets=image.get_offsets(self.start_time - parent._window_padding), shifts=shifts, weights=weights, lengthout=n_samples_semblance, @@ -285,7 +420,6 @@ async def get_images(self, sampling_rate: float | None = None) -> WaveformImages Returns: WaveformImages: The waveform images for the specified sampling rate. """ - if None not in self._images: images = await self.parent.image_functions.process_traces(self.traces) images.set_stations(self.parent.stations) @@ -293,7 +427,7 @@ async def get_images(self, sampling_rate: float | None = None) -> WaveformImages if sampling_rate not in self._images: if not isinstance(sampling_rate, float): - raise TypeError("sampling rate has to be a float") + raise TypeError("sampling rate has to be a float or int") images_resampled = deepcopy(self._images[None]) logger.debug("downsampling images to %g Hz", sampling_rate) @@ -320,10 +454,10 @@ async def search( sampling_rate = parent.sampling_rate octree = octree or parent.octree.copy(deep=True) - images = await self.get_images(sampling_rate=sampling_rate) + images = await self.get_images(sampling_rate=float(sampling_rate)) padding_samples = int( - round(parent.window_padding.total_seconds() * sampling_rate) + round(parent._window_padding.total_seconds() * sampling_rate) ) semblance = Semblance( n_nodes=octree.n_nodes, @@ -344,7 +478,7 @@ async def search( semblance.normalize(images.cumulative_weight()) parent.semblance_stats.update(semblance.get_stats()) - logger.info("semblance stats: %s", parent.semblance_stats) + logger.debug("semblance stats: %s", parent.semblance_stats) detection_idx, detection_semblance = semblance.find_peaks( height=parent.detection_threshold, @@ -352,6 +486,10 @@ async def search( distance=round(parent.detection_blinding.total_seconds() * sampling_rate), ) + if parent.plot_octree_surface: + octree.map_semblance(semblance.maximum_node_semblance()) + parent._plot_octree_surface(octree, time=self.start_time) + if detection_idx.size == 0: return [], semblance.get_trace() @@ -369,7 +507,9 @@ async def search( if not source_node.can_split(): continue - split_nodes = octree.get_nodes(semblance_detection * parent.split_fraction) + split_nodes = octree.get_nodes( + semblance_detection * parent.node_split_threshold + ) refine_nodes.update(split_nodes) # refine_nodes is empty when all sources fall into smallest octree nodes @@ -393,11 +533,6 @@ async def search( node_idx = (await semblance.maxima_node_idx())[time_idx] source_node = octree[node_idx] - if not octree.is_node_in_bounds(source_node): - logger.info( - "source node is inside octree's absorbing boundary (%.1f m)", - source_node.distance_border, - ) source_location = source_node.as_location() detection = EventDetection( @@ -435,17 +570,13 @@ async def search( ) detections.append(detection) - logger.info( - "%s new detection %s: %.5fE, %.5fN, %.1f m, semblance %.3f", - Symbols.Target, - detection.time, - *detection.effective_lat_lon, - detection.effective_depth, - detection.semblance, - ) - # detection.plot() # plot_octree_movie(octree, semblance, file=Path("/tmp/test.mp4")) + if parent.plot_octree_surface: + octree.map_semblance(semblance.maximum_node_semblance()) + parent._plot_octree_surface( + octree, time=self.start_time, detections=detections + ) return detections, semblance.get_trace() diff --git a/lassie/search/__init__.py b/lassie/search/__init__.py deleted file mode 100644 index 473af0a1..00000000 --- a/lassie/search/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .squirrel import SquirrelSearch # noqa diff --git a/lassie/search/squirrel.py b/lassie/search/squirrel.py deleted file mode 100644 index 34580404..00000000 --- a/lassie/search/squirrel.py +++ /dev/null @@ -1,209 +0,0 @@ -from __future__ import annotations - -import asyncio -import glob -import logging -from collections import deque -from datetime import datetime, timedelta -from pathlib import Path -from typing import TYPE_CHECKING, Any, Deque, Iterator - -from pydantic import PositiveInt, PrivateAttr, conint, field_validator -from pyrocko.squirrel import Squirrel - -from lassie.features import FeatureExtractors -from lassie.features.ground_motion import GroundMotionExtractor -from lassie.features.local_magnitude import LocalMagnitudeExtractor -from lassie.search.base import Search, SearchTraces -from lassie.utils import datetime_now, to_datetime - -if TYPE_CHECKING: - from pyrocko.squirrel.base import Batch - from pyrocko.trace import Trace - - from lassie.models.detection import EventDetection - -logger = logging.getLogger(__name__) - - -class SquirrelPrefetcher: - def __init__(self, iterator: Iterator[Batch], queue_size: int = 4) -> None: - self.iterator = iterator - self.queue: asyncio.Queue[Batch | None] = asyncio.Queue(maxsize=queue_size) - - self._task = asyncio.create_task(self.prefetch_worker()) - - async def prefetch_worker(self) -> None: - logger.info("start prefetching squirrel data") - while True: - start = datetime_now() - batch = await asyncio.to_thread(lambda: next(self.iterator, None)) - logger.debug("prefetched waveforms in %s", datetime_now() - start) - if batch is None: - logger.debug("squirrel prefetcher finished") - await self.queue.put(None) - break - await self.queue.put(batch) - - -class SquirrelSearch(Search): - time_span: tuple[datetime | None, datetime | None] = (None, None) - squirrel_environment: Path = Path(".") - waveform_data: list[Path] - waveform_prefetch_batches: PositiveInt = 4 - - features: list[FeatureExtractors] = [ - GroundMotionExtractor(), - LocalMagnitudeExtractor(), - ] - window_length_factor: conint(ge=5, le=100) = 10 - - _squirrel: Squirrel | None = PrivateAttr(None) - - def model_post_init(self, __context: Any) -> None: - if not all(self.time_span): - squirrel = self.get_squirrel() - sq_tmin, sq_tmax = squirrel.get_time_span(["waveform", "waveform_promise"]) - self.time_span = ( - self.time_span[0] or to_datetime(sq_tmin), - self.time_span[1] or to_datetime(sq_tmax), - ) - - logger.info( - "searching time span from %s to %s (%s)", - self.start_time, - self.end_time, - self.end_time - self.start_time, - ) - - @field_validator("time_span") - @classmethod - def _validate_time_span(cls, range): # noqa: N805 - if range[0] >= range[1]: - raise ValueError(f"time range is invalid {range[0]} - {range[1]}") - return range - - @property - def start_time(self) -> datetime: - return self.time_span[0] - - @property - def end_time(self) -> datetime: - return self.time_span[1] - - def get_squirrel(self) -> Squirrel: - if not self._squirrel: - squirrel = Squirrel(str(self.squirrel_environment.expanduser())) - paths = [] - for path in self.waveform_data: - if "**" in str(path): - paths.extend(glob.glob(str(path.expanduser()), recursive=True)) - else: - paths.append(str(path.expanduser())) - paths.extend((str(p.expanduser()) for p in self.stations.station_xmls)) - - squirrel.add(paths, check=False) - self._squirrel = squirrel - return self._squirrel - - async def scan_squirrel(self) -> None: - squirrel = self.get_squirrel() - - self.stations.weed_from_squirrel_waveforms(squirrel) - self.ray_tracers.prepare(self.octree, self.stations) - self.init_search() - - window_increment = self.shift_range * self.window_length_factor - logger.info("using trace window increment: %s", window_increment) - - start_time = self.start_time - if self.progress.time_progress: - start_time = self.progress.time_progress - logger.info("continuing search from %s", start_time) - - iterator = squirrel.chopper_waveforms( - tmin=start_time.timestamp(), - tmax=self.end_time.timestamp(), - tinc=window_increment.total_seconds(), - tpad=self.window_padding.total_seconds(), - want_incomplete=False, - codes=[(*nsl, "*") for nsl in self.stations.get_all_nsl()], - ) - prefetcher = SquirrelPrefetcher(iterator, self.waveform_prefetch_batches) - - batch_start_time = None - batch_durations: Deque[timedelta] = deque(maxlen=25) - while True: - batch = await prefetcher.queue.get() - if batch is None: - logger.info("squirrel search finished") - break - - window_start = to_datetime(batch.tmin) - window_end = to_datetime(batch.tmax) - window_length = window_end - window_start - logger.info( - "searching time window %d/%d %s - %s", - batch.i + 1, - batch.n, - window_start, - window_end, - ) - - traces: list[Trace] = batch.traces - if not traces: - logger.warning("window is empty") - continue - - if window_start > window_end or window_length < 2 * self.window_padding: - logger.warning("window length is too short") - continue - - block = SearchTraces( - parent=self, - traces=traces, - start_time=window_start, - end_time=window_end, - ) - detections, semblance_trace = await block.search() - self._detections.add_semblance(semblance_trace) - for detection in detections: - if detection.in_bounds: - await self.add_features(detection) - - self._detections.add(detection) - await self._new_detection.emit(detection) - - if detections: - self._detections.dump_detections(self.octree.size_limit) - - if batch_start_time is not None: - batch_duration = datetime_now() - batch_start_time - batch_durations.append(batch_duration) - logger.info( - "window %d/%d took %s", - batch.i + 1, - batch.n, - batch_duration, - ) - remaining_time = ( - sum(batch_durations, timedelta()) - / len(batch_durations) - * (batch.n - batch.i - 1) - ) - logger.info( - "remaining %s, estimated finish at %s", - remaining_time, - datetime.now() + remaining_time, # noqa: DTZ005 - ) - batch_start_time = datetime_now() - - self.set_progress(window_end) - prefetcher.queue.task_done() - - async def add_features(self, event: EventDetection) -> None: - squirrel = self.get_squirrel() - - for extractor in self.features: - logger.info("adding features from %s", extractor.feature) - await extractor.add_features(squirrel, event) diff --git a/lassie/station_corrections.py b/lassie/station_corrections.py index 0900ffe8..6e6de66d 100644 --- a/lassie/station_corrections.py +++ b/lassie/station_corrections.py @@ -311,7 +311,10 @@ def from_receiver(cls, receiver: Receiver) -> Self: class StationCorrections(BaseModel): - rundir: DirectoryPath + rundir: DirectoryPath | None = Field( + default=None, + description="The rundir to load the detections from", + ) measure: Literal["median", "average"] = "median" weighting: ArrivalWeighting = "mul-PhaseNet-semblance" @@ -323,6 +326,10 @@ class StationCorrections(BaseModel): _traveltime_delay_cache: dict[tuple[NSL, PhaseDescription], float] = PrivateAttr({}) def model_post_init(self, __context: Any) -> None: + if self.rundir is None: + logger.debug("no rundir specified, skipping station corrections") + return + logger.debug("loading station detections from %s", self.rundir) detections = EventDetections.load_rundir(self.rundir) with console.status("aggregating station detections"): diff --git a/lassie/tracers/__init__.py b/lassie/tracers/__init__.py index 4dcc0a44..cd6e3ad5 100644 --- a/lassie/tracers/__init__.py +++ b/lassie/tracers/__init__.py @@ -5,12 +5,12 @@ from pydantic import Field, RootModel -from lassie.tracers.base import ModelledArrival from lassie.tracers.cake import CakeArrival, CakeTracer from lassie.tracers.constant_velocity import ( ConstantVelocityArrival, ConstantVelocityTracer, ) +from lassie.tracers.fast_marching import FastMarchingArrival, FastMarchingTracer if TYPE_CHECKING: from lassie.models.station import Stations @@ -22,12 +22,12 @@ RayTracerType = Annotated[ - Union[CakeTracer, ConstantVelocityTracer], + Union[ConstantVelocityTracer, CakeTracer, FastMarchingTracer], Field(..., discriminator="tracer"), ] RayTracerArrival = Annotated[ - Union[CakeArrival, ConstantVelocityArrival, ModelledArrival], + Union[ConstantVelocityArrival, CakeArrival, FastMarchingArrival], Field(..., discriminator="tracer"), ] @@ -35,10 +35,16 @@ class RayTracers(RootModel): root: list[RayTracerType] = [] - def prepare(self, octree: Octree, stations: Stations) -> None: + async def prepare( + self, + octree: Octree, + stations: Stations, + phases: tuple[PhaseDescription, ...], + ) -> None: logger.info("preparing ray tracers") - for tracer in self: - tracer.prepare(octree, stations) + for phase in phases: + tracer = self.get_phase_tracer(phase) + await tracer.prepare(octree, stations) def get_available_phases(self) -> tuple[str]: phases = [] diff --git a/lassie/tracers/base.py b/lassie/tracers/base.py index a60432be..dc654b02 100644 --- a/lassie/tracers/base.py +++ b/lassie/tracers/base.py @@ -24,13 +24,13 @@ class ModelledArrival(PhaseArrival): class RayTracer(BaseModel): tracer: Literal["RayTracer"] = "RayTracer" - def prepare(self, octree: Octree, stations: Stations): + async def prepare(self, octree: Octree, stations: Stations): ... - def get_available_phases(self) -> tuple[str]: + def get_available_phases(self) -> tuple[str, ...]: ... - def get_traveltime_location( + def get_travel_time_location( self, phase: str, source: Location, @@ -38,17 +38,17 @@ def get_traveltime_location( ) -> float: raise NotImplementedError - def get_traveltimes_locations( + def get_travel_times_locations( self, phase: str, source: Location, receivers: Sequence[Location], ) -> np.ndarray: return np.array( - [self.get_traveltime_location(phase, source, recv) for recv in receivers] + [self.get_travel_time_location(phase, source, recv) for recv in receivers] ) - def get_traveltimes( + def get_travel_times( self, phase: str, octree: Octree, diff --git a/lassie/tracers/cake.py b/lassie/tracers/cake.py index 5beb6de1..d3e3bc9b 100644 --- a/lassie/tracers/cake.py +++ b/lassie/tracers/cake.py @@ -14,18 +14,18 @@ import numpy as np from lru import LRU from pydantic import ( - ConfigDict, BaseModel, + ByteSize, + ConfigDict, Field, - PositiveFloat, + FilePath, PrivateAttr, constr, - RootModel, - ByteSize, + model_validator, ) from pyrocko import orthodrome as od from pyrocko import spit -from pyrocko.cake import LayeredModel, PhaseDef, m2d, read_nd_model_str +from pyrocko.cake import LayeredModel, PhaseDef, load_model, m2d from pyrocko.gf import meta from rich.progress import Progress @@ -55,56 +55,107 @@ LRU_CACHE_SIZE = 2000 +# TODO: Move to a separate file +DEFAULT_VELOCITY_MODEL = """ +-1.00 5.50 3.59 2.7 + 0.00 5.50 3.59 2.7 + 1.00 5.50 3.59 2.7 + 1.00 6.00 3.92 2.7 + 4.00 6.00 3.92 2.7 + 4.00 6.20 4.05 2.7 + 8.00 6.20 4.05 2.7 + 8.00 6.30 4.12 2.7 +13.00 6.30 4.12 2.7 +13.00 6.40 4.18 2.7 +17.00 6.40 4.18 2.7 +17.00 6.50 4.25 2.7 +22.00 6.50 4.25 2.7 +22.00 6.60 4.31 2.7 +26.00 6.60 4.31 2.7 +26.00 6.80 4.44 2.7 +30.00 6.80 4.44 2.7 +30.00 8.10 5.29 2.7 +45.00 8.10 5.29 2.7 +""" + +DEFAULT_VELOCITY_MODEL_FILE = CACHE_DIR / "velocity_models" / "default.nd" +if not DEFAULT_VELOCITY_MODEL_FILE.exists(): + DEFAULT_VELOCITY_MODEL_FILE.parent.mkdir(exist_ok=True) + DEFAULT_VELOCITY_MODEL_FILE.write_text(DEFAULT_VELOCITY_MODEL) + + class CakeArrival(ModelledArrival): tracer: Literal["CakeArrival"] = "CakeArrival" phase: str -class EarthModel(RootModel): - root: list[tuple[float, PositiveFloat, PositiveFloat, PositiveFloat]] = [ - (0.00, 5.50, 3.59, 2.7), - (1.00, 5.50, 3.59, 2.7), - (1.00, 6.00, 3.92, 2.7), - (4.00, 6.00, 3.92, 2.7), - (4.00, 6.20, 4.05, 2.7), - (8.00, 6.20, 4.05, 2.7), - (8.00, 6.30, 4.12, 2.7), - (13.00, 6.30, 4.12, 2.7), - (13.00, 6.40, 4.18, 2.7), - (17.00, 6.40, 4.18, 2.7), - (17.00, 6.50, 4.25, 2.7), - (22.00, 6.50, 4.25, 2.7), - (22.00, 6.60, 4.31, 2.7), - (26.00, 6.60, 4.31, 2.7), - (26.00, 6.80, 4.44, 2.7), - (30.00, 6.80, 4.44, 2.7), - (30.00, 8.10, 5.29, 2.7), - (45.00, 8.10, 5.29, 2.7), - ] +class EarthModel(BaseModel): + filename: FilePath | None = Field( + DEFAULT_VELOCITY_MODEL_FILE, + description="Path to velocity model.", + ) + format: Literal["nd", "hyposat"] = Field( + "nd", + description="Format of the velocity model. nd or hyposat is supported.", + ) + crust2_profile: constr(to_upper=True) | tuple[float, float] = Field( + "", + description="Crust2 profile name or a tuple of (lat, lon) coordinates.", + ) + + raw_file_data: str | None = Field( + None, + description="Raw .nd file data.", + ) + _layered_model: LayeredModel = PrivateAttr() + model_config = ConfigDict(ignored_types=(cached_property,)) - def _as_array(self) -> np.ndarray: - return np.asarray(self.root) + @model_validator(mode="after") + def load_model(self) -> EarthModel: + if self.filename is not None: + logger.info("loading velocity model from %s", self.filename) + self.raw_file_data = self.filename.read_text() + + if self.raw_file_data is not None: + with NamedTemporaryFile("w") as tmpfile: + tmpfile.write(self.raw_file_data) + tmpfile.flush() + self._layered_model = load_model( + tmpfile.name, + format=self.format, + crust2_profile=self.crust2_profile or None, + ) + elif self.crust2_profile: + self._layered_model = load_model(crust2_profile=self.crust2_profile) + else: + raise AttributeError("No velocity model or crust2 profile defined.") + return self + + def trim(self, depth_max: float) -> None: + """Trim the model to a maximum depth. + + Args: + depth_max (float): Maximum depth in meters. + """ + logger.debug("trimming earth model to %.1f km depth", depth_max / KM) + self._layered_model = self.layered_model.extract(depth_max=depth_max) + + @property + def layered_model(self) -> LayeredModel: + return self._layered_model def get_profile_vp(self) -> np.ndarray: - # TODO: reduce to relevant layers - return self._as_array()[:, 1] * KM + return self.layered_model.profile("vp") def get_profile_vs(self) -> np.ndarray: - # TODO: reduce to relevant layers - return self._as_array()[:, 2] * KM - - def as_layered_model(self) -> LayeredModel: - line_tpl = "{} {} {} {}" - earthmodel = "\n".join(line_tpl.format(*layer) for layer in self.root) - return LayeredModel.from_scanlines(read_nd_model_str(earthmodel)) + return self.layered_model.profile("vs") @cached_property def hash(self) -> str: - layered_model = self.as_layered_model() model_serialised = BytesIO() for param in ("z", "vp", "vs", "rho"): - layered_model.profile(param).dump(model_serialised) + self.layered_model.profile(param).dump(model_serialised) return sha1(model_serialised.getvalue()).hexdigest() @@ -122,7 +173,7 @@ def id(self) -> str: return re.sub(r"[\,\s\;]", "", self.definition) -class TraveltimeTree(BaseModel): +class TravelTimeTree(BaseModel): earthmodel: EarthModel timing: Timing @@ -137,14 +188,14 @@ class TraveltimeTree(BaseModel): _sptree: spit.SPTree | None = PrivateAttr(None) _file: Path | None = PrivateAttr(None) - _cached_stations: Stations | None = PrivateAttr(None) - _cached_station_indeces: dict[str, int] | None = PrivateAttr({}) + _cached_stations: Stations = PrivateAttr() + _cached_station_indeces: dict[str, int] = PrivateAttr({}) _node_lut: dict[bytes, np.ndarray] = PrivateAttr( default_factory=lambda: LRU(LRU_CACHE_SIZE) ) def calculate_tree(self) -> spit.SPTree: - layered_model = self.earthmodel.as_layered_model() + layered_model = self.earthmodel.layered_model def evaluate(args) -> float | None: receiver_depth, source_depth, distances = args @@ -187,7 +238,7 @@ def check_bounds(self, requested) -> bool: return self[0] <= requested[0] and self[1] >= requested[1] return ( - self.earthmodel.root == earthmodel.root + str(self.earthmodel.layered_model) == str(earthmodel.layered_model) and self.timing == timing and check_bounds(self.distance_bounds, distance_bounds) and check_bounds(self.source_depth_bounds, source_depth_bounds) @@ -224,7 +275,14 @@ def save(self, path: Path) -> Path: logger.info("saving traveltimes to %s", file) with zipfile.ZipFile(file, "w") as archive: - archive.writestr("model.json", self.model_dump_json(indent=2)) + archive.writestr( + "model.json", + self.model_dump_json( + indent=2, + exclude={"earthmodel": {"filename"}}, + # include={"earthmodel": {"raw_file_data"}}, + ), + ) with NamedTemporaryFile() as tmpfile: self._get_sptree().dump(tmpfile.name) archive.write(tmpfile.name, "model.sptree") @@ -244,6 +302,7 @@ def load(cls, file: Path) -> Self: with zipfile.ZipFile(file, "r") as archive: path = zipfile.Path(archive) model_file = path / "model.json" + print(model_file.read_text()) model = cls.model_validate_json(model_file.read_text()) model._file = file return model @@ -281,7 +340,7 @@ def init_lut(self, octree: Octree, stations: Stations) -> None: self._cached_station_indeces = { sta.pretty_nsl: idx for idx, sta in enumerate(stations) } - station_traveltimes = self.interpolate_traveltimes(octree, stations) + station_traveltimes = self.interpolate_travel_times(octree, stations) for node, traveltimes in zip(octree, station_traveltimes, strict=True): self._node_lut[node.hash()] = traveltimes.astype(np.float32) @@ -300,7 +359,7 @@ def fill_lut(self, nodes: Sequence[Node]) -> None: sta_coords - node_coords[:, np.newaxis], axis=2 ) - traveltimes = self._interpolate_traveltimes( + traveltimes = self._interpolate_travel_times( receiver_distances, np.array([sta.effective_depth for sta in stations]), np.array([node.depth for node in nodes]), @@ -313,7 +372,7 @@ def lut_fill_level(self) -> float: """Return the fill level of the LUT as a float between 0.0 and 1.0""" return len(self._node_lut) / self._node_lut.get_size() - def get_traveltimes(self, octree: Octree, stations: Stations) -> np.ndarray: + def get_travel_times(self, octree: Octree, stations: Stations) -> np.ndarray: station_indices = np.fromiter( (self._cached_station_indeces[sta.pretty_nsl] for sta in stations), dtype=int, @@ -334,16 +393,16 @@ def get_traveltimes(self, octree: Octree, stations: Stations) -> np.ndarray: cache_hits, cache_misses = self._node_lut.get_stats() cache_hit_rate = cache_hits / (cache_hits + cache_misses) - logger.info( + logger.debug( "node LUT cache fill level %.1f%%, cache hit rate %.1f%%", self.lut_fill_level() * 100, cache_hit_rate * 100, ) - return self.get_traveltimes(octree, stations) + return self.get_travel_times(octree, stations) return np.asarray(stations_traveltimes).astype(float, copy=False) - def interpolate_traveltimes( + def interpolate_travel_times( self, octree: Octree, stations: Stations, @@ -352,11 +411,11 @@ def interpolate_traveltimes( receiver_depths = np.array([sta.effective_depth for sta in stations]) source_depths = np.array([node.depth for node in octree]) - return self._interpolate_traveltimes( + return self._interpolate_travel_times( receiver_distances, receiver_depths, source_depths ) - def _interpolate_traveltimes( + def _interpolate_travel_times( self, receiver_distances: np.ndarray, receiver_depths: np.ndarray, @@ -378,7 +437,8 @@ def _interpolate_traveltimes( n_nodes = len(coordinates) with Progress() as progress: status = progress.add_task( - f"interpolating station traveltimes for {n_nodes} nodes", + f"interpolating {self.timing.definition} travel times " + f"for {n_nodes} nodes", total=len(coordinates), ) traveltimes = [] @@ -388,7 +448,7 @@ def _interpolate_traveltimes( return np.asarray(traveltimes).astype(float) - def get_traveltime(self, source: Location, receiver: Location) -> float: + def get_travel_time(self, source: Location, receiver: Location) -> float: coordinates = [ receiver.effective_depth, source.effective_depth, @@ -403,14 +463,21 @@ def get_traveltime(self, source: Location, receiver: Location) -> float: class CakeTracer(RayTracer): tracer: Literal["CakeTracer"] = "CakeTracer" - timings: dict[PhaseDescription, Timing] = { + phases: dict[PhaseDescription, Timing] = { "cake:P": Timing(definition="P,p"), "cake:S": Timing(definition="S,s"), } earthmodel: EarthModel = EarthModel() - lut_cache_size: ByteSize = 4 * GiB + trim_earth_model_depth: bool = Field( + default=True, + description="Trim earth model to max depth of the octree.", + ) + lut_cache_size: ByteSize = Field( + default=2 * GiB, + description="Size of the LUT cache.", + ) - _traveltime_trees: dict[PhaseDescription, TraveltimeTree] = PrivateAttr({}) + _traveltime_trees: dict[PhaseDescription, TravelTimeTree] = PrivateAttr({}) @property def cache_dir(self) -> Path: @@ -424,22 +491,23 @@ def clear_cache(self) -> None: for file in self.cache_dir.glob("*.sptree"): file.unlink() - def get_available_phases(self) -> tuple[str]: - return tuple(self.timings.keys()) + def get_available_phases(self) -> tuple[str, ...]: + return tuple(self.phases.keys()) def get_vmin(self) -> float: earthmodel = self.earthmodel vel = np.concatenate((earthmodel.get_profile_vp(), earthmodel.get_profile_vs())) return float((vel[vel != 0.0]).min()) - def prepare(self, octree: Octree, stations: Stations) -> None: + async def prepare(self, octree: Octree, stations: Stations) -> None: global LRU_CACHE_SIZE bytes_per_node = stations.n_stations * np.float32().itemsize - n_trees = len(self.timings) + n_trees = len(self.phases) LRU_CACHE_SIZE = int(self.lut_cache_size / bytes_per_node / n_trees) - node_cache_fraction = LRU_CACHE_SIZE / octree.maximum_number_nodes() + # TODO: This should be total number nodes. Not only leaf nodes. + node_cache_fraction = LRU_CACHE_SIZE / octree.total_number_nodes() logging.info( "limiting traveltime LUT size to %d nodes (%s)," " caching %.1f%% of possible octree nodes", @@ -449,9 +517,9 @@ def prepare(self, octree: Octree, stations: Stations) -> None: ) cached_trees = [ - TraveltimeTree.load(file) for file in self.cache_dir.glob("*.sptree") + TravelTimeTree.load(file) for file in self.cache_dir.glob("*.sptree") ] - logger.debug("loaded %d cached traveltime trees", len(cached_trees)) + logger.debug("loaded %d cached travel time trees", len(cached_trees)) distances = octree.distances_stations(stations) source_depths = np.asarray(octree.depth_bounds) @@ -460,8 +528,11 @@ def prepare(self, octree: Octree, stations: Stations) -> None: receiver_depths_bounds = (receiver_depths.min(), receiver_depths.max()) source_depth_bounds = (source_depths.min(), source_depths.max()) distance_bounds = (distances.min(), distances.max()) - # TODO: Time tolerance is too hardcoded - time_tolerance = octree.size_limit / (self.get_vmin() * 3.0) + # FIXME: Time tolerance is too hardcoded. Is 5x a good value? + time_tolerance = octree.smallest_node_size() / (self.get_vmin() * 5.0) + + # if self.trim_earth_model_depth: + # self.earthmodel.trim(-source_depth_bounds[1]) traveltime_tree_args = { "earthmodel": self.earthmodel, @@ -472,43 +543,43 @@ def prepare(self, octree: Octree, stations: Stations) -> None: "time_tolerance": time_tolerance, } - for phase_descr, timing in self.timings.items(): + for phase_descr, timing in self.phases.items(): for tree in cached_trees: if tree.is_suited(timing=timing, **traveltime_tree_args): logger.info("using cached traveltime tree for %s", phase_descr) break else: logger.info("pre-calculating traveltime tree for %s", phase_descr) - tree = TraveltimeTree.new(timing=timing, **traveltime_tree_args) + tree = TravelTimeTree.new(timing=timing, **traveltime_tree_args) tree.save(self.cache_dir) tree.init_lut(octree, stations) self._traveltime_trees[phase_descr] = tree - def _get_sptree_model(self, phase: str) -> TraveltimeTree: + def _get_sptree_model(self, phase: str) -> TravelTimeTree: return self._traveltime_trees[phase] - def get_traveltime_location( + def get_travel_time_location( self, phase: str, source: Location, receiver: Location, ) -> float: - if phase not in self.timings: - raise ValueError(f"Timing {phase} is not defined.") + if phase not in self.phases: + raise ValueError(f"Phase {phase} is not defined.") tree = self._get_sptree_model(phase) - return tree.get_traveltime(source, receiver) + return tree.get_travel_time(source, receiver) @log_call - def get_traveltimes( + def get_travel_times( self, phase: str, octree: Octree, stations: Stations, ) -> np.ndarray: - if phase not in self.timings: - raise ValueError(f"Timing {phase} is not defined.") - return self._get_sptree_model(phase).get_traveltimes(octree, stations) + if phase not in self.phases: + raise ValueError(f"Phase {phase} is not defined.") + return self._get_sptree_model(phase).get_travel_times(octree, stations) def get_arrivals( self, @@ -517,7 +588,7 @@ def get_arrivals( source: Location, receivers: Sequence[Location], ) -> list[CakeArrival | None]: - traveltimes = self.get_traveltimes_locations( + traveltimes = self.get_travel_times_locations( phase, source=source, receivers=receivers, diff --git a/lassie/tracers/constant_velocity.py b/lassie/tracers/constant_velocity.py index 4e83ad9a..e84b8392 100644 --- a/lassie/tracers/constant_velocity.py +++ b/lassie/tracers/constant_velocity.py @@ -23,35 +23,36 @@ class ConstantVelocityArrival(ModelledArrival): class ConstantVelocityTracer(RayTracer): tracer: Literal["ConstantVelocityTracer"] = "ConstantVelocityTracer" - velocities: dict[PhaseDescription, PositiveFloat] = { - "constant:P": 6000.0, - "constant:S": 3900.0, - } + phase: PhaseDescription = "constant:P" + velocity: PositiveFloat = 5000.0 - def get_available_phases(self) -> tuple[str]: - return tuple(self.velocities.keys()) + def get_available_phases(self) -> tuple[str, ...]: + return (self.phase,) - def get_traveltime_location( + def _check_phase(self, phase: PhaseDescription) -> None: + if phase != self.phase: + raise ValueError(f"Phase {phase} is not defined.") + + def get_travel_time_location( self, phase: str, source: Location, receiver: Location, ) -> float: - if phase not in self.velocities: - raise ValueError(f"Phase {phase} is not defined.") - return source.distance_to(receiver) / self.velocities[phase] + self._check_phase(phase) + return source.distance_to(receiver) / self.velocity @log_call - def get_traveltimes( + def get_travel_times( self, phase: str, octree: Octree, stations: Stations, ) -> np.ndarray: - if phase not in self.velocities: - raise ValueError(f"Phase {phase} is not defined.") + self._check_phase(phase) + distances = octree.distances_stations(stations) - return distances / self.velocities[phase] + return distances / self.velocity def get_arrivals( self, @@ -60,7 +61,9 @@ def get_arrivals( source: Location, receivers: Sequence[Location], ) -> list[ConstantVelocityArrival]: - traveltimes = self.get_traveltimes_locations( + self._check_phase(phase) + + traveltimes = self.get_travel_times_locations( phase, source=source, receivers=receivers, diff --git a/lassie/tracers/fast_marching/__init__.py b/lassie/tracers/fast_marching/__init__.py new file mode 100644 index 00000000..e67c9f0b --- /dev/null +++ b/lassie/tracers/fast_marching/__init__.py @@ -0,0 +1 @@ +from .fast_marching import FastMarchingArrival, FastMarchingTracer # noqa diff --git a/lassie/tracers/fast_marching/fast_marching.py b/lassie/tracers/fast_marching/fast_marching.py new file mode 100644 index 00000000..4b272f0d --- /dev/null +++ b/lassie/tracers/fast_marching/fast_marching.py @@ -0,0 +1,524 @@ +from __future__ import annotations + +import asyncio +import functools +import logging +import os +import zipfile +from concurrent.futures import ThreadPoolExecutor +from datetime import datetime, timedelta +from pathlib import Path +from typing import TYPE_CHECKING, Any, Literal, Sequence + +import numpy as np +from lru import LRU +from pydantic import BaseModel, ByteSize, Field, PrivateAttr +from pyrocko.modelling import eikonal +from rich.progress import Progress +from scipy.interpolate import RegularGridInterpolator +from typing_extensions import Self + +from lassie.models.location import Location +from lassie.models.station import Station, Stations +from lassie.octree import Node +from lassie.tracers.base import ModelledArrival, RayTracer +from lassie.tracers.fast_marching.velocity_models import ( + Constant3DVelocityModel, + VelocityModel3D, + VelocityModels, +) +from lassie.utils import CACHE_DIR, PhaseDescription, datetime_now, human_readable_bytes + +if TYPE_CHECKING: + from lassie.octree import Octree + + +FMM_CACHE_DIR = CACHE_DIR / "fast-marching-cache" + +KM = 1e3 +GiB = int(1024**3) + +logger = logging.getLogger(__name__) + + +class FastMarchingArrival(ModelledArrival): + tracer: Literal["FastMarchingArrival"] = "FastMarchingArrival" + phase: PhaseDescription + + +class StationTravelTimeVolume(BaseModel): + center: Location + station: Station + + velocity_model_hash: str + + east_bounds: tuple[float, float] + north_bounds: tuple[float, float] + depth_bounds: tuple[float, float] + grid_spacing: float + + created: datetime = Field(default_factory=datetime_now) + + _travel_times: np.ndarray | None = PrivateAttr(None) + + _north_coords: np.ndarray = PrivateAttr() + _east_coords: np.ndarray = PrivateAttr() + _depth_coords: np.ndarray = PrivateAttr() + + # Cached values + _file: Path | None = PrivateAttr(None) + _interpolator: RegularGridInterpolator | None = PrivateAttr(None) + + @property + def travel_times(self) -> np.ndarray: + if self._travel_times is None: + self._travel_times = self._load_travel_times() + return self._travel_times + + def has_travel_times(self) -> bool: + return self._travel_times is not None or self._file is not None + + def free_cache(self): + self._interpolator = None + if self._file is not None: + logger.warning("cannot free travel time cache, file is not saved") + self._travel_times = None + + def model_post_init(self, __context: Any) -> None: + grid_spacing = self.grid_spacing + + self._east_coords = np.arange( + self.east_bounds[0], + self.east_bounds[1], + grid_spacing, + ) + self._north_coords = np.arange( + self.north_bounds[0], + self.north_bounds[1], + grid_spacing, + ) + self._depth_coords = np.arange( + self.depth_bounds[0], + self.depth_bounds[1], + grid_spacing, + ) + + @classmethod + async def calculate_from_eikonal( + cls, + model: VelocityModel3D, + station: Station, + save: Path | None = None, + executor: ThreadPoolExecutor | None = None, + ) -> Self: + arrival_times = model.get_source_arrival_grid(station) + + if not model.is_inside(station): + offset = station.offset_from(model.center) + raise ValueError(f"station is outside the velocity model {offset}") + + def eikonal_wrapper( + velocity_model: VelocityModel3D, + arrival_times: np.ndarray, + delta: float, + ) -> StationTravelTimeVolume: + logger.debug( + "calculating travel time volume for %s, grid size %s, spacing %s m...", + station.pretty_nsl, + arrival_times.shape, + velocity_model.grid_spacing, + ) + eikonal.eikonal_solver_fmm_cartesian( + velocity_model._velocity_model, + arrival_times, + delta=delta, + ) + station_travel_times = cls( + center=model.center, + velocity_model_hash=model.hash(), + station=station, + east_bounds=model.east_bounds, + north_bounds=model.north_bounds, + depth_bounds=model.depth_bounds, + grid_spacing=model.grid_spacing, + ) + station_travel_times._travel_times = arrival_times.astype(np.float32) + if save: + station_travel_times.save(save) + + return station_travel_times + + loop = asyncio.get_running_loop() + + work = functools.partial( + eikonal_wrapper, + model, + arrival_times, + delta=model.grid_spacing, + ) + + return await loop.run_in_executor(executor, work) + + @property + def filename(self) -> str: + # TODO: Add origin to hash to avoid collisions + return f"{self.station.pretty_nsl}-{self.velocity_model_hash}.3dtt" + + def get_travel_time_interpolator(self) -> RegularGridInterpolator: + if self._interpolator is None: + self._interpolator = RegularGridInterpolator( + (self._east_coords, self._north_coords, self._depth_coords), + self.travel_times, + bounds_error=False, + fill_value=np.nan, + ) + return self._interpolator + + def interpolate_travel_time( + self, + location: Location, + method: Literal["nearest", "linear", "cubic"] = "linear", + ) -> float: + interpolator = self.get_travel_time_interpolator() + offset = location.offset_from(self.center) + return interpolator([offset], method=method).astype(float, copy=False)[0] + + def interpolate_nodes( + self, + nodes: Sequence[Node], + method: Literal["nearest", "linear", "cubic"] = "linear", + ) -> np.ndarray: + interpolator = self.get_travel_time_interpolator() + + coordinates = [node.as_location().offset_from(self.center) for node in nodes] + return interpolator(coordinates, method=method).astype(float, copy=False) + + def get_meshgrid(self) -> list[np.ndarray]: + return np.meshgrid( + self._east_coords, + self._north_coords, + self._depth_coords, + indexing="ij", + ) + + def save(self, path: Path) -> Path: + """Save travel times to a zip file. + + The zip file contains a model.json file with the model metadata and a + numpy file with the travel times. + + Args: + path (Path): path to save the travel times to + + Returns: + Path: path to the saved travel times + """ + if not self.has_travel_times(): + raise AttributeError("travel times have not been calculated yet") + + file = path / self.filename if path.is_dir() else path + logger.debug("saving travel times to %s...", file) + + with zipfile.ZipFile(str(file), "w") as archive: + archive.writestr("model.json", self.model_dump_json(indent=2)) + travel_times = archive.open("travel_times.npy", "w") + np.save(travel_times, self.travel_times) + travel_times.close() + + self._file = file + return file + + @classmethod + def load(cls, file: Path) -> Self: + """Load 3D travel times from a .3dtt file. + + Args: + file (Path): path to the .3dtt file containing the travel times + + Returns: + Self: 3D travel times + """ + logger.debug("loading travel times from %s...", file) + with zipfile.ZipFile(file, "r") as archive: + path = zipfile.Path(archive) + model_file = path / "model.json" + model = cls.model_validate_json(model_file.read_text()) + model._file = file + return model + + def _load_travel_times(self) -> np.ndarray: + if not self._file or not self._file.exists(): + raise FileNotFoundError(f"file {self._file} not found") + + with zipfile.ZipFile(self._file, "r") as archive: + return np.load(archive.open("travel_times.npy", "r")) + + +class FastMarchingTracer(RayTracer): + tracer: Literal["FastMarchingRayTracer"] = "FastMarchingRayTracer" + + phase: PhaseDescription = "fm:P" + interpolation_method: Literal["nearest", "linear", "cubic"] = "linear" + nthreads: int = Field( + default=0, + description="Number of threads to use for travel time." + "If set to 0, cpu_count*2 will be used.", + ) + + lut_cache_size: ByteSize = Field( + default=2 * GiB, + description="Size of the LUT cache.", + ) + + velocity_model: VelocityModels = Constant3DVelocityModel() + + _travel_time_volumes: dict[str, StationTravelTimeVolume] = PrivateAttr({}) + _velocity_model: VelocityModel3D | None = PrivateAttr(None) + + _cached_stations: Stations = PrivateAttr() + _cached_station_indeces: dict[str, int] = PrivateAttr({}) + _node_lut: dict[bytes, np.ndarray] = PrivateAttr() + + def get_available_phases(self) -> tuple[str, ...]: + return (self.phase,) + + def get_travel_time_volume(self, location: Location) -> StationTravelTimeVolume: + return self._travel_time_volumes[location.location_hash()] + + def add_travel_time_volume( + self, + location: Location, + volume: StationTravelTimeVolume, + ) -> None: + self._travel_time_volumes[location.location_hash()] = volume + + def has_travel_time_volume(self, location: Location) -> bool: + return location.location_hash() in self._travel_time_volumes + + def lut_fill_level(self) -> float: + """Return the fill level of the LUT as a float between 0.0 and 1.0""" + return len(self._node_lut) / self._node_lut.get_size() + + async def prepare( + self, + octree: Octree, + stations: Stations, + ) -> None: + logger.info("preparing fast-marching tracer for %s phase...", self.phase) + velocity_model = self.velocity_model.get_model(octree) + self._velocity_model = velocity_model + + for station in stations: + if not velocity_model.is_inside(station): + offset = station.offset_from(velocity_model.center) + stations.blacklist_station( + station, + reason=f"outside fast-marching velocity model, offset {offset}", + ) + + for station in stations: + velocity_station = velocity_model.get_velocity(station) + if velocity_station < 0.0: + raise ValueError( + f"station {station.pretty_nsl} has negative velocity" + f" {velocity_station}" + ) + logger.info( + "velocity at station %s: %.1f m/s", + station.pretty_nsl, + velocity_station, + ) + + nodes_covered = [ + node for node in octree if velocity_model.is_inside(node.as_location()) + ] + if not nodes_covered: + raise ValueError("no octree node is inside the velocity model") + + logger.info( + "%d%% octree nodes are inside the %s velocity model", + len(nodes_covered) / octree.n_nodes * 100, + self.phase, + ) + + self._cached_stations = stations + self._cached_station_indeces = { + sta.pretty_nsl: idx for idx, sta in enumerate(stations) + } + bytes_per_node = stations.n_stations * np.float32().itemsize + lru_cache_size = int(self.lut_cache_size / bytes_per_node) + self._node_lut = LRU(lru_cache_size) + + # TODO: This should be total number nodes. Not only leaf nodes. + node_cache_fraction = lru_cache_size / octree.total_number_nodes() + logging.info( + "limiting traveltime LUT size to %d nodes (%s)," + " caching %.1f%% of possible octree nodes", + lru_cache_size, + human_readable_bytes(self.lut_cache_size), + node_cache_fraction * 100, + ) + + cache_dir = FMM_CACHE_DIR / f"{velocity_model.hash()}" + if not cache_dir.exists(): + cache_dir.mkdir(parents=True) + else: + self._load_cached_tavel_times(cache_dir) + + calc_stations = [ + station for station in stations if not self.has_travel_time_volume(station) + ] + await self._calculate_travel_times(calc_stations, cache_dir) + + def _load_cached_tavel_times(self, cache_dir: Path) -> None: + logger.debug("loading travel times volumes from cache %s...", cache_dir) + volumes: dict[str, StationTravelTimeVolume] = {} + for file in cache_dir.glob("*.3dtt"): + try: + travel_times = StationTravelTimeVolume.load(file) + except zipfile.BadZipFile: + logger.warning("removing bad travel time file %s", file) + file.unlink() + continue + volumes[travel_times.station.location_hash()] = travel_times + + logger.info( + "loaded %d travel times volumes for %s from cache", len(volumes), self.phase + ) + self._travel_time_volumes.update(volumes) + + async def _calculate_travel_times( + self, + stations: list[Station], + cache_dir: Path, + ) -> None: + nthreads = self.nthreads if self.nthreads > 0 else os.cpu_count() * 2 + executor = ThreadPoolExecutor( + max_workers=nthreads, + thread_name_prefix="lassie-fmm", + ) + if self._velocity_model is None: + raise AttributeError("velocity model has not been prepared yet") + + async def worker_station_travel_time(station: Station) -> None: + volume = await StationTravelTimeVolume.calculate_from_eikonal( + self._velocity_model, # noqa + station, + save=cache_dir, + executor=executor, + ) + self.add_travel_time_volume(station, volume) + + calculate_work = [worker_station_travel_time(station) for station in stations] + if not calculate_work: + return + + start = datetime_now() + tasks = [asyncio.create_task(work) for work in calculate_work] + with Progress() as progress: + status = progress.add_task( + f"calculating travel time volumes for {len(tasks)} stations" + f" ({nthreads} threads)", + total=len(tasks), + ) + for _task in asyncio.as_completed(tasks): + await _task + progress.advance(status) + logger.info("calculated travel time volumes in %s", datetime_now() - start) + + def get_travel_time_location( + self, + phase: str, + source: Location, + receiver: Location, + ) -> float: + if phase != self.phase: + raise ValueError(f"phase {phase} is not supported by this tracer") + + station_travel_times = self.get_travel_time_volume(receiver) + return station_travel_times.interpolate_travel_time( + source, + method=self.interpolation_method, + ) + + def get_travel_times( + self, + phase: str, + octree: Octree, + stations: Stations, + ) -> np.ndarray: + if phase != self.phase: + raise ValueError(f"phase {phase} is not supported by this tracer") + + station_indices = np.fromiter( + (self._cached_station_indeces[sta.pretty_nsl] for sta in stations), + dtype=int, + ) + + stations_traveltimes = [] + fill_nodes = [] + for node in octree: + try: + node_traveltimes = self._node_lut[node.hash()][station_indices] + except KeyError: + fill_nodes.append(node) + continue + stations_traveltimes.append(node_traveltimes) + + if fill_nodes: + self.fill_lut(fill_nodes) + + cache_hits, cache_misses = self._node_lut.get_stats() + cache_hit_rate = cache_hits / (cache_hits + cache_misses) + logger.debug( + "node LUT cache fill level %.1f%%, cache hit rate %.1f%%", + self.lut_fill_level() * 100, + cache_hit_rate * 100, + ) + return self.get_travel_times(phase, octree, stations) + + return np.asarray(stations_traveltimes).astype(float, copy=False) + + def fill_lut(self, nodes: Sequence[Node]) -> None: + travel_times = [] + n_nodes = len(nodes) + + with Progress() as progress: + status = progress.add_task( + f"interpolating {self.phase} travel times for {n_nodes} nodes", + total=self._cached_stations.n_stations, + ) + for station in self._cached_stations: + volume = self.get_travel_time_volume(station) + travel_times.append(volume.interpolate_nodes(nodes)) + progress.advance(status) + + travel_times = np.array(travel_times).T + + for node, station_travel_times in zip(nodes, travel_times, strict=True): + self._node_lut[node.hash()] = station_travel_times + + def get_arrivals( + self, + phase: str, + event_time: datetime, + source: Location, + receivers: Sequence[Location], + ) -> list[ModelledArrival | None]: + if phase != self.phase: + raise ValueError(f"phase {phase} is not supported by this tracer") + + traveltimes = [] + for receiver in receivers: + traveltimes.append(self.get_travel_time_location(phase, source, receiver)) + + arrivals = [] + for traveltime, _receiver in zip(traveltimes, receivers, strict=True): + if np.isnan(traveltime): + arrivals.append(None) + continue + + arrivaltime = event_time + timedelta(seconds=traveltime) + arrival = FastMarchingArrival(time=arrivaltime, phase=phase) + arrivals.append(arrival) + return arrivals diff --git a/lassie/tracers/fast_marching/velocity_models.py b/lassie/tracers/fast_marching/velocity_models.py new file mode 100644 index 00000000..ba0b15bd --- /dev/null +++ b/lassie/tracers/fast_marching/velocity_models.py @@ -0,0 +1,467 @@ +from __future__ import annotations + +import logging +import re +from hashlib import sha1 +from pathlib import Path +from typing import TYPE_CHECKING, Annotated, Any, Literal, Union + +import numpy as np +from pydantic import ( + BaseModel, + Field, + FilePath, + PositiveFloat, + PrivateAttr, + model_validator, +) +from pydantic.dataclasses import dataclass +from scipy.interpolate import RegularGridInterpolator +from typing_extensions import Self + +from lassie.models.location import Location + +if TYPE_CHECKING: + from lassie.octree import Octree + + +KM = 1e3 +logger = logging.getLogger(__name__) + + +class VelocityModel3D(BaseModel): + center: Location + + grid_spacing: float + + east_bounds: tuple[float, float] + north_bounds: tuple[float, float] + depth_bounds: tuple[float, float] + + _east_coords: np.ndarray = PrivateAttr(None) + _north_coords: np.ndarray = PrivateAttr(None) + _depth_coords: np.ndarray = PrivateAttr(None) + + _velocity_model: np.ndarray = PrivateAttr(None) + + _hash: str | None = PrivateAttr(None) + + def model_post_init(self, __context: Any) -> None: + grid_spacing = self.grid_spacing + + self._east_coords = np.arange( + self.east_bounds[0], + self.east_bounds[1], + grid_spacing, + ) + self._north_coords = np.arange( + self.north_bounds[0], + self.north_bounds[1], + grid_spacing, + ) + self._depth_coords = np.arange( + self.depth_bounds[0], + self.depth_bounds[1], + grid_spacing, + ) + + self._velocity_model = np.zeros( + ( + self._east_coords.size, + self._north_coords.size, + self._depth_coords.size, + ) + ) + + def set_velocity_model(self, velocity_model: np.ndarray) -> None: + if velocity_model.shape != self._velocity_model.shape: + raise ValueError( + f"Velocity model shape {velocity_model.shape} does not match" + f" expected shape {self._velocity_model.shape}" + ) + self._velocity_model = velocity_model.astype(float, copy=False) + + @property + def velocity_model(self) -> np.ndarray: + if self._velocity_model is None: + raise ValueError("Velocity model not set.") + return self._velocity_model + + def hash(self) -> str: + """Return hash of velocity model. + + Returns: + str: The hash. + """ + if self._hash is None: + sha1_hash = sha1(self._velocity_model.tobytes()) + self._hash = sha1_hash.hexdigest() + return self._hash + + def _get_location_indices(self, location: Location) -> tuple[int, int, int]: + """Return indices of location in velocity model, by nearest neighbor. + + Args: + location (Location): The location. + + Returns: + tuple[int, int, int]: The indices as (east, north, depth). + """ + if not self.is_inside(location): + raise ValueError("Location is outside of velocity model.") + station_offset = location.offset_from(self.center) + east_idx = np.argmin(np.abs(self._east_coords - station_offset[0])) + north_idx = np.argmin(np.abs(self._north_coords - station_offset[1])) + depth_idx = np.argmin(np.abs(self._depth_coords - station_offset[2])) + return int(east_idx), int(north_idx), int(depth_idx) + + def get_velocity(self, location: Location) -> float: + """Return velocity at location in [m/s], nearest neighbor. + + Args: + location (Location): The location. + + Returns: + float: The velocity in m/s. + """ + east_idx, north_idx, depth_idx = self._get_location_indices(location) + return self.velocity_model[east_idx, north_idx, depth_idx] + + def get_source_arrival_grid(self, location: Location) -> np.ndarray: + """Return travel times grid for Eikonal for specific. + + The initial travel time grid is filled with -1.0, except for the source + location, which is set to 0.0 s. + + Args: + location (Location): The location. + + Returns: + np.ndarray: The initial travel times grid. + """ + times = np.full_like(self.velocity_model, fill_value=-1.0) + east_idx, north_idx, depth_idx = self._get_location_indices(location) + times[east_idx, north_idx, depth_idx] = 0.0 + return times + + def is_inside(self, location: Location) -> bool: + """Return True if location is inside velocity model. + + Args: + location (Location): The location. + + Returns: + bool: True if location is inside velocity model. + """ + offset_to_center = location.offset_from(self.center) + return ( + self.east_bounds[0] <= offset_to_center[0] <= self.east_bounds[1] + and self.north_bounds[0] <= offset_to_center[1] <= self.north_bounds[1] + and self.depth_bounds[0] <= offset_to_center[2] <= self.depth_bounds[1] + ) + + def get_meshgrid(self) -> list[np.ndarray]: + """Return meshgrid of velocity model coordinates. + + Returns: + list[np.ndarray]: The meshgrid as list of numpy arrays for east, north, + depth. + """ + return np.meshgrid( + self._east_coords, + self._north_coords, + self._depth_coords, + indexing="ij", + ) + + def resample( + self, + grid_spacing: float, + method: Literal["nearest", "linear", "cubic"] = "linear", + ) -> Self: + """Resample velocity model to new grid spacing. + + Args: + grid_spacing (float): The new grid spacing in [m]. + method (Literal['nearest', 'linear', 'cubic'], optional): Interpolation + method. Defaults to "linear". + + Returns: + Self: A new, resampled velocity model. + """ + if grid_spacing == self.grid_spacing: + return self + + logger.info("resampling velocity model to grid spacing %s m", grid_spacing) + interpolator = RegularGridInterpolator( + (self._east_coords, self._north_coords, self._depth_coords), + self._velocity_model, + method=method, + bounds_error=False, + ) + resampled_model = VelocityModel3D( + center=self.center, + grid_spacing=grid_spacing, + east_bounds=self.east_bounds, + north_bounds=self.north_bounds, + depth_bounds=self.depth_bounds, + ) + coordinates = np.array( + [coords.ravel() for coords in resampled_model.get_meshgrid()] + ).T + resampled_model._velocity_model = interpolator(coordinates).reshape( + resampled_model._velocity_model.shape + ) + return resampled_model + + +class VelocityModelFactory(BaseModel): + model: Literal["VelocityModelFactory"] = "VelocityModelFactory" + + grid_spacing: PositiveFloat | Literal["quadtree"] = Field( + default="quadtree", + description="Grid spacing in meters." + " If 'quadtree' defaults to smallest octreee node size.", + ) + + def get_model(self, octree: Octree) -> VelocityModel3D: + raise NotImplementedError + + +class Constant3DVelocityModel(VelocityModelFactory): + """This model is for mere testing of the method.""" + + model: Literal["Constant3DVelocityModel"] = "Constant3DVelocityModel" + + velocity: PositiveFloat = 5000.0 + + def get_model(self, octree: Octree) -> VelocityModel3D: + if self.grid_spacing == "quadtree": + grid_spacing = octree.smallest_node_size() + else: + grid_spacing = self.grid_spacing + + model = VelocityModel3D( + center=octree.reference, + grid_spacing=grid_spacing, + east_bounds=octree.east_bounds, + north_bounds=octree.north_bounds, + depth_bounds=octree.depth_bounds, + ) + model._velocity_model.fill(self.velocity) + + return model + + +NonLinLocGridType = Literal["VELOCITY", "VELOCITY_METERS", "SLOW_LEN"] +GridDtype = Literal["FLOAT", "DOUBLE"] +DTYPE_MAP = {"FLOAT": np.float32, "DOUBLE": float} + + +@dataclass +class NonLinLocHeader: + origin: Location + nx: int + ny: int + nz: int + delta_x: float + delta_y: float + delta_z: float + grid_dtype: GridDtype + grid_type: NonLinLocGridType + + @classmethod + def from_header_file( + cls, + file: Path, + reference_location: Location | None = None, + ) -> Self: + """Load NonLinLoc velocity model header file. + + Args: + file (Path): Path to NonLinLoc model header file. + reference_location (Location | None, optional): relative location of + NonLinLoc model, used for models with relative coordinates. + Defaults to None. + + Raises: + ValueError: If grid spacing is not equal in all dimensions. + + Returns: + Self: The header. + """ + logger.info("loading NonLinLoc velocity model header file %s", file) + header_text = file.read_text().split("\n")[0] + header_text = re.sub(r"\s+", " ", header_text) # remove excessive spaces + ( + nx, + ny, + nz, + orig_x, + orig_y, + orig_z, + delta_x, + delta_y, + delta_z, + grid_type, + grid_dtype, + ) = header_text.split() + + if not delta_x == delta_y == delta_z: + raise ValueError("NonLinLoc velocity model must have equal spacing.") + + if reference_location: + origin = reference_location + origin.east_shift += float(orig_x) * KM + origin.north_shift += float(orig_y) * KM + origin.elevation -= float(orig_z) * KM + else: + origin = Location( + lon=float(orig_x), + lat=float(orig_y), + elevation=-float(orig_z) * KM, + ) + + return cls( + origin=origin, + nx=int(nx), + ny=int(ny), + nz=int(nz), + delta_x=float(delta_x) * KM, + delta_y=float(delta_y) * KM, + delta_z=float(delta_z) * KM, + grid_dtype=grid_dtype, + grid_type=grid_type, + ) + + @property + def dtype(self) -> np.dtype: + return DTYPE_MAP[self.grid_dtype] + + @property + def grid_spacing(self) -> float: + return self.delta_x + + @property + def east_bounds(self) -> tuple[float, float]: + """Relative to center location.""" + return -self.delta_x * self.nx / 2, self.delta_x * self.nx / 2 + + @property + def north_bounds(self) -> tuple[float, float]: + """Relative to center location.""" + return -self.delta_y * self.ny / 2, self.delta_y * self.ny / 2 + + @property + def depth_bounds(self) -> tuple[float, float]: + """Relative to center location.""" + return (0, self.delta_z * self.nz) + + @property + def center(self) -> Location: + """Return center location of velocity model. + + Returns: + Location: The center location of the grid. + """ + center = self.origin.model_copy(deep=True) + center.east_shift += self.delta_x * self.nx / 2 + center.north_shift += self.delta_y * self.ny / 2 + return center + + +class NonLinLocVelocityModel(VelocityModelFactory): + model: Literal["NonLinLocVelocityModel"] = "NonLinLocVelocityModel" + + header_file: FilePath = Field( + ..., + description="Path to NonLinLoc model header file file." + "The file should be in the format of a NonLinLoc velocity model header file.", + ) + buffer_file: FilePath | None = Field( + default=None, + description="Path to NonLinLoc model buffer file. If none, the filename will be" + "infered from the header file.", + ) + + grid_spacing: PositiveFloat | Literal["quadtree", "input"] = Field( + default="input", + description="Grid spacing in meters." + " If 'quadtree' defaults to smallest octreee node size. If 'input' uses the" + " grid spacing from the NonLinLoc header file.", + ) + interpolation: Literal["nearest", "linear", "cubic"] = Field( + default="linear", + description="Interpolation method for resampling the grid " + "for the fast-marching method.", + ) + + reference_location: Location | None = Field( + default=None, + description="relative location of NonLinLoc model, " + "used for models with relative coordinates.", + ) + + _header: NonLinLocHeader = PrivateAttr() + _velocity_model: np.ndarray = PrivateAttr() + + @model_validator(mode="after") + def load_header(self) -> Self: + self._header = NonLinLocHeader.from_header_file( + self.header_file, + reference_location=self.reference_location, + ) + self.buffer_file = self.buffer_file or self.header_file.with_suffix(".buf") + if not self.buffer_file.exists(): + raise FileNotFoundError(f"Buffer file {self.buffer_file} not found.") + + logger.debug( + "loading NonLinLoc velocity model buffer file %s", self.buffer_file + ) + self._velocity_model = np.fromfile( + self.buffer_file, dtype=self._header.dtype + ).reshape((self._header.nx, self._header.ny, self._header.nz)) + + if self._header.grid_type == "SLOW_LEN": + logger.debug("converting NonLinLoc SLOW_LEN model to velocity") + self._velocity_model = 1.0 / ( + self._velocity_model / self._header.grid_spacing + ) + elif self._header.grid_type == "VELOCITY": + self._velocity_model *= KM + + logging.info( + "NonLinLoc velocity model: %s" + " east_bounds: %s, north_bounds %s, depth_bounds %s", + self._header.center, + self._header.east_bounds, + self._header.north_bounds, + self._header.depth_bounds, + ) + return self + + def get_model(self, octree: Octree) -> VelocityModel3D: + if self.grid_spacing == "quadtree": + grid_spacing = octree.smallest_node_size() + if self.grid_spacing == "input": + grid_spacing = self._header.grid_spacing + else: + grid_spacing = self.grid_spacing + + header = self._header + + velocity_model = VelocityModel3D( + center=header.center, + grid_spacing=header.grid_spacing, + east_bounds=header.east_bounds, + north_bounds=header.north_bounds, + depth_bounds=header.depth_bounds, + ) + velocity_model.set_velocity_model(self._velocity_model) + return velocity_model.resample(grid_spacing, self.interpolation) + + +VelocityModels = Annotated[ + Union[Constant3DVelocityModel, NonLinLocVelocityModel], + Field(..., discriminator="model"), +] diff --git a/lassie/utils.py b/lassie/utils.py index 752db3f5..88ca53b1 100644 --- a/lassie/utils.py +++ b/lassie/utils.py @@ -5,14 +5,14 @@ from datetime import datetime, timedelta, timezone from functools import wraps from pathlib import Path -from typing import Awaitable, Callable, ParamSpec, TypeVar, Annotated, TYPE_CHECKING -from pydantic import constr +from typing import TYPE_CHECKING, Annotated, Awaitable, Callable, ParamSpec, TypeVar +from pydantic import constr from pyrocko.util import UnavailableDecimation from rich.logging import RichHandler if TYPE_CHECKING: - from pyrocko import Trace + from pyrocko.trace import Trace logger = logging.getLogger(__name__) FORMAT = "%(message)s" @@ -55,9 +55,12 @@ def to_datetime(time: float) -> datetime: def downsample(trace: Trace, sampling_rate: float) -> None: deltat = 1.0 / sampling_rate + + if trace.deltat == deltat: + return + try: trace.downsample_to(deltat, demean=False, snap=False, allow_upsample_max=4) - except UnavailableDecimation: logger.warning("using resample instead of decimation") trace.resample(deltat) diff --git a/lassie/waveforms/__init__.py b/lassie/waveforms/__init__.py new file mode 100644 index 00000000..680a929e --- /dev/null +++ b/lassie/waveforms/__init__.py @@ -0,0 +1,11 @@ +from typing import Annotated, Union + +from pydantic import Field + +from lassie.waveforms.base import WaveformProvider +from lassie.waveforms.squirrel import PyrockoSquirrel + +WaveformProviderType = Annotated[ + Union[PyrockoSquirrel, WaveformProvider], + Field(..., discriminator="provider"), +] diff --git a/lassie/waveforms/base.py b/lassie/waveforms/base.py new file mode 100644 index 00000000..38319a55 --- /dev/null +++ b/lassie/waveforms/base.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +import logging +from asyncio import Queue +from dataclasses import dataclass +from datetime import datetime, timedelta +from typing import TYPE_CHECKING, AsyncIterator, Literal + +import numpy as np +from pydantic import BaseModel, PrivateAttr +from pyrocko.trace import Trace + +if TYPE_CHECKING: + from pyrocko.squirrel import Squirrel + + from lassie.models.station import Stations + +logger = logging.getLogger(__name__) + + +@dataclass +class WaveformBatch: + traces: list[Trace] + start_time: datetime + end_time: datetime + i_batch: int + n_batches: int = 0 + + @property + def duration(self) -> timedelta: + return self.end_time - self.start_time + + def is_empty(self) -> bool: + """Check if the batch is empty. + + Returns: + bool: True if the batch is empty, False otherwise. + """ + return not bool(self.traces) + + def clean_traces(self) -> None: + """Remove empty or bad traces.""" + for tr in self.traces.copy(): + if not tr.ydata.size or not np.all(np.isfinite(tr.ydata)): + logger.warning("skipping empty or bad trace: %s", ".".join(tr.nslc_id)) + self.traces.remove(tr) + + +class WaveformProvider(BaseModel): + provider: Literal["WaveformProvider"] = "WaveformProvider" + + _queue: Queue[WaveformBatch | None] = PrivateAttr(default_factory=lambda: Queue()) + + def get_squirrel(self) -> Squirrel: + raise NotImplementedError + + def prepare(self, stations: Stations) -> None: + ... + + async def iter_batches( + self, + window_increment: timedelta, + window_padding: timedelta, + start_time: datetime | None = None, + ) -> AsyncIterator[WaveformBatch]: + yield + raise NotImplementedError diff --git a/lassie/waveforms/squirrel.py b/lassie/waveforms/squirrel.py new file mode 100644 index 00000000..e9ad91dc --- /dev/null +++ b/lassie/waveforms/squirrel.py @@ -0,0 +1,142 @@ +from __future__ import annotations + +import asyncio +import glob +import logging +from datetime import datetime, timedelta +from pathlib import Path +from typing import TYPE_CHECKING, AsyncIterator, Iterator, Literal + +from pydantic import ( + AwareDatetime, + PositiveInt, + PrivateAttr, + constr, + model_validator, +) +from pyrocko.squirrel import Squirrel +from typing_extensions import Self + +from lassie.models.station import Stations +from lassie.utils import datetime_now, to_datetime +from lassie.waveforms.base import WaveformBatch, WaveformProvider + +if TYPE_CHECKING: + from pyrocko.squirrel.base import Batch + +logger = logging.getLogger(__name__) + + +class SquirrelPrefetcher: + def __init__(self, iterator: Iterator[Batch], queue_size: int = 4) -> None: + self.iterator = iterator + self.queue: asyncio.Queue[Batch | None] = asyncio.Queue(maxsize=queue_size) + + self._task = asyncio.create_task(self.prefetch_worker()) + + async def prefetch_worker(self) -> None: + logger.info("start prefetching squirrel data") + while True: + start = datetime_now() + batch = await asyncio.to_thread(lambda: next(self.iterator, None)) + logger.debug("prefetched waveforms in %s", datetime_now() - start) + if batch is None: + logger.debug("squirrel prefetcher finished") + await self.queue.put(None) + break + await self.queue.put(batch) + + +class PyrockoSquirrel(WaveformProvider): + provider: Literal["PyrockoSquirrel"] = "PyrockoSquirrel" + + environment: Path = Path(".") + waveform_dirs: list[Path] = [] + start_time: AwareDatetime | None = None + end_time: AwareDatetime | None = None + + channel_selector: constr(max_length=3) = "*" + async_prefetch_batches: PositiveInt = 4 + + _squirrel: Squirrel | None = PrivateAttr(None) + _stations: Stations = PrivateAttr() + + @model_validator(mode="after") + def _validate_time_span(self) -> Self: # noqa: N805 + if self.start_time and self.end_time and self.start_time > self.end_time: + raise ValueError("start_time must be before end_time") + return self + + def get_squirrel(self) -> Squirrel: + if not self._squirrel: + logger.debug("initializing squirrel") + squirrel = Squirrel(str(self.environment.expanduser())) + paths = [] + for path in self.waveform_dirs: + if "**" in str(path): + paths.extend(glob.glob(str(path.expanduser()), recursive=True)) + else: + paths.append(str(path.expanduser())) + + squirrel.add(paths, check=False) + self._squirrel = squirrel + return self._squirrel + + def prepare(self, stations: Stations) -> None: + logger.info("preparing squirrel waveform provider") + squirrel = self.get_squirrel() + stations.weed_from_squirrel_waveforms(squirrel) + self._stations = stations + + async def iter_batches( + self, + window_increment: timedelta, + window_padding: timedelta, + start_time: datetime | None = None, + end_time: datetime | None = None, + ) -> AsyncIterator[WaveformBatch]: + if not self._stations: + raise ValueError("no stations provided. has prepare() been called?") + + squirrel = self.get_squirrel() + sq_tmin, sq_tmax = squirrel.get_time_span(["waveform"]) + + start_time = start_time or self.start_time or to_datetime(sq_tmin) + end_time = end_time or self.end_time or to_datetime(sq_tmax) + + logger.info( + "searching time span from %s to %s (%s)", + start_time, + end_time, + end_time - start_time, + ) + + iterator = squirrel.chopper_waveforms( + tmin=start_time.timestamp(), + tmax=end_time.timestamp(), + tinc=window_increment.total_seconds(), + tpad=window_padding.total_seconds(), + want_incomplete=False, + codes=[ + (*nsl, self.channel_selector) for nsl in self._stations.get_all_nsl() + ], + ) + prefetcher = SquirrelPrefetcher(iterator, self.async_prefetch_batches) + + while True: + batch = await prefetcher.queue.get() + if batch is None: + prefetcher.queue.task_done() + break + + yield WaveformBatch( + traces=batch.traces, + start_time=to_datetime(batch.tmin), + end_time=to_datetime(batch.tmax), + i_batch=batch.i, + n_batches=batch.n, + ) + + prefetcher.queue.task_done() + + logger.info("squirrel search finished") diff --git a/mkdocs.yml b/mkdocs.yml new file mode 100644 index 00000000..9198fde5 --- /dev/null +++ b/mkdocs.yml @@ -0,0 +1,14 @@ +site_name: Lassie +site_description: The friendly earthquake detector +repo_url: https://github.com/miili/lassie-v2 +repo_name: miili/lassie-v2 + + + +theme: + name: material + palette: + scheme: slate + icon: + repo: fontawesome/brands/git-alt + logo: logos/lassie.webp diff --git a/pyproject.toml b/pyproject.toml index 3a3a2155..a7614309 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,14 +30,14 @@ dependencies = [ "numpy>=1.17.3", "scipy>=1.8.0", "pyrocko>=2022.06.10", - "seisbench>=0.4.0", - "pydantic>=2.0", + "seisbench>=0.5.0", + "pydantic>=2.3", "aiohttp>=3.8", "aiohttp_cors>=0.7.0", "typing-extensions>=4.6", "lru-dict>=1.2", "rich>=13.4", - "nest-asyncio>=1.5", # wait for seisbench merge https://github.com/seisbench/seisbench/pull/214 + "nest_asyncio>=1.5", ] classifiers = [ @@ -59,6 +59,7 @@ dev = [ "black", "ruff", "pytest", + "pytest-asyncio", "mkdocs-material", "mkdocstrings-python", ] @@ -82,3 +83,6 @@ Issues = "https://git.pyrocko.org/pyrocko/lassie/issues" [tool.ruff] extend-select = ['W', 'N', 'DTZ', 'FA', 'G', 'RET', 'SIM', 'B', 'RET', 'C4'] target-version = 'py310' + +[tool.pytest.ini_options] +markers = ["plot: plot figures in tests"] diff --git a/test/conftest.py b/test/conftest.py index be644607..29e32c23 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1,17 +1,63 @@ +import asyncio import random +from datetime import timedelta from pathlib import Path +from tempfile import TemporaryDirectory +from typing import Generator +import aiohttp +import numpy as np import pytest +from rich.progress import Progress +from lassie.models.detection import EventDetection, EventDetections +from lassie.models.location import Location from lassie.models.station import Station, Stations from lassie.octree import Octree -from lassie.tracers.cake import EarthModel, Timing, TraveltimeTree +from lassie.tracers.cake import EarthModel, Timing, TravelTimeTree +from lassie.utils import datetime_now -DATA_PATH = Path(__file__).parent / "data" +DATA_DIR = Path(__file__).parent / "data" + +DATA_URL = "https://data.pyrocko.org/testing/lassie-v2/" +DATA_FILES = { + "FORGE_3D_5_large.P.mod.hdr", + "FORGE_3D_5_large.P.mod.buf", + "FORGE_3D_5_large.S.mod.hdr", + "FORGE_3D_5_large.S.mod.buf", +} KM = 1e3 +async def download_test_data() -> None: + request_files = [ + DATA_DIR / filename + for filename in DATA_FILES + if not (DATA_DIR / filename).exists() + ] + + if not request_files: + return + + async with aiohttp.ClientSession() as session: + for file in request_files: + url = DATA_URL + file.name + with Progress() as progress: + async with session.get(url) as response: + task = progress.add_task( + f"Downloading {url}", + total=response.content_length, + ) + with file.open("wb") as f: + while True: + chunk = await response.content.read(1024) + if not chunk: + break + f.write(chunk) + progress.advance(task, len(chunk)) + + def pytest_addoption(parser) -> None: parser.addoption("--plot", action="store_true", default=False) @@ -22,8 +68,8 @@ def plot(pytestconfig) -> bool: @pytest.fixture(scope="session") -def traveltime_tree() -> TraveltimeTree: - return TraveltimeTree.new( +def travel_time_tree() -> TravelTimeTree: + return TravelTimeTree.new( earthmodel=EarthModel(), distance_bounds=(0 * KM, 15 * KM), receiver_depth_bounds=(0 * KM, 0 * KM), @@ -34,12 +80,23 @@ def traveltime_tree() -> TraveltimeTree: ) +@pytest.fixture(scope="session") +def data_dir() -> Path: + if not DATA_DIR.exists(): + DATA_DIR.mkdir() + + asyncio.run(download_test_data()) + return DATA_DIR + + @pytest.fixture(scope="session") def octree() -> Octree: return Octree( - center_lat=10.0, - center_lon=10.0, - surface_elevation=0.0, + reference=Location( + lat=10.0, + lon=10.0, + elevation=1.0 * KM, + ), size_initial=2 * KM, size_limit=500, east_bounds=(-10 * KM, 10 * KM), @@ -65,3 +122,42 @@ def stations() -> Stations: ) stations.append(station) return Stations(stations=stations) + + +@pytest.fixture(scope="session") +def fixed_stations() -> Stations: + n_stations = 20 + rng = np.random.RandomState(0) + stations: list[Station] = [] + for i_sta in range(n_stations): + station = Station( + network="FX", + station="STA%02d" % i_sta, + lat=10.0, + lon=10.0, + elevation=rng.uniform(0, 1) * KM, + north_shift=rng.uniform(-10, 10) * KM, + east_shift=rng.uniform(-10, 10) * KM, + ) + stations.append(station) + return Stations(stations=stations) + + +@pytest.fixture(scope="session") +def detections() -> Generator[EventDetections, None, None]: + n_detections = 2000 + detections: list[EventDetection] = [] + for _ in range(n_detections): + time = datetime_now() - timedelta(days=random.uniform(0, 365)) + detection = EventDetection( + lat=10.0, + lon=10.0, + east_shift=random.uniform(-10, 10) * KM, + north_shift=random.uniform(-10, 10) * KM, + distance_border=1000.0, + semblance=random.uniform(0, 1), + time=time, + ) + detections.append(detection) + with TemporaryDirectory() as tmpdir: + yield EventDetections(rundir=Path(tmpdir), detections=detections) diff --git a/test/test_cake.py b/test/test_cake.py index 65350682..cccb4079 100644 --- a/test/test_cake.py +++ b/test/test_cake.py @@ -7,7 +7,7 @@ import numpy as np from lassie.models.location import Location -from lassie.tracers.cake import TraveltimeTree +from lassie.tracers.cake import TravelTimeTree if TYPE_CHECKING: from lassie.models.station import Stations @@ -16,14 +16,14 @@ KM = 1e3 -def test_sptree_model(traveltime_tree: TraveltimeTree): - model = traveltime_tree +def test_sptree_model(travel_time_tree: TravelTimeTree): + model = travel_time_tree with TemporaryDirectory() as d: tmp = Path(d) file = model.save(tmp) - model2 = TraveltimeTree.load(file) + model2 = TravelTimeTree.load(file) model2._load_sptree() source = Location( @@ -41,28 +41,30 @@ def test_sptree_model(traveltime_tree: TraveltimeTree): depth=0, ) - model.get_traveltime(source, receiver) + model.get_travel_time(source, receiver) def test_lut( - traveltime_tree: TraveltimeTree, octree: Octree, stations: Stations + travel_time_tree: TravelTimeTree, + octree: Octree, + stations: Stations, ) -> None: - model = traveltime_tree + model = travel_time_tree model.init_lut(octree, stations) - traveltimes_tree = model.interpolate_traveltimes(octree, stations) - traveltimes_lut = model.get_traveltimes(octree, stations) + traveltimes_tree = model.interpolate_travel_times(octree, stations) + traveltimes_lut = model.get_travel_times(octree, stations) np.testing.assert_equal(traveltimes_tree, traveltimes_lut) # Test refilling the LUT model._node_lut.clear() - traveltimes_tree = model.interpolate_traveltimes(octree, stations) - traveltimes_lut = model.get_traveltimes(octree, stations) + traveltimes_tree = model.interpolate_travel_times(octree, stations) + traveltimes_lut = model.get_travel_times(octree, stations) np.testing.assert_equal(traveltimes_tree, traveltimes_lut) assert len(model._node_lut) > 0, "did not refill lut" - stations_selection = stations.copy() + stations_selection = stations.model_copy() stations_selection.stations = stations_selection.stations[:5] - traveltimes_tree = model.interpolate_traveltimes(octree, stations_selection) - traveltimes_lut = model.get_traveltimes(octree, stations_selection) + traveltimes_tree = model.interpolate_travel_times(octree, stations_selection) + traveltimes_lut = model.get_travel_times(octree, stations_selection) np.testing.assert_equal(traveltimes_tree, traveltimes_lut) diff --git a/test/test_fast_marching.py b/test/test_fast_marching.py new file mode 100644 index 00000000..59d75206 --- /dev/null +++ b/test/test_fast_marching.py @@ -0,0 +1,235 @@ +from __future__ import annotations + +from pathlib import Path + +import numpy as np +import pytest +import pytest_asyncio + +from lassie.models.station import Station, Stations +from lassie.octree import Octree +from lassie.tracers.fast_marching.fast_marching import ( + FastMarchingTracer, + StationTravelTimeVolume, +) +from lassie.tracers.fast_marching.velocity_models import ( + Constant3DVelocityModel, + NonLinLocVelocityModel, + VelocityModel3D, +) +from lassie.utils import datetime_now + +CONSTANT_VELOCITY = 5000 +KM = 1e3 + + +def stations_inside( + model: VelocityModel3D, + nstations: int = 20, + seed: int = 0, + depth: float | None = None, +) -> Stations: + stations = [] + rng = np.random.RandomState(seed) + for i_sta in range(nstations): + station = Station( + network="FM", + station="STA%02d" % i_sta, + lat=model.center.lat, + lon=model.center.lon, + elevation=model.center.elevation, + north_shift=model.center.north_shift + rng.uniform(*model.north_bounds), + east_shift=model.center.east_shift + rng.uniform(*model.east_bounds), + depth=model.center.depth + + (depth if depth is not None else rng.uniform(*model.depth_bounds)), + ) + station = station.shifted_origin() + stations.append(station) + return Stations(stations=stations) + + +def octree_cover(model: VelocityModel3D) -> Octree: + return Octree( + reference=model.center, + size_initial=2 * KM, + size_limit=500, + east_bounds=model.east_bounds, + north_bounds=model.north_bounds, + depth_bounds=model.depth_bounds, + absorbing_boundary=0, + ) + + +@pytest_asyncio.fixture +async def station_travel_times( + octree: Octree, stations: Stations +) -> StationTravelTimeVolume: + octree.reference.elevation = 1 * KM + model = Constant3DVelocityModel(velocity=CONSTANT_VELOCITY, grid_spacing=100.0) + model_3d = model.get_model(octree) + return await StationTravelTimeVolume.calculate_from_eikonal( + model_3d, stations.stations[0] + ) + + +@pytest.mark.asyncio +async def test_load_save( + station_travel_times: StationTravelTimeVolume, + tmp_path: str, +) -> None: + outfile = Path(tmp_path) / "test_fast_marching_station.3dtt" + station_travel_times.save(outfile) + assert outfile.exists() + + station_travel_times2 = StationTravelTimeVolume.load(outfile) + np.testing.assert_equal( + station_travel_times.travel_times, station_travel_times2.travel_times + ) + + +@pytest.mark.asyncio +async def test_travel_time_interpolation( + station_travel_times: StationTravelTimeVolume, + octree: Octree, +) -> None: + eikonal_travel_times = [] + source_distances = [] + for node in octree: + source = node.as_location() + eikonal_travel_times.append( + station_travel_times.interpolate_travel_time(source) + ) + source_distances.append(station_travel_times.station.distance_to(source)) + + eikonal_travel_times = np.array(eikonal_travel_times) + assert np.any(eikonal_travel_times) + + analytical_travel_times = np.array(source_distances) / CONSTANT_VELOCITY + nan_travel_times = np.isnan(eikonal_travel_times) + + assert np.any(~nan_travel_times) + np.testing.assert_almost_equal( + eikonal_travel_times[~nan_travel_times], + analytical_travel_times[~nan_travel_times], + decimal=1, + ) + + eikonal_travel_times = station_travel_times.interpolate_nodes( + octree, method="cubic" + ) + + nan_travel_times = np.isnan(eikonal_travel_times) + assert np.any(~nan_travel_times) + np.testing.assert_almost_equal( + eikonal_travel_times[~nan_travel_times], + analytical_travel_times[~nan_travel_times], + decimal=1, + ) + + +@pytest.mark.asyncio +async def test_fast_marching_phase_tracer( + octree: Octree, fixed_stations: Stations +) -> None: + tracer = FastMarchingTracer( + phase="fm:P", + velocity_model=Constant3DVelocityModel( + velocity=CONSTANT_VELOCITY, grid_spacing=80.0 + ), + ) + await tracer.prepare(octree, fixed_stations) + tracer.get_travel_times("fm:P", octree, fixed_stations) + + +@pytest.mark.asyncio +async def test_non_lin_loc(data_dir: Path, octree: Octree, stations: Stations) -> None: + header_file = data_dir / "FORGE_3D_5_large.P.mod.hdr" + + tracer = FastMarchingTracer( + phase="fm:P", + velocity_model=NonLinLocVelocityModel(header_file=header_file), + ) + octree = octree_cover(tracer.velocity_model.get_model(octree)) + stations = stations_inside(tracer.velocity_model.get_model(octree)) + await tracer.prepare(octree, stations) + source = octree[1].as_location() + tracer.get_arrivals( + "fm:P", + event_time=datetime_now(), + source=source, + receivers=list(stations), + ) + + +@pytest.mark.plot +def test_non_lin_loc_model( + data_dir: Path, + octree: Octree, + stations: Stations, +) -> None: + import matplotlib.pyplot as plt + + header_file = data_dir / "FORGE_3D_5_large.P.mod.hdr" + + model = NonLinLocVelocityModel(header_file=header_file) + velocity_model = model.get_model(octree).resample( + grid_spacing=200.0, + method="linear", + ) + + # 3d figure of velocity model + fig = plt.figure() + ax = fig.add_subplot(projection="3d") + coords = velocity_model.get_meshgrid() + print(coords[0].shape) + cmap = ax.scatter( + coords[0], + coords[1], + -coords[2], + s=np.log(velocity_model.velocity_model.ravel() / KM), + c=velocity_model.velocity_model.ravel(), + ) + fig.colorbar(cmap) + plt.show() + + +@pytest.mark.plot +@pytest.mark.asyncio +async def test_non_lin_loc_travel_times(data_dir: Path, octree: Octree) -> None: + import matplotlib.pyplot as plt + + header_file = data_dir / "FORGE_3D_5_large.P.mod.hdr" + + tracer = FastMarchingTracer( + phase="fm:P", + velocity_model=NonLinLocVelocityModel( + header_file=header_file, + grid_spacing=100.0, + ), + ) + model_3d = tracer.velocity_model.get_model(octree) + octree = octree_cover(model_3d) + stations = stations_inside(model_3d, depth=0.0) + await tracer.prepare(octree, stations) + + volume = tracer.get_travel_time_volume(stations.stations[0]) + + # 3d figure of velocity model + fig = plt.figure() + ax = fig.add_subplot(projection="3d") + coords = volume.get_meshgrid() + print(coords[0].shape) + + cmap = ax.scatter( + coords[0], + coords[1], + coords[2], + c=volume.travel_times.ravel(), + alpha=0.2, + ) + + station_offet = volume.station.offset_from(volume.center) + print(station_offet) + ax.scatter(*station_offet, s=100, c="r") + fig.colorbar(cmap) + plt.show() diff --git a/test/test_location.py b/test/test_location.py index 05ac8c2d..43e552c6 100644 --- a/test/test_location.py +++ b/test/test_location.py @@ -1,8 +1,75 @@ +from __future__ import annotations + +import random + +import numpy as np + from lassie.models import Location +KM = 1e3 + def test_location() -> None: loc = Location(lat=11.0, lon=23.55) loc_other = Location(lat=13.123, lon=21.12) loc.surface_distance_to(loc_other) + + +def test_distance_same_origin(): + loc = Location(lat=11.0, lon=23.55) + + perturb_attributes = {"north_shift", "east_shift", "elevation", "depth"} + for _ in range(100): + distance = random.uniform(-10 * KM, 10 * KM) + for attr in perturb_attributes: + loc_other = loc.model_copy() + loc_other._cached_lat_lon = None + setattr(loc_other, attr, distance) + assert loc.distance_to(loc_other) == abs(distance) + + loc_shifted = loc_other.shifted_origin() + np.testing.assert_approx_equal( + loc.distance_to(loc_shifted), + abs(distance), + significant=2, + ) + + +def test_location_offset(): + loc = Location(lat=11.0, lon=23.55) + loc_other = Location( + lat=11.0, + lon=23.55, + north_shift=100.0, + east_shift=100.0, + depth=100.0, + ) + + offset = loc_other.offset_from(loc) + assert offset == (100.0, 100.0, 100.0) + + loc_other = Location( + lat=11.0, + lon=23.55, + north_shift=100.0, + east_shift=100.0, + elevation=100.0, + ) + offset = loc_other.offset_from(loc) + assert offset == (100.0, 100.0, -100.0) + + loc_other = Location( + lat=11.0, + lon=23.55, + north_shift=100.0, + east_shift=100.0, + elevation=100.0, + depth=10.0, + ) + offset = loc_other.offset_from(loc) + assert offset == (100.0, 100.0, -90.0) + + loc_other = loc_other.shifted_origin() + offset = loc_other.offset_from(loc) + np.testing.assert_almost_equal(offset, (100.0, 100.0, -90.0), decimal=0) diff --git a/test/test_octree.py b/test/test_octree.py index fcdf37ab..d9f5c3c7 100644 --- a/test/test_octree.py +++ b/test/test_octree.py @@ -1,23 +1,10 @@ -import pytest +from __future__ import annotations -from lassie.octree import Octree +from lassie.octree import NodeSplitError, Octree km = 1e3 -@pytest.fixture(scope="function") -def octree(): - yield Octree( - center_lat=0.0, - center_lon=0.0, - east_bounds=(-25 * km, 25 * km), - north_bounds=(-25 * km, 25 * km), - depth_bounds=(0, 40 * km), - size_initial=5 * km, - size_limit=0.5 * km, - ) - - def test_octree(octree: Octree, plot: bool) -> None: assert octree.n_nodes > 0 @@ -29,7 +16,11 @@ def test_octree(octree: Octree, plot: bool) -> None: assert nnodes * 8 == octree.n_nodes child, *_ = octree[80].split() - child, *_ = child.split() + while True: + try: + child, *_ = child.split() + except NodeSplitError: + break for node in octree: node.semblance = node.depth + node.east + node.north diff --git a/test/test_plot.py b/test/test_plot.py new file mode 100644 index 00000000..46d77fe9 --- /dev/null +++ b/test/test_plot.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +from pathlib import Path +from typing import TYPE_CHECKING + +import numpy as np +import pytest +from matplotlib import pyplot as plt + +from lassie.models.detection import EventDetection +from lassie.plot.detections import plot_detections +from lassie.plot.octree import plot_octree_surface_tiles +from lassie.utils import datetime_now + +if TYPE_CHECKING: + from lassie.models.detection import EventDetections + from lassie.octree import Octree + + +@pytest.mark.plot +def test_octree_2d(octree: Octree) -> None: + semblance = np.random.uniform(size=octree.n_nodes) + octree.map_semblance(semblance) + plot_octree_surface_tiles(octree, filename=Path("/tmp/test.png")) + + detection = EventDetection( + lat=0.0, + lon=0.0, + east_shift=0.0, + north_shift=0.0, + distance_border=1000.0, + semblance=1.0, + time=datetime_now(), + ) + + fig = plt.figure() + ax = fig.gca() + + plot_octree_surface_tiles(octree, axes=ax, detections=[detection]) + plt.show() + + +@pytest.mark.plot +def test_detections_semblance(detections: EventDetections) -> None: + plot_detections(detections, axes=None) diff --git a/test/test_search.py b/test/test_search.py index 9770534a..f99747b6 100644 --- a/test/test_search.py +++ b/test/test_search.py @@ -3,14 +3,14 @@ import pytest from lassie.models.location import locations_to_csv -from lassie.search import SquirrelSearch +from lassie.search import Search km = 1e3 @pytest.mark.skip(reason="Fail") def test_search() -> None: - search = SquirrelSearch() + search = Search() # search.scan_squirrel() locations = search.stations.model_copy() diff --git a/test/upload_data.sh b/test/upload_data.sh new file mode 100755 index 00000000..b2884ebf --- /dev/null +++ b/test/upload_data.sh @@ -0,0 +1,3 @@ +#!/bin/bash +echo "Uploading test data to data.pyrocko.org" +scp data/* pyrocko-www@data.pyrocko.org:/srv/data.pyrocko.org/www/testing/lassie-v2