Skip to content

Commit

Permalink
semblance: adding caching
Browse files Browse the repository at this point in the history
  • Loading branch information
miili committed Nov 12, 2023
1 parent 948db0a commit f39f90e
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 5 deletions.
25 changes: 23 additions & 2 deletions lassie/models/semblance.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import asyncio
import logging
from datetime import timedelta
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Iterable

import numpy as np
from pydantic import BaseModel, PrivateAttr
Expand All @@ -16,6 +16,8 @@
if TYPE_CHECKING:
from datetime import datetime

from lassie.octree import Node


logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -46,10 +48,11 @@ def mean(old_attr, new_attr) -> float:
class Semblance:
_max_semblance: np.ndarray | None = None
_node_idx_max: np.ndarray | None = None
_node_hashes: list[bytes]

def __init__(
self,
n_nodes: int,
nodes: Iterable[Node],
n_samples: int,
start_time: datetime,
sampling_rate: float,
Expand All @@ -59,6 +62,8 @@ def __init__(
self.sampling_rate = sampling_rate
self.padding_samples = padding_samples
self.n_samples_unpadded = n_samples
self._node_hashes = [node.hash() for node in nodes]
n_nodes = len(self._node_hashes)

self.semblance_unpadded = np.zeros((n_nodes, n_samples), dtype=np.float32)
logger.debug(
Expand Down Expand Up @@ -96,6 +101,22 @@ def maximum_semblance(self) -> np.ndarray:
self._max_semblance = self.semblance.max(axis=0)
return self._max_semblance

def get_cache(self) -> dict[bytes, np.ndarray]:
return {
node_hash: self.semblance_unpadded[i, :]
for i, node_hash in enumerate(self._node_hashes)
}

def get_cache_mask(self, cache: dict[bytes, np.ndarray]) -> np.ndarray:
return np.array([node in cache for node in self._node_hashes])

def apply_cache(self, cache: dict[bytes, np.ndarray]):
if not cache:
return
mask = self.get_cache_mask(cache)
data = [cache[node] for node in self._node_hashes if node in cache]
self.semblance_unpadded[mask, :] = np.stack(data)

def maximum_node_semblance(self) -> np.ndarray:
return self.semblance.max(axis=1)

Expand Down
13 changes: 10 additions & 3 deletions lassie/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,7 @@ async def calculate_semblance(
ray_tracer: RayTracer,
n_samples_semblance: int,
semblance_data: np.ndarray,
mask: np.ndarray | None = None,
) -> np.ndarray:
logger.debug("stacking image %s", image.image_function.name)
parent = self.parent
Expand All @@ -536,6 +537,9 @@ async def calculate_semblance(
weights /= station_contribution[:, np.newaxis]
weights[traveltimes_bad] = 0.0

if mask is not None:
weights[mask] = 0.0

semblance_data, offsets = await asyncio.to_thread(
parstack.parstack,
arrays=image.get_trace_data(),
Expand Down Expand Up @@ -579,6 +583,7 @@ async def get_images(self, sampling_rate: float | None = None) -> WaveformImages
async def search(
self,
octree: Octree | None = None,
semblance_cache: dict[bytes, np.ndarray] | None = None,
) -> tuple[list[EventDetection], Trace]:
"""Searches for events in the given traces.
Expand All @@ -600,7 +605,7 @@ async def search(
round(parent._window_padding.total_seconds() * sampling_rate)
)
semblance = Semblance(
n_nodes=octree.n_nodes,
nodes=octree,
n_samples=self._n_samples_semblance(),
start_time=self.start_time,
sampling_rate=sampling_rate,
Expand All @@ -614,7 +619,9 @@ async def search(
ray_tracer=parent.ray_tracers.get_phase_tracer(image.phase),
semblance_data=semblance.semblance_unpadded,
n_samples_semblance=semblance.n_samples_unpadded,
mask=semblance.get_cache_mask(semblance_cache or {}),
)
semblance.apply_cache(semblance_cache or {})
semblance.apply_exponent(1.0 / parent.image_mean_p)
semblance.normalize(images.cumulative_weight())

Expand Down Expand Up @@ -657,9 +664,9 @@ async def search(
node.split()
except NodeSplitError:
continue

cache = semblance.get_cache()
del semblance
return await self.search(octree)
return await self.search(octree, semblance_cache=cache)

detections = []
for time_idx, semblance_detection in zip(
Expand Down

0 comments on commit f39f90e

Please sign in to comment.