From f39f90efa4da445e24e854acab5ba60d9a60f6a1 Mon Sep 17 00:00:00 2001 From: miili Date: Mon, 13 Nov 2023 00:26:03 +0100 Subject: [PATCH] semblance: adding caching --- lassie/models/semblance.py | 25 +++++++++++++++++++++++-- lassie/search.py | 13 ++++++++++--- 2 files changed, 33 insertions(+), 5 deletions(-) diff --git a/lassie/models/semblance.py b/lassie/models/semblance.py index d84246c8..794cff75 100644 --- a/lassie/models/semblance.py +++ b/lassie/models/semblance.py @@ -3,7 +3,7 @@ import asyncio import logging from datetime import timedelta -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Iterable import numpy as np from pydantic import BaseModel, PrivateAttr @@ -16,6 +16,8 @@ if TYPE_CHECKING: from datetime import datetime + from lassie.octree import Node + logger = logging.getLogger(__name__) @@ -46,10 +48,11 @@ def mean(old_attr, new_attr) -> float: class Semblance: _max_semblance: np.ndarray | None = None _node_idx_max: np.ndarray | None = None + _node_hashes: list[bytes] def __init__( self, - n_nodes: int, + nodes: Iterable[Node], n_samples: int, start_time: datetime, sampling_rate: float, @@ -59,6 +62,8 @@ def __init__( self.sampling_rate = sampling_rate self.padding_samples = padding_samples self.n_samples_unpadded = n_samples + self._node_hashes = [node.hash() for node in nodes] + n_nodes = len(self._node_hashes) self.semblance_unpadded = np.zeros((n_nodes, n_samples), dtype=np.float32) logger.debug( @@ -96,6 +101,22 @@ def maximum_semblance(self) -> np.ndarray: self._max_semblance = self.semblance.max(axis=0) return self._max_semblance + def get_cache(self) -> dict[bytes, np.ndarray]: + return { + node_hash: self.semblance_unpadded[i, :] + for i, node_hash in enumerate(self._node_hashes) + } + + def get_cache_mask(self, cache: dict[bytes, np.ndarray]) -> np.ndarray: + return np.array([node in cache for node in self._node_hashes]) + + def apply_cache(self, cache: dict[bytes, np.ndarray]): + if not cache: + return + mask = self.get_cache_mask(cache) + data = [cache[node] for node in self._node_hashes if node in cache] + self.semblance_unpadded[mask, :] = np.stack(data) + def maximum_node_semblance(self) -> np.ndarray: return self.semblance.max(axis=1) diff --git a/lassie/search.py b/lassie/search.py index 9fa95099..78160d6a 100644 --- a/lassie/search.py +++ b/lassie/search.py @@ -512,6 +512,7 @@ async def calculate_semblance( ray_tracer: RayTracer, n_samples_semblance: int, semblance_data: np.ndarray, + mask: np.ndarray | None = None, ) -> np.ndarray: logger.debug("stacking image %s", image.image_function.name) parent = self.parent @@ -536,6 +537,9 @@ async def calculate_semblance( weights /= station_contribution[:, np.newaxis] weights[traveltimes_bad] = 0.0 + if mask is not None: + weights[mask] = 0.0 + semblance_data, offsets = await asyncio.to_thread( parstack.parstack, arrays=image.get_trace_data(), @@ -579,6 +583,7 @@ async def get_images(self, sampling_rate: float | None = None) -> WaveformImages async def search( self, octree: Octree | None = None, + semblance_cache: dict[bytes, np.ndarray] | None = None, ) -> tuple[list[EventDetection], Trace]: """Searches for events in the given traces. @@ -600,7 +605,7 @@ async def search( round(parent._window_padding.total_seconds() * sampling_rate) ) semblance = Semblance( - n_nodes=octree.n_nodes, + nodes=octree, n_samples=self._n_samples_semblance(), start_time=self.start_time, sampling_rate=sampling_rate, @@ -614,7 +619,9 @@ async def search( ray_tracer=parent.ray_tracers.get_phase_tracer(image.phase), semblance_data=semblance.semblance_unpadded, n_samples_semblance=semblance.n_samples_unpadded, + mask=semblance.get_cache_mask(semblance_cache or {}), ) + semblance.apply_cache(semblance_cache or {}) semblance.apply_exponent(1.0 / parent.image_mean_p) semblance.normalize(images.cumulative_weight()) @@ -657,9 +664,9 @@ async def search( node.split() except NodeSplitError: continue - + cache = semblance.get_cache() del semblance - return await self.search(octree) + return await self.search(octree, semblance_cache=cache) detections = [] for time_idx, semblance_detection in zip(