diff --git a/src/db.rs b/src/db.rs index 356f5ca..5d0ced4 100644 --- a/src/db.rs +++ b/src/db.rs @@ -1,5 +1,6 @@ pub mod impls; +use futures::{Stream, TryStreamExt}; use sqlx::migrate::{MigrateDatabase, Migrator}; use sqlx::{Postgres, QueryBuilder}; @@ -38,19 +39,18 @@ impl Db { #[tracing::instrument(skip(self))] pub async fn fetch_masks(&self, id: usize) -> eyre::Result> { - let masks: Vec<(i64, Bits)> = sqlx::query_as( + let masks_stream = sqlx::query_as( r#" SELECT id, mask FROM masks WHERE id > $1 ORDER BY id ASC - "#, + "#, ) .bind(id as i64) - .fetch_all(&self.pool) - .await?; + .fetch(&self.pool); - Ok(filter_sequential_items(masks, 1 + id as i64)) + Ok(stream_sequential_items(masks_stream, 1 + id as i64).await?) } #[tracing::instrument(skip(self))] @@ -106,19 +106,18 @@ impl Db { &self, id: usize, ) -> eyre::Result> { - let shares: Vec<(i64, EncodedBits)> = sqlx::query_as( + let shares_stream = sqlx::query_as( r#" SELECT id, share FROM shares WHERE id > $1 ORDER BY id ASC - "#, + "#, ) .bind(id as i64) - .fetch_all(&self.pool) - .await?; + .fetch(&self.pool); - Ok(filter_sequential_items(shares, 1 + id as i64)) + Ok(stream_sequential_items(shares_stream, 1 + id as i64).await?) } #[tracing::instrument(skip(self))] @@ -196,31 +195,23 @@ impl Db { } } -fn filter_sequential_items( - items: impl IntoIterator, +async fn stream_sequential_items( + mut stream: impl Stream> + Unpin, first_id: i64, -) -> Vec { - let mut last_key = None; +) -> Result, E> { + let mut items = vec![]; - let mut items = items.into_iter(); - - std::iter::from_fn(move || { - let (key, value) = items.next()?; - - if let Some(last_key) = last_key { - if key != last_key + 1 { - return None; - } - } else if key != first_id { - return None; + let mut next_key = first_id; + while let Some((key, value)) = stream.try_next().await? { + if key != next_key { + break; } - last_key = Some(key); + next_key = key + 1; + items.push(value); + } - Some(value) - }) - .fuse() - .collect() + Ok(items) } #[cfg(test)]