Skip to content

Commit

Permalink
squirrel: async loading
Browse files Browse the repository at this point in the history
  • Loading branch information
Marius Isken committed Mar 22, 2024
1 parent 7a6e175 commit c9ae757
Showing 1 changed file with 12 additions and 18 deletions.
30 changes: 12 additions & 18 deletions src/qseek/waveforms/squirrel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import logging
from datetime import datetime, timedelta
from pathlib import Path
from typing import TYPE_CHECKING, AsyncIterator, Iterator, Literal
from typing import TYPE_CHECKING, AsyncIterator, Literal

from pydantic import (
AwareDatetime,
Expand Down Expand Up @@ -42,36 +42,30 @@ class SquirrelPrefetcher:

def __init__(
self,
iterator: Iterator[Batch],
iterator: AsyncIterator[Batch],
queue_size: int = 8,
) -> None:
self.iterator = iterator
self.queue = asyncio.Queue(maxsize=queue_size)
self._load_queue = asyncio.Queue(maxsize=queue_size)
self._fetched_batches = 0

self._task = asyncio.create_task(self.prefetch_worker())

async def prefetch_worker(self) -> None:
logger.info(
"start prefetching data, queue size %d",
"start pre-fetching data, queue size %d",
self.queue.maxsize,
)

async def load_data() -> None | Batch:
while True:
start = datetime_now()
batch = await asyncio.to_thread(next, self.iterator, None)
if batch is None:
await self.queue.put(None)
return
logger.debug("read waveform batch in %s", datetime_now() - start)
self._fetched_batches += 1
self.load_time = datetime_now() - start
await self.queue.put(batch)
start = datetime_now()
async for batch in self.iterator:
self.load_time = datetime_now() - start
self._fetched_batches += 1
logger.debug("read waveform batch in %s", self.load_time)
start = datetime_now()
await self.queue.put(batch)

await asyncio.create_task(load_data())
logger.debug("loading waveform batches to finish")
await self.queue.put(None)


class SquirrelStats(Stats):
Expand Down Expand Up @@ -209,7 +203,7 @@ async def iter_batches(
end_time - start_time,
)

iterator = squirrel.chopper_waveforms(
iterator = squirrel.chopper_waveforms_async(
tmin=(start_time + window_padding).timestamp(),
tmax=(end_time - window_padding).timestamp(),
tinc=window_increment.total_seconds(),
Expand Down

0 comments on commit c9ae757

Please sign in to comment.