diff --git a/examples/redis-mq-example/src/main.rs b/examples/redis-mq-example/src/main.rs index 9a73e61..92a640a 100644 --- a/examples/redis-mq-example/src/main.rs +++ b/examples/redis-mq-example/src/main.rs @@ -56,7 +56,7 @@ where type Layer = AckLayer; - fn poll(mut self, _worker_id: WorkerId) -> Poller { + fn poll(mut self, _worker_id: Worker) -> Poller { let (mut tx, rx) = mpsc::channel(self.config.get_buffer_size()); let stream: RequestStream> = Box::pin(rx); let layer = AckLayer::new(self.clone()); diff --git a/packages/apalis-core/src/lib.rs b/packages/apalis-core/src/lib.rs index f5475bb..095dfad 100644 --- a/packages/apalis-core/src/lib.rs +++ b/packages/apalis-core/src/lib.rs @@ -27,7 +27,7 @@ use futures::Stream; use poller::Poller; use serde::{Deserialize, Serialize}; use tower::Service; -use worker::WorkerId; +use worker::{Context, Worker}; /// Represent utilities for creating worker instances. pub mod builder; @@ -81,7 +81,7 @@ pub trait Backend { /// Returns a poller that is ready for streaming fn poll>( self, - worker: WorkerId, + worker: Worker, ) -> Poller; } /// A codec allows backends to encode and decode data @@ -165,7 +165,7 @@ pub mod test_utils { use crate::error::BoxDynError; use crate::request::Request; use crate::task::task_id::TaskId; - use crate::worker::WorkerId; + use crate::worker::{Worker, WorkerId}; use crate::Backend; use futures::channel::mpsc::{channel, Receiver, Sender}; use futures::future::BoxFuture; @@ -264,8 +264,9 @@ pub mod test_utils { >>::Future: Send + 'static, { let worker_id = WorkerId::new("test-worker"); + let worker = Worker::new(worker_id, crate::worker::Context::default()); let b = backend.clone(); - let mut poller = b.poll::(worker_id); + let mut poller = b.poll::(worker); let (stop_tx, mut stop_rx) = channel::<()>(1); let (mut res_tx, res_rx) = channel(10); diff --git a/packages/apalis-core/src/memory.rs b/packages/apalis-core/src/memory.rs index 5645992..033696b 100644 --- a/packages/apalis-core/src/memory.rs +++ b/packages/apalis-core/src/memory.rs @@ -2,7 +2,7 @@ use crate::{ mq::MessageQueue, poller::{controller::Controller, stream::BackendStream}, request::{Request, RequestStream}, - worker::WorkerId, + worker::{self, Worker}, Backend, Poller, }; use futures::{ @@ -101,7 +101,7 @@ impl Backend, Res> for MemoryStora type Layer = Identity; - fn poll(self, _worker: WorkerId) -> Poller { + fn poll(self, _worker: Worker) -> Poller { let stream = self.inner.map(|r| Ok(Some(r))).boxed(); Poller { stream: BackendStream::new(stream, self.controller), diff --git a/packages/apalis-core/src/request.rs b/packages/apalis-core/src/request.rs index 428b4a8..f25600a 100644 --- a/packages/apalis-core/src/request.rs +++ b/packages/apalis-core/src/request.rs @@ -9,7 +9,7 @@ use crate::{ error::Error, poller::Poller, task::{attempt::Attempt, namespace::Namespace, task_id::TaskId}, - worker::WorkerId, + worker::{Context, Worker}, Backend, }; @@ -111,10 +111,10 @@ impl Backend, Res> for RequestStream(self, _worker: WorkerId) -> Poller { + fn poll(self, _worker: Worker) -> Poller { Poller { stream: self, - heartbeat: Box::pin(async {}), + heartbeat: Box::pin(futures::future::pending()), layer: Identity::new(), } } diff --git a/packages/apalis-core/src/worker/mod.rs b/packages/apalis-core/src/worker/mod.rs index 2557b31..ed9df74 100644 --- a/packages/apalis-core/src/worker/mod.rs +++ b/packages/apalis-core/src/worker/mod.rs @@ -303,7 +303,7 @@ impl Worker> { }; let backend = self.state.backend; let service = self.state.service; - let poller = backend.poll::(worker_id.clone()); + let poller = backend.poll::(worker.clone()); let stream = poller.stream; let heartbeat = poller.heartbeat.boxed(); let layer = poller.layer; @@ -387,7 +387,7 @@ impl Future for Runnable { } /// Stores the Workers context -#[derive(Clone)] +#[derive(Clone, Default)] pub struct Context { task_count: Arc, wakers: Arc>>, diff --git a/packages/apalis-cron/src/lib.rs b/packages/apalis-cron/src/lib.rs index 621f002..cb0942a 100644 --- a/packages/apalis-cron/src/lib.rs +++ b/packages/apalis-cron/src/lib.rs @@ -57,11 +57,11 @@ //! } //! ``` +use apalis_core::worker::{Context, Worker}; use apalis_core::layers::Identity; use apalis_core::poller::Poller; use apalis_core::request::RequestStream; use apalis_core::task::namespace::Namespace; -use apalis_core::worker::WorkerId; use apalis_core::Backend; use apalis_core::{error::Error, request::Request}; use chrono::{DateTime, TimeZone, Utc}; @@ -145,8 +145,8 @@ where type Layer = Identity; - fn poll(self, _worker: WorkerId) -> Poller { + fn poll(self, _worker: Worker) -> Poller { let stream = self.into_stream(); - Poller::new(stream, async {}) + Poller::new(stream, futures::future::pending()) } } diff --git a/packages/apalis-redis/Cargo.toml b/packages/apalis-redis/Cargo.toml index 9286d52..f622029 100644 --- a/packages/apalis-redis/Cargo.toml +++ b/packages/apalis-redis/Cargo.toml @@ -34,6 +34,7 @@ tokio = { version = "1", features = ["rt", "net"], optional = true } async-std = { version = "1.13.0", optional = true } async-trait = "0.1.80" tower = "0.4" +thiserror = "1" [dev-dependencies] diff --git a/packages/apalis-redis/src/lib.rs b/packages/apalis-redis/src/lib.rs index bffe581..f2848f5 100644 --- a/packages/apalis-redis/src/lib.rs +++ b/packages/apalis-redis/src/lib.rs @@ -33,3 +33,4 @@ pub use storage::Config; pub use storage::RedisContext; pub use storage::RedisQueueInfo; pub use storage::RedisStorage; +pub use storage::RedisPollError; diff --git a/packages/apalis-redis/src/storage.rs b/packages/apalis-redis/src/storage.rs index 1f53048..9fdeb72 100644 --- a/packages/apalis-redis/src/storage.rs +++ b/packages/apalis-redis/src/storage.rs @@ -10,10 +10,10 @@ use apalis_core::service_fn::FromRequest; use apalis_core::storage::Storage; use apalis_core::task::namespace::Namespace; use apalis_core::task::task_id::TaskId; -use apalis_core::worker::WorkerId; +use apalis_core::worker::{Event, Worker, WorkerId}; use apalis_core::{Backend, Codec}; use chrono::{DateTime, Utc}; -use futures::channel::mpsc::{self, Sender}; +use futures::channel::mpsc::{self, SendError, Sender}; use futures::{select, FutureExt, SinkExt, StreamExt, TryFutureExt}; use log::*; use redis::aio::ConnectionLike; @@ -106,6 +106,34 @@ impl FromRequest> for RedisContext { } } +/// Errors that can occur while polling a Redis backend. +#[derive(thiserror::Error, Debug)] +pub enum RedisPollError { + /// Error during a keep-alive heartbeat. + #[error("KeepAlive heartbeat encountered an error: `{0}`")] + KeepAliveError(RedisError), + + /// Error during enqueueing scheduled tasks. + #[error("EnqueueScheduled heartbeat encountered an error: `{0}`")] + EnqueueScheduledError(RedisError), + + /// Error during polling for the next task or message. + #[error("PollNext heartbeat encountered an error: `{0}`")] + PollNextError(RedisError), + + /// Error during enqueueing tasks for worker consumption. + #[error("Enqueue for worker consumption encountered an error: `{0}`")] + EnqueueError(SendError), + + /// Error during acknowledgment of tasks. + #[error("Ack heartbeat encountered an error: `{0}`")] + AckError(RedisError), + + /// Error during re-enqueuing orphaned tasks. + #[error("ReenqueueOrphaned heartbeat encountered an error: `{0}`")] + ReenqueueOrphanedError(RedisError), +} + /// Config for a [RedisStorage] #[derive(Clone, Debug)] pub struct Config { @@ -412,7 +440,7 @@ where fn poll>>( mut self, - worker: WorkerId, + worker: Worker, ) -> Poller { let (mut tx, rx) = mpsc::channel(self.config.buffer_size); let (ack, ack_rx) = mpsc::channel(self.config.buffer_size); @@ -433,32 +461,32 @@ where let mut ack_stream = ack_rx.fuse(); - if let Err(e) = self.keep_alive(&worker).await { - error!("RegistrationError: {}", e); + if let Err(e) = self.keep_alive(worker.id()).await { + worker.emit(Event::Error(Box::new(RedisPollError::KeepAliveError(e)))); } loop { select! { _ = keep_alive_stm.next() => { - if let Err(e) = self.keep_alive(&worker).await { - error!("KeepAliveError: {}", e); + if let Err(e) = self.keep_alive(worker.id()).await { + worker.emit(Event::Error(Box::new(RedisPollError::KeepAliveError(e)))); } } _ = enqueue_scheduled_stm.next() => { if let Err(e) = self.enqueue_scheduled(config.buffer_size).await { - error!("EnqueueScheduledError: {}", e); + worker.emit(Event::Error(Box::new(RedisPollError::EnqueueScheduledError(e)))); } } _ = poll_next_stm.next() => { - let res = self.fetch_next(&worker).await; + let res = self.fetch_next(worker.id()).await; match res { Err(e) => { - error!("PollNextError: {}", e); + worker.emit(Event::Error(Box::new(RedisPollError::PollNextError(e)))); } Ok(res) => { for job in res { if let Err(e) = tx.send(Ok(Some(job))).await { - error!("EnqueueError: {}", e); + worker.emit(Event::Error(Box::new(RedisPollError::EnqueueError(e)))); } } } @@ -468,7 +496,7 @@ where id_to_ack = ack_stream.next() => { if let Some((ctx, res)) = id_to_ack { if let Err(e) = self.ack(&ctx, &res).await { - error!("AckError: {}", e); + worker.emit(Event::Error(Box::new(RedisPollError::AckError(e)))); } } } @@ -476,7 +504,7 @@ where let dead_since = Utc::now() - chrono::Duration::from_std(config.reenqueue_orphaned_after).unwrap(); if let Err(e) = self.reenqueue_orphaned((config.buffer_size * 10) as i32, dead_since).await { - error!("ReenqueueOrphanedError: {}", e); + worker.emit(Event::Error(Box::new(RedisPollError::ReenqueueOrphanedError(e)))); } } }; diff --git a/packages/apalis-sql/Cargo.toml b/packages/apalis-sql/Cargo.toml index 7f14706..9068003 100644 --- a/packages/apalis-sql/Cargo.toml +++ b/packages/apalis-sql/Cargo.toml @@ -37,6 +37,7 @@ tokio = { version = "1", features = ["rt", "net"], optional = true } futures-lite = "2.3.0" async-std = { version = "1.13.0", optional = true } chrono = { version = "0.4", features = ["serde"] } +thiserror = "1" [dev-dependencies] diff --git a/packages/apalis-sql/src/mysql.rs b/packages/apalis-sql/src/mysql.rs index 1795fec..a16ad36 100644 --- a/packages/apalis-sql/src/mysql.rs +++ b/packages/apalis-sql/src/mysql.rs @@ -1,5 +1,5 @@ use apalis_core::codec::json::JsonCodec; -use apalis_core::error::Error; +use apalis_core::error::{BoxDynError, Error}; use apalis_core::layers::{Ack, AckLayer}; use apalis_core::notify::Notify; use apalis_core::poller::controller::Controller; @@ -10,7 +10,7 @@ use apalis_core::response::Response; use apalis_core::storage::Storage; use apalis_core::task::namespace::Namespace; use apalis_core::task::task_id::TaskId; -use apalis_core::worker::WorkerId; +use apalis_core::worker::{Context, Event, Worker, WorkerId}; use apalis_core::{Backend, Codec}; use async_stream::try_stream; use chrono::{DateTime, Utc}; @@ -180,8 +180,7 @@ where yield { let (req, ctx) = job.req.take_parts(); let req = C::decode(req) - .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e))) - .unwrap(); + .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?; let mut req: Request = Request::new_with_parts(req, ctx); req.parts.namespace = Some(Namespace(self.config.namespace.clone())); Some(req) @@ -371,16 +370,37 @@ where } } +/// Errors that can occur while polling a MySQL database. +#[derive(thiserror::Error, Debug)] +pub enum MysqlPollError { + /// Error during task acknowledgment. + #[error("Encountered an error during ACK: `{0}`")] + AckError(sqlx::Error), + + /// Error during result encoding. + #[error("Encountered an error during encoding the result: {0}")] + CodecError(BoxDynError), + + /// Error during a keep-alive heartbeat. + #[error("Encountered an error during KeepAlive heartbeat: `{0}`")] + KeepAliveError(sqlx::Error), + + /// Error during re-enqueuing orphaned tasks. + #[error("Encountered an error during ReenqueueOrphaned heartbeat: `{0}`")] + ReenqueueOrphanedError(sqlx::Error), +} + impl Backend, Res> for MysqlStorage where Req: Serialize + DeserializeOwned + Sync + Send + Unpin + 'static, C: Debug + Codec + Clone + Send + 'static + Sync, + C::Error: std::error::Error + 'static + Send + Sync, { type Stream = BackendStream>>; type Layer = AckLayer, Req, SqlContext, Res>; - fn poll(self, worker: WorkerId) -> Poller { + fn poll(self, worker: Worker) -> Poller { let layer = AckLayer::new(self.clone()); let config = self.config.clone(); let controller = self.controller.clone(); @@ -389,9 +409,10 @@ where let mut hb_storage = self.clone(); let requeue_storage = self.clone(); let stream = self - .stream_jobs(&worker, config.poll_interval, config.buffer_size) + .stream_jobs(worker.id(), config.poll_interval, config.buffer_size) .map_err(|e| Error::SourceError(Arc::new(Box::new(e)))); let stream = BackendStream::new(stream.boxed(), controller); + let w = worker.clone(); let ack_heartbeat = async move { while let Some(ids) = ack_notify @@ -403,28 +424,39 @@ where for (ctx, res) in ids { let query = "UPDATE jobs SET status = ?, done_at = now(), last_error = ? WHERE id = ? AND lock_by = ?"; let query = sqlx::query(query); - let query = query - .bind(calculate_status(&res.inner).to_string()) - .bind( - serde_json::to_string(&res.inner.as_ref().map_err(|e| e.to_string())) - .unwrap(), - ) - .bind(res.task_id.to_string()) - .bind(ctx.lock_by().as_ref().unwrap().to_string()); - if let Err(e) = query.execute(&pool).await { - error!("Ack failed: {e}"); + let last_result = + C::encode(res.inner.as_ref().map_err(|e| e.to_string())).map_err(Box::new); + match (last_result, ctx.lock_by()) { + (Ok(val), Some(worker_id)) => { + let query = query + .bind(calculate_status(&res.inner).to_string()) + .bind(val) + .bind(res.task_id.to_string()) + .bind(worker_id.to_string()); + if let Err(e) = query.execute(&pool).await { + w.emit(Event::Error(Box::new(MysqlPollError::AckError(e)))); + } + } + (Err(error), Some(_)) => { + w.emit(Event::Error(Box::new(MysqlPollError::CodecError(error)))); + } + _ => { + unreachable!( + "Attempted to ACK without a worker attached. This is a bug, File it on the repo" + ); + } } } apalis_core::sleep(config.poll_interval).await; } }; - + let w = worker.clone(); let heartbeat = async move { loop { let now = Utc::now(); - if let Err(e) = hb_storage.keep_alive_at::(&worker, now).await { - error!("Heartbeat failed: {e}"); + if let Err(e) = hb_storage.keep_alive_at::(w.id(), now).await { + w.emit(Event::Error(Box::new(MysqlPollError::KeepAliveError(e)))); } apalis_core::sleep(config.keep_alive).await; } @@ -432,12 +464,21 @@ where let reenqueue_beat = async move { loop { let dead_since = Utc::now() - - chrono::Duration::from_std(config.reenqueue_orphaned_after).unwrap(); + - chrono::Duration::from_std(config.reenqueue_orphaned_after) + .expect("Could not calculate dead since"); if let Err(e) = requeue_storage - .reenqueue_orphaned(config.buffer_size.try_into().unwrap(), dead_since) + .reenqueue_orphaned( + config + .buffer_size + .try_into() + .expect("Could not convert usize to i32"), + dead_since, + ) .await { - error!("ReenqueueOrphaned failed: {e}"); + worker.emit(Event::Error(Box::new( + MysqlPollError::ReenqueueOrphanedError(e), + ))); } apalis_core::sleep(config.poll_interval).await; } @@ -463,7 +504,10 @@ where type AckError = sqlx::Error; async fn ack(&mut self, ctx: &Self::Context, res: &Response) -> Result<(), sqlx::Error> { self.ack_notify - .notify((ctx.clone(), res.map(|res| C::encode(res).unwrap()))) + .notify(( + ctx.clone(), + res.map(|res| C::encode(res).expect("Could not encode result")), + )) .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::BrokenPipe, e)))?; Ok(()) diff --git a/packages/apalis-sql/src/postgres.rs b/packages/apalis-sql/src/postgres.rs index eacf3a9..ad7f394 100644 --- a/packages/apalis-sql/src/postgres.rs +++ b/packages/apalis-sql/src/postgres.rs @@ -41,7 +41,7 @@ use crate::context::SqlContext; use crate::{calculate_status, Config}; use apalis_core::codec::json::JsonCodec; -use apalis_core::error::Error; +use apalis_core::error::{BoxDynError, Error}; use apalis_core::layers::{Ack, AckLayer}; use apalis_core::notify::Notify; use apalis_core::poller::controller::Controller; @@ -52,7 +52,7 @@ use apalis_core::response::Response; use apalis_core::storage::Storage; use apalis_core::task::namespace::Namespace; use apalis_core::task::task_id::TaskId; -use apalis_core::worker::WorkerId; +use apalis_core::worker::{Context, Event, Worker, WorkerId}; use apalis_core::{Backend, Codec}; use chrono::{DateTime, Utc}; use futures::channel::mpsc; @@ -118,16 +118,45 @@ impl fmt::Debug for PostgresStorage { } } +/// Errors that can occur while polling a PostgreSQL database. +#[derive(thiserror::Error, Debug)] +pub enum PgPollError { + /// Error during task acknowledgment. + #[error("Encountered an error during ACK: `{0}`")] + AckError(sqlx::Error), + + /// Error while fetching the next item. + #[error("Encountered an error during FetchNext: `{0}`")] + FetchNextError(apalis_core::error::Error), + + /// Error while listening to PostgreSQL notifications. + #[error("Encountered an error during listening to PgNotification: {0}")] + PgNotificationError(apalis_core::error::Error), + + /// Error during a keep-alive heartbeat. + #[error("Encountered an error during KeepAlive heartbeat: `{0}`")] + KeepAliveError(sqlx::Error), + + /// Error during re-enqueuing orphaned tasks. + #[error("Encountered an error during ReenqueueOrphaned heartbeat: `{0}`")] + ReenqueueOrphanedError(sqlx::Error), + + /// Error during result encoding. + #[error("Encountered an error during encoding the result: {0}")] + CodecError(BoxDynError), +} + impl Backend, Res> for PostgresStorage where T: Serialize + DeserializeOwned + Sync + Send + Unpin + 'static, C: Codec + Send + 'static, + C::Error: std::error::Error + 'static + Send + Sync, { type Stream = BackendStream>>; type Layer = AckLayer, T, SqlContext, Res>; - fn poll(mut self, worker: WorkerId) -> Poller { + fn poll(mut self, worker: Worker) -> Poller { let layer = AckLayer::new(self.clone()); let subscription = self.subscription.clone(); let config = self.config.clone(); @@ -168,23 +197,23 @@ where } if let Err(e) = self - .keep_alive_at::(&worker, Utc::now().timestamp()) + .keep_alive_at::(worker.id(), Utc::now().timestamp()) .await { - error!("KeepAliveError: {}", e); + worker.emit(Event::Error(Box::new(PgPollError::KeepAliveError(e)))); } loop { select! { _ = keep_alive_stm.next() => { - if let Err(e) = self.keep_alive_at::(&worker, Utc::now().timestamp()).await { - error!("KeepAliveError: {}", e); + if let Err(e) = self.keep_alive_at::(worker.id(), Utc::now().timestamp()).await { + worker.emit(Event::Error(Box::new(PgPollError::KeepAliveError(e)))); } } ids = ack_stream.next() => { if let Some(ids) = ids { - let ack_ids: Vec<(String, String, String, String, u64)> = ids.iter().map(|(ctx, res)| { - (res.task_id.to_string(), ctx.lock_by().clone().unwrap().to_string(), serde_json::to_string(&res.inner.as_ref().map_err(|e| e.to_string())).expect("Could not convert response to json"), calculate_status(&res.inner).to_string(), (res.attempt.current() + 1) as u64 ) + let ack_ids: Vec<(String, String, String, String, u64)> = ids.iter().map(|(_ctx, res)| { + (res.task_id.to_string(), worker.id().to_string(), serde_json::to_string(&res.inner.as_ref().map_err(|e| e.to_string())).expect("Could not convert response to json"), calculate_status(&res.inner).to_string(), (res.attempt.current() + 1) as u64 ) }).collect(); let query = "UPDATE apalis.jobs @@ -203,31 +232,41 @@ where ) Q WHERE apalis.jobs.id = Q.id; "; - if let Err(e) = sqlx::query(query) - .bind(serde_json::to_value(&ack_ids).unwrap()) - .execute(&pool) - .await - { - panic!("AckError: {e}"); + let codec_res = C::encode(&ack_ids); + match codec_res { + Ok(val) => { + if let Err(e) = sqlx::query(query) + .bind(val) + .execute(&pool) + .await + { + worker.emit(Event::Error(Box::new(PgPollError::AckError(e)))); + } + } + Err(e) => { + worker.emit(Event::Error(Box::new(PgPollError::CodecError(e.into())))); + } } + } } _ = poll_next_stm.next() => { - if let Err(e) = fetch_next_batch(&mut self, &worker, &mut tx).await { - error!("FetchNextError: {e}"); + if let Err(e) = fetch_next_batch(&mut self, worker.id(), &mut tx).await { + worker.emit(Event::Error(Box::new(PgPollError::FetchNextError(e)))); } } _ = pg_notification.next() => { - if let Err(e) = fetch_next_batch(&mut self, &worker, &mut tx).await { - error!("PgNotificationError: {e}"); + if let Err(e) = fetch_next_batch(&mut self, worker.id(), &mut tx).await { + worker.emit(Event::Error(Box::new(PgPollError::PgNotificationError(e)))); + } } _ = reenqueue_orphaned_stm.next() => { let dead_since = Utc::now() - - chrono::Duration::from_std(config.reenqueue_orphaned_after).unwrap(); + - chrono::Duration::from_std(config.reenqueue_orphaned_after).expect("could not build dead_since"); if let Err(e) = self.reenqueue_orphaned((config.buffer_size * 10) as i32, dead_since).await { - error!("ReenqueueOrphanedError: {}", e); + worker.emit(Event::Error(Box::new(PgPollError::ReenqueueOrphanedError(e)))); } } @@ -402,7 +441,7 @@ where let (req, parts) = job.req.take_parts(); let req = C::decode(req) .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e))) - .unwrap(); + .expect("Unable to decode"); let mut req = Request::new_with_parts(req, parts); req.parts.namespace = Some(Namespace(self.config.namespace.clone())); req @@ -576,7 +615,7 @@ where let res = res.clone().map(|r| { C::encode(r) .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::Interrupted, e))) - .unwrap() + .expect("Could not encode result") }); self.ack_notify diff --git a/packages/apalis-sql/src/sqlite.rs b/packages/apalis-sql/src/sqlite.rs index 301d8c8..7efe8c2 100644 --- a/packages/apalis-sql/src/sqlite.rs +++ b/packages/apalis-sql/src/sqlite.rs @@ -11,7 +11,7 @@ use apalis_core::response::Response; use apalis_core::storage::Storage; use apalis_core::task::namespace::Namespace; use apalis_core::task::task_id::TaskId; -use apalis_core::worker::WorkerId; +use apalis_core::worker::{Context, Event, Worker, WorkerId}; use apalis_core::{Backend, Codec}; use async_stream::try_stream; use chrono::{DateTime, Utc}; @@ -448,40 +448,62 @@ impl SqliteStorage { } } +/// Errors that can occur while polling an SQLite database. +#[derive(thiserror::Error, Debug)] +pub enum SqlitePollError { + /// Error during a keep-alive heartbeat. + #[error("Encountered an error during KeepAlive heartbeat: `{0}`")] + KeepAliveError(sqlx::Error), + + /// Error during re-enqueuing orphaned tasks. + #[error("Encountered an error during ReenqueueOrphaned heartbeat: `{0}`")] + ReenqueueOrphanedError(sqlx::Error), +} + impl Backend, Res> for SqliteStorage { type Stream = BackendStream>>; type Layer = AckLayer, T, SqlContext, Res>; - fn poll(mut self, worker: WorkerId) -> Poller { + fn poll(mut self, worker: Worker) -> Poller { let layer = AckLayer::new(self.clone()); let config = self.config.clone(); let controller = self.controller.clone(); let stream = self - .stream_jobs(&worker, config.poll_interval, config.buffer_size) + .stream_jobs(worker.id(), config.poll_interval, config.buffer_size) .map_err(|e| Error::SourceError(Arc::new(Box::new(e)))); let stream = BackendStream::new(stream.boxed(), controller); let requeue_storage = self.clone(); + let w = worker.clone(); let heartbeat = async move { loop { let now: i64 = Utc::now().timestamp(); - self.keep_alive_at::(&worker, now) - .await - .unwrap(); + if let Err(e) = self.keep_alive_at::(worker.id(), now).await { + worker.emit(Event::Error(Box::new(SqlitePollError::KeepAliveError(e)))); + } apalis_core::sleep(Duration::from_secs(30)).await; } } .boxed(); + let reenqueue_beat = async move { loop { let dead_since = Utc::now() - chrono::Duration::from_std(config.reenqueue_orphaned_after).unwrap(); if let Err(e) = requeue_storage - .reenqueue_orphaned(config.buffer_size.try_into().unwrap(), dead_since) + .reenqueue_orphaned( + config + .buffer_size + .try_into() + .expect("could not convert usize to i32"), + dead_since, + ) .await { - error!("ReenqueueOrphaned failed: {e}"); + w.emit(Event::Error(Box::new( + SqlitePollError::ReenqueueOrphanedError(e), + ))); } apalis_core::sleep(config.poll_interval).await; } @@ -507,7 +529,12 @@ impl Ack for SqliteStorage { .map_err(|e| sqlx::Error::Io(io::Error::new(io::ErrorKind::InvalidData, e)))?; sqlx::query(query) .bind(res.task_id.to_string()) - .bind(ctx.lock_by().as_ref().unwrap().to_string()) + .bind( + ctx.lock_by() + .as_ref() + .expect("Task is not locked") + .to_string(), + ) .bind(result) .bind(calculate_status(&res.inner).to_string()) .execute(&pool)