diff --git a/Makefile b/Makefile new file mode 100644 index 000000000..aaadc8854 --- /dev/null +++ b/Makefile @@ -0,0 +1,18 @@ +# Run the backend without worker mode, i.e. only enabling single-shot +# verifications via the /v1/check_email endpoint. +.PHONY: run +run: + cd backend && cargo run --bin reacher_backend + + +# Run the backend with worker mode on. This enables the /v1/bulk endpoints. +# Make sure to have a Postgres DB and a RabbitMQ instance running. +.PHONY: run-with-worker +run-with-worker: export RCH__WORKER__ENABLE=true +run-with-worker: export RCH__WORKER__RABBITMQ__URL=amqp://guest:guest@localhost:5672 +run-with-worker: export RCH__STORAGE__POSTGRES__DB_URL=postgresql://localhost/reacherdb +run-with-worker: run + +.PHONY: run-with-commercial-license-trial +run-with-commercial-license-trial: export RCH__COMMERCIAL_LICENSE_TRIAL__URL=http://localhost:3000/api/v1/commercial_license_trial +run-with-commercial-license-trial: run \ No newline at end of file diff --git a/backend/backend_config.toml b/backend/backend_config.toml index 6a8978323..aee353a2f 100644 --- a/backend/backend_config.toml +++ b/backend/backend_config.toml @@ -88,12 +88,47 @@ hotmailb2c = "headless" # recommended. yahoo = "headless" +# Throttle the maximum number of requests per second, per minute, per hour, and +# per day for this worker. +# All fields are optional; comment them out to disable the limit. +# +# We however recommend setting the throttle for at least the per-minute and +# per-day limits to prevent the IPs from being blocked by the email providers. +# The default values are set to 60 requests per minute and 10,000 requests per +# day. +# +# Important: these throttle configurations only apply to /v1/* endpoints, and +# not to the previous /v0/check_email endpoint. The latter endpoint always +# executes the verification immediately, regardless of the throttle settings. +# +# Env variables: +# - RCH__THROTTLE__MAX_REQUESTS_PER_SECOND +# - RCH__THROTTLE__MAX_REQUESTS_PER_MINUTE +# - RCH__THROTTLE__MAX_REQUESTS_PER_HOUR +# - RCH__THROTTLE__MAX_REQUESTS_PER_DAY +[throttle] +# max_requests_per_second = 20 +max_requests_per_minute = 60 +# max_requests_per_hour = 1000 +max_requests_per_day = 10000 + +# Configuration for a queue-based architecture for Reacher. This feature is +# currently in **beta**. The queue-based architecture allows Reacher to scale +# horizontally by running multiple workers that consume emails from a RabbitMQ +# queue. +# +# To enable the queue-based architecture, set the "enable" field to "true" and +# configure the RabbitMQ connection below. The "concurrency" field specifies +# the number of concurrent emails to verify for this worker. +# +# For more information, see the documentation at: +# https://docs.reacher.email/self-hosting/scaling-for-production [worker] # Enable the worker to consume emails from the RabbitMQ queues. If set, the # RabbitMQ configuration below must be set as well. # # Env variable: RCH__WORKER__ENABLE -enable = true +enable = false # RabbitMQ configuration. [worker.rabbitmq] @@ -105,25 +140,6 @@ url = "amqp://guest:guest@localhost:5672" # Env variable: RCH__WORKER__RABBITMQ__CONCURRENCY concurrency = 5 -# Throttle the maximum number of requests per second, per minute, per hour, and -# per day for this worker. -# All fields are optional; comment them out to disable the limit. -# -# Important: these throttle configurations only apply to /v1/* endpoints, and -# not to the previous /v0/check_email endpoint. The latter endpoint always -# executes the verification immediately, regardless of the throttle settings. -# -# Env variables: -# - RCH__WORKER__THROTTLE__MAX_REQUESTS_PER_SECOND -# - RCH__WORKER__THROTTLE__MAX_REQUESTS_PER_MINUTE -# - RCH__WORKER__THROTTLE__MAX_REQUESTS_PER_HOUR -# - RCH__WORKER__THROTTLE__MAX_REQUESTS_PER_DAY -[worker.throttle] -# max_requests_per_second = 20 -# max_requests_per_minute = 100 -# max_requests_per_hour = 1000 -# max_requests_per_day = 20000 - # Below are the configurations for the storage of the email verification # results. We currently support the following storage backends: # - Postgres diff --git a/backend/openapi.json b/backend/openapi.json index a37655f5b..9bd647ab3 100644 --- a/backend/openapi.json +++ b/backend/openapi.json @@ -702,7 +702,7 @@ "duration": { "$ref": "#/components/schemas/Duration" }, - "server_name": { + "backend_name": { "type": "string", "x-stoplight": { "id": "2jrbdecvqh4t5" @@ -717,7 +717,7 @@ "start_time", "end_time", "duration", - "server_name", + "backend_name", "smtp" ] }, diff --git a/backend/src/config.rs b/backend/src/config.rs index 65abb133d..a9455d738 100644 --- a/backend/src/config.rs +++ b/backend/src/config.rs @@ -15,6 +15,7 @@ // along with this program. If not, see . use crate::storage::{postgres::PostgresStorage, StorageAdapter}; +use crate::throttle::ThrottleManager; use crate::worker::do_work::TaskWebhook; use crate::worker::setup_rabbit_mq; use anyhow::{bail, Context}; @@ -29,7 +30,7 @@ use sqlx::PgPool; use std::sync::Arc; use tracing::warn; -#[derive(Debug, Default, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize)] pub struct BackendConfig { /// Name of the backend. pub backend_name: String, @@ -65,36 +66,61 @@ pub struct BackendConfig { /// Whether to enable the Commercial License Trial. Setting this to true pub commercial_license_trial: Option, + /// Throttle configuration for all requests + pub throttle: ThrottleConfig, + // Internal fields, not part of the configuration. #[serde(skip)] channel: Option>, #[serde(skip)] storage_adapter: Arc, + + #[serde(skip)] + throttle_manager: Arc, } impl BackendConfig { + /// Create an empty BackendConfig. This is useful for testing purposes. + pub fn empty() -> Self { + Self { + backend_name: "".to_string(), + from_email: "".to_string(), + hello_name: "".to_string(), + webdriver_addr: "".to_string(), + proxy: None, + verif_method: VerifMethodConfig::default(), + http_host: "127.0.0.1".to_string(), + http_port: 8080, + header_secret: None, + smtp_timeout: None, + sentry_dsn: None, + worker: WorkerConfig::default(), + storage: Some(StorageConfig::Noop), + commercial_license_trial: None, + throttle: ThrottleConfig::new_without_throttle(), + channel: None, + storage_adapter: Arc::new(StorageAdapter::Noop), + throttle_manager: Arc::new( + ThrottleManager::new(ThrottleConfig::new_without_throttle()), + ), + } + } + /// Get the worker configuration. /// /// # Panics /// /// Panics if the worker configuration is missing. pub fn must_worker_config(&self) -> Result { - match ( - self.worker.enable, - &self.worker.throttle, - &self.worker.rabbitmq, - &self.channel, - ) { - (true, Some(throttle), Some(rabbitmq), Some(channel)) => Ok(MustWorkerConfig { + match (self.worker.enable, &self.worker.rabbitmq, &self.channel) { + (true, Some(rabbitmq), Some(channel)) => Ok(MustWorkerConfig { channel: channel.clone(), - throttle: throttle.clone(), rabbitmq: rabbitmq.clone(), - webhook: self.worker.webhook.clone(), }), - (true, _, _, _) => bail!("Worker configuration is missing"), + (true, _, _) => bail!("Worker configuration is missing"), _ => bail!("Calling must_worker_config on a non-worker backend"), } } @@ -126,6 +152,9 @@ impl BackendConfig { }; self.channel = channel; + // Initialize throttle manager + self.throttle_manager = Arc::new(ThrottleManager::new(self.throttle.clone())); + Ok(()) } @@ -142,6 +171,10 @@ impl BackendConfig { StorageAdapter::Noop => None, } } + + pub fn get_throttle_manager(&self) -> Arc { + self.throttle_manager.clone() + } } #[derive(Debug, Default, Deserialize, Clone, Serialize)] @@ -159,9 +192,6 @@ pub struct VerifMethodConfig { #[derive(Debug, Default, Deserialize, Clone, Serialize)] pub struct WorkerConfig { pub enable: bool, - - /// Throttle configuration for the worker. - pub throttle: Option, pub rabbitmq: Option, /// Optional webhook configuration to send email verification results. pub webhook: Option, @@ -172,8 +202,6 @@ pub struct WorkerConfig { #[derive(Debug, Clone)] pub struct MustWorkerConfig { pub channel: Arc, - - pub throttle: ThrottleConfig, pub rabbitmq: RabbitMQConfig, pub webhook: Option, } @@ -185,7 +213,7 @@ pub struct RabbitMQConfig { pub concurrency: u16, } -#[derive(Debug, Deserialize, Clone, Serialize)] +#[derive(Debug, Default, Deserialize, Clone, Serialize)] pub struct ThrottleConfig { pub max_requests_per_second: Option, pub max_requests_per_minute: Option, @@ -236,8 +264,13 @@ pub async fn load_config() -> Result { let cfg = cfg.build()?.try_deserialize::()?; - if !cfg.worker.enable && (cfg.worker.rabbitmq.is_some() || cfg.worker.throttle.is_some()) { - warn!(target: LOG_TARGET, "worker.enable is set to false, ignoring throttling and concurrency settings.") + if cfg.worker.enable { + warn!(target: LOG_TARGET, "The worker feature is currently in beta. Please send any feedback to amaury@reacher.email."); + + match &cfg.storage { + Some(StorageConfig::Postgres(_)) => {} + _ => bail!("When worker mode is enabled, a Postgres database must be configured."), + } } Ok(cfg) diff --git a/backend/src/http/error.rs b/backend/src/http/error.rs index 108e39268..48db2cd0e 100644 --- a/backend/src/http/error.rs +++ b/backend/src/http/error.rs @@ -28,7 +28,7 @@ pub trait DisplayDebug: fmt::Display + Debug + Sync + Send {} impl DisplayDebug for T {} /// Struct describing an error response. -#[derive(Debug)] +#[derive(Debug, thiserror::Error)] pub struct ReacherResponseError { pub code: StatusCode, pub error: Box, @@ -121,7 +121,14 @@ impl From for ReacherResponseError { impl From for ReacherResponseError { fn from(e: reqwest::Error) -> Self { - ReacherResponseError::new(StatusCode::INTERNAL_SERVER_ERROR, e) + ReacherResponseError::new( + e.status() + .map(|s| s.as_u16()) + .map(StatusCode::from_u16) + .and_then(Result::ok) + .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR), + e, + ) } } diff --git a/backend/src/http/mod.rs b/backend/src/http/mod.rs index be5fe373b..4aca92fa2 100644 --- a/backend/src/http/mod.rs +++ b/backend/src/http/mod.rs @@ -23,14 +23,12 @@ use crate::config::BackendConfig; use check_if_email_exists::LOG_TARGET; use error::handle_rejection; pub use error::ReacherResponseError; -use sqlx::PgPool; use sqlxmq::JobRunnerHandle; use std::env; use std::net::IpAddr; use std::sync::Arc; use tracing::info; pub use v0::check_email::post::CheckEmailRequest; -use warp::http::StatusCode; use warp::Filter; pub fn create_routes( @@ -101,24 +99,6 @@ pub async fn run_warp_server( Ok(runner) } -/// Warp filter to add the database pool to the handler. If the pool is not -/// configured, it will return an error. -pub fn with_db( - pg_pool: Option, -) -> impl Filter + Clone { - warp::any().and_then(move || { - let pool = pg_pool.clone(); - async move { - pool.ok_or_else(|| { - warp::reject::custom(ReacherResponseError::new( - StatusCode::SERVICE_UNAVAILABLE, - "Please configure a Postgres database on Reacher before calling this endpoint", - )) - }) - } - }) -} - /// The header which holds the Reacher backend secret. pub const REACHER_SECRET_HEADER: &str = "x-reacher-secret"; diff --git a/backend/src/http/v1/bulk/get_progress.rs b/backend/src/http/v1/bulk/get_progress.rs index a77ff973a..f7511703c 100644 --- a/backend/src/http/v1/bulk/get_progress.rs +++ b/backend/src/http/v1/bulk/get_progress.rs @@ -25,8 +25,9 @@ use sqlx::PgPool; use warp::http::StatusCode; use warp::Filter; +use super::with_worker_db; use crate::config::BackendConfig; -use crate::http::{with_db, ReacherResponseError}; +use crate::http::ReacherResponseError; /// NOTE: Type conversions from postgres to rust types /// are according to the table given by @@ -149,7 +150,7 @@ pub fn v1_get_bulk_job_progress( ) -> impl Filter + Clone { warp::path!("v1" / "bulk" / i32) .and(warp::get()) - .and(with_db(config.get_pg_pool())) + .and(with_worker_db(config)) .and_then(http_handler) // View access logs by setting `RUST_LOG=reacher`. .with(warp::log(LOG_TARGET)) diff --git a/backend/src/http/v1/bulk/get_results/mod.rs b/backend/src/http/v1/bulk/get_results/mod.rs index ad11eb401..54b5049ee 100644 --- a/backend/src/http/v1/bulk/get_results/mod.rs +++ b/backend/src/http/v1/bulk/get_results/mod.rs @@ -25,8 +25,9 @@ use std::{convert::TryInto, sync::Arc}; use warp::http::StatusCode; use warp::Filter; +use super::with_worker_db; use crate::config::BackendConfig; -use crate::http::{with_db, ReacherResponseError}; +use crate::http::ReacherResponseError; use csv_helper::{CsvResponse, CsvWrapper}; mod csv_helper; @@ -180,7 +181,7 @@ pub fn v1_get_bulk_job_results( ) -> impl Filter + Clone { warp::path!("v1" / "bulk" / i32 / "results") .and(warp::get()) - .and(with_db(config.get_pg_pool())) + .and(with_worker_db(config)) .and(warp::query::()) .and_then(http_handler) // View access logs by setting `RUST_LOG=reacher_backend`. diff --git a/backend/src/http/v1/bulk/mod.rs b/backend/src/http/v1/bulk/mod.rs index ad83e7be1..93dd9f1b6 100644 --- a/backend/src/http/v1/bulk/mod.rs +++ b/backend/src/http/v1/bulk/mod.rs @@ -14,6 +14,39 @@ // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . +use crate::config::BackendConfig; +use crate::http::ReacherResponseError; +use sqlx::PgPool; +use std::sync::Arc; +use warp::http::StatusCode; +use warp::Filter; + pub mod get_progress; pub mod get_results; pub mod post; + +/// Warp filter to add the database pool to the handler. This function should +/// only be used for /v1/bulk endpoints, which are only enabled when worker mode +/// is enabled. +pub fn with_worker_db( + config: Arc, +) -> impl Filter + Clone { + warp::any().and_then(move || { + let config = Arc::clone(&config); + let pool = config.get_pg_pool(); + async move { + if !config.worker.enable { + return Err(warp::reject::custom(ReacherResponseError::new( + StatusCode::SERVICE_UNAVAILABLE, + "Please enable worker mode on Reacher before calling this endpoint", + ))); + } + pool.ok_or_else(|| { + warp::reject::custom(ReacherResponseError::new( + StatusCode::SERVICE_UNAVAILABLE, + "Please configure a Postgres database on Reacher before calling this endpoint", + )) + }) + } + }) +} diff --git a/backend/src/http/v1/bulk/post.rs b/backend/src/http/v1/bulk/post.rs index abe05e4a0..577ab78e7 100644 --- a/backend/src/http/v1/bulk/post.rs +++ b/backend/src/http/v1/bulk/post.rs @@ -29,10 +29,10 @@ use tracing::{debug, info}; use warp::http::StatusCode; use warp::Filter; +use super::with_worker_db; use crate::config::BackendConfig; use crate::http::check_header; use crate::http::v0::check_email::post::with_config; -use crate::http::with_db; use crate::http::CheckEmailRequest; use crate::http::ReacherResponseError; use crate::worker::consume::CHECK_EMAIL_QUEUE; @@ -154,7 +154,7 @@ pub fn v1_create_bulk_job( .and(warp::post()) .and(check_header(Arc::clone(&config))) .and(with_config(Arc::clone(&config))) - .and(with_db(config.get_pg_pool())) + .and(with_worker_db(config)) // When accepting a body, we want a JSON body (and to reject huge // payloads)... // TODO: Configure max size limit for a bulk job diff --git a/backend/src/http/v1/check_email/post.rs b/backend/src/http/v1/check_email/post.rs index 16e2a595d..602efc8ae 100644 --- a/backend/src/http/v1/check_email/post.rs +++ b/backend/src/http/v1/check_email/post.rs @@ -24,6 +24,7 @@ use lapin::options::{ use lapin::types::FieldTable; use lapin::BasicProperties; use std::sync::Arc; +use tracing::info; use warp::http::StatusCode; use warp::{http, Filter}; @@ -36,55 +37,49 @@ use crate::worker::consume::MAX_QUEUE_PRIORITY; use crate::worker::do_work::{CheckEmailJobId, CheckEmailTask}; use crate::worker::single_shot::SingleShotReply; -/// The main endpoint handler that implements the logic of this route. -async fn http_handler( +async fn handle_without_worker( config: Arc, - body: CheckEmailRequest, -) -> Result { - // The to_email field must be present - if body.to_email.is_empty() { - return Err(ReacherResponseError::new( - http::StatusCode::BAD_REQUEST, - "to_email field is required.", + body: &CheckEmailRequest, + throttle_manager: &crate::throttle::ThrottleManager, +) -> Result, warp::Rejection> { + info!(target: LOG_TARGET, email=body.to_email, "Starting verification"); + let input = body.to_check_email_input(Arc::clone(&config)); + let result = check_email(&input).await; + let result_ok = Ok(result); + + // Increment counters after successful verification + throttle_manager.increment_counters().await; + + // Store the result regardless of how we got it + let storage = Arc::clone(&config).get_storage_adapter(); + storage + .store( + &CheckEmailTask { + input: body.to_check_email_input(Arc::clone(&config)), + job_id: CheckEmailJobId::SingleShot, + webhook: None, + }, + &result_ok, + storage.get_extra(), ) - .into()); - } + .map_err(ReacherResponseError::from) + .await?; - // If worker mode is disabled, we do a direct check, and skip rabbitmq. - if !config.worker.enable { - let input = body.to_check_email_input(Arc::clone(&config)); - let result = check_email(&input).await; - let value = Ok(result); - - // Also store the result "manually", since we don't have a worker. - let storage = config.get_storage_adapter(); - storage - .store( - &CheckEmailTask { - input: input.clone(), - job_id: CheckEmailJobId::SingleShot, - webhook: None, - }, - &value, - storage.get_extra(), - ) - .map_err(ReacherResponseError::from) - .await?; - - // If we're in the Commercial License Trial, we also store the - // result by sending it to back to Reacher. - send_to_reacher(Arc::clone(&config), &input.to_email, &value) - .await - .map_err(ReacherResponseError::from)?; - - let result_bz = serde_json::to_vec(&value).map_err(ReacherResponseError::from)?; - - return Ok(warp::reply::with_header( - result_bz, - "Content-Type", - "application/json", - )); - } + // If we're in the Commercial License Trial, we also store the + // result by sending it to back to Reacher. + send_to_reacher(Arc::clone(&config), &body.to_email, &result_ok) + .await + .map_err(ReacherResponseError::from)?; + + let result = result_ok.unwrap(); + info!(target: LOG_TARGET, email=body.to_email, is_reachable=?result.is_reachable, "Done verification"); + Ok(serde_json::to_vec(&result).map_err(ReacherResponseError::from)?) +} + +async fn handle_with_worker( + config: Arc, + body: &CheckEmailRequest, +) -> Result, warp::Rejection> { let channel = config .must_worker_config() .map_err(ReacherResponseError::from)? @@ -116,7 +111,7 @@ async fn http_handler( publish_task( channel.clone(), CheckEmailTask { - input: body.to_check_email_input(config), + input: body.to_check_email_input(config.clone()), job_id: CheckEmailJobId::SingleShot, webhook: None, }, @@ -156,11 +151,7 @@ async fn http_handler( match single_shot_response { SingleShotReply::Ok(body) => { - return Ok(warp::reply::with_header( - body, - "Content-Type", - "application/json", - )); + return Ok(body); } SingleShotReply::Err((e, code)) => { let status_code = @@ -189,6 +180,46 @@ async fn http_handler( .into()) } +/// The main endpoint handler that implements the logic of this route. +async fn http_handler( + config: Arc, + body: CheckEmailRequest, +) -> Result { + // The to_email field must be present + if body.to_email.is_empty() { + return Err(ReacherResponseError::new( + http::StatusCode::BAD_REQUEST, + "to_email field is required.", + ) + .into()); + } + + // Check throttle regardless of worker mode + let throttle_manager = config.get_throttle_manager(); + if let Some(throttle_result) = throttle_manager.check_throttle().await { + return Err(ReacherResponseError::new( + http::StatusCode::TOO_MANY_REQUESTS, + format!( + "Rate limit {} exceeded, please wait {:?}", + throttle_result.limit_type, throttle_result.delay + ), + ) + .into()); + } + + let result_bz = if !config.worker.enable { + handle_without_worker(Arc::clone(&config), &body, &throttle_manager).await? + } else { + handle_with_worker(Arc::clone(&config), &body).await? + }; + + Ok(warp::reply::with_header( + result_bz, + "Content-Type", + "application/json", + )) +} + /// Create the `POST /v1/check_email` endpoint. pub fn v1_check_email( config: Arc, diff --git a/backend/src/lib.rs b/backend/src/lib.rs index 71c140ccf..a44ccc9d8 100644 --- a/backend/src/lib.rs +++ b/backend/src/lib.rs @@ -17,6 +17,7 @@ pub mod config; pub mod http; mod storage; +pub mod throttle; pub mod worker; const CARGO_PKG_VERSION: &str = env!("CARGO_PKG_VERSION"); diff --git a/backend/src/storage/commercial_license_trial.rs b/backend/src/storage/commercial_license_trial.rs index 4c84f5e65..551c130ee 100644 --- a/backend/src/storage/commercial_license_trial.rs +++ b/backend/src/storage/commercial_license_trial.rs @@ -15,10 +15,12 @@ // along with this program. If not, see . use crate::config::{BackendConfig, CommercialLicenseTrialConfig}; +use crate::http::ReacherResponseError; use crate::worker::do_work::TaskError; use check_if_email_exists::{CheckEmailOutput, LOG_TARGET}; use std::sync::Arc; use tracing::debug; +use warp::http::StatusCode; /// If we're in the Commercial License Trial, we also store the /// result by sending it to back to Reacher. @@ -26,7 +28,7 @@ pub async fn send_to_reacher( config: Arc, email: &str, worker_output: &Result, -) -> Result<(), reqwest::Error> { +) -> Result<(), ReacherResponseError> { if let Some(CommercialLicenseTrialConfig { api_token, url }) = &config.commercial_license_trial { let res = reqwest::Client::new() @@ -35,6 +37,19 @@ pub async fn send_to_reacher( .json(worker_output) .send() .await?; + + // Error if not 2xx status code + if !res.status().is_success() { + let status = StatusCode::from_u16(res.status().as_u16())?; + let body: serde_json::Value = res.json().await?; + + // Extract error message from the "error" field, if it exists, or + // else just return the whole body. + let error_body = body.get("error").unwrap_or(&body).to_owned(); + + return Err(ReacherResponseError::new(status, error_body)); + } + let res = res.text().await?; debug!(target: LOG_TARGET, email=email, res=res, "Sent result to Reacher Commercial License Trial"); } diff --git a/backend/src/throttle.rs b/backend/src/throttle.rs new file mode 100644 index 000000000..90c6126ea --- /dev/null +++ b/backend/src/throttle.rs @@ -0,0 +1,239 @@ +// Reacher - Email Verification +// Copyright (C) 2018-2023 Reacher + +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. + +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. + +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +use crate::config::ThrottleConfig; +use std::fmt; +use std::sync::Arc; +use std::time::{Duration, Instant}; +use tokio::sync::Mutex; + +/// Represents the type of throttle limit that was hit. +/// - `PerSecond`: The per-second request limit was exceeded +/// - `PerMinute`: The per-minute request limit was exceeded +/// - `PerHour`: The per-hour request limit was exceeded +/// - `PerDay`: The per-day request limit was exceeded +#[derive(Debug, Clone, PartialEq)] +pub enum ThrottleLimit { + PerSecond, + PerMinute, + PerHour, + PerDay, +} + +impl fmt::Display for ThrottleLimit { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::PerSecond => write!(f, "per second"), + Self::PerMinute => write!(f, "per minute"), + Self::PerHour => write!(f, "per hour"), + Self::PerDay => write!(f, "per day"), + } + } +} + +/// Represents the result of a throttle check. +/// - `delay`: How long to wait before making the next request +/// - `limit_type`: Which rate limit was exceeded (second/minute/hour/day) +#[derive(Debug, Clone, PartialEq)] +pub struct ThrottleResult { + pub delay: Duration, + pub limit_type: ThrottleLimit, +} + +#[derive(Debug, Clone)] +struct Throttle { + requests_per_second: u32, + requests_per_minute: u32, + requests_per_hour: u32, + requests_per_day: u32, + last_reset_second: Instant, + last_reset_minute: Instant, + last_reset_hour: Instant, + last_reset_day: Instant, +} + +impl Default for Throttle { + fn default() -> Self { + let now = Instant::now(); + Self { + requests_per_second: 0, + requests_per_minute: 0, + requests_per_hour: 0, + requests_per_day: 0, + last_reset_second: now, + last_reset_minute: now, + last_reset_hour: now, + last_reset_day: now, + } + } +} + +impl Throttle { + fn new() -> Self { + let now = Instant::now(); + Throttle { + requests_per_second: 0, + requests_per_minute: 0, + requests_per_hour: 0, + requests_per_day: 0, + last_reset_second: now, + last_reset_minute: now, + last_reset_hour: now, + last_reset_day: now, + } + } + + fn reset_if_needed(&mut self) { + let now = Instant::now(); + + // Reset per-second counter + if now.duration_since(self.last_reset_second) >= Duration::from_secs(1) { + self.requests_per_second = 0; + self.last_reset_second = now; + } + + // Reset per-minute counter + if now.duration_since(self.last_reset_minute) >= Duration::from_secs(60) { + self.requests_per_minute = 0; + self.last_reset_minute = now; + } + + // Reset per-hour counter + if now.duration_since(self.last_reset_hour) >= Duration::from_secs(3600) { + self.requests_per_hour = 0; + self.last_reset_hour = now; + } + + // Reset per-day counter + if now.duration_since(self.last_reset_day) >= Duration::from_secs(86400) { + self.requests_per_day = 0; + self.last_reset_day = now; + } + } + + fn increment_counters(&mut self) { + self.requests_per_second += 1; + self.requests_per_minute += 1; + self.requests_per_hour += 1; + self.requests_per_day += 1; + } + + fn should_throttle(&self, config: &ThrottleConfig) -> Option { + let now = Instant::now(); + + if let Some(max_per_second) = config.max_requests_per_second { + if self.requests_per_second >= max_per_second { + return Some(ThrottleResult { + delay: Duration::from_secs(1) - now.duration_since(self.last_reset_second), + limit_type: ThrottleLimit::PerSecond, + }); + } + } + + if let Some(max_per_minute) = config.max_requests_per_minute { + if self.requests_per_minute >= max_per_minute { + return Some(ThrottleResult { + delay: Duration::from_secs(60) - now.duration_since(self.last_reset_minute), + limit_type: ThrottleLimit::PerMinute, + }); + } + } + + if let Some(max_per_hour) = config.max_requests_per_hour { + if self.requests_per_hour >= max_per_hour { + return Some(ThrottleResult { + delay: Duration::from_secs(3600) - now.duration_since(self.last_reset_hour), + limit_type: ThrottleLimit::PerHour, + }); + } + } + + if let Some(max_per_day) = config.max_requests_per_day { + if self.requests_per_day >= max_per_day { + return Some(ThrottleResult { + delay: Duration::from_secs(86400) - now.duration_since(self.last_reset_day), + limit_type: ThrottleLimit::PerDay, + }); + } + } + + None + } +} + +#[derive(Debug, Default)] +pub struct ThrottleManager { + inner: Arc>, + config: ThrottleConfig, +} + +impl ThrottleManager { + pub fn new(config: ThrottleConfig) -> Self { + Self { + inner: Arc::new(Mutex::new(Throttle::new())), + config, + } + } + + pub async fn check_throttle(&self) -> Option { + let mut throttle = self.inner.lock().await; + throttle.reset_if_needed(); + throttle.should_throttle(&self.config) + } + + pub async fn increment_counters(&self) { + self.inner.lock().await.increment_counters(); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tokio::time::{sleep, Duration}; + + #[tokio::test] + async fn test_throttle_limits() { + // Create config with low limits for testing + let config = ThrottleConfig { + max_requests_per_second: Some(2), + max_requests_per_minute: Some(5), + max_requests_per_hour: Some(10), + max_requests_per_day: Some(20), + }; + + let manager = ThrottleManager::new(config); + + // Should allow initial requests + assert_eq!(manager.check_throttle().await, None); + manager.increment_counters().await; + assert_eq!(manager.check_throttle().await, None); + manager.increment_counters().await; + + // Should throttle after hitting per-second limit + let throttle_result = manager.check_throttle().await; + assert!(throttle_result.is_some()); + assert_eq!( + throttle_result.unwrap().limit_type, + ThrottleLimit::PerSecond + ); + + // Wait 1 second for per-second counter to reset + sleep(Duration::from_secs(1)).await; + + // Should allow more requests + assert_eq!(manager.check_throttle().await, None); + } +} diff --git a/backend/src/worker/consume.rs b/backend/src/worker/consume.rs index 9104f33c3..b4ea712b2 100644 --- a/backend/src/worker/consume.rs +++ b/backend/src/worker/consume.rs @@ -16,7 +16,7 @@ use super::do_work::{do_check_email_work, CheckEmailTask, TaskError}; use super::single_shot::send_single_shot_reply; -use crate::config::{BackendConfig, RabbitMQConfig, ThrottleConfig}; +use crate::config::{BackendConfig, RabbitMQConfig}; use crate::worker::do_work::CheckEmailJobId; use anyhow::Context; use check_if_email_exists::LOG_TARGET; @@ -24,9 +24,7 @@ use futures::stream::StreamExt; use lapin::{options::*, types::FieldTable, Channel, Connection, ConnectionProperties}; use sentry_anyhow::capture_anyhow; use std::sync::Arc; -use std::time::{Duration, Instant}; -use tokio::sync::Mutex; -use tracing::{debug, error, info}; +use tracing::{debug, error, info, trace}; /// Our RabbitMQ only has one queue: "check_email". pub const CHECK_EMAIL_QUEUE: &str = "check_email"; @@ -96,12 +94,9 @@ async fn consume_check_email(config: Arc) -> Result<(), anyhow::E let config_clone = Arc::clone(&config); let worker_config = config_clone.must_worker_config()?; let channel = worker_config.channel; - - let throttle = Arc::new(Mutex::new(Throttle::new())); + let throttle = config.get_throttle_manager(); tokio::spawn(async move { - let worker_config = config_clone.must_worker_config()?; - let mut consumer = channel .basic_consume( CHECK_EMAIL_QUEUE, @@ -117,16 +112,11 @@ async fn consume_check_email(config: Arc) -> Result<(), anyhow::E let payload = serde_json::from_slice::(&delivery.data)?; debug!(target: LOG_TARGET, email=?payload.input.to_email, "Consuming message"); - // Reset throttle counters if needed - throttle.lock().await.reset_if_needed(); - // Check if we should throttle before fetching the next message - if let Some(wait_duration) = throttle - .lock() - .await - .should_throttle(&worker_config.throttle) - { - info!(target: LOG_TARGET, wait=?wait_duration, email=?payload.input.to_email, "Too many requests, throttling"); + if let Some(throttle_result) = throttle.check_throttle().await { + // This line below will log every time the worker fetches from + // RabbitMQ. It's noisy + trace!(target: LOG_TARGET, wait=?throttle_result.delay, email=?payload.input.to_email, "Too many requests {}, throttling", throttle_result.limit_type); // For single-shot tasks, we return an error early, so that the user knows they need to retry. match payload.job_id { @@ -139,7 +129,7 @@ async fn consume_check_email(config: Arc) -> Result<(), anyhow::E send_single_shot_reply( Arc::clone(&channel), &delivery, - &Err(TaskError::Throttle(wait_duration)), + &Err(TaskError::Throttle(throttle_result)), ) .await?; } @@ -170,7 +160,7 @@ async fn consume_check_email(config: Arc) -> Result<(), anyhow::E }); // Increment throttle counters once we spawn the task - throttle.lock().await.increment_counters(); + throttle.increment_counters().await; } Ok::<(), anyhow::Error>(()) @@ -178,96 +168,3 @@ async fn consume_check_email(config: Arc) -> Result<(), anyhow::E Ok(()) } - -#[derive(Clone)] -struct Throttle { - requests_per_second: u32, - requests_per_minute: u32, - requests_per_hour: u32, - requests_per_day: u32, - last_reset_second: Instant, - last_reset_minute: Instant, - last_reset_hour: Instant, - last_reset_day: Instant, -} - -impl Throttle { - fn new() -> Self { - let now = Instant::now(); - Throttle { - requests_per_second: 0, - requests_per_minute: 0, - requests_per_hour: 0, - requests_per_day: 0, - last_reset_second: now, - last_reset_minute: now, - last_reset_hour: now, - last_reset_day: now, - } - } - - fn reset_if_needed(&mut self) { - let now = Instant::now(); - - // Reset per-second counter - if now.duration_since(self.last_reset_second) >= Duration::from_secs(1) { - self.requests_per_second = 0; - self.last_reset_second = now; - } - - // Reset per-minute counter - if now.duration_since(self.last_reset_minute) >= Duration::from_secs(60) { - self.requests_per_minute = 0; - self.last_reset_minute = now; - } - - // Reset per-hour counter - if now.duration_since(self.last_reset_hour) >= Duration::from_secs(3600) { - self.requests_per_hour = 0; - self.last_reset_hour = now; - } - - // Reset per-day counter - if now.duration_since(self.last_reset_day) >= Duration::from_secs(86400) { - self.requests_per_day = 0; - self.last_reset_day = now; - } - } - - fn increment_counters(&mut self) { - self.requests_per_second += 1; - self.requests_per_minute += 1; - self.requests_per_hour += 1; - self.requests_per_day += 1; - } - - fn should_throttle(&self, config: &ThrottleConfig) -> Option { - let now = Instant::now(); - - if let Some(max_per_second) = config.max_requests_per_second { - if self.requests_per_second >= max_per_second { - return Some(Duration::from_secs(1) - now.duration_since(self.last_reset_second)); - } - } - - if let Some(max_per_minute) = config.max_requests_per_minute { - if self.requests_per_minute >= max_per_minute { - return Some(Duration::from_secs(60) - now.duration_since(self.last_reset_minute)); - } - } - - if let Some(max_per_hour) = config.max_requests_per_hour { - if self.requests_per_hour >= max_per_hour { - return Some(Duration::from_secs(3600) - now.duration_since(self.last_reset_hour)); - } - } - - if let Some(max_per_day) = config.max_requests_per_day { - if self.requests_per_day >= max_per_day { - return Some(Duration::from_secs(86400) - now.duration_since(self.last_reset_day)); - } - } - - None - } -} diff --git a/backend/src/worker/do_work.rs b/backend/src/worker/do_work.rs index 214ca110a..bfe2c928c 100644 --- a/backend/src/worker/do_work.rs +++ b/backend/src/worker/do_work.rs @@ -16,11 +16,11 @@ use crate::config::BackendConfig; use crate::storage::commercial_license_trial::send_to_reacher; +use crate::throttle::ThrottleResult; use crate::worker::single_shot::send_single_shot_reply; use check_if_email_exists::{ check_email, CheckEmailInput, CheckEmailOutput, Reachable, LOG_TARGET, }; -use core::time; use lapin::message::Delivery; use lapin::{options::*, Channel}; use serde::{Deserialize, Serialize}; @@ -53,7 +53,7 @@ pub enum TaskError { /// verification, as for bulk verification tasks the task will simply stay /// in the queue until one worker is ready to process it. #[error("Worker at full capacity, wait {0:?}")] - Throttle(time::Duration), + Throttle(ThrottleResult), #[error("Lapin error: {0}")] Lapin(lapin::Error), #[error("Reqwest error during webhook: {0}")] diff --git a/backend/src/worker/single_shot.rs b/backend/src/worker/single_shot.rs index 1b10c5cf2..a5804f0e5 100644 --- a/backend/src/worker/single_shot.rs +++ b/backend/src/worker/single_shot.rs @@ -47,7 +47,7 @@ impl TryFrom<&Result> for SingleShotReply { match result { Ok(output) => Ok(Self::Ok(serde_json::to_vec(output)?)), Err(TaskError::Throttle(e)) => Ok(Self::Err(( - TaskError::Throttle(*e).to_string(), + TaskError::Throttle(e.clone()).to_string(), StatusCode::TOO_MANY_REQUESTS.as_u16(), ))), Err(e) => Ok(Self::Err(( diff --git a/backend/tests/check_email.rs b/backend/tests/check_email.rs index 326133e89..93c62f42e 100644 --- a/backend/tests/check_email.rs +++ b/backend/tests/check_email.rs @@ -14,7 +14,7 @@ // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . -#[cfg(not(feature = "worker"))] +#[cfg(test)] mod tests { use std::sync::Arc; @@ -27,7 +27,7 @@ mod tests { const FOO_BAR_BAZ_RESPONSE: &str = r#"{"input":"foo@bar.baz","is_reachable":"invalid","misc":{"is_disposable":false,"is_role_account":false,"gravatar_url":null,"haveibeenpwned":null},"mx":{"accepts_mail":false,"records":[]},"smtp":{"can_connect_smtp":false,"has_full_inbox":false,"is_catch_all":false,"is_deliverable":false,"is_disabled":false},"syntax":{"address":"foo@bar.baz","domain":"bar.baz","is_valid_syntax":true,"username":"foo","normalized_email":"foo@bar.baz","suggestion":null}"#; fn create_backend_config(header_secret: &str) -> Arc { - let mut config = BackendConfig::default(); + let mut config = BackendConfig::empty(); config.header_secret = Some(header_secret.to_string()); Arc::new(config) }