From 61bfaf01863de6caa0760e6852d47ee8af2e3a98 Mon Sep 17 00:00:00 2001 From: Shing Him Ng Date: Fri, 3 Jan 2025 09:03:31 -0600 Subject: [PATCH] Introduce DbPoolError to store Redis and timeout errors --- payjoin-directory/src/db.rs | 18 ++++++++++++++---- payjoin-directory/src/error.rs | 25 +++++++++++++++++++++++++ payjoin-directory/src/lib.rs | 20 ++++++++++++-------- 3 files changed, 51 insertions(+), 12 deletions(-) create mode 100644 payjoin-directory/src/error.rs diff --git a/payjoin-directory/src/db.rs b/payjoin-directory/src/db.rs index 6165abf9..33780067 100644 --- a/payjoin-directory/src/db.rs +++ b/payjoin-directory/src/db.rs @@ -4,6 +4,8 @@ use futures::StreamExt; use redis::{AsyncCommands, Client, ErrorKind, RedisError, RedisResult}; use tracing::debug; +use crate::error::DbPoolError; + const DEFAULT_COLUMN: &str = ""; const PJ_V1_COLUMN: &str = "pjv1"; @@ -19,11 +21,12 @@ impl DbPool { Ok(Self { client, timeout }) } + /// Peek using [`DEFAULT_COLUMN`] as the channel type. pub async fn push_default(&self, subdirectory_id: &str, data: Vec) -> RedisResult<()> { self.push(subdirectory_id, DEFAULT_COLUMN, data).await } - pub async fn peek_default(&self, subdirectory_id: &str) -> Option>> { + pub async fn peek_default(&self, subdirectory_id: &str) -> Result, DbPoolError> { self.peek_with_timeout(subdirectory_id, DEFAULT_COLUMN).await } @@ -31,7 +34,8 @@ impl DbPool { self.push(subdirectory_id, PJ_V1_COLUMN, data).await } - pub async fn peek_v1(&self, subdirectory_id: &str) -> Option>> { + /// Peek using [`PJ_V1_COLUMN`] as the channel type. + pub async fn peek_v1(&self, subdirectory_id: &str) -> Result, DbPoolError> { self.peek_with_timeout(subdirectory_id, PJ_V1_COLUMN).await } @@ -52,8 +56,14 @@ impl DbPool { &self, subdirectory_id: &str, channel_type: &str, - ) -> Option>> { - tokio::time::timeout(self.timeout, self.peek(subdirectory_id, channel_type)).await.ok() + ) -> Result, DbPoolError> { + match tokio::time::timeout(self.timeout, self.peek(subdirectory_id, channel_type)).await { + Ok(redis_result) => match redis_result { + Ok(result) => Ok(result), + Err(redis_err) => Err(DbPoolError::Redis(redis_err)), + }, + Err(elapsed) => Err(DbPoolError::Timeout(elapsed)), + } } async fn peek(&self, subdirectory_id: &str, channel_type: &str) -> RedisResult> { diff --git a/payjoin-directory/src/error.rs b/payjoin-directory/src/error.rs new file mode 100644 index 00000000..426da28b --- /dev/null +++ b/payjoin-directory/src/error.rs @@ -0,0 +1,25 @@ +#[derive(Debug)] +pub enum DbPoolError { + Redis(redis::RedisError), + Timeout(tokio::time::error::Elapsed), +} + +impl std::fmt::Display for DbPoolError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + use DbPoolError::*; + + match &self { + Redis(error) => write!(f, "Redis error: {}", error), + Timeout(timeout) => write!(f, "Timeout: {}", timeout), + } + } +} + +impl std::error::Error for DbPoolError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + DbPoolError::Redis(e) => Some(e), + DbPoolError::Timeout(e) => Some(e), + } + } +} diff --git a/payjoin-directory/src/lib.rs b/payjoin-directory/src/lib.rs index 9a1c651c..aa6512a2 100644 --- a/payjoin-directory/src/lib.rs +++ b/payjoin-directory/src/lib.rs @@ -34,7 +34,10 @@ const V1_UNAVAILABLE_RES_JSON: &str = r#"{{"errorCode": "unavailable", "message" const ID_LENGTH: usize = 13; mod db; +mod error; + use crate::db::DbPool; +use crate::error::DbPoolError; #[cfg(feature = "_danger-local-https")] type BoxError = Box; @@ -341,11 +344,11 @@ async fn post_fallback_v1( .await .map_err(|e| HandlerError::BadRequest(e.into()))?; match pool.peek_v1(id).await { - Some(result) => match result { - Ok(buffered_req) => Ok(Response::new(full(buffered_req))), - Err(e) => Err(HandlerError::BadRequest(e.into())), + Ok(buffered_req) => Ok(Response::new(full(buffered_req))), + Err(e) => match e { + DbPoolError::Redis(_) => Err(HandlerError::BadRequest(e.into())), + DbPoolError::Timeout(_) => Ok(none_response), }, - None => Ok(none_response), } } @@ -409,11 +412,12 @@ async fn get_subdir( trace!("get_subdir"); let id = check_id_length(id)?; match pool.peek_default(id).await { - Some(result) => match result { - Ok(buffered_req) => Ok(Response::new(full(buffered_req))), - Err(e) => Err(HandlerError::BadRequest(e.into())), + Ok(buffered_req) => Ok(Response::new(full(buffered_req))), + Err(e) => match e { + DbPoolError::Redis(_) => Err(HandlerError::BadRequest(e.into())), + DbPoolError::Timeout(_) => + Ok(Response::builder().status(StatusCode::ACCEPTED).body(empty())?), }, - None => Ok(Response::builder().status(StatusCode::ACCEPTED).body(empty())?), } }