Skip to content

Commit

Permalink
Postgres and Sqlite passing
Browse files Browse the repository at this point in the history
  • Loading branch information
pxp9 committed Apr 14, 2024
1 parent 69a2c57 commit 38d6632
Show file tree
Hide file tree
Showing 8 changed files with 1,086 additions and 695 deletions.
5 changes: 5 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"rust-analyzer.linkedProjects": [
"./fang/Cargo.toml",
]
}
198 changes: 130 additions & 68 deletions fang/src/asynk/async_queue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,28 +13,32 @@ use async_trait::async_trait;
use chrono::DateTime;
use chrono::Utc;
use cron::Schedule;
//use sqlx::any::install_default_drivers; // this is supported in sqlx 0.7
use sqlx::any::AnyConnectOptions;
use sqlx::any::AnyKind;
#[cfg(any(
feature = "asynk-postgres",
feature = "asynk-mysql",
feature = "asynk-sqlite"
))]
use sqlx::pool::PoolOptions;
use sqlx::Database;

use sqlx::MySql;
use sqlx::MySqlPool;
use sqlx::Pool;
use sqlx::Postgres;
use sqlx::Sqlite;
use sqlx::SqlitePool;
use std::any::Any;
//use sqlx::any::install_default_drivers; // this is supported in sqlx 0.7
use std::str::FromStr;
use thiserror::Error;
use typed_builder::TypedBuilder;
use uuid::Uuid;

#[cfg(feature = "asynk-postgres")]
use sqlx::PgPool;
#[cfg(feature = "asynk-postgres")]
use sqlx::Postgres;

#[cfg(feature = "asynk-mysql")]
use sqlx::MySql;
#[cfg(feature = "asynk-mysql")]
use sqlx::MySqlPool;

#[cfg(feature = "asynk-sqlite")]
use sqlx::Sqlite;
#[cfg(feature = "asynk-sqlite")]
use sqlx::SqlitePool;

Expand Down Expand Up @@ -148,14 +152,48 @@ pub trait AsyncQueueable: Send {
/// .build();
/// ```
///
///
#[derive(Debug, Clone)]
pub(crate) enum InternalPool {
#[cfg(feature = "asynk-postgres")]
Pg(PgPool),
#[cfg(feature = "asynk-mysql")]
MySql(MySqlPool),
#[cfg(feature = "asynk-sqlite")]
Sqlite(SqlitePool),
}

impl InternalPool {
#[cfg(feature = "asynk-postgres")]
pub(crate) fn unwrap_pg_pool(&self) -> &PgPool {
match self {
InternalPool::Pg(pool) => pool,
_ => panic!("Not a PgPool!"),
}
}

#[cfg(feature = "asynk-mysql")]
pub(crate) fn unwrap_mysql_pool(&self) -> &MySqlPool {
match self {
InternalPool::MySql(pool) => pool,
_ => panic!("Not a MySqlPool!"),
}
}

#[cfg(feature = "asynk-sqlite")]
pub(crate) fn unwrap_sqlite_pool(&self) -> &SqlitePool {
match self {
InternalPool::Sqlite(pool) => pool,
_ => panic!("Not a SqlitePool!"),
}
}
}

#[derive(TypedBuilder, Debug, Clone)]
pub struct AsyncQueue<DB>
where
DB: Database,
{
pub struct AsyncQueue {
#[builder(default=None, setter(skip))]
pool: Option<Pool<DB>>,
pool: Option<InternalPool>,
#[builder(setter(into))]
uri: String,
#[builder(setter(into))]
Expand All @@ -169,49 +207,64 @@ where
#[cfg(test)]
use tokio::sync::Mutex;

#[cfg(test)]
#[cfg(all(test, feature = "asynk-postgres"))]
static ASYNC_QUEUE_POSTGRES_TEST_COUNTER: Mutex<u32> = Mutex::const_new(0);

#[cfg(test)]
#[cfg(all(test, feature = "asynk-sqlite"))]
static ASYNC_QUEUE_SQLITE_TEST_COUNTER: Mutex<u32> = Mutex::const_new(0);

#[cfg(test)]
#[cfg(all(test, feature = "asynk-mysql"))]
static ASYNC_QUEUE_MYSQL_TEST_COUNTER: Mutex<u32> = Mutex::const_new(0);

#[cfg(test)]
use sqlx::Executor;

#[cfg(test)]
#[cfg(all(test, feature = "asynk-sqlite"))]
use std::path::Path;

#[cfg(test)]
use std::env;

use super::backend_sqlx::BackendSqlX;

fn get_backend<'a, DB: Database>(pool: &'a Pool<DB>) -> BackendSqlX {
let type_pool = pool.type_id();
#[cfg(feature = "asynk-postgres")]
if std::any::TypeId::of::<PgPool>() == type_pool {
return BackendSqlX::Pg;
}
#[cfg(feature = "asynk-mysql")]
if std::any::TypeId::of::<MySqlPool>() == type_pool {
return BackendSqlX::MySql;
}

#[cfg(feature = "asynk-sqlite")]
if std::any::TypeId::of::<SqlitePool>() == type_pool {
return BackendSqlX::Sqlite;
async fn get_backend(
kind: AnyKind,
_uri: &str,
_max_connections: u32,
) -> Result<(BackendSqlX, InternalPool), AsyncQueueError> {
match kind {
#[cfg(feature = "asynk-postgres")]
AnyKind::Postgres => {
let pool = PoolOptions::<Postgres>::new()
.max_connections(_max_connections)
.connect(_uri)
.await?;

Ok((BackendSqlX::Pg, InternalPool::Pg(pool)))
}
#[cfg(feature = "asynk-mysql")]
AnyKind::MySql => {
let pool = PoolOptions::<MySql>::new()
.max_connections(_max_connections)
.connect(_uri)
.await?;

Ok((BackendSqlX::MySql, InternalPool::MySql(pool)))
}
#[cfg(feature = "asynk-sqlite")]
AnyKind::Sqlite => {
let pool = PoolOptions::<Sqlite>::new()
.max_connections(_max_connections)
.connect(_uri)
.await?;

Ok((BackendSqlX::Sqlite, InternalPool::Sqlite(pool)))
}
_ => panic!("Not a valid backend"),
}

unreachable!()
}

impl<DB> AsyncQueue<DB>
where
DB: Database,
{
impl AsyncQueue {
/// Check if the connection with db is established
pub fn check_if_connection(&self) -> Result<(), AsyncQueueError> {
if self.connected {
Expand All @@ -225,20 +278,18 @@ where
pub async fn connect(&mut self) -> Result<(), AsyncQueueError> {
//install_default_drivers();

let pool: Pool<DB> = PoolOptions::new()
.max_connections(self.max_pool_size)
.connect(&self.uri)
.await?;
let kind: AnyKind = self.uri.parse::<AnyConnectOptions>()?.kind();

self.backend = get_backend(&pool);
let (backend, pool) = get_backend(kind, &self.uri, self.max_pool_size).await?;

self.pool = Some(pool);
self.backend = backend;
self.connected = true;
Ok(())
}

async fn fetch_and_touch_task_query(
pool: &Pool<DB>,
pool: &InternalPool,
backend: &BackendSqlX,
task_type: Option<String>,
) -> Result<Option<Task>, AsyncQueueError> {
Expand Down Expand Up @@ -274,7 +325,7 @@ where
}

async fn insert_task_query(
pool: &Pool<DB>,
pool: &InternalPool,
backend: &BackendSqlX,
metadata: &serde_json::Value,
task_type: &str,
Expand All @@ -287,15 +338,15 @@ where
.build();

let task = backend
.execute_query(SqlXQuery::InsertTask, pool, query_params)
.execute_query(SqlXQuery::InsertTask, &pool, query_params)
.await?
.unwrap_task();

Ok(task)
}

async fn insert_task_if_not_exist_query(
pool: &Pool<DB>,
pool: &InternalPool,
backend: &BackendSqlX,
metadata: &serde_json::Value,
task_type: &str,
Expand All @@ -308,15 +359,15 @@ where
.build();

let task = backend
.execute_query(SqlXQuery::InsertTaskIfNotExists, pool, query_params)
.execute_query(SqlXQuery::InsertTaskIfNotExists, &pool, query_params)
.await?
.unwrap_task();

Ok(task)
}

async fn schedule_task_query(
pool: &Pool<DB>,
pool: &InternalPool,
backend: &BackendSqlX,
task: &dyn AsyncRunnable,
) -> Result<Task, AsyncQueueError> {
Expand Down Expand Up @@ -358,10 +409,7 @@ where
}

#[async_trait]
impl<DB> AsyncQueueable for AsyncQueue<DB>
where
DB: Database,
{
impl AsyncQueueable for AsyncQueue {
async fn find_task_by_id(&mut self, id: &Uuid) -> Result<Task, AsyncQueueError> {
self.check_if_connection()?;
let pool = self.pool.as_ref().unwrap();
Expand Down Expand Up @@ -580,8 +628,8 @@ where
}
}

#[cfg(test)]
impl AsyncQueue<Postgres> {
#[cfg(all(test, feature = "asynk-postgres"))]
impl AsyncQueue {
/// Provides an AsyncQueue connected to its own DB
pub async fn test_postgres() -> Self {
dotenvy::dotenv().expect(".env file not found");
Expand All @@ -602,7 +650,14 @@ impl AsyncQueue<Postgres> {
let create_query: &str = &format!("CREATE DATABASE {} WITH TEMPLATE fang;", db_name);
let delete_query: &str = &format!("DROP DATABASE IF EXISTS {};", db_name);

let mut conn = res.pool.as_mut().unwrap().acquire().await.unwrap();
let mut conn = res
.pool
.as_mut()
.unwrap()
.unwrap_pg_pool()
.acquire()
.await
.unwrap();

log::info!("Deleting database {db_name} ...");
conn.execute(delete_query).await.unwrap();
Expand All @@ -629,8 +684,8 @@ impl AsyncQueue<Postgres> {
}
}

#[cfg(test)]
impl AsyncQueue<Sqlite> {
#[cfg(all(test, feature = "asynk-sqlite"))]
impl AsyncQueue {
/// Provides an AsyncQueue connected to its own DB
pub async fn test_sqlite() -> Self {
dotenvy::dotenv().expect(".env file not found");
Expand Down Expand Up @@ -664,8 +719,8 @@ impl AsyncQueue<Sqlite> {
}
}

#[cfg(test)]
impl AsyncQueue<MySql> {
#[cfg(all(test, feature = "asynk-mysql"))]
impl AsyncQueue {
/// Provides an AsyncQueue connected to its own DB
pub async fn test_mysql() -> Self {
dotenvy::dotenv().expect(".env file not found");
Expand All @@ -690,7 +745,14 @@ impl AsyncQueue<MySql> {

let delete_query: &str = &format!("DROP DATABASE IF EXISTS {};", db_name);

let mut conn = res.pool.as_mut().unwrap().acquire().await.unwrap();
let mut conn = res
.pool
.as_mut()
.unwrap()
.unwrap_mysql_pool()
.acquire()
.await
.unwrap();

log::info!("Deleting database {db_name} ...");
conn.execute(delete_query).await.unwrap();
Expand All @@ -717,11 +779,11 @@ impl AsyncQueue<MySql> {
}
}

#[cfg(test)]
test_asynk_queue! {postgres, crate::AsyncQueue<Postgres>, crate::AsyncQueue::test_postgres()}
#[cfg(all(test, feature = "asynk-postgres"))]
test_asynk_queue! {postgres, crate::AsyncQueue,crate::AsyncQueue::test_postgres()}

#[cfg(test)]
test_asynk_queue! {sqlite, crate::AsyncQueue<Sqlite>, crate::AsyncQueue::test_sqlite()}
#[cfg(all(test, feature = "asynk-sqlite"))]
test_asynk_queue! {sqlite, crate::AsyncQueue,crate::AsyncQueue::test_sqlite()}

#[cfg(test)]
test_asynk_queue! {mysql, crate::AsyncQueue<MySql>, crate::AsyncQueue::test_mysql()}
#[cfg(all(test, feature = "asynk-mysql"))]
test_asynk_queue! {mysql, crate::AsyncQueue, crate::AsyncQueue::test_mysql()}
2 changes: 1 addition & 1 deletion fang/src/asynk/async_queue/async_queue_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ macro_rules! test_asynk_queue {
}

#[tokio::test]
async fn failed_task_query_test() {
async fn failed_task_test() {
let mut test: $q = $e.await;

let task = test.insert_task(&AsyncTask { number: 1 }).await.unwrap();
Expand Down
6 changes: 1 addition & 5 deletions fang/src/asynk/async_worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,6 @@ mod async_worker_tests {
use chrono::Duration;
use chrono::Utc;
use serde::{Deserialize, Serialize};
use sqlx::Database;

#[derive(Serialize, Deserialize)]
struct WorkerAsyncTask {
Expand Down Expand Up @@ -564,10 +563,7 @@ mod async_worker_tests {
assert_eq!(id2, task2.id);
}

async fn insert_task<DB: Database>(
test: &mut AsyncQueue<DB>,
task: &dyn AsyncRunnable,
) -> Task {
async fn insert_task(test: &mut AsyncQueue, task: &dyn AsyncRunnable) -> Task {
test.insert_task(task).await.unwrap()
}

Expand Down
Loading

0 comments on commit 38d6632

Please sign in to comment.