From 18818ce1d7bf706a85c7bec1d10e35ffbd2356dd Mon Sep 17 00:00:00 2001 From: miili Date: Mon, 11 Sep 2023 23:36:14 +0200 Subject: [PATCH] wip: fast-marching implementation --- lassie/images/__init__.py | 4 +- lassie/images/base.py | 2 +- lassie/images/phase_net.py | 4 +- lassie/models/location.py | 13 +++--- lassie/models/station.py | 6 ++- lassie/octree.py | 37 +++++++--------- lassie/search/base.py | 2 +- lassie/search/squirrel.py | 18 ++++---- lassie/tracers/fast_marching/fast_marching.py | 19 ++++++-- .../tracers/fast_marching/velocity_models.py | 44 ++++++++++++++----- test/conftest.py | 9 ++-- test/test_octree.py | 15 +------ 12 files changed, 97 insertions(+), 76 deletions(-) diff --git a/lassie/images/__init__.py b/lassie/images/__init__.py index 5825df13..b85f656c 100644 --- a/lassie/images/__init__.py +++ b/lassie/images/__init__.py @@ -57,9 +57,7 @@ def get_phases(self) -> tuple[str, ...]: Returns: tuple[str, ...]: All available phases. """ - return tuple( - chain.from_iterable(image.get_available_phases() for image in self) - ) + return tuple(chain.from_iterable(image.get_provided_phases() for image in self)) def get_blinding(self) -> timedelta: return max(image.blinding for image in self) diff --git a/lassie/images/base.py b/lassie/images/base.py index 00a3a96f..cc083eb4 100644 --- a/lassie/images/base.py +++ b/lassie/images/base.py @@ -35,7 +35,7 @@ def blinding(self) -> timedelta: """Blinding duration for the image function. Added to padded waveforms.""" raise NotImplementedError("must be implemented by subclass") - def get_available_phases(self) -> tuple[str, ...]: + def get_provided_phases(self) -> tuple[str, ...]: ... diff --git a/lassie/images/phase_net.py b/lassie/images/phase_net.py index beac1e3d..670d8bfe 100644 --- a/lassie/images/phase_net.py +++ b/lassie/images/phase_net.py @@ -146,5 +146,5 @@ async def process_traces(self, traces: list[Trace]) -> list[PhaseNetImage]: return [annotation_s, annotation_p] - def get_available_phases(self) -> tuple[str, ...]: - return tuple(self.phase_map.keys()) + def get_provided_phases(self) -> tuple[str, ...]: + return tuple(self.phase_map.values()) diff --git a/lassie/models/location.py b/lassie/models/location.py index 8bb2aebc..eb5465b0 100644 --- a/lassie/models/location.py +++ b/lassie/models/location.py @@ -115,15 +115,18 @@ def offset_to(self, other: Location) -> tuple[float, float, float]: return ( self.east_shift - other.east_shift, self.north_shift - other.north_shift, - self.effective_depth - other.effective_depth, + -(self.effective_elevation - other.effective_elevation), ) - sx, sy, sz = od.geodetic_to_ecef(*self.effective_lat_lon, self.effective_depth) - ox, oy, oz = od.geodetic_to_ecef( - *other.effective_lat_lon, other.effective_depth + shift_north, shift_east = od.latlon_to_ne_numpy( + self.lat, self.lon, other.lat, other.lon ) - return sx - ox, sy - oy, sz - oz + return ( + self.east_shift - other.east_shift + shift_east[0], + self.north_shift - other.north_shift + shift_north[0], + -(self.effective_elevation - other.effective_elevation), + ) def __hash__(self) -> int: return hash(self.location_hash()) diff --git a/lassie/models/station.py b/lassie/models/station.py index 48b3d887..c7cdf3ec 100644 --- a/lassie/models/station.py +++ b/lassie/models/station.py @@ -102,6 +102,8 @@ def weed_stations(self) -> None: def blacklist_station(self, station: Station, reason: str) -> None: logger.warning("blacklisting station %s: %s", station.pretty_nsl, reason) self.blacklist.add(station.pretty_nsl) + if self.n_stations == 0: + raise ValueError("no stations available, all stations blacklisted") def weed_from_squirrel_waveforms(self, squirrel: Squirrel) -> None: """Remove stations without waveforms from squirrel instances. @@ -117,7 +119,7 @@ def weed_from_squirrel_waveforms(self, squirrel: Squirrel) -> None: n_removed_stations = 0 for sta in self.stations.copy(): if sta.pretty_nsl not in available_squirrel_nsls: - logger.debug( + logger.info( "removing station %s: waveforms not available in squirrel", sta.pretty_nsl, ) @@ -125,7 +127,7 @@ def weed_from_squirrel_waveforms(self, squirrel: Squirrel) -> None: n_removed_stations += 1 if n_removed_stations: - logger.info("removed %d stations without waveforms", n_removed_stations) + logger.warning("removed %d stations without waveforms", n_removed_stations) if not self.stations: raise ValueError("no stations available, add waveforms to start detection") diff --git a/lassie/octree.py b/lassie/octree.py index f8a75d69..14903126 100644 --- a/lassie/octree.py +++ b/lassie/octree.py @@ -65,13 +65,13 @@ class Node(BaseModel): semblance: float = 0.0 tree: Octree | None = Field(None, exclude=True) - children: tuple[Node] = Field((), exclude=True) + children: tuple[Node, ...] = Field((), exclude=True) _hash: bytes | None = PrivateAttr(None) - _children_cached: tuple[Node] = PrivateAttr(()) + _children_cached: tuple[Node, ...] = PrivateAttr(()) _location: Location | None = PrivateAttr(None) - def split(self) -> tuple[Node]: + def split(self) -> tuple[Node, ...]: if not self.tree: raise EnvironmentError("Parent tree is not set.") @@ -137,14 +137,17 @@ def distance_to_location(self, location: Location) -> float: return location.distance_to(self.as_location()) def as_location(self) -> Location: + if not self.tree: + raise AttributeError("parent tree not set") if not self._location: + reference = self.tree.reference self._location = Location.model_construct( - lat=self.tree.center_lat, - lon=self.tree.center_lon, - elevation=self.tree.surface_elevation, - east_shift=float(self.east), - north_shift=float(self.north), - depth=float(self.depth), + lat=reference.lat, + lon=reference.lon, + elevation=reference.elevation, + east_shift=reference.east_shift + float(self.east), + north_shift=reference.north_shift + float(self.north), + depth=reference.depth + float(self.depth), ) return self._location @@ -160,8 +163,8 @@ def hash(self) -> bytes: self._hash = sha1( struct.pack( "dddddd", - self.tree.center_lat, - self.tree.center_lon, + self.tree.reference.lat, + self.tree.reference.lon, self.east, self.north, self.depth, @@ -175,9 +178,7 @@ def __hash__(self) -> int: class Octree(BaseModel): - center_lat: float = 0.0 - center_lon: float = 0.0 - surface_elevation: float = 0.0 + reference: Location = Location(lat=0.0, lon=0) size_initial: PositiveFloat = 2 * KM size_limit: PositiveFloat = 500 east_bounds: tuple[float, float] = (-10 * KM, 10 * KM) @@ -239,14 +240,6 @@ def _get_root_nodes(self, size: float) -> list[Node]: for depth in depth_nodes ] - @property - def center_location(self) -> Location: - return Location( - lat=self.center_lat, - lon=self.center_lon, - elevation=self.surface_elevation, - ) - @cached_property def n_nodes(self) -> int: """Number of nodes in the octree""" diff --git a/lassie/search/base.py b/lassie/search/base.py index d12953ea..68029f54 100644 --- a/lassie/search/base.py +++ b/lassie/search/base.py @@ -365,7 +365,7 @@ async def search( semblance.normalize(images.cumulative_weight()) parent.semblance_stats.update(semblance.get_stats()) - logger.info("semblance stats: %s", parent.semblance_stats) + logger.debug("semblance stats: %s", parent.semblance_stats) detection_idx, detection_semblance = semblance.find_peaks( height=parent.detection_threshold, diff --git a/lassie/search/squirrel.py b/lassie/search/squirrel.py index 46002879..629b8573 100644 --- a/lassie/search/squirrel.py +++ b/lassie/search/squirrel.py @@ -61,13 +61,13 @@ class SquirrelSearch(Search): _squirrel: Squirrel | None = PrivateAttr(None) def model_post_init(self, __context: Any) -> None: - if not all(self.time_span): - squirrel = self.get_squirrel() - sq_tmin, sq_tmax = squirrel.get_time_span(["waveform", "waveform_promise"]) - self.time_span = ( - self.time_span[0] or to_datetime(sq_tmin), - self.time_span[1] or to_datetime(sq_tmax), - ) + squirrel = self.get_squirrel() + sq_tmin, sq_tmax = squirrel.get_time_span(["waveform"]) + + self.time_span = ( + self.time_span[0] or to_datetime(sq_tmin), + self.time_span[1] or to_datetime(sq_tmax), + ) logger.info( "searching time span from %s to %s (%s)", @@ -78,8 +78,8 @@ def model_post_init(self, __context: Any) -> None: @field_validator("time_span") @classmethod - def _validate_time_span(cls, range): # noqa: N805 - if range[0] >= range[1]: + def _validate_time_span(cls, range) -> Any: # noqa: N805 + if range[0] and range[1] and range[0] >= range[1]: raise ValueError(f"time range is invalid {range[0]} - {range[1]}") return range diff --git a/lassie/tracers/fast_marching/fast_marching.py b/lassie/tracers/fast_marching/fast_marching.py index 360777d4..a7ec0716 100644 --- a/lassie/tracers/fast_marching/fast_marching.py +++ b/lassie/tracers/fast_marching/fast_marching.py @@ -60,9 +60,9 @@ class StationTravelTimeVolume(BaseModel): _travel_times: np.ndarray | None = PrivateAttr(None) - _north_coords: np.ndarray = PrivateAttr(None) - _east_coords: np.ndarray = PrivateAttr(None) - _depth_coords: np.ndarray = PrivateAttr(None) + _north_coords: np.ndarray = PrivateAttr() + _east_coords: np.ndarray = PrivateAttr() + _depth_coords: np.ndarray = PrivateAttr() # Cached values _file: Path | None = PrivateAttr(None) @@ -249,7 +249,7 @@ class FastMarchingTracer(RayTracer): tracer: Literal["FastMarchingRayTracer"] = "FastMarchingRayTracer" phase: PhaseDescription = "fm:P" - interpolation_method: Literal["nearest", "linear", "cubic"] = "nearest" + interpolation_method: Literal["nearest", "linear", "cubic"] = "linear" nthreads: int = Field( 0, description="Number of threads to use for travel time." @@ -303,6 +303,17 @@ async def prepare( reason=f"outside the fast-marching velocity model, offset {offset}", ) + nodes_covered = [ + node for node in octree if velocity_model.is_inside(node.as_location()) + ] + if not nodes_covered: + raise ValueError("no octree node is inside the velocity model") + + logger.info( + "%d%% octree nodes are inside the velocity model", + len(nodes_covered) / octree.n_nodes * 100, + ) + self._cached_stations = stations self._cached_station_indeces = { sta.pretty_nsl: idx for idx, sta in enumerate(stations) diff --git a/lassie/tracers/fast_marching/velocity_models.py b/lassie/tracers/fast_marching/velocity_models.py index 4f03c61e..2085eba0 100644 --- a/lassie/tracers/fast_marching/velocity_models.py +++ b/lassie/tracers/fast_marching/velocity_models.py @@ -170,7 +170,7 @@ def get_model(self, octree: Octree) -> VelocityModel3D: grid_spacing = self.grid_spacing model = VelocityModel3D( - center=octree.center_location, + center=octree.reference, grid_spacing=grid_spacing, east_bounds=octree.east_bounds, north_bounds=octree.north_bounds, @@ -199,7 +199,11 @@ class NonLinLocHeader: grid_type: NonLinLocGridType @classmethod - def from_header_file(cls, file: Path) -> Self: + def from_header_file( + cls, + file: Path, + reference_location: Location | None = None, + ) -> Self: logger.info("loading NonLinLoc velocity model header file %s", file) header_text = file.read_text().split("\n")[0] header_text = re.sub(r"\s+", " ", header_text) # remove excessive spaces @@ -220,12 +224,20 @@ def from_header_file(cls, file: Path) -> Self: if not delta_x == delta_y == delta_z: raise ValueError("NonLinLoc velocity model must have equal spacing.") - return cls( - origin=Location( + if reference_location: + origin = reference_location + origin.east_shift += float(orig_x) * KM + origin.north_shift += float(orig_y) * KM + origin.elevation -= float(orig_z) * KM + else: + origin = Location( lon=float(orig_x), lat=float(orig_y), elevation=-float(orig_z) * KM, - ), + ) + + return cls( + origin=origin, nx=int(nx), ny=int(ny), nz=int(nz), @@ -261,9 +273,9 @@ def depth_bounds(self) -> tuple[float, float]: @property def center(self) -> Location: - center = self.origin.model_copy() - center.north_shift = self.delta_x * self.nx / 2 - center.east_shift = self.delta_y * self.ny / 2 + center = self.origin.model_copy(deep=True) + center.east_shift += self.delta_x * self.nx / 2 + center.north_shift += self.delta_y * self.ny / 2 return center @@ -293,17 +305,28 @@ class NonLinLocVelocityModel(VelocityModelFactory): "for the fast-marching method.", ) + reference_location: Location | None = Field( + None, + description="relative location of NonLinLoc model, " + "used for models with relative coordinates.", + ) + _header: NonLinLocHeader = PrivateAttr() _velocity_model: np.ndarray = PrivateAttr() @model_validator(mode="after") def load_header(self) -> Self: - self._header = NonLinLocHeader.from_header_file(self.header_file) + self._header = NonLinLocHeader.from_header_file( + self.header_file, + reference_location=self.reference_location, + ) self.buffer_file = self.buffer_file or self.header_file.with_suffix(".buf") if not self.buffer_file.exists(): raise FileNotFoundError(f"Buffer file {self.buffer_file} not found.") - logger.info("loading NonLinLoc velocity model buffer file %s", self.buffer_file) + logger.debug( + "loading NonLinLoc velocity model buffer file %s", self.buffer_file + ) self._velocity_model = np.fromfile( self.buffer_file, dtype=self._header.dtype ).reshape((self._header.nx, self._header.ny, self._header.nz)) @@ -316,6 +339,7 @@ def load_header(self) -> Self: elif self._header.grid_type == "VELOCITY": self._velocity_model *= KM + logging.info("loaded NonLinLoc velocity model %s", self._header) return self def get_model(self, octree: Octree) -> VelocityModel3D: diff --git a/test/conftest.py b/test/conftest.py index 6a101dbc..bc02cf64 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -8,6 +8,7 @@ import pytest from lassie.models.detection import EventDetection, EventDetections +from lassie.models.location import Location from lassie.models.station import Station, Stations from lassie.octree import Octree from lassie.tracers.cake import EarthModel, Timing, TravelTimeTree @@ -48,9 +49,11 @@ def data_dir() -> Path: @pytest.fixture(scope="session") def octree() -> Octree: return Octree( - center_lat=10.0, - center_lon=10.0, - surface_elevation=1.0 * KM, + reference=Location( + lat=10.0, + lon=10.0, + elevation=1.0 * KM, + ), size_initial=2 * KM, size_limit=500, east_bounds=(-10 * KM, 10 * KM), diff --git a/test/test_octree.py b/test/test_octree.py index fcdf37ab..65c1bae8 100644 --- a/test/test_octree.py +++ b/test/test_octree.py @@ -1,23 +1,10 @@ -import pytest +from __future__ import annotations from lassie.octree import Octree km = 1e3 -@pytest.fixture(scope="function") -def octree(): - yield Octree( - center_lat=0.0, - center_lon=0.0, - east_bounds=(-25 * km, 25 * km), - north_bounds=(-25 * km, 25 * km), - depth_bounds=(0, 40 * km), - size_initial=5 * km, - size_limit=0.5 * km, - ) - - def test_octree(octree: Octree, plot: bool) -> None: assert octree.n_nodes > 0