Skip to content

Commit

Permalink
Feature: Introducing Request Context (#416)
Browse files Browse the repository at this point in the history
* wip: introduce context to request

* fix: get request context working

* lint: cargo fmt

* fix: get tests compiling

* add: push_request and shedule_request

* fix: task_id for Testwrapper

* fix: minor checks and fixes on postgres tests

* fix: bug on postgres fetch_next
  • Loading branch information
geofmureithi authored Sep 17, 2024
1 parent 4ec676f commit 7a496ad
Show file tree
Hide file tree
Showing 49 changed files with 1,327 additions and 1,090 deletions.
3 changes: 1 addition & 2 deletions examples/actix-web/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ async fn push_email(
let mut storage = storage.clone();
let res = storage.push(email.into_inner()).await;
match res {
Ok(jid) => HttpResponse::Ok().body(format!("Email with job_id [{jid}] added to queue")),
Ok(ctx) => HttpResponse::Ok().json(ctx),
Err(e) => HttpResponse::InternalServerError().body(format!("{e}")),
}
}
Expand Down Expand Up @@ -46,7 +46,6 @@ async fn main() -> Result<()> {
WorkerBuilder::new("tasty-avocado")
.layer(TraceLayer::new())
.backend(storage)
// .chain(|svc|svc.map_err(|e| Box::new(e)))
.build_fn(send_email)
})
.run_with_signal(signal::ctrl_c());
Expand Down
12 changes: 6 additions & 6 deletions examples/async-std-runtime/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use apalis_cron::{CronStream, Schedule};
use chrono::{DateTime, Utc};
use tracing::{debug, info, Instrument, Level, Span};

type WorkerCtx = Context<AsyncStdExecutor>;
type WorkerCtx = Data<Context<AsyncStdExecutor>>;

#[derive(Default, Debug, Clone)]
struct Reminder(DateTime<Utc>);
Expand Down Expand Up @@ -48,7 +48,7 @@ async fn main() -> Result<()> {
.build_fn(send_reminder);

Monitor::<AsyncStdExecutor>::new()
.register_with_count(2, worker)
.register(worker)
.on_event(|e| debug!("Worker event: {e:?}"))
.run_with_signal(async {
ctrl_c.recv().await.ok();
Expand Down Expand Up @@ -95,10 +95,10 @@ impl ReminderSpan {
}
}

impl<B> MakeSpan<B> for ReminderSpan {
fn make_span(&mut self, req: &Request<B>) -> Span {
let task_id: &TaskId = req.get().unwrap();
let attempts: Attempt = req.get().cloned().unwrap_or_default();
impl<B, Ctx> MakeSpan<B, Ctx> for ReminderSpan {
fn make_span(&mut self, req: &Request<B, Ctx>) -> Span {
let task_id: &TaskId = &req.parts.task_id;
let attempts: &Attempt = &req.parts.attempt;
let span = Span::current();
macro_rules! make_span {
($level:expr) => {
Expand Down
6 changes: 3 additions & 3 deletions examples/axum/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ where
let new_job = storage.push(input).await;

match new_job {
Ok(id) => (
Ok(ctx) => (
StatusCode::CREATED,
format!("Job [{id}] was successfully added"),
format!("Job [{ctx:?}] was successfully added"),
),
Err(e) => (
StatusCode::INTERNAL_SERVER_ERROR,
Expand Down Expand Up @@ -74,7 +74,7 @@ async fn main() -> Result<()> {
};
let monitor = async {
Monitor::<TokioExecutor>::new()
.register_with_count(2, {
.register({
WorkerBuilder::new("tasty-pear")
.layer(TraceLayer::new())
.backend(storage.clone())
Expand Down
2 changes: 1 addition & 1 deletion examples/basics/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ license = "MIT OR Apache-2.0"
thiserror = "1"
tokio = { version = "1", features = ["full"] }
apalis = { path = "../../", features = ["limit", "tokio-comp", "catch-panic"] }
apalis-sql = { path = "../../packages/apalis-sql" }
apalis-sql = { path = "../../packages/apalis-sql", features = ["sqlite"] }
serde = "1"
tracing-subscriber = "0.3.11"
email-service = { path = "../email-service" }
Expand Down
14 changes: 9 additions & 5 deletions examples/basics/src/layer.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use std::task::{Context, Poll};
use std::{
fmt::Debug,
task::{Context, Poll},
};

use apalis::prelude::Request;
use tower::{Layer, Service};
Expand Down Expand Up @@ -34,10 +37,11 @@ pub struct LogService<S> {
service: S,
}

impl<S, Req> Service<Request<Req>> for LogService<S>
impl<S, Req, Ctx> Service<Request<Req, Ctx>> for LogService<S>
where
S: Service<Request<Req>> + Clone,
Req: std::fmt::Debug,
S: Service<Request<Req, Ctx>> + Clone,
Req: Debug,
Ctx: Debug,
{
type Response = S::Response;
type Error = S::Error;
Expand All @@ -47,7 +51,7 @@ where
self.service.poll_ready(cx)
}

fn call(&mut self, request: Request<Req>) -> Self::Future {
fn call(&mut self, request: Request<Req, Ctx>) -> Self::Future {
// Use service to apply middleware before or(and) after a request
info!("request = {:?}, target = {:?}", request, self.target);
self.service.call(request)
Expand Down
27 changes: 21 additions & 6 deletions examples/basics/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ mod cache;
mod layer;
mod service;

use std::time::Duration;
use std::{sync::Arc, time::Duration};

use apalis::{
layers::{catch_panic::CatchPanicLayer, tracing::TraceLayer},
Expand Down Expand Up @@ -35,7 +35,7 @@ async fn produce_jobs(storage: &SqliteStorage<Email>) {
}

#[derive(thiserror::Error, Debug)]
pub enum Error {
pub enum ServiceError {
#[error("data store disconnected")]
Disconnect(#[from] std::io::Error),
#[error("the data for key `{0}` is not available")]
Expand All @@ -46,15 +46,21 @@ pub enum Error {
Unknown,
}

#[derive(thiserror::Error, Debug)]
pub enum PanicError {
#[error("{0}")]
Panic(String),
}

/// Quick solution to prevent spam.
/// If email in cache, then send email else complete the job but let a validation process run in the background,
async fn send_email(
email: Email,
svc: Data<EmailService>,
worker_ctx: Data<WorkerCtx>,
worker_id: WorkerId,
worker_id: Data<WorkerId>,
cache: Data<ValidEmailCache>,
) -> Result<(), Error> {
) -> Result<(), ServiceError> {
info!("Job started in worker {:?}", worker_id);
let cache_clone = cache.clone();
let email_to = email.to.clone();
Expand Down Expand Up @@ -97,10 +103,19 @@ async fn main() -> Result<(), std::io::Error> {
produce_jobs(&sqlite).await;

Monitor::<TokioExecutor>::new()
.register_with_count(2, {
.register({
WorkerBuilder::new("tasty-banana")
// This handles any panics that may occur in any of the layers below
.layer(CatchPanicLayer::new())
.layer(CatchPanicLayer::with_panic_handler(|e| {
let panic_info = if let Some(s) = e.downcast_ref::<&str>() {
s.to_string()
} else if let Some(s) = e.downcast_ref::<String>() {
s.clone()
} else {
"Unknown panic".to_string()
};
Error::Abort(Arc::new(Box::new(PanicError::Panic(panic_info))))
}))
.layer(TraceLayer::new())
.layer(LogLayer::new("some-log-example"))
// Add shared context to all jobs executed by this worker
Expand Down
1 change: 1 addition & 0 deletions examples/cron/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ apalis = { path = "../../", default-features = false, features = [
"tokio-comp",
"tracing",
"limit",
"catch-panic"
] }
apalis-cron = { path = "../../packages/apalis-cron" }
tokio = { version = "1", features = ["full"] }
Expand Down
4 changes: 3 additions & 1 deletion examples/cron/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use apalis::layers::tracing::TraceLayer;
use apalis::prelude::*;
use apalis::utils::TokioExecutor;
use apalis_cron::CronStream;
Expand Down Expand Up @@ -31,13 +32,14 @@ async fn send_reminder(job: Reminder, svc: Data<FakeService>) {
async fn main() {
let schedule = Schedule::from_str("1/1 * * * * *").unwrap();
let worker = WorkerBuilder::new("morning-cereal")
.layer(TraceLayer::new())
.layer(LoadShedLayer::new()) // Important when you have layers that block the service
.layer(RateLimitLayer::new(1, Duration::from_secs(2)))
.data(FakeService)
.backend(CronStream::new(schedule))
.build_fn(send_reminder);
Monitor::<TokioExecutor>::new()
.register(worker)
.register_with_count(2, worker)
.run()
.await
.unwrap();
Expand Down
6 changes: 3 additions & 3 deletions examples/fn-args/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,16 @@ struct SimpleJob {}
// A task can have up to 16 arguments
async fn simple_job(
_: SimpleJob, // Required, must be of the type of the job/message
worker_id: WorkerId, // The worker running the job, added by worker
worker_id: Data<WorkerId>, // The worker running the job, added by worker
_worker_ctx: Context<TokioExecutor>, // The worker context, added by worker
_sqlite: Data<SqliteStorage<SimpleJob>>, // The source, added by storage
task_id: Data<TaskId>, // The task id, added by storage
ctx: Data<SqlContext>, // The task context, added by storage
ctx: SqlContext, // The task context
count: Data<Count>, // Our custom data added via layer
) {
// increment the counter
let current = count.fetch_add(1, Ordering::Relaxed);
info!("worker: {worker_id}; task_id: {task_id:?}, ctx: {ctx:?}, count: {current:?}");
info!("worker: {worker_id:?}; task_id: {task_id:?}, ctx: {ctx:?}, count: {current:?}");
}

async fn produce_jobs(storage: &mut SqliteStorage<SimpleJob>) {
Expand Down
8 changes: 4 additions & 4 deletions examples/prometheus/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ async fn main() -> Result<()> {
};
let monitor = async {
Monitor::<TokioExecutor>::new()
.register_with_count(2, {
.register({
WorkerBuilder::new("tasty-banana")
.layer(PrometheusLayer)
.layer(PrometheusLayer::default())
.backend(storage.clone())
.build_fn(send_email)
})
Expand Down Expand Up @@ -94,9 +94,9 @@ where
let new_job = storage.push(input).await;

match new_job {
Ok(jid) => (
Ok(ctx) => (
StatusCode::CREATED,
format!("Job [{jid}] was successfully added"),
format!("Job [{ctx:?}] was successfully added"),
),
Err(e) => (
StatusCode::INTERNAL_SERVER_ERROR,
Expand Down
2 changes: 1 addition & 1 deletion examples/redis-deadpool/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ async fn main() -> Result<()> {
.build_fn(send_email);

Monitor::<TokioExecutor>::new()
.register_with_count(2, worker)
.register(worker)
.shutdown_timeout(Duration::from_millis(5000))
.run_with_signal(async {
tokio::signal::ctrl_c().await?;
Expand Down
57 changes: 34 additions & 23 deletions examples/redis-mq-example/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,17 @@ use std::{fmt::Debug, marker::PhantomData, time::Duration};

use apalis::{layers::tracing::TraceLayer, prelude::*};

use apalis_redis::{self, Config, RedisJob};
use apalis_redis::{self, Config};

use apalis_core::{
codec::json::JsonCodec,
layers::{Ack, AckLayer},
response::Response,
};
use email_service::{send_email, Email};
use futures::{channel::mpsc, SinkExt};
use rsmq_async::{Rsmq, RsmqConnection, RsmqError};
use serde::{de::DeserializeOwned, Serialize};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use tokio::time::sleep;
use tracing::{error, info};

Expand All @@ -22,6 +23,18 @@ struct RedisMq<T, C = JsonCodec<Vec<u8>>> {
codec: PhantomData<C>,
}

#[derive(Clone, Debug, Serialize, Deserialize, Default)]
pub struct RedisMqContext {
max_attempts: usize,
message_id: String,
}

impl<Req> FromRequest<Request<Req, RedisMqContext>> for RedisMqContext {
fn from_request(req: &Request<Req, RedisMqContext>) -> Result<Self, Error> {
Ok(req.parts.context.clone())
}
}

// Manually implement Clone for RedisMq
impl<T, C> Clone for RedisMq<T, C> {
fn clone(&self) -> Self {
Expand All @@ -34,32 +47,30 @@ impl<T, C> Clone for RedisMq<T, C> {
}
}

impl<M, C, Res> Backend<Request<M>, Res> for RedisMq<M, C>
impl<Req, C, Res> Backend<Request<Req, RedisMqContext>, Res> for RedisMq<Req, C>
where
M: Send + DeserializeOwned + 'static,
Req: Send + DeserializeOwned + 'static,
C: Codec<Compact = Vec<u8>>,
{
type Stream = RequestStream<Request<M>>;
type Stream = RequestStream<Request<Req, RedisMqContext>>;

type Layer = AckLayer<Self, M, Res>;
type Layer = AckLayer<Self, Req, RedisMqContext, Res>;

fn poll<Svc>(mut self, _worker_id: WorkerId) -> Poller<Self::Stream, Self::Layer> {
let (mut tx, rx) = mpsc::channel(self.config.get_buffer_size());
let stream: RequestStream<Request<M>> = Box::pin(rx);
let stream: RequestStream<Request<Req, RedisMqContext>> = Box::pin(rx);
let layer = AckLayer::new(self.clone());
let heartbeat = async move {
loop {
sleep(*self.config.get_poll_interval()).await;
let msg: Option<Request<M>> = self
let msg: Option<Request<Req, RedisMqContext>> = self
.conn
.receive_message(self.config.get_namespace(), None)
.await
.unwrap()
.map(|r| {
let mut req: Request<M> = C::decode::<RedisJob<M>>(r.message)
.map_err(Into::into)
.unwrap()
.into();
let mut req: Request<Req, RedisMqContext> =
C::decode(r.message).map_err(Into::into).unwrap();
req.insert(r.id);
req
});
Expand All @@ -76,18 +87,20 @@ where
Res: Debug + Send + Sync,
C: Send,
{
type Context = String;
type Context = RedisMqContext;

type AckError = RsmqError;

async fn ack(
&mut self,
ctx: &Self::Context,
_res: &Result<Res, apalis_core::error::Error>,
res: &Response<Res>,
) -> Result<(), Self::AckError> {
self.conn
.delete_message(self.config.get_namespace(), ctx)
.await?;
if res.is_success() || res.attempt.current() >= ctx.max_attempts {
self.conn
.delete_message(self.config.get_namespace(), &ctx.message_id)
.await?;
}
Ok(())
}
}
Expand All @@ -100,7 +113,7 @@ where
type Error = RsmqError;

async fn enqueue(&mut self, message: Message) -> Result<(), Self::Error> {
let bytes = C::encode(&RedisJob::new(message, Default::default()))
let bytes = C::encode(&Request::<Message, RedisMqContext>::new(message))
.map_err(Into::into)
.unwrap();
self.conn
Expand All @@ -115,11 +128,9 @@ where
.receive_message(self.config.get_namespace(), None)
.await?
.map(|r| {
let req: Request<Message> = C::decode::<RedisJob<Message>>(r.message)
.map_err(Into::into)
.unwrap()
.into();
req.take()
let req: Request<Message, RedisMqContext> =
C::decode(r.message).map_err(Into::into).unwrap();
req.args
}))
}

Expand Down
2 changes: 1 addition & 1 deletion examples/redis-with-msg-pack/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ async fn main() -> Result<()> {
.build_fn(send_email);

Monitor::<TokioExecutor>::new()
.register_with_count(2, worker)
.register(worker)
.shutdown_timeout(Duration::from_millis(5000))
.run_with_signal(async {
tokio::signal::ctrl_c().await?;
Expand Down
Loading

0 comments on commit 7a496ad

Please sign in to comment.