diff --git a/fang/src/asynk/async_queue.rs b/fang/src/asynk/async_queue.rs index a9c1b18..c905ce8 100644 --- a/fang/src/asynk/async_queue.rs +++ b/fang/src/asynk/async_queue.rs @@ -14,16 +14,30 @@ use chrono::DateTime; use chrono::Utc; use cron::Schedule; //use sqlx::any::install_default_drivers; // this is supported in sqlx 0.7 -use sqlx::any::AnyKind; use sqlx::pool::PoolOptions; -use sqlx::Any; -use sqlx::AnyPool; +use sqlx::Database; + +use sqlx::MySql; +use sqlx::MySqlPool; use sqlx::Pool; +use sqlx::Postgres; +use sqlx::Sqlite; +use sqlx::SqlitePool; +use std::any::Any; use std::str::FromStr; use thiserror::Error; use typed_builder::TypedBuilder; use uuid::Uuid; +#[cfg(feature = "asynk-postgres")] +use sqlx::PgPool; + +#[cfg(feature = "asynk-mysql")] +use sqlx::MySqlPool; + +#[cfg(feature = "asynk-sqlite")] +use sqlx::SqlitePool; + #[cfg(test)] use self::async_queue_tests::test_asynk_queue; @@ -136,9 +150,12 @@ pub trait AsyncQueueable: Send { /// #[derive(TypedBuilder, Debug, Clone)] -pub struct AsyncQueue { +pub struct AsyncQueue +where + DB: Database, +{ #[builder(default=None, setter(skip))] - pool: Option, + pool: Option>, #[builder(setter(into))] uri: String, #[builder(setter(into))] @@ -172,20 +189,29 @@ use std::env; use super::backend_sqlx::BackendSqlX; -fn get_backend(_anykind: AnyKind) -> BackendSqlX { - match _anykind { - #[cfg(feature = "asynk-postgres")] - AnyKind::Postgres => BackendSqlX::Pg, - #[cfg(feature = "asynk-mysql")] - AnyKind::MySql => BackendSqlX::MySql, - #[cfg(feature = "asynk-sqlite")] - AnyKind::Sqlite => BackendSqlX::Sqlite, - #[allow(unreachable_patterns)] - _ => unreachable!(), +fn get_backend<'a, DB: Database>(pool: &'a Pool) -> BackendSqlX { + let type_pool = pool.type_id(); + #[cfg(feature = "asynk-postgres")] + if std::any::TypeId::of::() == type_pool { + return BackendSqlX::Pg; + } + #[cfg(feature = "asynk-mysql")] + if std::any::TypeId::of::() == type_pool { + return BackendSqlX::MySql; + } + + #[cfg(feature = "asynk-sqlite")] + if std::any::TypeId::of::() == type_pool { + return BackendSqlX::Sqlite; } + + unreachable!() } -impl AsyncQueue { +impl AsyncQueue +where + DB: Database, +{ /// Check if the connection with db is established pub fn check_if_connection(&self) -> Result<(), AsyncQueueError> { if self.connected { @@ -199,14 +225,12 @@ impl AsyncQueue { pub async fn connect(&mut self) -> Result<(), AsyncQueueError> { //install_default_drivers(); - let pool: AnyPool = PoolOptions::new() + let pool: Pool = PoolOptions::new() .max_connections(self.max_pool_size) .connect(&self.uri) .await?; - let anykind = pool.any_kind(); - - self.backend = get_backend(anykind); + self.backend = get_backend(&pool); self.pool = Some(pool); self.connected = true; @@ -214,7 +238,7 @@ impl AsyncQueue { } async fn fetch_and_touch_task_query( - pool: &Pool, + pool: &Pool, backend: &BackendSqlX, task_type: Option, ) -> Result, AsyncQueueError> { @@ -250,7 +274,7 @@ impl AsyncQueue { } async fn insert_task_query( - pool: &Pool, + pool: &Pool, backend: &BackendSqlX, metadata: &serde_json::Value, task_type: &str, @@ -271,7 +295,7 @@ impl AsyncQueue { } async fn insert_task_if_not_exist_query( - pool: &Pool, + pool: &Pool, backend: &BackendSqlX, metadata: &serde_json::Value, task_type: &str, @@ -292,7 +316,7 @@ impl AsyncQueue { } async fn schedule_task_query( - pool: &Pool, + pool: &Pool, backend: &BackendSqlX, task: &dyn AsyncRunnable, ) -> Result { @@ -334,7 +358,10 @@ impl AsyncQueue { } #[async_trait] -impl AsyncQueueable for AsyncQueue { +impl AsyncQueueable for AsyncQueue +where + DB: Database, +{ async fn find_task_by_id(&mut self, id: &Uuid) -> Result { self.check_if_connection()?; let pool = self.pool.as_ref().unwrap(); @@ -554,7 +581,7 @@ impl AsyncQueueable for AsyncQueue { } #[cfg(test)] -impl AsyncQueue { +impl AsyncQueue { /// Provides an AsyncQueue connected to its own DB pub async fn test_postgres() -> Self { dotenvy::dotenv().expect(".env file not found"); @@ -600,7 +627,10 @@ impl AsyncQueue { res } +} +#[cfg(test)] +impl AsyncQueue { /// Provides an AsyncQueue connected to its own DB pub async fn test_sqlite() -> Self { dotenvy::dotenv().expect(".env file not found"); @@ -632,7 +662,10 @@ impl AsyncQueue { res.connect().await.expect("fail to connect"); res } +} +#[cfg(test)] +impl AsyncQueue { /// Provides an AsyncQueue connected to its own DB pub async fn test_mysql() -> Self { dotenvy::dotenv().expect(".env file not found"); @@ -685,10 +718,10 @@ impl AsyncQueue { } #[cfg(test)] -test_asynk_queue! {postgres, crate::AsyncQueue, crate::AsyncQueue::test_postgres()} +test_asynk_queue! {postgres, crate::AsyncQueue, crate::AsyncQueue::test_postgres()} #[cfg(test)] -test_asynk_queue! {sqlite, crate::AsyncQueue, crate::AsyncQueue::test_sqlite()} +test_asynk_queue! {sqlite, crate::AsyncQueue, crate::AsyncQueue::test_sqlite()} #[cfg(test)] -test_asynk_queue! {mysql, crate::AsyncQueue, crate::AsyncQueue::test_mysql()} +test_asynk_queue! {mysql, crate::AsyncQueue, crate::AsyncQueue::test_mysql()} diff --git a/fang/src/asynk/async_worker.rs b/fang/src/asynk/async_worker.rs index 7c73227..8f6b727 100644 --- a/fang/src/asynk/async_worker.rs +++ b/fang/src/asynk/async_worker.rs @@ -263,6 +263,7 @@ mod async_worker_tests { use chrono::Duration; use chrono::Utc; use serde::{Deserialize, Serialize}; + use sqlx::Database; #[derive(Serialize, Deserialize)] struct WorkerAsyncTask { @@ -563,7 +564,10 @@ mod async_worker_tests { assert_eq!(id2, task2.id); } - async fn insert_task(test: &mut AsyncQueue, task: &dyn AsyncRunnable) -> Task { + async fn insert_task( + test: &mut AsyncQueue, + task: &dyn AsyncRunnable, + ) -> Task { test.insert_task(task).await.unwrap() } diff --git a/fang/src/asynk/backend_sqlx.rs b/fang/src/asynk/backend_sqlx.rs index 0206553..6986960 100644 --- a/fang/src/asynk/backend_sqlx.rs +++ b/fang/src/asynk/backend_sqlx.rs @@ -2,6 +2,7 @@ use chrono::{DateTime, Duration, Utc}; use sha2::Digest; use sha2::Sha256; use sqlx::Any; +use sqlx::Database; use sqlx::Pool; use std::fmt::Debug; use typed_builder::TypedBuilder; @@ -80,10 +81,10 @@ impl Res { } impl BackendSqlX { - pub(crate) async fn execute_query<'a>( + pub(crate) async fn execute_query<'a, DB: Database>( &self, _query: SqlXQuery, - _pool: &Pool, + _pool: &Pool, _params: QueryParams<'_>, ) -> Result { match self { @@ -212,9 +213,9 @@ async fn general_any_impl_insert_task_uniq( } #[allow(dead_code)] -async fn general_any_impl_update_task_state( +async fn general_any_impl_update_task_state( query: &str, - pool: &Pool, + pool: &Pool, params: QueryParams<'_>, ) -> Result { let updated_at_str = format!("{}", Utc::now().format("%F %T%.f+00")); diff --git a/fang/src/asynk/backend_sqlx/postgres.rs b/fang/src/asynk/backend_sqlx/postgres.rs index a5185a1..0535f3f 100644 --- a/fang/src/asynk/backend_sqlx/postgres.rs +++ b/fang/src/asynk/backend_sqlx/postgres.rs @@ -24,6 +24,7 @@ const RETRY_TASK_QUERY_POSTGRES: &str = include_str!("../queries_postgres/retry_ #[derive(Debug, Clone)] pub(super) struct BackendSqlXPg {} +use sqlx::Database; use SqlXQuery as Q; use crate::AsyncQueueError; @@ -41,12 +42,12 @@ use super::general_any_impl_remove_task_type; use super::general_any_impl_retry_task; use super::general_any_impl_update_task_state; use super::{QueryParams, Res, SqlXQuery}; -use sqlx::{Any, Pool}; +use sqlx::Pool; impl BackendSqlXPg { - pub(super) async fn execute_query( + pub(super) async fn execute_query( query: SqlXQuery, - pool: &Pool, + pool: &Pool, params: QueryParams<'_>, ) -> Result { match query {