Skip to content

Commit

Permalink
wip: fixing tests
Browse files Browse the repository at this point in the history
  • Loading branch information
miili committed Sep 12, 2023
1 parent b1c3373 commit 7d9b55d
Show file tree
Hide file tree
Showing 7 changed files with 42 additions and 23 deletions.
14 changes: 5 additions & 9 deletions lassie/models/station.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import TYPE_CHECKING, Any, Iterable, Iterator

import numpy as np
from pydantic import BaseModel, PrivateAttr, constr
from pydantic import BaseModel, constr
from pyrocko.io.stationxml import load_xml
from pyrocko.model import Station as PyrockoStation
from pyrocko.model import dump_stations_yaml, load_stations
Expand Down Expand Up @@ -41,7 +41,7 @@ def from_pyrocko_station(cls, station: PyrockoStation) -> Station:
)

def to_pyrocko_station(self) -> PyrockoStation:
return PyrockoStation(**self.dict(exclude={"effective_lat_lon"}))
return PyrockoStation(**self.model_dump(exclude={"effective_lat_lon"}))

@property
def pretty_nsl(self) -> str:
Expand All @@ -62,8 +62,6 @@ class Stations(BaseModel):
station_xmls: list[Path] = []
pyrocko_station_yamls: list[Path] = []

_cached_coordinates: np.ndarray | None = PrivateAttr(None)

def model_post_init(self, __context: Any) -> None:
loaded_stations = []
for file in self.pyrocko_station_yamls:
Expand Down Expand Up @@ -181,11 +179,9 @@ def get_centroid(self) -> Location:
)

def get_coordinates(self, system: CoordSystem = "geographic") -> np.ndarray:
if self._cached_coordinates is None:
self._cached_coordinates = np.array(
[(*sta.effective_lat_lon, sta.effective_elevation) for sta in self]
)
return self._cached_coordinates
return np.array(
[(*sta.effective_lat_lon, sta.effective_elevation) for sta in self]
)

def dump_pyrocko_stations(self, filename: Path) -> None:
"""Dump stations to pyrocko station yaml file.
Expand Down
6 changes: 4 additions & 2 deletions lassie/octree.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
Field,
PositiveFloat,
PrivateAttr,
confloat,
field_validator,
model_validator,
)
Expand Down Expand Up @@ -184,7 +185,7 @@ class Octree(BaseModel):
east_bounds: tuple[float, float] = (-10 * KM, 10 * KM)
north_bounds: tuple[float, float] = (-10 * KM, 10 * KM)
depth_bounds: tuple[float, float] = (0 * KM, 20 * KM)
absorbing_boundary: PositiveFloat = 1 * KM
absorbing_boundary: confloat(ge=0.0) = 1 * KM

_root_nodes: list[Node] = PrivateAttr([])
_cached_coordinates: dict[CoordSystem, np.ndarray] = PrivateAttr({})
Expand All @@ -205,7 +206,8 @@ def check_limits(self) -> Octree:
"""Check that the size limits are valid."""
if self.size_limit > self.size_initial:
raise ValueError(
"invalid octree size limits, expected size_limit <= size_initial"
f"invalid octree size limits ({self.size_initial}, {self.size_limit}),"
" expected size_limit <= size_initial"
)
return self

Expand Down
8 changes: 5 additions & 3 deletions lassie/tracers/cake.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ class EarthModel(BaseModel):
)

raw_file_data: str | None = Field(
None, description="Raw .nd file data.", exclude=True
None,
description="Raw .nd file data.",
)
_layered_model: LayeredModel = PrivateAttr()

Expand Down Expand Up @@ -278,8 +279,8 @@ def save(self, path: Path) -> Path:
"model.json",
self.model_dump_json(
indent=2,
exclude={"earthmodel": {"nd_file"}},
include={"earthmodel": {"raw_file_data"}},
exclude={"earthmodel": {"filename"}},
# include={"earthmodel": {"raw_file_data"}},
),
)
with NamedTemporaryFile() as tmpfile:
Expand All @@ -301,6 +302,7 @@ def load(cls, file: Path) -> Self:
with zipfile.ZipFile(file, "r") as archive:
path = zipfile.Path(archive)
model_file = path / "model.json"
print(model_file.read_text())
model = cls.model_validate_json(model_file.read_text())
model._file = file
return model
Expand Down
2 changes: 1 addition & 1 deletion test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def plot(pytestconfig) -> bool:


@pytest.fixture(scope="session")
def traveltime_tree() -> TravelTimeTree:
def travel_time_tree() -> TravelTimeTree:
return TravelTimeTree.new(
earthmodel=EarthModel(),
distance_bounds=(0 * KM, 15 * KM),
Expand Down
12 changes: 7 additions & 5 deletions test/test_cake.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
KM = 1e3


def test_sptree_model(traveltime_tree: TravelTimeTree):
model = traveltime_tree
def test_sptree_model(travel_time_tree: TravelTimeTree):
model = travel_time_tree

with TemporaryDirectory() as d:
tmp = Path(d)
Expand Down Expand Up @@ -45,9 +45,11 @@ def test_sptree_model(traveltime_tree: TravelTimeTree):


def test_lut(
traveltime_tree: TravelTimeTree, octree: Octree, stations: Stations
travel_time_tree: TravelTimeTree,
octree: Octree,
stations: Stations,
) -> None:
model = traveltime_tree
model = travel_time_tree
model.init_lut(octree, stations)

traveltimes_tree = model.interpolate_travel_times(octree, stations)
Expand All @@ -61,7 +63,7 @@ def test_lut(
np.testing.assert_equal(traveltimes_tree, traveltimes_lut)
assert len(model._node_lut) > 0, "did not refill lut"

stations_selection = stations.copy()
stations_selection = stations.model_copy()
stations_selection.stations = stations_selection.stations[:5]
traveltimes_tree = model.interpolate_travel_times(octree, stations_selection)
traveltimes_lut = model.get_travel_times(octree, stations_selection)
Expand Down
15 changes: 14 additions & 1 deletion test/test_fast_marching.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,23 @@ def stations_inside(
return Stations(stations=stations)


def octree_cover(model: VelocityModel3D) -> Octree:
return Octree(
reference=model.center,
size_initial=2 * KM,
size_limit=500,
east_bounds=model.east_bounds,
north_bounds=model.north_bounds,
depth_bounds=model.depth_bounds,
absorbing_boundary=0,
)


@pytest_asyncio.fixture
async def station_travel_times(
octree: Octree, stations: Stations
) -> StationTravelTimeVolume:
octree.surface_elevation = 1 * KM
octree.reference.elevation = 1 * KM
model = Constant3DVelocityModel(velocity=CONSTANT_VELOCITY, grid_spacing=100.0)
model_3d = model.get_model(octree)
return await StationTravelTimeVolume.calculate_from_eikonal(
Expand Down Expand Up @@ -120,6 +132,7 @@ async def test_non_lin_loc(data_dir: Path, octree: Octree, stations: Stations) -
phase="fm:P",
velocity_model=NonLinLocVelocityModel(header_file=header_file),
)
octree = octree_cover(tracer.velocity_model.get_model(octree))
stations = stations_inside(tracer.velocity_model.get_model(octree))
await tracer.prepare(octree, stations)
source = octree[1].as_location()
Expand Down
8 changes: 6 additions & 2 deletions test/test_octree.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from lassie.octree import Octree
from lassie.octree import NodeSplitError, Octree

km = 1e3

Expand All @@ -16,7 +16,11 @@ def test_octree(octree: Octree, plot: bool) -> None:
assert nnodes * 8 == octree.n_nodes

child, *_ = octree[80].split()
child, *_ = child.split()
while True:
try:
child, *_ = child.split()
except NodeSplitError:
break

for node in octree:
node.semblance = node.depth + node.east + node.north
Expand Down

0 comments on commit 7d9b55d

Please sign in to comment.