Skip to content

Commit

Permalink
add: test utils that allow backend polling during tests
Browse files Browse the repository at this point in the history
  • Loading branch information
geofmureithi committed Jul 16, 2024
1 parent 1238fb0 commit 4a4eb20
Show file tree
Hide file tree
Showing 9 changed files with 352 additions and 87 deletions.
205 changes: 205 additions & 0 deletions packages/apalis-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,208 @@ impl crate::executor::Executor for TestExecutor {
tokio::spawn(future);
}
}

/// Test utilities that allows you to test backends
pub mod test_utils {
use crate::error::{BoxDynError};

use crate::request::Request;

use crate::task::task_id::TaskId;
use crate::worker::WorkerId;
use crate::Backend;
use futures::channel::mpsc::{channel, Sender};
use futures::stream::{Stream, StreamExt};
use futures::{Future, FutureExt, SinkExt};
use std::collections::HashMap;
use std::fmt::Debug;
use std::marker::PhantomData;
use std::ops::{Deref, DerefMut};
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll};
use std::thread;
use tower::{Layer, Service};

/// Define a dummy service
#[derive(Debug, Clone)]
pub struct DummyService;

impl<Request: Send + 'static> Service<Request> for DummyService {
type Response = Request;
type Error = std::convert::Infallible;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;

fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}

fn call(&mut self, req: Request) -> Self::Future {
let fut = async move { Ok(req) };
Box::pin(fut)
}
}

/// A generic backend wrapper that polls and executes jobs
#[derive(Debug)]
pub struct TestWrapper<B, Req> {
stop_tx: Sender<()>,
executions: Arc<Mutex<HashMap<TaskId, Result<String, String>>>>,
_p: PhantomData<Req>,
backend: B,
}

impl<B: Clone, Req> Clone for TestWrapper<B, Req> {
fn clone(&self) -> Self {
TestWrapper {
stop_tx: self.stop_tx.clone(),
executions: Arc::clone(&self.executions),
_p: PhantomData,
backend: self.backend.clone(),
}
}
}

impl<B, Req> TestWrapper<B, Req>
where
B: Backend<Request<Req>> + Send + Sync + 'static + Clone,
Req: Send + 'static,
B::Stream: Send + 'static,
B::Stream: Stream<Item = Result<Option<Request<Req>>, crate::error::Error>> + Unpin,
{
/// Build a new instance provided a custom service
pub fn new_with_service<S>(backend: B, service: S) -> Self
where
S: Service<Request<Req>> + Send + 'static,
B::Layer: Layer<S>,
<<B as Backend<Request<Req>>>::Layer as Layer<S>>::Service: Service<Request<Req>> + Send + 'static,
<<<B as Backend<Request<Req>>>::Layer as Layer<S>>::Service as Service<Request<Req>>>::Response: Debug,
<<<B as Backend<Request<Req>>>::Layer as Layer<S>>::Service as Service<Request<Req>>>::Error: Send + Into<BoxDynError> + Sync
{
let worker_id = WorkerId::new("test-worker");
let b = backend.clone();
let mut poller = b.poll(worker_id);
let (stop_tx, mut stop_rx) = channel::<()>(1);

let mut service = poller.layer.layer(service);

let executions: Arc<Mutex<HashMap<TaskId, Result<String, String>>>> =
Default::default();
let executions_clone = executions.clone();
thread::spawn(move || {
futures::executor::block_on(async move {
let heartbeat = poller.heartbeat.shared();
loop {
futures::select! {

item = poller.stream.next().fuse() => match item {
Some(Ok(Some(req))) => {

let task_id = req.get::<TaskId>().cloned().expect("Request does not contain Task_ID");
// handle request
match service.call(req).await {
Ok(res) => {
executions_clone.lock().unwrap().insert(task_id, Ok(format!("{res:?}")));
},
Err(err) => {
executions_clone.lock().unwrap().insert(task_id, Err(err.into().to_string()));
}
}
}
Some(Ok(None)) | None => break,
Some(Err(_e)) => {
// handle error
break;
}
},
_ = stop_rx.next().fuse() => break,
_ = heartbeat.clone().fuse() => {

},
}
}
});
});

Self {
stop_tx,
executions,
_p: PhantomData,
backend,
}
}

/// Stop polling
pub fn stop(mut self) {
let _ = self.stop_tx.send(());

Check failure on line 304 in packages/apalis-core/src/lib.rs

View workflow job for this annotation

GitHub Actions / Clippy

non-binding `let` on a future
}

/// Gets the current state of results
pub fn get_results(&self) -> HashMap<TaskId, Result<String, String>> {
self.executions.lock().unwrap().clone()
}
}

impl<B, Req> Deref for TestWrapper<B, Req>
where
B: Backend<Request<Req>>,
{
type Target = B;

fn deref(&self) -> &Self::Target {
&self.backend
}
}

impl<B, Req> DerefMut for TestWrapper<B, Req>
where
B: Backend<Request<Req>>,
{
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.backend
}
}

pub use tower::service_fn as apalis_test_service_fn;

#[macro_export]
/// Tests a generic mq
macro_rules! test_message_queue {
($backend_instance:expr) => {
#[tokio::test]
async fn it_works_as_an_mq_backend() {
let backend = $backend_instance;
let service = apalis_test_service_fn(|request: Request<u32>| async {
Ok::<_, io::Error>(request)
});
let mut t = TestWrapper::new_with_service(backend, service);
let res = t.get_results();
assert_eq!(res.len(), 0); // No job is done
t.enqueue(1).await.unwrap();
tokio::time::sleep(Duration::from_secs(1)).await;
let res = t.get_results();
assert_eq!(res.len(), 1); // One job is done
}
};
}
#[macro_export]
/// Tests a generic storage
macro_rules! test_storage {
($backend_instance:expr) => {
#[tokio::test]
async fn it_works_as_a_storage_backend() {
let backend = $backend_instance;
let service = apalis_test_service_fn(|request: Request<u32>| async {
Ok::<_, io::Error>(request)
});
let mut t = TestWrapper::new_with_service(backend, service);
let res = t.get_results();
assert_eq!(res.len(), 0); // No job is done
t.push(1).await.unwrap();
::apalis_core::sleep(Duration::from_secs(1)).await;
let res = t.get_results();
assert_eq!(res.len(), 1); // One job is done
}
};
}
}
12 changes: 6 additions & 6 deletions packages/apalis-core/src/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ impl<T> Clone for MemoryStorage<T> {
/// In-memory queue that implements [Stream]
#[derive(Debug)]
pub struct MemoryWrapper<T> {
sender: Sender<T>,
receiver: Arc<futures::lock::Mutex<Receiver<T>>>,
sender: Sender<Request<T>>,
receiver: Arc<futures::lock::Mutex<Receiver<Request<T>>>>,
}

impl<T> Clone for MemoryWrapper<T> {
Expand Down Expand Up @@ -84,7 +84,7 @@ impl<T> Default for MemoryWrapper<T> {
}

impl<T> Stream for MemoryWrapper<T> {
type Item = T;
type Item = Request<T>;

fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
if let Some(mut receiver) = self.receiver.try_lock() {
Expand All @@ -102,7 +102,7 @@ impl<T: Send + 'static + Sync> Backend<Request<T>> for MemoryStorage<T> {
type Layer = Identity;

fn poll(self, _worker: WorkerId) -> Poller<Self::Stream> {
let stream = self.inner.map(|r| Ok(Some(Request::new(r)))).boxed();
let stream = self.inner.map(|r| Ok(Some(r))).boxed();
Poller {
stream: BackendStream::new(stream, self.controller),
heartbeat: Box::pin(async {}),
Expand All @@ -114,12 +114,12 @@ impl<T: Send + 'static + Sync> Backend<Request<T>> for MemoryStorage<T> {
impl<Message: Send + 'static + Sync> MessageQueue<Message> for MemoryStorage<Message> {
type Error = ();
async fn enqueue(&mut self, message: Message) -> Result<(), Self::Error> {
self.inner.sender.try_send(message).unwrap();
self.inner.sender.try_send(Request::new(message)).map_err(|_| ())?;
Ok(())
}

async fn dequeue(&mut self) -> Result<Option<Message>, ()> {
Ok(self.inner.receiver.lock().await.next().await)
Ok(self.inner.receiver.lock().await.next().await.map(|r| r.req))
}

async fn size(&mut self) -> Result<usize, ()> {
Expand Down
9 changes: 7 additions & 2 deletions packages/apalis-core/src/monitor/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,7 @@ impl<E> Monitor<E> {

#[cfg(test)]
mod tests {
use crate::test_utils::apalis_test_service_fn;
use std::{io, time::Duration};

use tokio::time::sleep;
Expand All @@ -307,11 +308,15 @@ mod tests {
monitor::Monitor,
mq::MessageQueue,
request::Request,
test_message_queue,
test_utils::TestWrapper,
TestExecutor,
};

test_message_queue!(MemoryStorage::new());

#[tokio::test]
async fn it_works() {
async fn it_works_with_workers() {
let backend = MemoryStorage::new();
let mut handle = backend.clone();

Expand Down Expand Up @@ -342,7 +347,7 @@ mod tests {
let mut handle = backend.clone();

tokio::spawn(async move {
for i in 0..1000 {
for i in 0..10 {
handle.enqueue(i).await.unwrap();
}
});
Expand Down
10 changes: 5 additions & 5 deletions packages/apalis-core/src/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use tower::layer::util::Identity;

use std::{fmt::Debug, pin::Pin};

use crate::{data::Extensions, error::Error, poller::Poller, worker::WorkerId, Backend};
use crate::{data::Extensions, error::Error, poller::Poller, task::task_id::TaskId, worker::WorkerId, Backend};

/// Represents a job which can be serialized and executed
Expand All @@ -18,10 +18,10 @@ pub struct Request<T> {
impl<T> Request<T> {
/// Creates a new [Request]
pub fn new(req: T) -> Self {
Self {
req,
data: Extensions::new(),
}
let id = TaskId::new();
let mut data = Extensions::new();
data.insert(id);
Self::new_with_data(req, data)
}

/// Creates a request with context provided
Expand Down
2 changes: 1 addition & 1 deletion packages/apalis-core/src/task/task_id.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use serde::{de::Visitor, Deserialize, Deserializer, Serialize, Serializer};
use ulid::Ulid;

/// A wrapper type that defines a task id.
#[derive(Debug, Clone)]
#[derive(Debug, Clone, Eq, Hash, PartialEq)]
pub struct TaskId(Ulid);

impl TaskId {
Expand Down
11 changes: 11 additions & 0 deletions packages/apalis-redis/src/storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -983,8 +983,19 @@ impl<T, Conn: ConnectionLike + Send + Sync + 'static> RedisStorage<T, Conn> {

#[cfg(test)]
mod tests {
use apalis_core::test_storage;
use email_service::Email;

use apalis_core::test_utils::apalis_test_service_fn;
use apalis_core::test_utils::TestWrapper;

test_storage!({
let redis_url = std::env::var("REDIS_URL").expect("No REDIS_URL is specified");
let conn = connect(redis_url).await.unwrap();
let storage = RedisStorage::new(conn);
storage
});

use super::*;

/// migrate DB and return a storage instance.
Expand Down
Loading

0 comments on commit 4a4eb20

Please sign in to comment.