Skip to content

Commit

Permalink
wip: finishing up fast-marching
Browse files Browse the repository at this point in the history
  • Loading branch information
miili committed Sep 12, 2023
1 parent 7d9b55d commit cd33e2a
Show file tree
Hide file tree
Showing 6 changed files with 142 additions and 20 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@ jobs:
pip install .[dev]
- name: Test with pytest
run: |
pytest
pytest -m "not plot"
2 changes: 1 addition & 1 deletion lassie/models/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ def jitter_location(self, meters: float) -> EventDetection:
detection.east_shift += uniform(-half_meters, half_meters)
detection.north_shift += uniform(-half_meters, half_meters)
detection.depth += uniform(-half_meters, half_meters)
del detection.effective_lat_lon
detection._cached_lat_lon = None
return detection

def snuffle(self, squirrel: Squirrel, restituted: bool = False) -> None:
Expand Down
48 changes: 32 additions & 16 deletions lassie/models/location.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
import hashlib
import math
import struct
from functools import cached_property
from typing import TYPE_CHECKING, Iterable, Literal, TypeVar
from typing import TYPE_CHECKING, Iterable, Literal, Self, TypeVar

from pydantic import BaseModel, computed_field
from pydantic import BaseModel, PrivateAttr
from pyrocko import orthodrome as od

if TYPE_CHECKING:
Expand All @@ -23,6 +22,8 @@ class Location(BaseModel):
elevation: float = 0.0
depth: float = 0.0

_cached_lat_lon: tuple[float, float] | None = PrivateAttr(None)

@property
def effective_lat(self) -> float:
return self.effective_lat_lon[0]
Expand All @@ -31,19 +32,21 @@ def effective_lat(self) -> float:
def effective_lon(self) -> float:
return self.effective_lat_lon[1]

@computed_field
@cached_property
@property
def effective_lat_lon(self) -> tuple[float, float]:
"""Shift-corrected lat/lon pair of the location."""
if self.north_shift == 0.0 and self.east_shift == 0.0:
return self.lat, self.lon
lat, lon = od.ne_to_latlon(
self.lat,
self.lon,
self.north_shift,
self.east_shift,
)
return float(lat), float(lon)
if self._cached_lat_lon is None:
if self.north_shift == 0.0 and self.east_shift == 0.0:
self._cached_lat_lon = self.lat, self.lon
else:
lat, lon = od.ne_to_latlon(
self.lat,
self.lon,
self.north_shift,
self.east_shift,
)
self._cached_lat_lon = float(lat), float(lon)
return self._cached_lat_lon

@property
def effective_elevation(self) -> float:
Expand Down Expand Up @@ -123,11 +126,24 @@ def offset_to(self, other: Location) -> tuple[float, float, float]:
)

return (
self.east_shift - other.east_shift + shift_east[0],
self.north_shift - other.north_shift + shift_north[0],
self.east_shift - other.east_shift - shift_east[0],
self.north_shift - other.north_shift - shift_north[0],
-(self.effective_elevation - other.effective_elevation),
)

def shifted_origin(self) -> Self:
"""Shift the origin of the location to the effective lat/lon.
Returns:
Self: The shifted location.
"""
shifted = self.model_copy()
shifted.lat = self.effective_lat
shifted.lon = self.effective_lon
shifted.east_shift = 0.0
shifted.north_shift = 0.0
return shifted

def __hash__(self) -> int:
return hash(self.location_hash())

Expand Down
11 changes: 9 additions & 2 deletions lassie/tracers/fast_marching/velocity_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,12 @@ def set_velocity_model(self, velocity_model: np.ndarray) -> None:
)
self._velocity_model = velocity_model.astype(float, copy=False)

@property
def velocity_model(self) -> np.ndarray:
if self._velocity_model is None:
raise ValueError("Velocity model not set.")
return self._velocity_model

def hash(self) -> str:
if self._hash is None:
sha1_hash = sha1(self._velocity_model.tobytes())
Expand Down Expand Up @@ -340,8 +346,9 @@ def load_header(self) -> Self:
self._velocity_model *= KM

logging.info(
"loaded NonLinLoc velocity model, "
"east_bounds: %s, north_bounds %s, depth_bounds %s",
"NonLinLoc velocity model: %s"
" east_bounds: %s, north_bounds %s, depth_bounds %s",
self._header.center,
self._header.east_bounds,
self._header.north_bounds,
self._header.depth_bounds,
Expand Down
32 changes: 32 additions & 0 deletions test/test_fast_marching.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,3 +142,35 @@ async def test_non_lin_loc(data_dir: Path, octree: Octree, stations: Stations) -
source=source,
receivers=list(stations),
)


@pytest.mark.plot
def test_non_lin_loc_model(
data_dir: Path,
octree: Octree,
stations: Stations,
) -> None:
import matplotlib.pyplot as plt

header_file = data_dir / "FORGE_3D_5_large.P.mod.hdr"

model = NonLinLocVelocityModel(header_file=header_file)
velocity_model = model.get_model(octree).resample(
grid_spacing=200.0,
method="linear",
)

# 3d figure of velocity model
fig = plt.figure()
ax = fig.add_subplot(projection="3d")
coords = velocity_model.get_meshgrid()
print(coords[0].shape)
cmap = ax.scatter(
coords[0],
coords[1],
-coords[2],
s=np.log(velocity_model.velocity_model.ravel() / KM),
c=velocity_model.velocity_model.ravel(),
)
fig.colorbar(cmap)
plt.show()
67 changes: 67 additions & 0 deletions test/test_location.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,75 @@
from __future__ import annotations

import random

import numpy as np

from lassie.models import Location

KM = 1e3


def test_location() -> None:
loc = Location(lat=11.0, lon=23.55)
loc_other = Location(lat=13.123, lon=21.12)

loc.surface_distance_to(loc_other)


def test_distance_same_origin():
loc = Location(lat=11.0, lon=23.55)

perturb_attributes = {"north_shift", "east_shift", "elevation", "depth"}
for _ in range(100):
distance = random.uniform(-10 * KM, 10 * KM)
for attr in perturb_attributes:
loc_other = loc.model_copy()
loc_other._cached_lat_lon = None
setattr(loc_other, attr, distance)
assert loc.distance_to(loc_other) == abs(distance)

loc_shifted = loc_other.shifted_origin()
np.testing.assert_approx_equal(
loc.distance_to(loc_shifted),
abs(distance),
significant=2,
)


def test_location_offset():
loc = Location(lat=11.0, lon=23.55)
loc_other = Location(
lat=11.0,
lon=23.55,
north_shift=100.0,
east_shift=100.0,
depth=100.0,
)

offset = loc_other.offset_to(loc)
assert offset == (100.0, 100.0, 100.0)

loc_other = Location(
lat=11.0,
lon=23.55,
north_shift=100.0,
east_shift=100.0,
elevation=100.0,
)
offset = loc_other.offset_to(loc)
assert offset == (100.0, 100.0, -100.0)

loc_other = Location(
lat=11.0,
lon=23.55,
north_shift=100.0,
east_shift=100.0,
elevation=100.0,
depth=10.0,
)
offset = loc_other.offset_to(loc)
assert offset == (100.0, 100.0, -90.0)

loc_other = loc_other.shifted_origin()
offset = loc_other.offset_to(loc)
np.testing.assert_almost_equal(offset, (100.0, 100.0, -90.0), decimal=0)

0 comments on commit cd33e2a

Please sign in to comment.