From e2e6b0b7cf9ac4f4adc9d80117c2f8eff1813f5f Mon Sep 17 00:00:00 2001 From: miili Date: Mon, 11 Sep 2023 18:00:49 +0200 Subject: [PATCH] wip: fast-marching implementation --- lassie/models/station.py | 9 +- lassie/octree.py | 2 +- lassie/search/base.py | 5 +- lassie/tracers/__init__.py | 6 +- lassie/tracers/cake.py | 3 +- lassie/tracers/fast_marching/__init__.py | 1 + lassie/tracers/fast_marching/fast_marching.py | 161 +++++++++++++----- .../tracers/fast_marching/velocity_models.py | 30 +++- test/conftest.py | 20 +++ test/test_fast_marching.py | 70 +++++--- 10 files changed, 224 insertions(+), 83 deletions(-) diff --git a/lassie/models/station.py b/lassie/models/station.py index d3face3b..48b3d887 100644 --- a/lassie/models/station.py +++ b/lassie/models/station.py @@ -63,7 +63,6 @@ class Stations(BaseModel): pyrocko_station_yamls: list[Path] = [] _cached_coordinates: np.ndarray | None = PrivateAttr(None) - _cached_iter: list[Station] | None = PrivateAttr(None) def model_post_init(self, __context: Any) -> None: loaded_stations = [] @@ -131,17 +130,11 @@ def weed_from_squirrel_waveforms(self, squirrel: Squirrel) -> None: raise ValueError("no stations available, add waveforms to start detection") def __iter__(self) -> Iterator[Station]: - if self._cached_iter is None: - self._cached_iter = [ - sta for sta in self.stations if sta.pretty_nsl not in self.blacklist - ] - return iter(self._cached_iter) + return (sta for sta in self.stations if sta.pretty_nsl not in self.blacklist) @property def n_stations(self) -> int: """Number of stations in the stations object.""" - if self._cached_iter: - return len(self._cached_iter) return sum(1 for _ in self) def get_all_nsl(self) -> list[tuple[str, str, str]]: diff --git a/lassie/octree.py b/lassie/octree.py index e798ab75..f8a75d69 100644 --- a/lassie/octree.py +++ b/lassie/octree.py @@ -374,7 +374,7 @@ def maximum_number_nodes(self) -> int: (self.east_bounds[1] - self.east_bounds[0]) * (self.north_bounds[1] - self.north_bounds[0]) * (self.depth_bounds[1] - self.depth_bounds[0]) - / self.smallest_node_size() ** 3 + / (self.smallest_node_size() ** 3) ) def copy(self, deep=False) -> Self: diff --git a/lassie/search/base.py b/lassie/search/base.py index 5bb273a1..d12953ea 100644 --- a/lassie/search/base.py +++ b/lassie/search/base.py @@ -21,6 +21,7 @@ from lassie.signals import Signal from lassie.station_corrections import StationCorrections from lassie.tracers import CakeTracer, ConstantVelocityTracer, RayTracers +from lassie.tracers.fast_marching import FastMarchingTracer from lassie.utils import PhaseDescription, Symbols, alog_call, time_to_path if TYPE_CHECKING: @@ -50,7 +51,9 @@ class Search(BaseModel): octree: Octree = Octree() stations: Stations = Stations() - ray_tracers: RayTracers = RayTracers(root=[ConstantVelocityTracer(), CakeTracer()]) + ray_tracers: RayTracers = RayTracers( + root=[ConstantVelocityTracer(), CakeTracer(), FastMarchingTracer()] + ) image_functions: ImageFunctions station_corrections: StationCorrections | None = None diff --git a/lassie/tracers/__init__.py b/lassie/tracers/__init__.py index ab5a16d0..029bf9a5 100644 --- a/lassie/tracers/__init__.py +++ b/lassie/tracers/__init__.py @@ -5,12 +5,12 @@ from pydantic import Field, RootModel -from lassie.tracers.base import ModelledArrival from lassie.tracers.cake import CakeArrival, CakeTracer from lassie.tracers.constant_velocity import ( ConstantVelocityArrival, ConstantVelocityTracer, ) +from lassie.tracers.fast_marching import FastMarchingArrival, FastMarchingTracer if TYPE_CHECKING: from lassie.models.station import Stations @@ -22,12 +22,12 @@ RayTracerType = Annotated[ - Union[CakeTracer, ConstantVelocityTracer], + Union[ConstantVelocityTracer, CakeTracer, FastMarchingTracer], Field(..., discriminator="tracer"), ] RayTracerArrival = Annotated[ - Union[CakeArrival, ConstantVelocityArrival, ModelledArrival], + Union[ConstantVelocityArrival, CakeArrival, FastMarchingArrival], Field(..., discriminator="tracer"), ] diff --git a/lassie/tracers/cake.py b/lassie/tracers/cake.py index 1e7deed7..8e1e94c8 100644 --- a/lassie/tracers/cake.py +++ b/lassie/tracers/cake.py @@ -468,7 +468,7 @@ class CakeTracer(RayTracer): trim_earth_model_depth: bool = Field( True, description="Trim earth model to max depth of the octree." ) - lut_cache_size: ByteSize = Field("4GB", description="Size of the LUT cache.") + lut_cache_size: ByteSize = Field(2 * GiB, description="Size of the LUT cache.") _traveltime_trees: dict[PhaseDescription, TravelTimeTree] = PrivateAttr({}) @@ -499,6 +499,7 @@ async def prepare(self, octree: Octree, stations: Stations) -> None: n_trees = len(self.phases) LRU_CACHE_SIZE = int(self.lut_cache_size / bytes_per_node / n_trees) + # TODO: This should be total number nodes. Not only leaf nodes. node_cache_fraction = LRU_CACHE_SIZE / octree.maximum_number_nodes() logging.info( "limiting traveltime LUT size to %d nodes (%s)," diff --git a/lassie/tracers/fast_marching/__init__.py b/lassie/tracers/fast_marching/__init__.py index e69de29b..e67c9f0b 100644 --- a/lassie/tracers/fast_marching/__init__.py +++ b/lassie/tracers/fast_marching/__init__.py @@ -0,0 +1 @@ +from .fast_marching import FastMarchingArrival, FastMarchingTracer # noqa diff --git a/lassie/tracers/fast_marching/fast_marching.py b/lassie/tracers/fast_marching/fast_marching.py index 42d87c1c..360777d4 100644 --- a/lassie/tracers/fast_marching/fast_marching.py +++ b/lassie/tracers/fast_marching/fast_marching.py @@ -11,6 +11,7 @@ from typing import TYPE_CHECKING, Any, Literal, Self, Sequence import numpy as np +from lru import LRU from pydantic import BaseModel, ByteSize, Field, PrivateAttr from pyrocko.modelling import eikonal from rich.progress import Progress @@ -18,13 +19,14 @@ from lassie.models.location import Location from lassie.models.station import Station, Stations +from lassie.octree import Node from lassie.tracers.base import ModelledArrival, RayTracer from lassie.tracers.fast_marching.velocity_models import ( - NonLinLocVelocityModel, + Constant3DVelocityModel, VelocityModel3D, VelocityModels, ) -from lassie.utils import CACHE_DIR, PhaseDescription, datetime_now +from lassie.utils import CACHE_DIR, PhaseDescription, datetime_now, human_readable_bytes if TYPE_CHECKING: from lassie.octree import Octree @@ -111,7 +113,8 @@ async def calculate_from_eikonal( arrival_times = model.get_source_arrival_grid(station) if not model.is_inside(station): - raise ValueError("station is outside the velocity model") + offset = station.offset_to(model.center) + raise ValueError(f"station is outside the velocity model {offset}") def eikonal_wrapper( velocity_model: VelocityModel3D, @@ -179,18 +182,14 @@ def interpolate_travel_time( offset = location.offset_to(self.center) return interpolator([offset], method=method)[0] - def interpolate_travel_times( + def interpolate_nodes( self, - octree: Octree, + nodes: Sequence[Node], method: Literal["nearest", "linear", "cubic"] = "linear", ) -> np.ndarray: interpolator = self.get_traveltime_interpolator() - coordinates = [] - for node in octree: - location = node.as_location() - coordinates.append(location.offset_to(self.center)) - + coordinates = [node.as_location().offset_to(self.center) for node in nodes] return interpolator(coordinates, method=method) def save(self, path: Path) -> Path: @@ -251,48 +250,91 @@ class FastMarchingTracer(RayTracer): phase: PhaseDescription = "fm:P" interpolation_method: Literal["nearest", "linear", "cubic"] = "nearest" - nthreads: int = Field(default_factory=os.cpu_count) - lut_cache_size: ByteSize = Field("4GB", description="Size of the LUT cache.") + nthreads: int = Field( + 0, + description="Number of threads to use for travel time." + "If set to 0, cpu_count*2 will be used.", + ) + + lut_cache_size: ByteSize = Field(2 * GiB, description="Size of the LUT cache.") - velocity_model: VelocityModels = NonLinLocVelocityModel() + velocity_model: VelocityModels = Constant3DVelocityModel() - _traveltime_models: dict[str, StationTravelTimeVolume] = PrivateAttr({}) + _travel_time_volumes: dict[str, StationTravelTimeVolume] = PrivateAttr({}) _velocity_model: VelocityModel3D | None = PrivateAttr(None) + _cached_stations: Stations = PrivateAttr() + _cached_station_indeces: dict[str, int] = PrivateAttr({}) + _node_lut: dict[bytes, np.ndarray] = PrivateAttr() + def get_available_phases(self) -> tuple[str, ...]: return (self.phase,) + def get_travel_time_volume(self, location: Location) -> StationTravelTimeVolume: + return self._travel_time_volumes[location.location_hash()] + + def add_travel_time_volume( + self, + location: Location, + volume: StationTravelTimeVolume, + ) -> None: + self._travel_time_volumes[location.location_hash()] = volume + + def has_travel_time_volume(self, location: Location) -> bool: + return location.location_hash() in self._travel_time_volumes + + def lut_fill_level(self) -> float: + """Return the fill level of the LUT as a float between 0.0 and 1.0""" + return len(self._node_lut) / self._node_lut.get_size() + async def prepare( self, octree: Octree, stations: Stations, ) -> None: - velocity_model = self.velocity_model.get_model(octree, stations) + velocity_model = self.velocity_model.get_model(octree) self._velocity_model = velocity_model for station in stations: if not velocity_model.is_inside(station): + offset = station.offset_to(velocity_model.center) stations.blacklist_station( - station, reason="outside the fast-marching velocity model" + station, + reason=f"outside the fast-marching velocity model, offset {offset}", ) + self._cached_stations = stations + self._cached_station_indeces = { + sta.pretty_nsl: idx for idx, sta in enumerate(stations) + } + bytes_per_node = stations.n_stations * np.float32().itemsize + lru_cache_size = int(self.lut_cache_size / bytes_per_node) + self._node_lut = LRU(lru_cache_size) + + # TODO: This should be total number nodes. Not only leaf nodes. + node_cache_fraction = lru_cache_size / octree.maximum_number_nodes() + logging.info( + "limiting traveltime LUT size to %d nodes (%s)," + " caching %.1f%% of possible octree nodes", + lru_cache_size, + human_readable_bytes(self.lut_cache_size), + node_cache_fraction * 100, + ) + cache_dir = FMM_CACHE_DIR / f"{velocity_model.hash()}" if not cache_dir.exists(): - logger.info("creating cache directory %s", cache_dir) cache_dir.mkdir(parents=True) else: self._load_cached_tavel_times(cache_dir) - work_stations = [ - station - for station in stations - if station.location_hash() not in self._traveltime_models + calc_stations = [ + station for station in stations if not self.has_travel_time_volume(station) ] - await self._calculate_travel_times(work_stations, cache_dir) + await self._calculate_travel_times(calc_stations, cache_dir) def _load_cached_tavel_times(self, cache_dir: Path) -> None: logger.debug("loading travel times volumes from cache %s...", cache_dir) - models: dict[str, StationTravelTimeVolume] = {} + volumes: dict[str, StationTravelTimeVolume] = {} for file in cache_dir.glob("*.3dtt"): try: travel_times = StationTravelTimeVolume.load(file) @@ -300,30 +342,32 @@ def _load_cached_tavel_times(self, cache_dir: Path) -> None: logger.warning("removing bad travel time file %s", file) file.unlink() continue - models[travel_times.station.location_hash()] = travel_times + volumes[travel_times.station.location_hash()] = travel_times - logger.info("loaded %d travel times volumes from cache", len(models)) - self._traveltime_models.update(models) + logger.info("loaded %d travel times volumes from cache", len(volumes)) + self._travel_time_volumes.update(volumes) async def _calculate_travel_times( self, stations: list[Station], cache_dir: Path, ) -> None: + nthreads = self.nthreads if self.nthreads > 0 else os.cpu_count() * 2 executor = ThreadPoolExecutor( - max_workers=self.nthreads, thread_name_prefix="lassie-fmm" + max_workers=nthreads, + thread_name_prefix="lassie-fmm", ) if self._velocity_model is None: raise AttributeError("velocity model has not been prepared yet") async def worker_station_travel_time(station: Station) -> None: - model = await StationTravelTimeVolume.calculate_from_eikonal( + volume = await StationTravelTimeVolume.calculate_from_eikonal( self._velocity_model, # noqa station, save=cache_dir, executor=executor, ) - self._traveltime_models[station.location_hash()] = model + self.add_travel_time_volume(station, volume) calculate_work = [worker_station_travel_time(station) for station in stations] if not calculate_work: @@ -333,7 +377,8 @@ async def worker_station_travel_time(station: Station) -> None: tasks = [asyncio.create_task(work) for work in calculate_work] with Progress() as progress: status = progress.add_task( - f"calculating travel time volumes for {len(tasks)} station", + f"calculating travel time volumes for {len(tasks)} stations" + f" ({nthreads} threads)", total=len(tasks), ) for _task in asyncio.as_completed(tasks): @@ -350,7 +395,7 @@ def get_travel_time_location( if phase != self.phase: raise ValueError(f"phase {phase} is not supported by this tracer") - station_travel_times = self._traveltime_models[receiver.location_hash()] + station_travel_times = self.get_travel_time_volume(receiver) return station_travel_times.interpolate_travel_time( source, method=self.interpolation_method, @@ -365,15 +410,53 @@ def get_travel_times( if phase != self.phase: raise ValueError(f"phase {phase} is not supported by this tracer") - result = [] - for station in stations: - station_travel_times = self._traveltime_models[station.location_hash()] - result.append( - station_travel_times.interpolate_travel_times( - octree, method=self.interpolation_method - ) + station_indices = np.fromiter( + (self._cached_station_indeces[sta.pretty_nsl] for sta in stations), + dtype=int, + ) + + stations_traveltimes = [] + fill_nodes = [] + for node in octree: + try: + node_traveltimes = self._node_lut[node.hash()][station_indices] + except KeyError: + fill_nodes.append(node) + continue + stations_traveltimes.append(node_traveltimes) + + if fill_nodes: + self.fill_lut(fill_nodes) + + cache_hits, cache_misses = self._node_lut.get_stats() + cache_hit_rate = cache_hits / (cache_hits + cache_misses) + logger.info( + "node LUT cache fill level %.1f%%, cache hit rate %.1f%%", + self.lut_fill_level() * 100, + cache_hit_rate * 100, + ) + return self.get_travel_times(phase, octree, stations) + + return np.asarray(stations_traveltimes).astype(float, copy=False) + + def fill_lut(self, nodes: Sequence[Node]) -> None: + travel_times = [] + n_nodes = len(nodes) + + with Progress() as progress: + status = progress.add_task( + f"interpolating station traveltimes for {n_nodes} nodes", + total=self._cached_stations.n_stations, ) - return np.array(result).T + for station in self._cached_stations: + volume = self.get_travel_time_volume(station) + travel_times.append(volume.interpolate_nodes(nodes)) + progress.advance(status) + + travel_times = np.array(travel_times).T + + for node, station_travel_times in zip(nodes, travel_times, strict=True): + self._node_lut[node.hash()] = station_travel_times def get_arrivals( self, diff --git a/lassie/tracers/fast_marching/velocity_models.py b/lassie/tracers/fast_marching/velocity_models.py index d0aebe86..4f03c61e 100644 --- a/lassie/tracers/fast_marching/velocity_models.py +++ b/lassie/tracers/fast_marching/velocity_models.py @@ -21,7 +21,7 @@ from lassie.models.location import Location if TYPE_CHECKING: - from lassie.models.station import Station, Stations + from lassie.models.station import Station from lassie.octree import Octree @@ -79,7 +79,7 @@ def set_velocity_model(self, velocity_model: np.ndarray) -> None: f"Velocity model shape {velocity_model.shape} does not match" f" expected shape {self._velocity_model.shape}" ) - self._velocity_model = velocity_model + self._velocity_model = velocity_model.astype(float, copy=False) def hash(self) -> str: if self._hash is None: @@ -117,6 +117,9 @@ def resample( grid_spacing: float, method: Literal["nearest", "linear", "cubic"] = "linear", ) -> Self: + if grid_spacing == self.grid_spacing: + return self + logger.info("resampling velocity model to grid spacing %s m", grid_spacing) interpolator = RegularGridInterpolator( (self._east_coords, self._north_coords, self._depth_coords), @@ -149,16 +152,18 @@ class VelocityModelFactory(BaseModel): " If 'quadtree' defaults to smallest octreee node size.", ) - def get_model(self, octree: Octree, stations: Stations) -> VelocityModel3D: + def get_model(self, octree: Octree) -> VelocityModel3D: raise NotImplementedError class Constant3DVelocityModel(VelocityModelFactory): + """This model is for mere testing of the method.""" + model: Literal["Constant3DVelocityModel"] = "Constant3DVelocityModel" velocity: PositiveFloat = 5000.0 - def get_model(self, octree: Octree, stations: Stations) -> VelocityModel3D: + def get_model(self, octree: Octree) -> VelocityModel3D: if self.grid_spacing == "quadtree": grid_spacing = octree.smallest_node_size() else: @@ -275,7 +280,18 @@ class NonLinLocVelocityModel(VelocityModelFactory): description="Path to NonLinLoc model buffer file. If none, the filename will be" "infered from the header file.", ) - interpolation: Literal["nearest", "linear", "cubic"] = "linear" + + grid_spacing: PositiveFloat | Literal["quadtree", "input"] = Field( + "input", + description="Grid spacing in meters." + " If 'quadtree' defaults to smallest octreee node size. If 'input' uses the" + " grid spacing from the NonLinLoc header file.", + ) + interpolation: Literal["nearest", "linear", "cubic"] = Field( + "linear", + description="Interpolation method for resampling the grid " + "for the fast-marching method.", + ) _header: NonLinLocHeader = PrivateAttr() _velocity_model: np.ndarray = PrivateAttr() @@ -302,9 +318,11 @@ def load_header(self) -> Self: return self - def get_model(self, octree: Octree, stations: Stations) -> VelocityModel3D: + def get_model(self, octree: Octree) -> VelocityModel3D: if self.grid_spacing == "quadtree": grid_spacing = octree.smallest_node_size() + if self.grid_spacing == "input": + grid_spacing = self._header.grid_spacing else: grid_spacing = self.grid_spacing diff --git a/test/conftest.py b/test/conftest.py index 9a5f508b..6a101dbc 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -4,6 +4,7 @@ from tempfile import TemporaryDirectory from typing import Generator +import numpy as np import pytest from lassie.models.detection import EventDetection, EventDetections @@ -77,6 +78,25 @@ def stations() -> Stations: return Stations(stations=stations) +@pytest.fixture(scope="session") +def fixed_stations() -> Stations: + n_stations = 20 + rng = np.random.RandomState(0) + stations: list[Station] = [] + for i_sta in range(n_stations): + station = Station( + network="FX", + station="STA%02d" % i_sta, + lat=10.0, + lon=10.0, + elevation=rng.uniform(0, 1) * KM, + north_shift=rng.uniform(-10, 10) * KM, + east_shift=rng.uniform(-10, 10) * KM, + ) + stations.append(station) + return Stations(stations=stations) + + @pytest.fixture(scope="session") def detections() -> Generator[EventDetections, None, None]: n_detections = 2000 diff --git a/test/test_fast_marching.py b/test/test_fast_marching.py index b215c226..0fb62069 100644 --- a/test/test_fast_marching.py +++ b/test/test_fast_marching.py @@ -6,29 +6,50 @@ import pytest import pytest_asyncio -from lassie.models.station import Stations +from lassie.models.station import Station, Stations from lassie.octree import Octree from lassie.tracers.fast_marching.fast_marching import ( - FastMarchingRayTracer, FastMarchingTracer, StationTravelTimeVolume, ) from lassie.tracers.fast_marching.velocity_models import ( Constant3DVelocityModel, NonLinLocVelocityModel, + VelocityModel3D, ) +from lassie.utils import datetime_now CONSTANT_VELOCITY = 5000 KM = 1e3 +def stations_inside( + model: VelocityModel3D, nstations: int = 20, seed: int = 0 +) -> Stations: + stations = [] + rng = np.random.RandomState(seed) + for i_sta in range(nstations): + station = Station( + network="FM", + station="STA%02d" % i_sta, + lat=model.center.lat, + lon=model.center.lon, + elevation=model.center.elevation, + north_shift=model.center.north_shift + rng.uniform(*model.north_bounds), + east_shift=model.center.east_shift + rng.uniform(*model.east_bounds), + depth=model.center.depth + rng.uniform(*model.depth_bounds), + ) + stations.append(station) + return Stations(stations=stations) + + @pytest_asyncio.fixture async def station_travel_times( octree: Octree, stations: Stations ) -> StationTravelTimeVolume: octree.surface_elevation = 1 * KM model = Constant3DVelocityModel(velocity=CONSTANT_VELOCITY, grid_spacing=100.0) - model_3d = model.get_model(octree, stations) + model_3d = model.get_model(octree) return await StationTravelTimeVolume.calculate_from_eikonal( model_3d, stations.stations[0] ) @@ -74,36 +95,37 @@ async def test_load_interpolation( decimal=1, ) - station_travel_times.interpolate_travel_times(octree) + station_travel_times.interpolate_nodes(octree) @pytest.mark.asyncio -async def test_fast_marching_phase_tracer(octree: Octree, stations: Stations) -> None: +async def test_fast_marching_phase_tracer( + octree: Octree, fixed_stations: Stations +) -> None: tracer = FastMarchingTracer( + phase="fm:P", velocity_model=Constant3DVelocityModel( velocity=CONSTANT_VELOCITY, grid_spacing=80.0 - ) + ), ) - await tracer.prepare(octree, stations) + await tracer.prepare(octree, fixed_stations) + tracer.get_travel_times("fm:P", octree, fixed_stations) @pytest.mark.asyncio -async def test_fast_marching_ray_tracer(octree: Octree, stations: Stations) -> None: - tracer = FastMarchingRayTracer( - tracers={ - "fmm:P": FastMarchingTracer( - velocity_model=Constant3DVelocityModel( - velocity=CONSTANT_VELOCITY, grid_spacing=80.0 - ) - ) - } - ) - await tracer.prepare(octree, stations) - tracer.get_traveltimes("fmm:P", octree, stations) - - -def test_non_lin_loc_load(data_dir: Path, octree: Octree, stations: Stations) -> None: +async def test_non_lin_loc(data_dir: Path, octree: Octree, stations: Stations) -> None: header_file = data_dir / "FORGE_3D_5_large.P.mod.hdr" - velocity_model = NonLinLocVelocityModel(header_file=header_file) - velocity_model.get_model(octree=octree, stations=stations) + tracer = FastMarchingTracer( + phase="fm:P", + velocity_model=NonLinLocVelocityModel(header_file=header_file), + ) + stations = stations_inside(tracer.velocity_model.get_model(octree)) + await tracer.prepare(octree, stations) + source = octree[1].as_location() + tracer.get_arrivals( + "fm:P", + event_time=datetime_now(), + source=source, + receivers=list(stations), + )