Skip to content

Commit

Permalink
wip: fast-marching implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
miili committed Sep 11, 2023
1 parent dfa2aa0 commit e2e6b0b
Show file tree
Hide file tree
Showing 10 changed files with 224 additions and 83 deletions.
9 changes: 1 addition & 8 deletions lassie/models/station.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ class Stations(BaseModel):
pyrocko_station_yamls: list[Path] = []

_cached_coordinates: np.ndarray | None = PrivateAttr(None)
_cached_iter: list[Station] | None = PrivateAttr(None)

def model_post_init(self, __context: Any) -> None:
loaded_stations = []
Expand Down Expand Up @@ -131,17 +130,11 @@ def weed_from_squirrel_waveforms(self, squirrel: Squirrel) -> None:
raise ValueError("no stations available, add waveforms to start detection")

def __iter__(self) -> Iterator[Station]:
if self._cached_iter is None:
self._cached_iter = [
sta for sta in self.stations if sta.pretty_nsl not in self.blacklist
]
return iter(self._cached_iter)
return (sta for sta in self.stations if sta.pretty_nsl not in self.blacklist)

@property
def n_stations(self) -> int:
"""Number of stations in the stations object."""
if self._cached_iter:
return len(self._cached_iter)
return sum(1 for _ in self)

def get_all_nsl(self) -> list[tuple[str, str, str]]:
Expand Down
2 changes: 1 addition & 1 deletion lassie/octree.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ def maximum_number_nodes(self) -> int:
(self.east_bounds[1] - self.east_bounds[0])
* (self.north_bounds[1] - self.north_bounds[0])
* (self.depth_bounds[1] - self.depth_bounds[0])
/ self.smallest_node_size() ** 3
/ (self.smallest_node_size() ** 3)
)

def copy(self, deep=False) -> Self:
Expand Down
5 changes: 4 additions & 1 deletion lassie/search/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from lassie.signals import Signal
from lassie.station_corrections import StationCorrections
from lassie.tracers import CakeTracer, ConstantVelocityTracer, RayTracers
from lassie.tracers.fast_marching import FastMarchingTracer
from lassie.utils import PhaseDescription, Symbols, alog_call, time_to_path

if TYPE_CHECKING:
Expand Down Expand Up @@ -50,7 +51,9 @@ class Search(BaseModel):

octree: Octree = Octree()
stations: Stations = Stations()
ray_tracers: RayTracers = RayTracers(root=[ConstantVelocityTracer(), CakeTracer()])
ray_tracers: RayTracers = RayTracers(
root=[ConstantVelocityTracer(), CakeTracer(), FastMarchingTracer()]
)
image_functions: ImageFunctions

station_corrections: StationCorrections | None = None
Expand Down
6 changes: 3 additions & 3 deletions lassie/tracers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@

from pydantic import Field, RootModel

from lassie.tracers.base import ModelledArrival
from lassie.tracers.cake import CakeArrival, CakeTracer
from lassie.tracers.constant_velocity import (
ConstantVelocityArrival,
ConstantVelocityTracer,
)
from lassie.tracers.fast_marching import FastMarchingArrival, FastMarchingTracer

if TYPE_CHECKING:
from lassie.models.station import Stations
Expand All @@ -22,12 +22,12 @@


RayTracerType = Annotated[
Union[CakeTracer, ConstantVelocityTracer],
Union[ConstantVelocityTracer, CakeTracer, FastMarchingTracer],
Field(..., discriminator="tracer"),
]

RayTracerArrival = Annotated[
Union[CakeArrival, ConstantVelocityArrival, ModelledArrival],
Union[ConstantVelocityArrival, CakeArrival, FastMarchingArrival],
Field(..., discriminator="tracer"),
]

Expand Down
3 changes: 2 additions & 1 deletion lassie/tracers/cake.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ class CakeTracer(RayTracer):
trim_earth_model_depth: bool = Field(
True, description="Trim earth model to max depth of the octree."
)
lut_cache_size: ByteSize = Field("4GB", description="Size of the LUT cache.")
lut_cache_size: ByteSize = Field(2 * GiB, description="Size of the LUT cache.")

_traveltime_trees: dict[PhaseDescription, TravelTimeTree] = PrivateAttr({})

Expand Down Expand Up @@ -499,6 +499,7 @@ async def prepare(self, octree: Octree, stations: Stations) -> None:
n_trees = len(self.phases)
LRU_CACHE_SIZE = int(self.lut_cache_size / bytes_per_node / n_trees)

# TODO: This should be total number nodes. Not only leaf nodes.
node_cache_fraction = LRU_CACHE_SIZE / octree.maximum_number_nodes()
logging.info(
"limiting traveltime LUT size to %d nodes (%s),"
Expand Down
1 change: 1 addition & 0 deletions lassie/tracers/fast_marching/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .fast_marching import FastMarchingArrival, FastMarchingTracer # noqa
161 changes: 122 additions & 39 deletions lassie/tracers/fast_marching/fast_marching.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,22 @@
from typing import TYPE_CHECKING, Any, Literal, Self, Sequence

import numpy as np
from lru import LRU
from pydantic import BaseModel, ByteSize, Field, PrivateAttr
from pyrocko.modelling import eikonal
from rich.progress import Progress
from scipy.interpolate import RegularGridInterpolator

from lassie.models.location import Location
from lassie.models.station import Station, Stations
from lassie.octree import Node
from lassie.tracers.base import ModelledArrival, RayTracer
from lassie.tracers.fast_marching.velocity_models import (
NonLinLocVelocityModel,
Constant3DVelocityModel,
VelocityModel3D,
VelocityModels,
)
from lassie.utils import CACHE_DIR, PhaseDescription, datetime_now
from lassie.utils import CACHE_DIR, PhaseDescription, datetime_now, human_readable_bytes

if TYPE_CHECKING:
from lassie.octree import Octree
Expand Down Expand Up @@ -111,7 +113,8 @@ async def calculate_from_eikonal(
arrival_times = model.get_source_arrival_grid(station)

if not model.is_inside(station):
raise ValueError("station is outside the velocity model")
offset = station.offset_to(model.center)
raise ValueError(f"station is outside the velocity model {offset}")

def eikonal_wrapper(
velocity_model: VelocityModel3D,
Expand Down Expand Up @@ -179,18 +182,14 @@ def interpolate_travel_time(
offset = location.offset_to(self.center)
return interpolator([offset], method=method)[0]

def interpolate_travel_times(
def interpolate_nodes(
self,
octree: Octree,
nodes: Sequence[Node],
method: Literal["nearest", "linear", "cubic"] = "linear",
) -> np.ndarray:
interpolator = self.get_traveltime_interpolator()

coordinates = []
for node in octree:
location = node.as_location()
coordinates.append(location.offset_to(self.center))

coordinates = [node.as_location().offset_to(self.center) for node in nodes]
return interpolator(coordinates, method=method)

def save(self, path: Path) -> Path:
Expand Down Expand Up @@ -251,79 +250,124 @@ class FastMarchingTracer(RayTracer):

phase: PhaseDescription = "fm:P"
interpolation_method: Literal["nearest", "linear", "cubic"] = "nearest"
nthreads: int = Field(default_factory=os.cpu_count)
lut_cache_size: ByteSize = Field("4GB", description="Size of the LUT cache.")
nthreads: int = Field(
0,
description="Number of threads to use for travel time."
"If set to 0, cpu_count*2 will be used.",
)

lut_cache_size: ByteSize = Field(2 * GiB, description="Size of the LUT cache.")

velocity_model: VelocityModels = NonLinLocVelocityModel()
velocity_model: VelocityModels = Constant3DVelocityModel()

_traveltime_models: dict[str, StationTravelTimeVolume] = PrivateAttr({})
_travel_time_volumes: dict[str, StationTravelTimeVolume] = PrivateAttr({})
_velocity_model: VelocityModel3D | None = PrivateAttr(None)

_cached_stations: Stations = PrivateAttr()
_cached_station_indeces: dict[str, int] = PrivateAttr({})
_node_lut: dict[bytes, np.ndarray] = PrivateAttr()

def get_available_phases(self) -> tuple[str, ...]:
return (self.phase,)

def get_travel_time_volume(self, location: Location) -> StationTravelTimeVolume:
return self._travel_time_volumes[location.location_hash()]

def add_travel_time_volume(
self,
location: Location,
volume: StationTravelTimeVolume,
) -> None:
self._travel_time_volumes[location.location_hash()] = volume

def has_travel_time_volume(self, location: Location) -> bool:
return location.location_hash() in self._travel_time_volumes

def lut_fill_level(self) -> float:
"""Return the fill level of the LUT as a float between 0.0 and 1.0"""
return len(self._node_lut) / self._node_lut.get_size()

async def prepare(
self,
octree: Octree,
stations: Stations,
) -> None:
velocity_model = self.velocity_model.get_model(octree, stations)
velocity_model = self.velocity_model.get_model(octree)
self._velocity_model = velocity_model

for station in stations:
if not velocity_model.is_inside(station):
offset = station.offset_to(velocity_model.center)
stations.blacklist_station(
station, reason="outside the fast-marching velocity model"
station,
reason=f"outside the fast-marching velocity model, offset {offset}",
)

self._cached_stations = stations
self._cached_station_indeces = {
sta.pretty_nsl: idx for idx, sta in enumerate(stations)
}
bytes_per_node = stations.n_stations * np.float32().itemsize
lru_cache_size = int(self.lut_cache_size / bytes_per_node)
self._node_lut = LRU(lru_cache_size)

# TODO: This should be total number nodes. Not only leaf nodes.
node_cache_fraction = lru_cache_size / octree.maximum_number_nodes()
logging.info(
"limiting traveltime LUT size to %d nodes (%s),"
" caching %.1f%% of possible octree nodes",
lru_cache_size,
human_readable_bytes(self.lut_cache_size),
node_cache_fraction * 100,
)

cache_dir = FMM_CACHE_DIR / f"{velocity_model.hash()}"
if not cache_dir.exists():
logger.info("creating cache directory %s", cache_dir)
cache_dir.mkdir(parents=True)
else:
self._load_cached_tavel_times(cache_dir)

work_stations = [
station
for station in stations
if station.location_hash() not in self._traveltime_models
calc_stations = [
station for station in stations if not self.has_travel_time_volume(station)
]
await self._calculate_travel_times(work_stations, cache_dir)
await self._calculate_travel_times(calc_stations, cache_dir)

def _load_cached_tavel_times(self, cache_dir: Path) -> None:
logger.debug("loading travel times volumes from cache %s...", cache_dir)
models: dict[str, StationTravelTimeVolume] = {}
volumes: dict[str, StationTravelTimeVolume] = {}
for file in cache_dir.glob("*.3dtt"):
try:
travel_times = StationTravelTimeVolume.load(file)
except zipfile.BadZipFile:
logger.warning("removing bad travel time file %s", file)
file.unlink()
continue
models[travel_times.station.location_hash()] = travel_times
volumes[travel_times.station.location_hash()] = travel_times

logger.info("loaded %d travel times volumes from cache", len(models))
self._traveltime_models.update(models)
logger.info("loaded %d travel times volumes from cache", len(volumes))
self._travel_time_volumes.update(volumes)

async def _calculate_travel_times(
self,
stations: list[Station],
cache_dir: Path,
) -> None:
nthreads = self.nthreads if self.nthreads > 0 else os.cpu_count() * 2
executor = ThreadPoolExecutor(
max_workers=self.nthreads, thread_name_prefix="lassie-fmm"
max_workers=nthreads,
thread_name_prefix="lassie-fmm",
)
if self._velocity_model is None:
raise AttributeError("velocity model has not been prepared yet")

async def worker_station_travel_time(station: Station) -> None:
model = await StationTravelTimeVolume.calculate_from_eikonal(
volume = await StationTravelTimeVolume.calculate_from_eikonal(
self._velocity_model, # noqa
station,
save=cache_dir,
executor=executor,
)
self._traveltime_models[station.location_hash()] = model
self.add_travel_time_volume(station, volume)

calculate_work = [worker_station_travel_time(station) for station in stations]
if not calculate_work:
Expand All @@ -333,7 +377,8 @@ async def worker_station_travel_time(station: Station) -> None:
tasks = [asyncio.create_task(work) for work in calculate_work]
with Progress() as progress:
status = progress.add_task(
f"calculating travel time volumes for {len(tasks)} station",
f"calculating travel time volumes for {len(tasks)} stations"
f" ({nthreads} threads)",
total=len(tasks),
)
for _task in asyncio.as_completed(tasks):
Expand All @@ -350,7 +395,7 @@ def get_travel_time_location(
if phase != self.phase:
raise ValueError(f"phase {phase} is not supported by this tracer")

station_travel_times = self._traveltime_models[receiver.location_hash()]
station_travel_times = self.get_travel_time_volume(receiver)
return station_travel_times.interpolate_travel_time(
source,
method=self.interpolation_method,
Expand All @@ -365,15 +410,53 @@ def get_travel_times(
if phase != self.phase:
raise ValueError(f"phase {phase} is not supported by this tracer")

result = []
for station in stations:
station_travel_times = self._traveltime_models[station.location_hash()]
result.append(
station_travel_times.interpolate_travel_times(
octree, method=self.interpolation_method
)
station_indices = np.fromiter(
(self._cached_station_indeces[sta.pretty_nsl] for sta in stations),
dtype=int,
)

stations_traveltimes = []
fill_nodes = []
for node in octree:
try:
node_traveltimes = self._node_lut[node.hash()][station_indices]
except KeyError:
fill_nodes.append(node)
continue
stations_traveltimes.append(node_traveltimes)

if fill_nodes:
self.fill_lut(fill_nodes)

cache_hits, cache_misses = self._node_lut.get_stats()
cache_hit_rate = cache_hits / (cache_hits + cache_misses)
logger.info(
"node LUT cache fill level %.1f%%, cache hit rate %.1f%%",
self.lut_fill_level() * 100,
cache_hit_rate * 100,
)
return self.get_travel_times(phase, octree, stations)

return np.asarray(stations_traveltimes).astype(float, copy=False)

def fill_lut(self, nodes: Sequence[Node]) -> None:
travel_times = []
n_nodes = len(nodes)

with Progress() as progress:
status = progress.add_task(
f"interpolating station traveltimes for {n_nodes} nodes",
total=self._cached_stations.n_stations,
)
return np.array(result).T
for station in self._cached_stations:
volume = self.get_travel_time_volume(station)
travel_times.append(volume.interpolate_nodes(nodes))
progress.advance(status)

travel_times = np.array(travel_times).T

for node, station_travel_times in zip(nodes, travel_times, strict=True):
self._node_lut[node.hash()] = station_travel_times

def get_arrivals(
self,
Expand Down
Loading

0 comments on commit e2e6b0b

Please sign in to comment.