Skip to content

Commit

Permalink
adding stats
Browse files Browse the repository at this point in the history
  • Loading branch information
miili committed Nov 7, 2023
1 parent 868a4bb commit b26e1b9
Show file tree
Hide file tree
Showing 10 changed files with 251 additions and 86 deletions.
21 changes: 17 additions & 4 deletions lassie/images/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,18 @@
import asyncio
import logging
from dataclasses import dataclass
from datetime import timedelta
from itertools import chain
from typing import TYPE_CHECKING, Annotated, Any, AsyncIterator, Iterator, Tuple, Union

from pydantic import Field, PrivateAttr, RootModel
from pydantic import Field, PositiveInt, PrivateAttr, RootModel

from lassie.images.base import ImageFunction, PickedArrival
from lassie.images.phase_net import PhaseNet, PhaseNetPick
from lassie.utils import PhaseDescription
from lassie.stats import Stats
from lassie.utils import PhaseDescription, datetime_now

if TYPE_CHECKING:
from datetime import timedelta

from pyrocko.trace import Trace

from lassie.images.base import WaveformImage
Expand All @@ -37,18 +37,26 @@
]


class ImageFunctionsStats(Stats):
queue_size: PositiveInt = 0
queue_max_size: PositiveInt = 0
time_per_batch: timedelta = timedelta()


class ImageFunctions(RootModel):
root: list[ImageFunctionType] = [PhaseNet()]

_queue: asyncio.Queue[Tuple[WaveformImages, WaveformBatch] | None] = PrivateAttr()
_processed_images: int = PrivateAttr(0)
_stats = PrivateAttr(ImageFunctionsStats())

def model_post_init(self, __context: Any) -> None:
# Check if phases are provided twice
phases = self.get_phases()
if len(set(phases)) != len(phases):
raise ValueError("A phase was provided twice")
self._queue = asyncio.Queue(maxsize=4)
self._stats.queue_max_size = self._queue.maxsize

async def process_traces(self, traces: list[Trace]) -> WaveformImages:
images = []
Expand All @@ -71,6 +79,8 @@ async def iter_images(
AsyncIterator[WaveformImages]: Async iterator over images.
"""

stats = self._stats

async def worker() -> None:
logger.info(
"start pre-processing images, queue size %d", self._queue.maxsize
Expand All @@ -87,7 +97,10 @@ async def worker() -> None:
task = asyncio.create_task(worker())

while True:
stats.queue_size = self._queue.qsize()
start_time = datetime_now()
ret = await self._queue.get()
stats.time_per_batch = datetime_now() - start_time
if ret is None:
logger.debug("image function finished")
break
Expand Down
5 changes: 4 additions & 1 deletion lassie/images/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
from typing import TYPE_CHECKING, Literal

import numpy as np
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, PrivateAttr

from lassie.models.phase_arrival import PhaseArrival
from lassie.models.station import Stations
from lassie.stats import Stats
from lassie.utils import PhaseDescription, downsample

if TYPE_CHECKING:
Expand All @@ -23,6 +24,8 @@ class PickedArrival(PhaseArrival):
class ImageFunction(BaseModel):
image: Literal["base"] = "base"

_stats: Stats = PrivateAttr(Stats())

async def process_traces(self, traces: list[Trace]) -> list[WaveformImage]:
...

Expand Down
1 change: 0 additions & 1 deletion lassie/images/phase_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,6 @@ async def process_traces(self, traces: list[Trace]) -> list[PhaseNetImage]:
scale = self.upscale_input
for tr in stream:
tr.stats.sampling_rate = tr.stats.sampling_rate / scale

annotations: Stream = await asyncio.to_thread(
self._phase_net.annotate,
stream,
Expand Down
128 changes: 88 additions & 40 deletions lassie/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,14 @@
from typing import TYPE_CHECKING, Deque, Literal

import numpy as np
from pydantic import BaseModel, Field, PositiveFloat, PositiveInt, PrivateAttr
from pydantic import (
BaseModel,
Field,
PositiveFloat,
PositiveInt,
PrivateAttr,
computed_field,
)
from pyrocko import parstack

from lassie.features import (
Expand All @@ -26,6 +33,7 @@
from lassie.octree import NodeSplitError, Octree
from lassie.signals import Signal
from lassie.station_corrections import StationCorrections
from lassie.stats import RuntimeStats, Stats
from lassie.tracers import (
CakeTracer,
ConstantVelocityTracer,
Expand All @@ -40,6 +48,7 @@
time_to_path,
)
from lassie.waveforms import PyrockoSquirrel, WaveformProviderType
from lassie.waveforms.base import WaveformBatch

if TYPE_CHECKING:
from pyrocko.trace import Trace
Expand All @@ -54,6 +63,71 @@
SamplingRate = Literal[10, 20, 25, 50, 100]


class SearchStats(Stats):
batch_time: datetime = datetime.min
batch_count: int = 0
batch_count_total: int = 0
processing_rate_bytes: float = 0.0

_batch_processing_times: Deque[timedelta] = PrivateAttr(
default_factory=lambda: deque(maxlen=25)
)

@computed_field
@property
def time_remaining(self) -> timedelta:
if not self.batch_count:
return timedelta()

remaining_batches = self.batch_count_total - self.batch_count
if not remaining_batches:
return timedelta()

return (
sum(self._batch_processing_times, timedelta())
/ len(self._batch_processing_times)
* remaining_batches
)

@computed_field
@property
def processed_percent(self) -> float:
if not self.batch_count_total:
return 0.0
return self.batch_count / self.batch_count_total * 100.0

def add_processed_batch(
self,
batch: WaveformBatch,
duration: timedelta,
log: bool = False,
) -> None:
self.batch_count = batch.i_batch
self.batch_count_total = batch.n_batches
self.batch_time = batch.end_time
self._batch_processing_times.append(duration)
self.processing_rate_bytes = batch.cumulative_bytes / duration.total_seconds()
if log:
self.log()

def log(self) -> None:
log_str = (
f"{self.batch_count+1}/{self.batch_count_total or '?'} {self.batch_time}"
)
logger.info(
"%s%% processed - batch %s in %s",
f"{self.processed_percent:.1f}" if self.processed_percent else "??",
log_str,
self._batch_processing_times[-1],
)
logger.info(
"processing rate %s/s - %s remaining - finish at %s",
human_readable_bytes(self.processing_rate_bytes),
self.time_remaining,
datetime.now() + self.time_remaining, # noqa: DTZ005
)


class SearchProgress(BaseModel):
time_progress: datetime | None = None
semblance_stats: SemblanceStats = SemblanceStats()
Expand Down Expand Up @@ -153,12 +227,14 @@ class Search(BaseModel):

# Signals
_new_detection: Signal[EventDetection] = PrivateAttr(Signal())
_batch_proc_time: Deque[timedelta] = PrivateAttr(
default_factory=lambda: deque(maxlen=25)
)
_batch_cum_durations: Deque[timedelta] = PrivateAttr(
default_factory=lambda: deque(maxlen=25)
)

_stats: SearchStats = PrivateAttr(SearchStats())
_runtime_stats: RuntimeStats = PrivateAttr(default_factory=RuntimeStats.new)

def model_post_init(self, *args) -> None:
self._runtime_stats.add_stats(self._stats)
self._runtime_stats.add_stats(self.data_provider._stats)
self._runtime_stats.add_stats(self.image_functions._stats)

def init_rundir(self, force: bool = False) -> None:
rundir = (
Expand Down Expand Up @@ -278,6 +354,7 @@ async def start(self, force_rundir: bool = False) -> None:
await self.prepare()

logger.info("starting search...")
stats = self._stats
batch_processing_start = datetime_now()
processing_start = datetime_now()

Expand All @@ -291,6 +368,8 @@ async def start(self, force_rundir: bool = False) -> None:
min_length=2 * self._window_padding,
)

# console = asyncio.create_task(self._runtime_stats.live_view())

async for images, batch in self.image_functions.iter_images(waveform_iterator):
images.set_stations(self.stations)
images.apply_exponent(self.image_mean_p)
Expand All @@ -316,43 +395,12 @@ async def start(self, force_rundir: bool = False) -> None:
)

processing_time = datetime_now() - batch_processing_start
self._batch_proc_time.append(processing_time)
self._batch_cum_durations.append(batch.cumulative_duration)

processed_percent = (
((batch.i_batch + 1) / batch.n_batches) * 100
if batch.n_batches
else 0.0
)
# processing_rate = (
# sum(self._batch_cum_durations, timedelta())
# / sum(self._batch_proc_time, timedelta()).total_seconds()
# )
processing_rate_bytes = human_readable_bytes(
batch.cumulative_bytes / processing_time.total_seconds()
)

logger.info(
"%s%% processed - batch %s in %s",
f"{processed_percent:.1f}" if processed_percent else "??",
batch.log_str(),
processing_time,
)
if batch.n_batches:
remaining_time = sum(self._batch_proc_time, timedelta()) / len(
self._batch_proc_time
)
remaining_time *= batch.n_batches - batch.i_batch - 1
logger.info(
"processing rate %s/s - %s remaining - finish at %s",
processing_rate_bytes,
remaining_time,
datetime.now() + remaining_time, # noqa: DTZ005
)
stats.add_processed_batch(batch, processing_time, log=True)

batch_processing_start = datetime_now()
self.set_progress(batch.end_time)

# console.cancel()
self._detections.dump_detections(jitter_location=self.octree.size_limit)
logger.info("finished search in %s", datetime_now() - processing_start)
logger.info("found %d detections", self._detections.n_detections)
Expand Down
79 changes: 79 additions & 0 deletions lassie/stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from __future__ import annotations

import asyncio
import logging
from typing import Iterator, Type

from pydantic import BaseModel, create_model
from pydantic.fields import ComputedFieldInfo, FieldInfo
from rich.console import Group
from rich.live import Live
from rich.panel import Panel
from rich.progress import Progress
from rich.table import Table

logger = logging.getLogger(__name__)

STATS_CLASSES: set[Type[Stats]] = set()


PROGRESS = Progress()


def titelify(name: str) -> str:
return " ".join(word for word in name.split("_")).capitalize()


class RuntimeStats(BaseModel):
@classmethod
def new(cls) -> RuntimeStats:
return create_model(
"RuntimeStats",
**{stats.__name__: (stats, None) for stats in STATS_CLASSES},
__base__=cls,
)()

def __rich__(self) -> Group:
return Group(
*(getattr(self, stat_name) for stat_name in self.model_fields_set),
PROGRESS,
)

def add_stats(self, stats: Stats) -> None:
logger.debug("Adding stats %s", stats.__class__.__name__)
if stats.__class__.__name__ not in self.model_fields:
raise ValueError(f"{stats.__class__.__name__} is not a valid stats name")
if stats.__class__.__name__ in self.model_fields_set:
raise ValueError(f"{stats.__class__.__name__} is already set")
setattr(self, stats.__class__.__name__, stats)

async def live_view(self):
with Live(
self,
refresh_per_second=10,
screen=True,
auto_refresh=True,
redirect_stdout=True,
redirect_stderr=True,
) as _:
while True:
await asyncio.sleep(1.0)


class Stats(BaseModel):
def __init_subclass__(cls: Type[Stats], **kwargs) -> None:
STATS_CLASSES.add(cls)

def populate_table(self, table: Table) -> None:
for name, field in self.iter_fields():
title = field.title or titelify(name)
table.add_row(title, str(getattr(self, name)))

def iter_fields(self) -> Iterator[tuple[str, FieldInfo | ComputedFieldInfo]]:
yield from self.model_fields.items()
yield from self.model_computed_fields.items()

def __rich__(self) -> Panel:
table = Table(box=None, row_styles=["", "dim"])
self.populate_table(table)
return Panel(table, title=self.__class__.__name__)
Loading

0 comments on commit b26e1b9

Please sign in to comment.