Skip to content

Commit

Permalink
Remove cProfile imports and unused variables // refactor prefetch
Browse files Browse the repository at this point in the history
  • Loading branch information
Marius Isken committed Jan 2, 2024
1 parent 4772156 commit 3cb832d
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 29 deletions.
2 changes: 0 additions & 2 deletions src/qseek/models/semblance.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import asyncio
import cProfile
import logging
from datetime import timedelta
from typing import TYPE_CHECKING, ClassVar, Iterable
Expand All @@ -22,7 +21,6 @@
from qseek.octree import Node


p = cProfile.Profile()
logger = logging.getLogger(__name__)


Expand Down
2 changes: 0 additions & 2 deletions src/qseek/search.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import asyncio
import cProfile
import logging
from collections import deque
from datetime import datetime, timedelta, timezone
Expand Down Expand Up @@ -53,7 +52,6 @@
logger = logging.getLogger(__name__)

SamplingRate = Literal[10, 20, 25, 50, 100]
p = cProfile.Profile()


class SearchStats(Stats):
Expand Down
53 changes: 28 additions & 25 deletions src/qseek/waveforms/squirrel.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class SquirrelPrefetcher:
downsample_to: float | None
load_time: timedelta = timedelta(seconds=0.0)

_load_queue: asyncio.Queue[Batch | None]
_fetched_batches: int
_task: asyncio.Task[None]

Expand All @@ -53,6 +54,8 @@ def __init__(
) -> None:
self.iterator = iterator
self.queue = asyncio.Queue(maxsize=queue_size)
self._load_queue = asyncio.Queue(maxsize=queue_size)

self.downsample_to = downsample_to
self.highpass = highpass
self.lowpass = lowpass
Expand All @@ -65,7 +68,6 @@ async def prefetch_worker(self) -> None:
"start prefetching data, queue size %d",
self.queue.maxsize,
)
done = asyncio.Event()

def post_processing(batch: Batch) -> Batch:
# Filter traces in-place
Expand All @@ -79,7 +81,7 @@ def post_processing(batch: Batch) -> Batch:
desired_deltat = 1.0 / self.downsample_to
for tr in batch.traces:
if tr.deltat < desired_deltat:
tr.downsample_to(desired_deltat, allow_upsample_max=2)
tr.downsample_to(desired_deltat, allow_upsample_max=3)
except Exception as exc:
logger.exception(exc)

Expand All @@ -97,29 +99,30 @@ def post_processing(batch: Batch) -> Batch:
logger.debug("filtered waveform batch in %s", datetime_now() - start)
return batch

async def load_next() -> None | Batch:
start = datetime_now()
batch = await asyncio.to_thread(next, self.iterator, None)
if batch is None:
done.set()
return None
logger.debug("read waveform batch in %s", datetime_now() - start)
self._fetched_batches += 1
self.load_time = datetime_now() - start
return batch

async def post_process_batch(batch: Batch) -> None:
await asyncio.to_thread(post_processing, batch)
await self.queue.put(batch)

post_processing_task: asyncio.Task | None = None
while not done.is_set():
batch = await load_next()
if batch is None:
break
if post_processing_task:
await post_processing_task
post_processing_task = asyncio.create_task(post_process_batch(batch))
async def load_data() -> None | Batch:
while True:
start = datetime_now()
batch = await asyncio.to_thread(next, self.iterator, None)
if batch is None:
return
logger.debug("read waveform batch in %s", datetime_now() - start)
self._fetched_batches += 1
self.load_time = datetime_now() - start
await self._load_queue.put(batch)

async def post_process() -> None:
while True:
batch = await self._load_queue.get()
if batch is None:
return
await asyncio.to_thread(post_processing, batch)
await self.queue.put(batch)

post_process_task = asyncio.create_task(post_process())
load_task = asyncio.create_task(load_data())

await load_task
await post_process_task

await self.queue.put(None)

Expand Down

0 comments on commit 3cb832d

Please sign in to comment.