Skip to content

Commit

Permalink
wip: fast-marching implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
miili committed Sep 11, 2023
1 parent e2e6b0b commit 18818ce
Show file tree
Hide file tree
Showing 12 changed files with 97 additions and 76 deletions.
4 changes: 1 addition & 3 deletions lassie/images/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,7 @@ def get_phases(self) -> tuple[str, ...]:
Returns:
tuple[str, ...]: All available phases.
"""
return tuple(
chain.from_iterable(image.get_available_phases() for image in self)
)
return tuple(chain.from_iterable(image.get_provided_phases() for image in self))

def get_blinding(self) -> timedelta:
return max(image.blinding for image in self)
Expand Down
2 changes: 1 addition & 1 deletion lassie/images/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def blinding(self) -> timedelta:
"""Blinding duration for the image function. Added to padded waveforms."""
raise NotImplementedError("must be implemented by subclass")

def get_available_phases(self) -> tuple[str, ...]:
def get_provided_phases(self) -> tuple[str, ...]:
...


Expand Down
4 changes: 2 additions & 2 deletions lassie/images/phase_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,5 +146,5 @@ async def process_traces(self, traces: list[Trace]) -> list[PhaseNetImage]:

return [annotation_s, annotation_p]

def get_available_phases(self) -> tuple[str, ...]:
return tuple(self.phase_map.keys())
def get_provided_phases(self) -> tuple[str, ...]:
return tuple(self.phase_map.values())
13 changes: 8 additions & 5 deletions lassie/models/location.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,15 +115,18 @@ def offset_to(self, other: Location) -> tuple[float, float, float]:
return (
self.east_shift - other.east_shift,
self.north_shift - other.north_shift,
self.effective_depth - other.effective_depth,
-(self.effective_elevation - other.effective_elevation),
)

sx, sy, sz = od.geodetic_to_ecef(*self.effective_lat_lon, self.effective_depth)
ox, oy, oz = od.geodetic_to_ecef(
*other.effective_lat_lon, other.effective_depth
shift_north, shift_east = od.latlon_to_ne_numpy(
self.lat, self.lon, other.lat, other.lon
)

return sx - ox, sy - oy, sz - oz
return (
self.east_shift - other.east_shift + shift_east[0],
self.north_shift - other.north_shift + shift_north[0],
-(self.effective_elevation - other.effective_elevation),
)

def __hash__(self) -> int:
return hash(self.location_hash())
Expand Down
6 changes: 4 additions & 2 deletions lassie/models/station.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ def weed_stations(self) -> None:
def blacklist_station(self, station: Station, reason: str) -> None:
logger.warning("blacklisting station %s: %s", station.pretty_nsl, reason)
self.blacklist.add(station.pretty_nsl)
if self.n_stations == 0:
raise ValueError("no stations available, all stations blacklisted")

def weed_from_squirrel_waveforms(self, squirrel: Squirrel) -> None:
"""Remove stations without waveforms from squirrel instances.
Expand All @@ -117,15 +119,15 @@ def weed_from_squirrel_waveforms(self, squirrel: Squirrel) -> None:
n_removed_stations = 0
for sta in self.stations.copy():
if sta.pretty_nsl not in available_squirrel_nsls:
logger.debug(
logger.info(
"removing station %s: waveforms not available in squirrel",
sta.pretty_nsl,
)
self.stations.remove(sta)
n_removed_stations += 1

if n_removed_stations:
logger.info("removed %d stations without waveforms", n_removed_stations)
logger.warning("removed %d stations without waveforms", n_removed_stations)
if not self.stations:
raise ValueError("no stations available, add waveforms to start detection")

Expand Down
37 changes: 15 additions & 22 deletions lassie/octree.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,13 @@ class Node(BaseModel):
semblance: float = 0.0

tree: Octree | None = Field(None, exclude=True)
children: tuple[Node] = Field((), exclude=True)
children: tuple[Node, ...] = Field((), exclude=True)

_hash: bytes | None = PrivateAttr(None)
_children_cached: tuple[Node] = PrivateAttr(())
_children_cached: tuple[Node, ...] = PrivateAttr(())
_location: Location | None = PrivateAttr(None)

def split(self) -> tuple[Node]:
def split(self) -> tuple[Node, ...]:
if not self.tree:
raise EnvironmentError("Parent tree is not set.")

Expand Down Expand Up @@ -137,14 +137,17 @@ def distance_to_location(self, location: Location) -> float:
return location.distance_to(self.as_location())

def as_location(self) -> Location:
if not self.tree:
raise AttributeError("parent tree not set")
if not self._location:
reference = self.tree.reference
self._location = Location.model_construct(
lat=self.tree.center_lat,
lon=self.tree.center_lon,
elevation=self.tree.surface_elevation,
east_shift=float(self.east),
north_shift=float(self.north),
depth=float(self.depth),
lat=reference.lat,
lon=reference.lon,
elevation=reference.elevation,
east_shift=reference.east_shift + float(self.east),
north_shift=reference.north_shift + float(self.north),
depth=reference.depth + float(self.depth),
)
return self._location

Expand All @@ -160,8 +163,8 @@ def hash(self) -> bytes:
self._hash = sha1(
struct.pack(
"dddddd",
self.tree.center_lat,
self.tree.center_lon,
self.tree.reference.lat,
self.tree.reference.lon,
self.east,
self.north,
self.depth,
Expand All @@ -175,9 +178,7 @@ def __hash__(self) -> int:


class Octree(BaseModel):
center_lat: float = 0.0
center_lon: float = 0.0
surface_elevation: float = 0.0
reference: Location = Location(lat=0.0, lon=0)
size_initial: PositiveFloat = 2 * KM
size_limit: PositiveFloat = 500
east_bounds: tuple[float, float] = (-10 * KM, 10 * KM)
Expand Down Expand Up @@ -239,14 +240,6 @@ def _get_root_nodes(self, size: float) -> list[Node]:
for depth in depth_nodes
]

@property
def center_location(self) -> Location:
return Location(
lat=self.center_lat,
lon=self.center_lon,
elevation=self.surface_elevation,
)

@cached_property
def n_nodes(self) -> int:
"""Number of nodes in the octree"""
Expand Down
2 changes: 1 addition & 1 deletion lassie/search/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ async def search(
semblance.normalize(images.cumulative_weight())

parent.semblance_stats.update(semblance.get_stats())
logger.info("semblance stats: %s", parent.semblance_stats)
logger.debug("semblance stats: %s", parent.semblance_stats)

detection_idx, detection_semblance = semblance.find_peaks(
height=parent.detection_threshold,
Expand Down
18 changes: 9 additions & 9 deletions lassie/search/squirrel.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,13 @@ class SquirrelSearch(Search):
_squirrel: Squirrel | None = PrivateAttr(None)

def model_post_init(self, __context: Any) -> None:
if not all(self.time_span):
squirrel = self.get_squirrel()
sq_tmin, sq_tmax = squirrel.get_time_span(["waveform", "waveform_promise"])
self.time_span = (
self.time_span[0] or to_datetime(sq_tmin),
self.time_span[1] or to_datetime(sq_tmax),
)
squirrel = self.get_squirrel()
sq_tmin, sq_tmax = squirrel.get_time_span(["waveform"])

self.time_span = (
self.time_span[0] or to_datetime(sq_tmin),
self.time_span[1] or to_datetime(sq_tmax),
)

logger.info(
"searching time span from %s to %s (%s)",
Expand All @@ -78,8 +78,8 @@ def model_post_init(self, __context: Any) -> None:

@field_validator("time_span")
@classmethod
def _validate_time_span(cls, range): # noqa: N805
if range[0] >= range[1]:
def _validate_time_span(cls, range) -> Any: # noqa: N805
if range[0] and range[1] and range[0] >= range[1]:
raise ValueError(f"time range is invalid {range[0]} - {range[1]}")
return range

Expand Down
19 changes: 15 additions & 4 deletions lassie/tracers/fast_marching/fast_marching.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ class StationTravelTimeVolume(BaseModel):

_travel_times: np.ndarray | None = PrivateAttr(None)

_north_coords: np.ndarray = PrivateAttr(None)
_east_coords: np.ndarray = PrivateAttr(None)
_depth_coords: np.ndarray = PrivateAttr(None)
_north_coords: np.ndarray = PrivateAttr()
_east_coords: np.ndarray = PrivateAttr()
_depth_coords: np.ndarray = PrivateAttr()

# Cached values
_file: Path | None = PrivateAttr(None)
Expand Down Expand Up @@ -249,7 +249,7 @@ class FastMarchingTracer(RayTracer):
tracer: Literal["FastMarchingRayTracer"] = "FastMarchingRayTracer"

phase: PhaseDescription = "fm:P"
interpolation_method: Literal["nearest", "linear", "cubic"] = "nearest"
interpolation_method: Literal["nearest", "linear", "cubic"] = "linear"
nthreads: int = Field(
0,
description="Number of threads to use for travel time."
Expand Down Expand Up @@ -303,6 +303,17 @@ async def prepare(
reason=f"outside the fast-marching velocity model, offset {offset}",
)

nodes_covered = [
node for node in octree if velocity_model.is_inside(node.as_location())
]
if not nodes_covered:
raise ValueError("no octree node is inside the velocity model")

logger.info(
"%d%% octree nodes are inside the velocity model",
len(nodes_covered) / octree.n_nodes * 100,
)

self._cached_stations = stations
self._cached_station_indeces = {
sta.pretty_nsl: idx for idx, sta in enumerate(stations)
Expand Down
44 changes: 34 additions & 10 deletions lassie/tracers/fast_marching/velocity_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def get_model(self, octree: Octree) -> VelocityModel3D:
grid_spacing = self.grid_spacing

model = VelocityModel3D(
center=octree.center_location,
center=octree.reference,
grid_spacing=grid_spacing,
east_bounds=octree.east_bounds,
north_bounds=octree.north_bounds,
Expand Down Expand Up @@ -199,7 +199,11 @@ class NonLinLocHeader:
grid_type: NonLinLocGridType

@classmethod
def from_header_file(cls, file: Path) -> Self:
def from_header_file(
cls,
file: Path,
reference_location: Location | None = None,
) -> Self:
logger.info("loading NonLinLoc velocity model header file %s", file)
header_text = file.read_text().split("\n")[0]
header_text = re.sub(r"\s+", " ", header_text) # remove excessive spaces
Expand All @@ -220,12 +224,20 @@ def from_header_file(cls, file: Path) -> Self:
if not delta_x == delta_y == delta_z:
raise ValueError("NonLinLoc velocity model must have equal spacing.")

return cls(
origin=Location(
if reference_location:
origin = reference_location
origin.east_shift += float(orig_x) * KM
origin.north_shift += float(orig_y) * KM
origin.elevation -= float(orig_z) * KM
else:
origin = Location(
lon=float(orig_x),
lat=float(orig_y),
elevation=-float(orig_z) * KM,
),
)

return cls(
origin=origin,
nx=int(nx),
ny=int(ny),
nz=int(nz),
Expand Down Expand Up @@ -261,9 +273,9 @@ def depth_bounds(self) -> tuple[float, float]:

@property
def center(self) -> Location:
center = self.origin.model_copy()
center.north_shift = self.delta_x * self.nx / 2
center.east_shift = self.delta_y * self.ny / 2
center = self.origin.model_copy(deep=True)
center.east_shift += self.delta_x * self.nx / 2
center.north_shift += self.delta_y * self.ny / 2
return center


Expand Down Expand Up @@ -293,17 +305,28 @@ class NonLinLocVelocityModel(VelocityModelFactory):
"for the fast-marching method.",
)

reference_location: Location | None = Field(
None,
description="relative location of NonLinLoc model, "
"used for models with relative coordinates.",
)

_header: NonLinLocHeader = PrivateAttr()
_velocity_model: np.ndarray = PrivateAttr()

@model_validator(mode="after")
def load_header(self) -> Self:
self._header = NonLinLocHeader.from_header_file(self.header_file)
self._header = NonLinLocHeader.from_header_file(
self.header_file,
reference_location=self.reference_location,
)
self.buffer_file = self.buffer_file or self.header_file.with_suffix(".buf")
if not self.buffer_file.exists():
raise FileNotFoundError(f"Buffer file {self.buffer_file} not found.")

logger.info("loading NonLinLoc velocity model buffer file %s", self.buffer_file)
logger.debug(
"loading NonLinLoc velocity model buffer file %s", self.buffer_file
)
self._velocity_model = np.fromfile(
self.buffer_file, dtype=self._header.dtype
).reshape((self._header.nx, self._header.ny, self._header.nz))
Expand All @@ -316,6 +339,7 @@ def load_header(self) -> Self:
elif self._header.grid_type == "VELOCITY":
self._velocity_model *= KM

logging.info("loaded NonLinLoc velocity model %s", self._header)
return self

def get_model(self, octree: Octree) -> VelocityModel3D:
Expand Down
9 changes: 6 additions & 3 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pytest

from lassie.models.detection import EventDetection, EventDetections
from lassie.models.location import Location
from lassie.models.station import Station, Stations
from lassie.octree import Octree
from lassie.tracers.cake import EarthModel, Timing, TravelTimeTree
Expand Down Expand Up @@ -48,9 +49,11 @@ def data_dir() -> Path:
@pytest.fixture(scope="session")
def octree() -> Octree:
return Octree(
center_lat=10.0,
center_lon=10.0,
surface_elevation=1.0 * KM,
reference=Location(
lat=10.0,
lon=10.0,
elevation=1.0 * KM,
),
size_initial=2 * KM,
size_limit=500,
east_bounds=(-10 * KM, 10 * KM),
Expand Down
15 changes: 1 addition & 14 deletions test/test_octree.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,10 @@
import pytest
from __future__ import annotations

from lassie.octree import Octree

km = 1e3


@pytest.fixture(scope="function")
def octree():
yield Octree(
center_lat=0.0,
center_lon=0.0,
east_bounds=(-25 * km, 25 * km),
north_bounds=(-25 * km, 25 * km),
depth_bounds=(0, 40 * km),
size_initial=5 * km,
size_limit=0.5 * km,
)


def test_octree(octree: Octree, plot: bool) -> None:
assert octree.n_nodes > 0

Expand Down

0 comments on commit 18818ce

Please sign in to comment.