Skip to content

Commit

Permalink
fixing tests
Browse files Browse the repository at this point in the history
  • Loading branch information
miili committed Sep 12, 2023
1 parent 344624e commit d400398
Show file tree
Hide file tree
Showing 10 changed files with 202 additions and 20 deletions.
3 changes: 0 additions & 3 deletions lassie/apps/lassie.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from datetime import datetime
from pathlib import Path

import nest_asyncio
from pkg_resources import get_distribution

from lassie.console import console
Expand All @@ -19,8 +18,6 @@
from lassie.station_corrections import StationCorrections
from lassie.utils import CACHE_DIR, setup_rich_logging

nest_asyncio.apply()

logger = logging.getLogger(__name__)


Expand Down
21 changes: 21 additions & 0 deletions lassie/octree.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,27 @@ def smallest_node_size(self) -> float:
size /= 2
return size

def n_levels(self) -> int:
"""Returns the number of levels in the octree.
Returns:
int: Number of levels.
"""
levels = 0
size = self.size_initial
while size >= self.size_limit * 2:
levels += 1
size /= 2
return levels

def total_number_nodes(self) -> int:
"""Returns the total number of nodes of all levels.
Returns:
int: Total number of nodes.
"""
return len(self._root_nodes) * (8 ** self.n_levels())

def maximum_number_nodes(self) -> int:
"""Returns the maximum number of nodes.
Expand Down
14 changes: 12 additions & 2 deletions lassie/search/squirrel.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,14 @@
from pathlib import Path
from typing import TYPE_CHECKING, Any, Deque, Iterator

from pydantic import AwareDatetime, Field, PositiveInt, PrivateAttr, field_validator
from pydantic import (
AwareDatetime,
Field,
PositiveInt,
PrivateAttr,
constr,
field_validator,
)
from pyrocko.squirrel import Squirrel

from lassie.features import FeatureExtractors
Expand Down Expand Up @@ -49,6 +56,7 @@ async def prefetch_worker(self) -> None:
class SquirrelSearch(Search):
time_span: tuple[AwareDatetime | None, AwareDatetime | None] = (None, None)
squirrel_environment: Path = Path(".")
channel_selector: constr(max_length=3) = "*"
waveform_data: list[Path]
waveform_prefetch_batches: PositiveInt = 4

Expand Down Expand Up @@ -131,7 +139,9 @@ async def scan_squirrel(self) -> None:
tinc=window_increment.total_seconds(),
tpad=self.window_padding.total_seconds(),
want_incomplete=False,
codes=[(*nsl, "*") for nsl in self.stations.get_all_nsl()],
codes=[
(*nsl, self.channel_selector) for nsl in self.stations.get_all_nsl()
],
)
prefetcher = SquirrelPrefetcher(iterator, self.waveform_prefetch_batches)

Expand Down
4 changes: 2 additions & 2 deletions lassie/tracers/cake.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ async def prepare(self, octree: Octree, stations: Stations) -> None:
LRU_CACHE_SIZE = int(self.lut_cache_size / bytes_per_node / n_trees)

# TODO: This should be total number nodes. Not only leaf nodes.
node_cache_fraction = LRU_CACHE_SIZE / octree.maximum_number_nodes()
node_cache_fraction = LRU_CACHE_SIZE / octree.total_number_nodes()
logging.info(
"limiting traveltime LUT size to %d nodes (%s),"
" caching %.1f%% of possible octree nodes",
Expand All @@ -515,7 +515,7 @@ async def prepare(self, octree: Octree, stations: Stations) -> None:
cached_trees = [
TravelTimeTree.load(file) for file in self.cache_dir.glob("*.sptree")
]
logger.debug("loaded %d cached traveltime trees", len(cached_trees))
logger.debug("loaded %d cached travel time trees", len(cached_trees))

distances = octree.distances_stations(stations)
source_depths = np.asarray(octree.depth_bounds)
Expand Down
22 changes: 19 additions & 3 deletions lassie/tracers/fast_marching/fast_marching.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,15 +312,29 @@ async def prepare(
reason=f"outside fast-marching velocity model, offset {offset}",
)

for station in stations:
velocity_station = velocity_model.get_velocity(station)
if velocity_station < 0.0:
raise ValueError(
f"station {station.pretty_nsl} has negative velocity"
f" {velocity_station}"
)
logger.info(
"velocity at station %s: %.1f m/s",
station.pretty_nsl,
velocity_station,
)

nodes_covered = [
node for node in octree if velocity_model.is_inside(node.as_location())
]
if not nodes_covered:
raise ValueError("no octree node is inside the velocity model")

logger.info(
"%d%% octree nodes are inside the velocity model",
"%d%% octree nodes are inside the %s velocity model",
len(nodes_covered) / octree.n_nodes * 100,
self.phase,
)

self._cached_stations = stations
Expand All @@ -332,7 +346,7 @@ async def prepare(
self._node_lut = LRU(lru_cache_size)

# TODO: This should be total number nodes. Not only leaf nodes.
node_cache_fraction = lru_cache_size / octree.maximum_number_nodes()
node_cache_fraction = lru_cache_size / octree.total_number_nodes()
logging.info(
"limiting traveltime LUT size to %d nodes (%s),"
" caching %.1f%% of possible octree nodes",
Expand Down Expand Up @@ -364,7 +378,9 @@ def _load_cached_tavel_times(self, cache_dir: Path) -> None:
continue
volumes[travel_times.station.location_hash()] = travel_times

logger.info("loaded %d travel times volumes from cache", len(volumes))
logger.info(
"loaded %d travel times volumes for %s from cache", len(volumes), self.phase
)
self._travel_time_volumes.update(volumes)

async def _calculate_travel_times(
Expand Down
92 changes: 86 additions & 6 deletions lassie/tracers/fast_marching/velocity_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from lassie.models.location import Location

if TYPE_CHECKING:
from lassie.models.station import Station
from lassie.octree import Octree


Expand Down Expand Up @@ -89,25 +88,71 @@ def velocity_model(self) -> np.ndarray:
return self._velocity_model

def hash(self) -> str:
"""Return hash of velocity model.
Returns:
str: The hash.
"""
if self._hash is None:
sha1_hash = sha1(self._velocity_model.tobytes())
self._hash = sha1_hash.hexdigest()
return self._hash

def get_source_arrival_grid(self, station: Station) -> np.ndarray:
if not self.is_inside(station):
raise ValueError("Station is outside of velocity model.")
def _get_location_indices(self, location: Location) -> tuple[int, int, int]:
"""Return indices of location in velocity model, by nearest neighbor.
times = np.full_like(self.velocity_model, fill_value=-1.0)
Args:
location (Location): The location.
station_offset = station.offset_to(self.center)
Returns:
tuple[int, int, int]: The indices as (east, north, depth).
"""
if not self.is_inside(location):
raise ValueError("Location is outside of velocity model.")
station_offset = location.offset_to(self.center)
east_idx = np.argmin(np.abs(self._east_coords - station_offset[0]))
north_idx = np.argmin(np.abs(self._north_coords - station_offset[1]))
depth_idx = np.argmin(np.abs(self._depth_coords - station_offset[2]))
return int(east_idx), int(north_idx), int(depth_idx)

def get_velocity(self, location: Location) -> float:
"""Return velocity at location in [m/s], nearest neighbor.
Args:
location (Location): The location.
Returns:
float: The velocity in m/s.
"""
east_idx, north_idx, depth_idx = self._get_location_indices(location)
return self.velocity_model[east_idx, north_idx, depth_idx]

def get_source_arrival_grid(self, location: Location) -> np.ndarray:
"""Return travel times grid for Eikonal for specific.
The initial travel time grid is filled with -1.0, except for the source
location, which is set to 0.0 s.
Args:
location (Location): The location.
Returns:
np.ndarray: The initial travel times grid.
"""
times = np.full_like(self.velocity_model, fill_value=-1.0)
east_idx, north_idx, depth_idx = self._get_location_indices(location)
times[east_idx, north_idx, depth_idx] = 0.0
return times

def is_inside(self, location: Location) -> bool:
"""Return True if location is inside velocity model.
Args:
location (Location): The location.
Returns:
bool: True if location is inside velocity model.
"""
offset_to_center = location.offset_to(self.center)
return (
self.east_bounds[0] <= offset_to_center[0] <= self.east_bounds[1]
Expand All @@ -116,6 +161,12 @@ def is_inside(self, location: Location) -> bool:
)

def get_meshgrid(self) -> list[np.ndarray]:
"""Return meshgrid of velocity model coordinates.
Returns:
list[np.ndarray]: The meshgrid as list of numpy arrays for east, north,
depth.
"""
return np.meshgrid(
self._east_coords,
self._north_coords,
Expand All @@ -128,6 +179,16 @@ def resample(
grid_spacing: float,
method: Literal["nearest", "linear", "cubic"] = "linear",
) -> Self:
"""Resample velocity model to new grid spacing.
Args:
grid_spacing (float): The new grid spacing in [m].
method (Literal['nearest', 'linear', 'cubic'], optional): Interpolation
method. Defaults to "linear".
Returns:
Self: A new, resampled velocity model.
"""
if grid_spacing == self.grid_spacing:
return self

Expand Down Expand Up @@ -215,6 +276,20 @@ def from_header_file(
file: Path,
reference_location: Location | None = None,
) -> Self:
"""Load NonLinLoc velocity model header file.
Args:
file (Path): Path to NonLinLoc model header file.
reference_location (Location | None, optional): relative location of
NonLinLoc model, used for models with relative coordinates.
Defaults to None.
Raises:
ValueError: If grid spacing is not equal in all dimensions.
Returns:
Self: The header.
"""
logger.info("loading NonLinLoc velocity model header file %s", file)
header_text = file.read_text().split("\n")[0]
header_text = re.sub(r"\s+", " ", header_text) # remove excessive spaces
Expand Down Expand Up @@ -284,6 +359,11 @@ def depth_bounds(self) -> tuple[float, float]:

@property
def center(self) -> Location:
"""Return center location of velocity model.
Returns:
Location: The center location of the grid.
"""
center = self.origin.model_copy(deep=True)
center.east_shift += self.delta_x * self.nx / 2
center.north_shift += self.delta_y * self.ny / 2
Expand Down
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,13 @@ dependencies = [
"numpy>=1.17.3",
"scipy>=1.8.0",
"pyrocko>=2022.06.10",
"seisbench>=0.4.0",
"seisbench>=0.5.0",
"pydantic>=2.3",
"aiohttp>=3.8",
"aiohttp_cors>=0.7.0",
"typing-extensions>=4.6",
"lru-dict>=1.2",
"rich>=13.4",
"nest-asyncio>=1.5", # wait for seisbench merge https://github.com/seisbench/seisbench/pull/214
]

classifiers = [
Expand Down
43 changes: 43 additions & 0 deletions test/conftest.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import asyncio
import random
from datetime import timedelta
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Generator

import aiohttp
import numpy as np
import pytest
from rich.progress import Progress

from lassie.models.detection import EventDetection, EventDetections
from lassie.models.location import Location
Expand All @@ -16,6 +19,14 @@

DATA_DIR = Path(__file__).parent / "data"

DATA_URL = "https://data.pyrocko.org/testing/lassie-v2/"
DATA_FILES = {
"FORGE_3D_5_large.P.mod.hdr",
"FORGE_3D_5_large.P.mod.buf",
"FORGE_3D_5_large.S.mod.hdr",
"FORGE_3D_5_large.S.mod.buf",
}

KM = 1e3


Expand Down Expand Up @@ -43,6 +54,38 @@ def travel_time_tree() -> TravelTimeTree:

@pytest.fixture(scope="session")
def data_dir() -> Path:
if not DATA_DIR.exists():
DATA_DIR.mkdir()

async def download_data():
download_files = DATA_FILES.copy()
for filename in DATA_FILES:
filepath = DATA_DIR / filename
if filepath.exists():
download_files.remove(filename)

if not download_files:
return

async with aiohttp.ClientSession() as session:
for filename in download_files:
filepath = DATA_DIR / filename
url = DATA_URL + filename
with Progress() as progress:
async with session.get(url) as response:
task = progress.add_task(
f"Downloading {url}",
total=response.content_length,
)
with filepath.open("wb") as f:
while True:
chunk = await response.content.read(1024)
if not chunk:
break
f.write(chunk)
progress.advance(task, len(chunk))

asyncio.run(download_data())
return DATA_DIR


Expand Down
Loading

0 comments on commit d400398

Please sign in to comment.