Skip to content

Commit

Permalink
wip: fast-marching implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
miili committed Sep 11, 2023
1 parent b7c31c4 commit dfa2aa0
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 24 deletions.
11 changes: 9 additions & 2 deletions lassie/models/location.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

import hashlib
import math
import struct
from functools import cached_property
from typing import TYPE_CHECKING, Iterable, Literal, TypeVar

Expand Down Expand Up @@ -124,8 +126,12 @@ def offset_to(self, other: Location) -> tuple[float, float, float]:
return sx - ox, sy - oy, sz - oz

def __hash__(self) -> int:
return hash(
(
return hash(self.location_hash())

def location_hash(self) -> str:
sha1 = hashlib.sha1(
struct.pack(
"dddddd",
self.lat,
self.lon,
self.east_shift,
Expand All @@ -134,6 +140,7 @@ def __hash__(self) -> int:
self.depth,
)
)
return sha1.hexdigest()


def locations_to_csv(locations: Iterable[Location], filename: Path) -> Path:
Expand Down
4 changes: 2 additions & 2 deletions lassie/search/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def init_search(self) -> None:

# Timing ranges
for phase, tracer in self.ray_tracers.iter_phase_tracer():
traveltimes = tracer.get_traveltimes(phase, self.octree, self.stations)
traveltimes = tracer.get_travel_times(phase, self.octree, self.stations)
self.travel_time_ranges[phase] = (
timedelta(seconds=np.nanmin(traveltimes)),
timedelta(seconds=np.nanmax(traveltimes)),
Expand Down Expand Up @@ -259,7 +259,7 @@ async def calculate_semblance(
logger.debug("stacking image %s", image.image_function.name)
parent = self.parent

traveltimes = ray_tracer.get_traveltimes(image.phase, octree, image.stations)
traveltimes = ray_tracer.get_travel_times(image.phase, octree, image.stations)

if parent.station_corrections:
station_delays = parent.station_corrections.get_delays(
Expand Down
4 changes: 2 additions & 2 deletions lassie/tracers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def get_travel_time_location(
) -> float:
raise NotImplementedError

def get_traveltimes_locations(
def get_travel_times_locations(
self,
phase: str,
source: Location,
Expand All @@ -48,7 +48,7 @@ def get_traveltimes_locations(
[self.get_travel_time_location(phase, source, recv) for recv in receivers]
)

def get_traveltimes(
def get_travel_times(
self,
phase: str,
octree: Octree,
Expand Down
8 changes: 4 additions & 4 deletions lassie/tracers/cake.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ def _interpolate_travel_times(

return np.asarray(traveltimes).astype(float)

def get_traveltime(self, source: Location, receiver: Location) -> float:
def get_travel_time(self, source: Location, receiver: Location) -> float:
coordinates = [
receiver.effective_depth,
source.effective_depth,
Expand Down Expand Up @@ -560,10 +560,10 @@ def get_travel_time_location(
if phase not in self.phases:
raise ValueError(f"Phase {phase} is not defined.")
tree = self._get_sptree_model(phase)
return tree.get_traveltime(source, receiver)
return tree.get_travel_time(source, receiver)

@log_call
def get_traveltimes(
def get_travel_times(
self,
phase: str,
octree: Octree,
Expand All @@ -580,7 +580,7 @@ def get_arrivals(
source: Location,
receivers: Sequence[Location],
) -> list[CakeArrival | None]:
traveltimes = self.get_traveltimes_locations(
traveltimes = self.get_travel_times_locations(
phase,
source=source,
receivers=receivers,
Expand Down
4 changes: 2 additions & 2 deletions lassie/tracers/constant_velocity.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def get_travel_time_location(
return source.distance_to(receiver) / self.velocity

@log_call
def get_traveltimes(
def get_travel_times(
self,
phase: str,
octree: Octree,
Expand All @@ -63,7 +63,7 @@ def get_arrivals(
) -> list[ConstantVelocityArrival]:
self._check_phase(phase)

traveltimes = self.get_traveltimes_locations(
traveltimes = self.get_travel_times_locations(
phase,
source=source,
receivers=receivers,
Expand Down
47 changes: 36 additions & 11 deletions lassie/tracers/fast_marching/fast_marching.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,9 @@ class FastMarchingTracer(RayTracer):
_traveltime_models: dict[str, StationTravelTimeVolume] = PrivateAttr({})
_velocity_model: VelocityModel3D | None = PrivateAttr(None)

def get_available_phases(self) -> tuple[str, ...]:
return (self.phase,)

async def prepare(
self,
octree: Octree,
Expand All @@ -280,10 +283,12 @@ async def prepare(
else:
self._load_cached_tavel_times(cache_dir)

await self._calculate_travel_times(
[sta for sta in stations if sta.pretty_nsl not in self._traveltime_models],
cache_dir,
)
work_stations = [
station
for station in stations
if station.location_hash() not in self._traveltime_models
]
await self._calculate_travel_times(work_stations, cache_dir)

def _load_cached_tavel_times(self, cache_dir: Path) -> None:
logger.debug("loading travel times volumes from cache %s...", cache_dir)
Expand All @@ -295,7 +300,7 @@ def _load_cached_tavel_times(self, cache_dir: Path) -> None:
logger.warning("removing bad travel time file %s", file)
file.unlink()
continue
models[travel_times.station.pretty_nsl] = travel_times
models[travel_times.station.location_hash()] = travel_times

logger.info("loaded %d travel times volumes from cache", len(models))
self._traveltime_models.update(models)
Expand All @@ -318,7 +323,7 @@ async def worker_station_travel_time(station: Station) -> None:
save=cache_dir,
executor=executor,
)
self._traveltime_models[station.pretty_nsl] = model
self._traveltime_models[station.location_hash()] = model

calculate_work = [worker_station_travel_time(station) for station in stations]
if not calculate_work:
Expand All @@ -336,16 +341,33 @@ async def worker_station_travel_time(station: Station) -> None:
progress.advance(status)
logger.info("calculated travel time volumes in %s", datetime_now() - start)

def get_travel_time(self, source: Location, receiver: Location) -> float:
station_travel_times = self._traveltime_models[hash(receiver)]
def get_travel_time_location(
self,
phase: str,
source: Location,
receiver: Location,
) -> float:
if phase != self.phase:
raise ValueError(f"phase {phase} is not supported by this tracer")

station_travel_times = self._traveltime_models[receiver.location_hash()]
return station_travel_times.interpolate_travel_time(
source, method=self.interpolation_method
source,
method=self.interpolation_method,
)

def get_travel_times(self, octree: Octree, stations: Stations) -> np.ndarray:
def get_travel_times(
self,
phase: str,
octree: Octree,
stations: Stations,
) -> np.ndarray:
if phase != self.phase:
raise ValueError(f"phase {phase} is not supported by this tracer")

result = []
for station in stations:
station_travel_times = self._traveltime_models[station.pretty_nsl]
station_travel_times = self._traveltime_models[station.location_hash()]
result.append(
station_travel_times.interpolate_travel_times(
octree, method=self.interpolation_method
Expand All @@ -360,6 +382,9 @@ def get_arrivals(
source: Location,
receivers: Sequence[Location],
) -> list[ModelledArrival | None]:
if phase != self.phase:
raise ValueError(f"phase {phase} is not supported by this tracer")

traveltimes = []
for receiver in receivers:
traveltimes.append(self.get_travel_time_location(phase, source, receiver))
Expand Down
2 changes: 1 addition & 1 deletion test/test_cake.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def test_sptree_model(traveltime_tree: TravelTimeTree):
depth=0,
)

model.get_traveltime(source, receiver)
model.get_travel_time(source, receiver)


def test_lut(
Expand Down

0 comments on commit dfa2aa0

Please sign in to comment.