From 7d9b55dc842324fd62520e116c5926c3d8a985e9 Mon Sep 17 00:00:00 2001 From: miili Date: Tue, 12 Sep 2023 09:45:18 +0200 Subject: [PATCH] wip: fixing tests --- lassie/models/station.py | 14 +++++--------- lassie/octree.py | 6 ++++-- lassie/tracers/cake.py | 8 +++++--- test/conftest.py | 2 +- test/test_cake.py | 12 +++++++----- test/test_fast_marching.py | 15 ++++++++++++++- test/test_octree.py | 8 ++++++-- 7 files changed, 42 insertions(+), 23 deletions(-) diff --git a/lassie/models/station.py b/lassie/models/station.py index c7cdf3ec..0bb0c523 100644 --- a/lassie/models/station.py +++ b/lassie/models/station.py @@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Any, Iterable, Iterator import numpy as np -from pydantic import BaseModel, PrivateAttr, constr +from pydantic import BaseModel, constr from pyrocko.io.stationxml import load_xml from pyrocko.model import Station as PyrockoStation from pyrocko.model import dump_stations_yaml, load_stations @@ -41,7 +41,7 @@ def from_pyrocko_station(cls, station: PyrockoStation) -> Station: ) def to_pyrocko_station(self) -> PyrockoStation: - return PyrockoStation(**self.dict(exclude={"effective_lat_lon"})) + return PyrockoStation(**self.model_dump(exclude={"effective_lat_lon"})) @property def pretty_nsl(self) -> str: @@ -62,8 +62,6 @@ class Stations(BaseModel): station_xmls: list[Path] = [] pyrocko_station_yamls: list[Path] = [] - _cached_coordinates: np.ndarray | None = PrivateAttr(None) - def model_post_init(self, __context: Any) -> None: loaded_stations = [] for file in self.pyrocko_station_yamls: @@ -181,11 +179,9 @@ def get_centroid(self) -> Location: ) def get_coordinates(self, system: CoordSystem = "geographic") -> np.ndarray: - if self._cached_coordinates is None: - self._cached_coordinates = np.array( - [(*sta.effective_lat_lon, sta.effective_elevation) for sta in self] - ) - return self._cached_coordinates + return np.array( + [(*sta.effective_lat_lon, sta.effective_elevation) for sta in self] + ) def dump_pyrocko_stations(self, filename: Path) -> None: """Dump stations to pyrocko station yaml file. diff --git a/lassie/octree.py b/lassie/octree.py index 14903126..a07519ce 100644 --- a/lassie/octree.py +++ b/lassie/octree.py @@ -15,6 +15,7 @@ Field, PositiveFloat, PrivateAttr, + confloat, field_validator, model_validator, ) @@ -184,7 +185,7 @@ class Octree(BaseModel): east_bounds: tuple[float, float] = (-10 * KM, 10 * KM) north_bounds: tuple[float, float] = (-10 * KM, 10 * KM) depth_bounds: tuple[float, float] = (0 * KM, 20 * KM) - absorbing_boundary: PositiveFloat = 1 * KM + absorbing_boundary: confloat(ge=0.0) = 1 * KM _root_nodes: list[Node] = PrivateAttr([]) _cached_coordinates: dict[CoordSystem, np.ndarray] = PrivateAttr({}) @@ -205,7 +206,8 @@ def check_limits(self) -> Octree: """Check that the size limits are valid.""" if self.size_limit > self.size_initial: raise ValueError( - "invalid octree size limits, expected size_limit <= size_initial" + f"invalid octree size limits ({self.size_initial}, {self.size_limit})," + " expected size_limit <= size_initial" ) return self diff --git a/lassie/tracers/cake.py b/lassie/tracers/cake.py index 8e1e94c8..09924324 100644 --- a/lassie/tracers/cake.py +++ b/lassie/tracers/cake.py @@ -104,7 +104,8 @@ class EarthModel(BaseModel): ) raw_file_data: str | None = Field( - None, description="Raw .nd file data.", exclude=True + None, + description="Raw .nd file data.", ) _layered_model: LayeredModel = PrivateAttr() @@ -278,8 +279,8 @@ def save(self, path: Path) -> Path: "model.json", self.model_dump_json( indent=2, - exclude={"earthmodel": {"nd_file"}}, - include={"earthmodel": {"raw_file_data"}}, + exclude={"earthmodel": {"filename"}}, + # include={"earthmodel": {"raw_file_data"}}, ), ) with NamedTemporaryFile() as tmpfile: @@ -301,6 +302,7 @@ def load(cls, file: Path) -> Self: with zipfile.ZipFile(file, "r") as archive: path = zipfile.Path(archive) model_file = path / "model.json" + print(model_file.read_text()) model = cls.model_validate_json(model_file.read_text()) model._file = file return model diff --git a/test/conftest.py b/test/conftest.py index bc02cf64..36edbfb5 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -29,7 +29,7 @@ def plot(pytestconfig) -> bool: @pytest.fixture(scope="session") -def traveltime_tree() -> TravelTimeTree: +def travel_time_tree() -> TravelTimeTree: return TravelTimeTree.new( earthmodel=EarthModel(), distance_bounds=(0 * KM, 15 * KM), diff --git a/test/test_cake.py b/test/test_cake.py index fef5eafb..cccb4079 100644 --- a/test/test_cake.py +++ b/test/test_cake.py @@ -16,8 +16,8 @@ KM = 1e3 -def test_sptree_model(traveltime_tree: TravelTimeTree): - model = traveltime_tree +def test_sptree_model(travel_time_tree: TravelTimeTree): + model = travel_time_tree with TemporaryDirectory() as d: tmp = Path(d) @@ -45,9 +45,11 @@ def test_sptree_model(traveltime_tree: TravelTimeTree): def test_lut( - traveltime_tree: TravelTimeTree, octree: Octree, stations: Stations + travel_time_tree: TravelTimeTree, + octree: Octree, + stations: Stations, ) -> None: - model = traveltime_tree + model = travel_time_tree model.init_lut(octree, stations) traveltimes_tree = model.interpolate_travel_times(octree, stations) @@ -61,7 +63,7 @@ def test_lut( np.testing.assert_equal(traveltimes_tree, traveltimes_lut) assert len(model._node_lut) > 0, "did not refill lut" - stations_selection = stations.copy() + stations_selection = stations.model_copy() stations_selection.stations = stations_selection.stations[:5] traveltimes_tree = model.interpolate_travel_times(octree, stations_selection) traveltimes_lut = model.get_travel_times(octree, stations_selection) diff --git a/test/test_fast_marching.py b/test/test_fast_marching.py index 0fb62069..502180af 100644 --- a/test/test_fast_marching.py +++ b/test/test_fast_marching.py @@ -43,11 +43,23 @@ def stations_inside( return Stations(stations=stations) +def octree_cover(model: VelocityModel3D) -> Octree: + return Octree( + reference=model.center, + size_initial=2 * KM, + size_limit=500, + east_bounds=model.east_bounds, + north_bounds=model.north_bounds, + depth_bounds=model.depth_bounds, + absorbing_boundary=0, + ) + + @pytest_asyncio.fixture async def station_travel_times( octree: Octree, stations: Stations ) -> StationTravelTimeVolume: - octree.surface_elevation = 1 * KM + octree.reference.elevation = 1 * KM model = Constant3DVelocityModel(velocity=CONSTANT_VELOCITY, grid_spacing=100.0) model_3d = model.get_model(octree) return await StationTravelTimeVolume.calculate_from_eikonal( @@ -120,6 +132,7 @@ async def test_non_lin_loc(data_dir: Path, octree: Octree, stations: Stations) - phase="fm:P", velocity_model=NonLinLocVelocityModel(header_file=header_file), ) + octree = octree_cover(tracer.velocity_model.get_model(octree)) stations = stations_inside(tracer.velocity_model.get_model(octree)) await tracer.prepare(octree, stations) source = octree[1].as_location() diff --git a/test/test_octree.py b/test/test_octree.py index 65c1bae8..d9f5c3c7 100644 --- a/test/test_octree.py +++ b/test/test_octree.py @@ -1,6 +1,6 @@ from __future__ import annotations -from lassie.octree import Octree +from lassie.octree import NodeSplitError, Octree km = 1e3 @@ -16,7 +16,11 @@ def test_octree(octree: Octree, plot: bool) -> None: assert nnodes * 8 == octree.n_nodes child, *_ = octree[80].split() - child, *_ = child.split() + while True: + try: + child, *_ = child.split() + except NodeSplitError: + break for node in octree: node.semblance = node.depth + node.east + node.north