Skip to content

Commit

Permalink
Update code: Add debug log message, save octree to pickle, and fix ti…
Browse files Browse the repository at this point in the history
…ming issue
  • Loading branch information
Marius Isken committed Jan 30, 2024
1 parent fa01262 commit 4fbeeaf
Show file tree
Hide file tree
Showing 9 changed files with 52 additions and 19 deletions.
1 change: 1 addition & 0 deletions src/qseek/images/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ async def worker() -> None:
break
yield ret

logger.debug("waiting for image function to finish")
await task

def get_phases(self) -> tuple[PhaseDescription, ...]:
Expand Down
10 changes: 7 additions & 3 deletions src/qseek/models/detection_uncertainty.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


# Equivalent to one standard deviation
THRESHOLD = 1.0 / np.sqrt(np.e)
PERCENTILE = 0.02


class DetectionUncertainty(BaseModel):
Expand All @@ -30,22 +30,26 @@ class DetectionUncertainty(BaseModel):

@classmethod
def from_event(
cls, source_node: Node, octree: Octree, width: float = THRESHOLD
cls, source_node: Node, octree: Octree, percentile: float = PERCENTILE
) -> Self:
"""
Calculate the uncertainty of an event detection.
Args:
event: The event detection to calculate the uncertainty for.
octree: The octree to use for the calculation.
percentile: The percentile to use for the calculation.
Defaults to 0.02 (2%).
Returns:
The calculated uncertainty.
"""
if not source_node.semblance:
raise ValueError("Source node must have semblance value.")

nodes = octree.get_nodes(semblance_threshold=source_node.semblance * width)
nodes = octree.get_nodes(
semblance_threshold=source_node.semblance * (1.0 - percentile)
)
vicinity_coords = np.array(
[(node.east, node.north, node.depth) for node in nodes]
)
Expand Down
13 changes: 13 additions & 0 deletions src/qseek/octree.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from functools import cached_property, reduce
from hashlib import sha1
from operator import mul
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Iterator, Sequence

import numpy as np
Expand Down Expand Up @@ -519,6 +520,18 @@ def copy(self, deep=False) -> Self:
node.set_tree(tree)
return tree

def save_pickle(self, filename: Path) -> None:
"""Save the octree to a pickle file.
Args:
filename (Path): Filename to save to.
"""
import pickle

logger.info("saving octree pickle to %s", filename)
with filename.open("wb") as f:
pickle.dump(self, f)

def __hash__(self) -> int:
return hash(
(
Expand Down
6 changes: 4 additions & 2 deletions src/qseek/plot/octree.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,17 @@ def octree_to_rectangles(
height=size,
)
rectangles.append(rect)

if normalize:
semblances /= semblances.max()
colors = cmap(semblances)
edge_colors = cm.get_cmap("binary")(semblances**2, alpha=0.8)

return PatchCollection(
patches=rectangles,
facecolors=colors,
edgecolors=(0, 0, 0, 0.3),
linewidths=0.5,
edgecolors=edge_colors,
linewidths=0.1,
)


Expand Down
10 changes: 3 additions & 7 deletions src/qseek/search.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import asyncio
import cProfile
import logging
from collections import deque
from datetime import datetime, timedelta, timezone
Expand Down Expand Up @@ -54,7 +53,6 @@
logger = logging.getLogger(__name__)

SamplingRate = Literal[10, 20, 25, 50, 100]
p = cProfile.Profile()


class SearchStats(Stats):
Expand Down Expand Up @@ -445,7 +443,6 @@ async def start(self, force_rundir: bool = False) -> None:
)

console = asyncio.create_task(RuntimeStats.live_view())
p.enable()

async for images, batch in self.image_functions.iter_images(waveform_iterator):
batch_processing_start = datetime_now()
Expand All @@ -471,12 +468,11 @@ async def start(self, force_rundir: bool = False) -> None:
)

self.set_progress(batch.end_time)
p.dump_stats("qseek.prof")
p.disable()

# await BackgroundTasks.wait_all()
console.cancel()
await BackgroundTasks.wait_all()
await self._catalog.save()
await self._catalog.export_detections(jitter_location=self.octree.size_limit)
console.cancel()
logger.info("finished search in %s", datetime_now() - processing_start)
logger.info("found %d detections", self._catalog.n_events)

Expand Down
5 changes: 4 additions & 1 deletion src/qseek/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,10 @@ def generate_grid() -> Table:
) as live:
while True:
live.update(generate_grid())
await asyncio.sleep(0.4)
try:
await asyncio.sleep(0.2)
except asyncio.CancelledError:
break


class Stats(BaseModel):
Expand Down
2 changes: 2 additions & 0 deletions src/qseek/tracers/cake.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import asyncio
import logging
import re
import struct
Expand Down Expand Up @@ -476,6 +477,7 @@ async def _interpolate_travel_times(
for coords in coordinates:
travel_times.append(self._interpolate_traveltimes_sptree(coords))
PROGRESS.update(status, advance=1)
await asyncio.sleep(0.0)

PROGRESS.remove_task(status)

Expand Down
22 changes: 16 additions & 6 deletions src/qseek/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ def cancel_all(cls) -> None:

@classmethod
async def wait_all(cls) -> None:
if not cls.tasks:
return
logger.debug("waiting for %d tasks to finish", len(cls.tasks))
await asyncio.gather(*cls.tasks)


Expand Down Expand Up @@ -472,8 +475,8 @@ def load_insights() -> None:
import qseek.insights # noqa: F401

logger.info("loaded qseek.insights package")
except ImportError as exc:
logger.warning("package qseek.insights not installed", exc_info=exc)
except ImportError:
logger.debug("package qseek.insights not installed")


MeasurementUnit = Literal[
Expand Down Expand Up @@ -512,6 +515,14 @@ def get_traces(self, traces_flt: list[Trace]) -> list[Trace]:

traces_flt = [tr for tr in traces_flt if tr.channel[-1] in self.channels]

tmins = {tr.tmin for tr in traces_flt}
tmaxs = {tr.tmax for tr in traces_flt}
if len(tmins) != 1 or len(tmaxs) != 1:
raise KeyError(
f"unhealthy timing on channels {self.channels}",
f" for: {', '.join('.'.join(tr.nslc_id) for tr in traces_flt)}",
)

if len(traces_flt) != self.number_channels:
raise KeyError(
f"cannot get {self.number_channels} channels"
Expand All @@ -520,10 +531,9 @@ def get_traces(self, traces_flt: list[Trace]) -> list[Trace]:
)
if self.normalize:
traces_norm = traces_flt[0].copy()
traces_norm.ydata = np.linalg.norm(
np.atleast_2d(np.array([tr.ydata for tr in traces_flt])),
axis=0,
)
data = np.atleast_2d(np.array([tr.ydata for tr in traces_flt]))

traces_norm.ydata = np.linalg.norm(data, axis=0)
return [traces_norm]
return traces_flt

Expand Down
2 changes: 2 additions & 0 deletions src/qseek/waveforms/squirrel.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ async def load_data() -> None | Batch:
start = datetime_now()
batch = await asyncio.to_thread(next, self.iterator, None)
if batch is None:
await self._load_queue.put(None)
return
logger.debug("read waveform batch in %s", datetime_now() - start)
self._fetched_batches += 1
Expand All @@ -123,6 +124,7 @@ async def post_process() -> None:

await load_task
await post_process_task
logger.debug("waiting for waveform batches to finish")

await self.queue.put(None)

Expand Down

0 comments on commit 4fbeeaf

Please sign in to comment.