Skip to content

Commit

Permalink
fix typing
Browse files Browse the repository at this point in the history
  • Loading branch information
miili committed Sep 12, 2023
1 parent f66bd87 commit 344624e
Show file tree
Hide file tree
Showing 9 changed files with 93 additions and 35 deletions.
8 changes: 4 additions & 4 deletions lassie/apps/lassie.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def main() -> None:
features.add_argument("rundir", type=Path, help="path of existing run")

station_corrections = subparsers.add_parser(
"station-corrections",
"corrections",
help="analyse station corrections from existing run",
description="analyze and plot station corrections from a finished run",
)
Expand All @@ -95,14 +95,14 @@ def main() -> None:

serve = subparsers.add_parser(
"serve",
help="serve results from an existing run",
help="start webserver and serve results from an existing run",
description="start a webserver and serve detections and results from a run",
)
serve.add_argument("rundir", type=Path, help="rundir to serve")

subparsers.add_parser(
"clear-cache",
help="clear the cached travel times",
help="clear the cach directory",
)

dump_schemas = subparsers.add_parser(
Expand Down Expand Up @@ -178,7 +178,7 @@ async def extract() -> None:

asyncio.run(extract())

elif args.command == "station-corrections":
elif args.command == "corrections":
rundir = Path(args.rundir)
station_corrections = StationCorrections(rundir=rundir)
if args.plot:
Expand Down
13 changes: 12 additions & 1 deletion lassie/models/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from lassie.models.location import Location
from lassie.models.station import Station, Stations
from lassie.tracers import RayTracerArrival
from lassie.utils import PhaseDescription, time_to_path
from lassie.utils import PhaseDescription, Symbols, time_to_path

if TYPE_CHECKING:
from pyrocko.squirrel import Response, Squirrel
Expand Down Expand Up @@ -449,6 +449,17 @@ def add(self, detection: EventDetection) -> None:
marker.save_markers(detection.get_pyrocko_markers(), str(markers_file))

self.detections.append(detection)
logger.info(
"%s detection #%d %s: %.5f°, %.5f°, depth %.1f m, "
"border distance %.1f m, semblance %.3f",
Symbols.Target,
self.n_detections,
detection.time,
*detection.effective_lat_lon,
detection.depth,
detection.distance_border,
detection.semblance,
)
# This has to happen after the markers are saved
detection.dump_append(self.rundir, self.n_detections - 1)

Expand Down
3 changes: 2 additions & 1 deletion lassie/models/location.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
import hashlib
import math
import struct
from typing import TYPE_CHECKING, Iterable, Literal, Self, TypeVar
from typing import TYPE_CHECKING, Iterable, Literal, TypeVar

from pydantic import BaseModel, PrivateAttr
from pyrocko import orthodrome as od
from typing_extensions import Self

if TYPE_CHECKING:
from pathlib import Path
Expand Down
16 changes: 1 addition & 15 deletions lassie/search/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from lassie.station_corrections import StationCorrections
from lassie.tracers import CakeTracer, ConstantVelocityTracer, RayTracers
from lassie.tracers.fast_marching import FastMarchingTracer
from lassie.utils import PhaseDescription, Symbols, alog_call, time_to_path
from lassie.utils import PhaseDescription, alog_call, time_to_path

if TYPE_CHECKING:
from pyrocko.trace import Trace
Expand Down Expand Up @@ -422,11 +422,6 @@ async def search(

node_idx = (await semblance.maxima_node_idx())[time_idx]
source_node = octree[node_idx]
if not octree.is_node_in_bounds(source_node):
logger.info(
"source node is inside octree's absorbing boundary (%.1f m)",
source_node.distance_border,
)
source_location = source_node.as_location()

detection = EventDetection(
Expand Down Expand Up @@ -464,15 +459,6 @@ async def search(
)

detections.append(detection)
logger.info(
"%s new detection %s: %.5fE, %.5fN, depth %.1f m, semblance %.3f",
Symbols.Target,
detection.time,
*detection.effective_lat_lon,
detection.depth,
detection.semblance,
)

# detection.plot()

# plot_octree_movie(octree, semblance, file=Path("/tmp/test.mp4"))
Expand Down
2 changes: 1 addition & 1 deletion lassie/search/squirrel.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ async def scan_squirrel(self) -> None:
* (batch.n - batch.i - 1)
)
logger.info(
"remaining %s, estimated finish at %s",
"remaining time %s, estimated finish at %s",
remaining_time,
datetime.now() + remaining_time, # noqa: DTZ005
)
Expand Down
2 changes: 1 addition & 1 deletion lassie/tracers/cake.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ def get_travel_times(self, octree: Octree, stations: Stations) -> np.ndarray:

cache_hits, cache_misses = self._node_lut.get_stats()
cache_hit_rate = cache_hits / (cache_hits + cache_misses)
logger.info(
logger.debug(
"node LUT cache fill level %.1f%%, cache hit rate %.1f%%",
self.lut_fill_level() * 100,
cache_hit_rate * 100,
Expand Down
25 changes: 17 additions & 8 deletions lassie/tracers/fast_marching/fast_marching.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime, timedelta
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal, Self, Sequence
from typing import TYPE_CHECKING, Any, Literal, Sequence

import numpy as np
from lru import LRU
from pydantic import BaseModel, ByteSize, Field, PrivateAttr
from pyrocko.modelling import eikonal
from rich.progress import Progress
from scipy.interpolate import RegularGridInterpolator
from typing_extensions import Self

from lassie.models.location import Location
from lassie.models.station import Station, Stations
Expand Down Expand Up @@ -163,7 +164,7 @@ def filename(self) -> str:
# TODO: Add origin to hash to avoid collisions
return f"{self.station.pretty_nsl}-{self.velocity_model_hash}.3dtt"

def get_traveltime_interpolator(self) -> RegularGridInterpolator:
def get_travel_time_interpolator(self) -> RegularGridInterpolator:
if self._interpolator is None:
self._interpolator = RegularGridInterpolator(
(self._east_coords, self._north_coords, self._depth_coords),
Expand All @@ -178,7 +179,7 @@ def interpolate_travel_time(
location: Location,
method: Literal["nearest", "linear", "cubic"] = "linear",
) -> float:
interpolator = self.get_traveltime_interpolator()
interpolator = self.get_travel_time_interpolator()
offset = location.offset_to(self.center)
return interpolator([offset], method=method).astype(float, copy=False)[0]

Expand All @@ -187,11 +188,19 @@ def interpolate_nodes(
nodes: Sequence[Node],
method: Literal["nearest", "linear", "cubic"] = "linear",
) -> np.ndarray:
interpolator = self.get_traveltime_interpolator()
interpolator = self.get_travel_time_interpolator()

coordinates = [node.as_location().offset_to(self.center) for node in nodes]
return interpolator(coordinates, method=method).astype(float, copy=False)

def get_meshgrid(self) -> list[np.ndarray]:
return np.meshgrid(
self._east_coords,
self._north_coords,
self._depth_coords,
indexing="ij",
)

def save(self, path: Path) -> Path:
"""Save travel times to a zip file.
Expand Down Expand Up @@ -221,10 +230,10 @@ def save(self, path: Path) -> Path:

@classmethod
def load(cls, file: Path) -> Self:
"""Load 3D travel times from a zip file.
"""Load 3D travel times from a .3dtt file.
Args:
file (Path): path to the zip file containing the travel times
file (Path): path to the .3dtt file containing the travel times
Returns:
Self: 3D travel times
Expand Down Expand Up @@ -441,7 +450,7 @@ def get_travel_times(

cache_hits, cache_misses = self._node_lut.get_stats()
cache_hit_rate = cache_hits / (cache_hits + cache_misses)
logger.info(
logger.debug(
"node LUT cache fill level %.1f%%, cache hit rate %.1f%%",
self.lut_fill_level() * 100,
cache_hit_rate * 100,
Expand All @@ -456,7 +465,7 @@ def fill_lut(self, nodes: Sequence[Node]) -> None:

with Progress() as progress:
status = progress.add_task(
f"interpolating {self.phase} traveltimes for {n_nodes} nodes",
f"interpolating {self.phase} travel times for {n_nodes} nodes",
total=self._cached_stations.n_stations,
)
for station in self._cached_stations:
Expand Down
9 changes: 7 additions & 2 deletions lassie/tracers/fast_marching/velocity_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import re
from hashlib import sha1
from pathlib import Path
from typing import TYPE_CHECKING, Annotated, Any, Literal, Self, Union
from typing import TYPE_CHECKING, Annotated, Any, Literal, Union

import numpy as np
from pydantic import (
Expand All @@ -17,6 +17,7 @@
)
from pydantic.dataclasses import dataclass
from scipy.interpolate import RegularGridInterpolator
from typing_extensions import Self

from lassie.models.location import Location

Expand Down Expand Up @@ -94,7 +95,10 @@ def hash(self) -> str:
return self._hash

def get_source_arrival_grid(self, station: Station) -> np.ndarray:
times = np.full_like(self._velocity_model, fill_value=-1.0)
if not self.is_inside(station):
raise ValueError("Station is outside of velocity model.")

times = np.full_like(self.velocity_model, fill_value=-1.0)

station_offset = station.offset_to(self.center)
east_idx = np.argmin(np.abs(self._east_coords - station_offset[0]))
Expand All @@ -116,6 +120,7 @@ def get_meshgrid(self) -> list[np.ndarray]:
self._east_coords,
self._north_coords,
self._depth_coords,
indexing="ij",
)

def resample(
Expand Down
50 changes: 48 additions & 2 deletions test/test_fast_marching.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@


def stations_inside(
model: VelocityModel3D, nstations: int = 20, seed: int = 0
model: VelocityModel3D,
nstations: int = 20,
seed: int = 0,
depth: float | None = None,
) -> Stations:
stations = []
rng = np.random.RandomState(seed)
Expand All @@ -37,7 +40,8 @@ def stations_inside(
elevation=model.center.elevation,
north_shift=model.center.north_shift + rng.uniform(*model.north_bounds),
east_shift=model.center.east_shift + rng.uniform(*model.east_bounds),
depth=model.center.depth + rng.uniform(*model.depth_bounds),
depth=model.center.depth
+ (depth if depth is not None else rng.uniform(*model.depth_bounds)),
)
stations.append(station)
return Stations(stations=stations)
Expand Down Expand Up @@ -174,3 +178,45 @@ def test_non_lin_loc_model(
)
fig.colorbar(cmap)
plt.show()


@pytest.mark.plot
@pytest.mark.asyncio
async def test_non_lin_loc_travel_times(data_dir: Path, octree: Octree) -> None:
import matplotlib.pyplot as plt

header_file = data_dir / "FORGE_3D_5_large.P.mod.hdr"

tracer = FastMarchingTracer(
phase="fm:P",
velocity_model=NonLinLocVelocityModel(
header_file=header_file,
grid_spacing=100.0,
),
)
model_3d = tracer.velocity_model.get_model(octree)
octree = octree_cover(model_3d)
stations = stations_inside(model_3d, depth=0.0)
await tracer.prepare(octree, stations)

volume = tracer.get_travel_time_volume(stations.stations[0])

# 3d figure of velocity model
fig = plt.figure()
ax = fig.add_subplot(projection="3d")
coords = volume.get_meshgrid()
print(coords[0].shape)

cmap = ax.scatter(
coords[0],
coords[1],
coords[2],
c=volume.travel_times.ravel(),
alpha=0.2,
)

station_offet = volume.station.offset_to(volume.center)
print(station_offet)
ax.scatter(*station_offet, s=100, c="r")
fig.colorbar(cmap)
plt.show()

0 comments on commit 344624e

Please sign in to comment.