From a4efecb73147e2d006bb6b9c7a8ce7dbeaa3d66f Mon Sep 17 00:00:00 2001
From: Amaury <1293565+amaury1093@users.noreply.github.com>
Date: Sat, 14 Dec 2024 14:38:18 +0100
Subject: [PATCH] refactor: Add throttle as a global config (#1547)
* cursor v1
* fix default
* cleanup
* add docs for worker
* fix bulk and worker
* add logs
* better throttle logs
* remove warn
* better errors
* better error displau
* fmt
* clippy
* fix test
---
Makefile | 18 ++
backend/backend_config.toml | 56 ++--
backend/openapi.json | 4 +-
backend/src/config.rs | 71 ++++--
backend/src/http/error.rs | 11 +-
backend/src/http/mod.rs | 20 --
backend/src/http/v1/bulk/get_progress.rs | 5 +-
backend/src/http/v1/bulk/get_results/mod.rs | 5 +-
backend/src/http/v1/bulk/mod.rs | 33 +++
backend/src/http/v1/bulk/post.rs | 4 +-
backend/src/http/v1/check_email/post.rs | 135 ++++++----
backend/src/lib.rs | 1 +
.../src/storage/commercial_license_trial.rs | 17 +-
backend/src/throttle.rs | 239 ++++++++++++++++++
backend/src/worker/consume.rs | 121 +--------
backend/src/worker/do_work.rs | 4 +-
backend/src/worker/single_shot.rs | 2 +-
backend/tests/check_email.rs | 4 +-
18 files changed, 511 insertions(+), 239 deletions(-)
create mode 100644 Makefile
create mode 100644 backend/src/throttle.rs
diff --git a/Makefile b/Makefile
new file mode 100644
index 000000000..aaadc8854
--- /dev/null
+++ b/Makefile
@@ -0,0 +1,18 @@
+# Run the backend without worker mode, i.e. only enabling single-shot
+# verifications via the /v1/check_email endpoint.
+.PHONY: run
+run:
+ cd backend && cargo run --bin reacher_backend
+
+
+# Run the backend with worker mode on. This enables the /v1/bulk endpoints.
+# Make sure to have a Postgres DB and a RabbitMQ instance running.
+.PHONY: run-with-worker
+run-with-worker: export RCH__WORKER__ENABLE=true
+run-with-worker: export RCH__WORKER__RABBITMQ__URL=amqp://guest:guest@localhost:5672
+run-with-worker: export RCH__STORAGE__POSTGRES__DB_URL=postgresql://localhost/reacherdb
+run-with-worker: run
+
+.PHONY: run-with-commercial-license-trial
+run-with-commercial-license-trial: export RCH__COMMERCIAL_LICENSE_TRIAL__URL=http://localhost:3000/api/v1/commercial_license_trial
+run-with-commercial-license-trial: run
\ No newline at end of file
diff --git a/backend/backend_config.toml b/backend/backend_config.toml
index 6a8978323..aee353a2f 100644
--- a/backend/backend_config.toml
+++ b/backend/backend_config.toml
@@ -88,12 +88,47 @@ hotmailb2c = "headless"
# recommended.
yahoo = "headless"
+# Throttle the maximum number of requests per second, per minute, per hour, and
+# per day for this worker.
+# All fields are optional; comment them out to disable the limit.
+#
+# We however recommend setting the throttle for at least the per-minute and
+# per-day limits to prevent the IPs from being blocked by the email providers.
+# The default values are set to 60 requests per minute and 10,000 requests per
+# day.
+#
+# Important: these throttle configurations only apply to /v1/* endpoints, and
+# not to the previous /v0/check_email endpoint. The latter endpoint always
+# executes the verification immediately, regardless of the throttle settings.
+#
+# Env variables:
+# - RCH__THROTTLE__MAX_REQUESTS_PER_SECOND
+# - RCH__THROTTLE__MAX_REQUESTS_PER_MINUTE
+# - RCH__THROTTLE__MAX_REQUESTS_PER_HOUR
+# - RCH__THROTTLE__MAX_REQUESTS_PER_DAY
+[throttle]
+# max_requests_per_second = 20
+max_requests_per_minute = 60
+# max_requests_per_hour = 1000
+max_requests_per_day = 10000
+
+# Configuration for a queue-based architecture for Reacher. This feature is
+# currently in **beta**. The queue-based architecture allows Reacher to scale
+# horizontally by running multiple workers that consume emails from a RabbitMQ
+# queue.
+#
+# To enable the queue-based architecture, set the "enable" field to "true" and
+# configure the RabbitMQ connection below. The "concurrency" field specifies
+# the number of concurrent emails to verify for this worker.
+#
+# For more information, see the documentation at:
+# https://docs.reacher.email/self-hosting/scaling-for-production
[worker]
# Enable the worker to consume emails from the RabbitMQ queues. If set, the
# RabbitMQ configuration below must be set as well.
#
# Env variable: RCH__WORKER__ENABLE
-enable = true
+enable = false
# RabbitMQ configuration.
[worker.rabbitmq]
@@ -105,25 +140,6 @@ url = "amqp://guest:guest@localhost:5672"
# Env variable: RCH__WORKER__RABBITMQ__CONCURRENCY
concurrency = 5
-# Throttle the maximum number of requests per second, per minute, per hour, and
-# per day for this worker.
-# All fields are optional; comment them out to disable the limit.
-#
-# Important: these throttle configurations only apply to /v1/* endpoints, and
-# not to the previous /v0/check_email endpoint. The latter endpoint always
-# executes the verification immediately, regardless of the throttle settings.
-#
-# Env variables:
-# - RCH__WORKER__THROTTLE__MAX_REQUESTS_PER_SECOND
-# - RCH__WORKER__THROTTLE__MAX_REQUESTS_PER_MINUTE
-# - RCH__WORKER__THROTTLE__MAX_REQUESTS_PER_HOUR
-# - RCH__WORKER__THROTTLE__MAX_REQUESTS_PER_DAY
-[worker.throttle]
-# max_requests_per_second = 20
-# max_requests_per_minute = 100
-# max_requests_per_hour = 1000
-# max_requests_per_day = 20000
-
# Below are the configurations for the storage of the email verification
# results. We currently support the following storage backends:
# - Postgres
diff --git a/backend/openapi.json b/backend/openapi.json
index a37655f5b..9bd647ab3 100644
--- a/backend/openapi.json
+++ b/backend/openapi.json
@@ -702,7 +702,7 @@
"duration": {
"$ref": "#/components/schemas/Duration"
},
- "server_name": {
+ "backend_name": {
"type": "string",
"x-stoplight": {
"id": "2jrbdecvqh4t5"
@@ -717,7 +717,7 @@
"start_time",
"end_time",
"duration",
- "server_name",
+ "backend_name",
"smtp"
]
},
diff --git a/backend/src/config.rs b/backend/src/config.rs
index 65abb133d..a9455d738 100644
--- a/backend/src/config.rs
+++ b/backend/src/config.rs
@@ -15,6 +15,7 @@
// along with this program. If not, see .
use crate::storage::{postgres::PostgresStorage, StorageAdapter};
+use crate::throttle::ThrottleManager;
use crate::worker::do_work::TaskWebhook;
use crate::worker::setup_rabbit_mq;
use anyhow::{bail, Context};
@@ -29,7 +30,7 @@ use sqlx::PgPool;
use std::sync::Arc;
use tracing::warn;
-#[derive(Debug, Default, Serialize, Deserialize)]
+#[derive(Debug, Serialize, Deserialize)]
pub struct BackendConfig {
/// Name of the backend.
pub backend_name: String,
@@ -65,36 +66,61 @@ pub struct BackendConfig {
/// Whether to enable the Commercial License Trial. Setting this to true
pub commercial_license_trial: Option,
+ /// Throttle configuration for all requests
+ pub throttle: ThrottleConfig,
+
// Internal fields, not part of the configuration.
#[serde(skip)]
channel: Option>,
#[serde(skip)]
storage_adapter: Arc,
+
+ #[serde(skip)]
+ throttle_manager: Arc,
}
impl BackendConfig {
+ /// Create an empty BackendConfig. This is useful for testing purposes.
+ pub fn empty() -> Self {
+ Self {
+ backend_name: "".to_string(),
+ from_email: "".to_string(),
+ hello_name: "".to_string(),
+ webdriver_addr: "".to_string(),
+ proxy: None,
+ verif_method: VerifMethodConfig::default(),
+ http_host: "127.0.0.1".to_string(),
+ http_port: 8080,
+ header_secret: None,
+ smtp_timeout: None,
+ sentry_dsn: None,
+ worker: WorkerConfig::default(),
+ storage: Some(StorageConfig::Noop),
+ commercial_license_trial: None,
+ throttle: ThrottleConfig::new_without_throttle(),
+ channel: None,
+ storage_adapter: Arc::new(StorageAdapter::Noop),
+ throttle_manager: Arc::new(
+ ThrottleManager::new(ThrottleConfig::new_without_throttle()),
+ ),
+ }
+ }
+
/// Get the worker configuration.
///
/// # Panics
///
/// Panics if the worker configuration is missing.
pub fn must_worker_config(&self) -> Result {
- match (
- self.worker.enable,
- &self.worker.throttle,
- &self.worker.rabbitmq,
- &self.channel,
- ) {
- (true, Some(throttle), Some(rabbitmq), Some(channel)) => Ok(MustWorkerConfig {
+ match (self.worker.enable, &self.worker.rabbitmq, &self.channel) {
+ (true, Some(rabbitmq), Some(channel)) => Ok(MustWorkerConfig {
channel: channel.clone(),
- throttle: throttle.clone(),
rabbitmq: rabbitmq.clone(),
-
webhook: self.worker.webhook.clone(),
}),
- (true, _, _, _) => bail!("Worker configuration is missing"),
+ (true, _, _) => bail!("Worker configuration is missing"),
_ => bail!("Calling must_worker_config on a non-worker backend"),
}
}
@@ -126,6 +152,9 @@ impl BackendConfig {
};
self.channel = channel;
+ // Initialize throttle manager
+ self.throttle_manager = Arc::new(ThrottleManager::new(self.throttle.clone()));
+
Ok(())
}
@@ -142,6 +171,10 @@ impl BackendConfig {
StorageAdapter::Noop => None,
}
}
+
+ pub fn get_throttle_manager(&self) -> Arc {
+ self.throttle_manager.clone()
+ }
}
#[derive(Debug, Default, Deserialize, Clone, Serialize)]
@@ -159,9 +192,6 @@ pub struct VerifMethodConfig {
#[derive(Debug, Default, Deserialize, Clone, Serialize)]
pub struct WorkerConfig {
pub enable: bool,
-
- /// Throttle configuration for the worker.
- pub throttle: Option,
pub rabbitmq: Option,
/// Optional webhook configuration to send email verification results.
pub webhook: Option,
@@ -172,8 +202,6 @@ pub struct WorkerConfig {
#[derive(Debug, Clone)]
pub struct MustWorkerConfig {
pub channel: Arc,
-
- pub throttle: ThrottleConfig,
pub rabbitmq: RabbitMQConfig,
pub webhook: Option,
}
@@ -185,7 +213,7 @@ pub struct RabbitMQConfig {
pub concurrency: u16,
}
-#[derive(Debug, Deserialize, Clone, Serialize)]
+#[derive(Debug, Default, Deserialize, Clone, Serialize)]
pub struct ThrottleConfig {
pub max_requests_per_second: Option,
pub max_requests_per_minute: Option,
@@ -236,8 +264,13 @@ pub async fn load_config() -> Result {
let cfg = cfg.build()?.try_deserialize::()?;
- if !cfg.worker.enable && (cfg.worker.rabbitmq.is_some() || cfg.worker.throttle.is_some()) {
- warn!(target: LOG_TARGET, "worker.enable is set to false, ignoring throttling and concurrency settings.")
+ if cfg.worker.enable {
+ warn!(target: LOG_TARGET, "The worker feature is currently in beta. Please send any feedback to amaury@reacher.email.");
+
+ match &cfg.storage {
+ Some(StorageConfig::Postgres(_)) => {}
+ _ => bail!("When worker mode is enabled, a Postgres database must be configured."),
+ }
}
Ok(cfg)
diff --git a/backend/src/http/error.rs b/backend/src/http/error.rs
index 108e39268..48db2cd0e 100644
--- a/backend/src/http/error.rs
+++ b/backend/src/http/error.rs
@@ -28,7 +28,7 @@ pub trait DisplayDebug: fmt::Display + Debug + Sync + Send {}
impl DisplayDebug for T {}
/// Struct describing an error response.
-#[derive(Debug)]
+#[derive(Debug, thiserror::Error)]
pub struct ReacherResponseError {
pub code: StatusCode,
pub error: Box,
@@ -121,7 +121,14 @@ impl From for ReacherResponseError {
impl From for ReacherResponseError {
fn from(e: reqwest::Error) -> Self {
- ReacherResponseError::new(StatusCode::INTERNAL_SERVER_ERROR, e)
+ ReacherResponseError::new(
+ e.status()
+ .map(|s| s.as_u16())
+ .map(StatusCode::from_u16)
+ .and_then(Result::ok)
+ .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR),
+ e,
+ )
}
}
diff --git a/backend/src/http/mod.rs b/backend/src/http/mod.rs
index be5fe373b..4aca92fa2 100644
--- a/backend/src/http/mod.rs
+++ b/backend/src/http/mod.rs
@@ -23,14 +23,12 @@ use crate::config::BackendConfig;
use check_if_email_exists::LOG_TARGET;
use error::handle_rejection;
pub use error::ReacherResponseError;
-use sqlx::PgPool;
use sqlxmq::JobRunnerHandle;
use std::env;
use std::net::IpAddr;
use std::sync::Arc;
use tracing::info;
pub use v0::check_email::post::CheckEmailRequest;
-use warp::http::StatusCode;
use warp::Filter;
pub fn create_routes(
@@ -101,24 +99,6 @@ pub async fn run_warp_server(
Ok(runner)
}
-/// Warp filter to add the database pool to the handler. If the pool is not
-/// configured, it will return an error.
-pub fn with_db(
- pg_pool: Option,
-) -> impl Filter + Clone {
- warp::any().and_then(move || {
- let pool = pg_pool.clone();
- async move {
- pool.ok_or_else(|| {
- warp::reject::custom(ReacherResponseError::new(
- StatusCode::SERVICE_UNAVAILABLE,
- "Please configure a Postgres database on Reacher before calling this endpoint",
- ))
- })
- }
- })
-}
-
/// The header which holds the Reacher backend secret.
pub const REACHER_SECRET_HEADER: &str = "x-reacher-secret";
diff --git a/backend/src/http/v1/bulk/get_progress.rs b/backend/src/http/v1/bulk/get_progress.rs
index a77ff973a..f7511703c 100644
--- a/backend/src/http/v1/bulk/get_progress.rs
+++ b/backend/src/http/v1/bulk/get_progress.rs
@@ -25,8 +25,9 @@ use sqlx::PgPool;
use warp::http::StatusCode;
use warp::Filter;
+use super::with_worker_db;
use crate::config::BackendConfig;
-use crate::http::{with_db, ReacherResponseError};
+use crate::http::ReacherResponseError;
/// NOTE: Type conversions from postgres to rust types
/// are according to the table given by
@@ -149,7 +150,7 @@ pub fn v1_get_bulk_job_progress(
) -> impl Filter + Clone {
warp::path!("v1" / "bulk" / i32)
.and(warp::get())
- .and(with_db(config.get_pg_pool()))
+ .and(with_worker_db(config))
.and_then(http_handler)
// View access logs by setting `RUST_LOG=reacher`.
.with(warp::log(LOG_TARGET))
diff --git a/backend/src/http/v1/bulk/get_results/mod.rs b/backend/src/http/v1/bulk/get_results/mod.rs
index ad11eb401..54b5049ee 100644
--- a/backend/src/http/v1/bulk/get_results/mod.rs
+++ b/backend/src/http/v1/bulk/get_results/mod.rs
@@ -25,8 +25,9 @@ use std::{convert::TryInto, sync::Arc};
use warp::http::StatusCode;
use warp::Filter;
+use super::with_worker_db;
use crate::config::BackendConfig;
-use crate::http::{with_db, ReacherResponseError};
+use crate::http::ReacherResponseError;
use csv_helper::{CsvResponse, CsvWrapper};
mod csv_helper;
@@ -180,7 +181,7 @@ pub fn v1_get_bulk_job_results(
) -> impl Filter + Clone {
warp::path!("v1" / "bulk" / i32 / "results")
.and(warp::get())
- .and(with_db(config.get_pg_pool()))
+ .and(with_worker_db(config))
.and(warp::query::())
.and_then(http_handler)
// View access logs by setting `RUST_LOG=reacher_backend`.
diff --git a/backend/src/http/v1/bulk/mod.rs b/backend/src/http/v1/bulk/mod.rs
index ad83e7be1..93dd9f1b6 100644
--- a/backend/src/http/v1/bulk/mod.rs
+++ b/backend/src/http/v1/bulk/mod.rs
@@ -14,6 +14,39 @@
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see .
+use crate::config::BackendConfig;
+use crate::http::ReacherResponseError;
+use sqlx::PgPool;
+use std::sync::Arc;
+use warp::http::StatusCode;
+use warp::Filter;
+
pub mod get_progress;
pub mod get_results;
pub mod post;
+
+/// Warp filter to add the database pool to the handler. This function should
+/// only be used for /v1/bulk endpoints, which are only enabled when worker mode
+/// is enabled.
+pub fn with_worker_db(
+ config: Arc,
+) -> impl Filter + Clone {
+ warp::any().and_then(move || {
+ let config = Arc::clone(&config);
+ let pool = config.get_pg_pool();
+ async move {
+ if !config.worker.enable {
+ return Err(warp::reject::custom(ReacherResponseError::new(
+ StatusCode::SERVICE_UNAVAILABLE,
+ "Please enable worker mode on Reacher before calling this endpoint",
+ )));
+ }
+ pool.ok_or_else(|| {
+ warp::reject::custom(ReacherResponseError::new(
+ StatusCode::SERVICE_UNAVAILABLE,
+ "Please configure a Postgres database on Reacher before calling this endpoint",
+ ))
+ })
+ }
+ })
+}
diff --git a/backend/src/http/v1/bulk/post.rs b/backend/src/http/v1/bulk/post.rs
index abe05e4a0..577ab78e7 100644
--- a/backend/src/http/v1/bulk/post.rs
+++ b/backend/src/http/v1/bulk/post.rs
@@ -29,10 +29,10 @@ use tracing::{debug, info};
use warp::http::StatusCode;
use warp::Filter;
+use super::with_worker_db;
use crate::config::BackendConfig;
use crate::http::check_header;
use crate::http::v0::check_email::post::with_config;
-use crate::http::with_db;
use crate::http::CheckEmailRequest;
use crate::http::ReacherResponseError;
use crate::worker::consume::CHECK_EMAIL_QUEUE;
@@ -154,7 +154,7 @@ pub fn v1_create_bulk_job(
.and(warp::post())
.and(check_header(Arc::clone(&config)))
.and(with_config(Arc::clone(&config)))
- .and(with_db(config.get_pg_pool()))
+ .and(with_worker_db(config))
// When accepting a body, we want a JSON body (and to reject huge
// payloads)...
// TODO: Configure max size limit for a bulk job
diff --git a/backend/src/http/v1/check_email/post.rs b/backend/src/http/v1/check_email/post.rs
index 16e2a595d..602efc8ae 100644
--- a/backend/src/http/v1/check_email/post.rs
+++ b/backend/src/http/v1/check_email/post.rs
@@ -24,6 +24,7 @@ use lapin::options::{
use lapin::types::FieldTable;
use lapin::BasicProperties;
use std::sync::Arc;
+use tracing::info;
use warp::http::StatusCode;
use warp::{http, Filter};
@@ -36,55 +37,49 @@ use crate::worker::consume::MAX_QUEUE_PRIORITY;
use crate::worker::do_work::{CheckEmailJobId, CheckEmailTask};
use crate::worker::single_shot::SingleShotReply;
-/// The main endpoint handler that implements the logic of this route.
-async fn http_handler(
+async fn handle_without_worker(
config: Arc,
- body: CheckEmailRequest,
-) -> Result {
- // The to_email field must be present
- if body.to_email.is_empty() {
- return Err(ReacherResponseError::new(
- http::StatusCode::BAD_REQUEST,
- "to_email field is required.",
+ body: &CheckEmailRequest,
+ throttle_manager: &crate::throttle::ThrottleManager,
+) -> Result, warp::Rejection> {
+ info!(target: LOG_TARGET, email=body.to_email, "Starting verification");
+ let input = body.to_check_email_input(Arc::clone(&config));
+ let result = check_email(&input).await;
+ let result_ok = Ok(result);
+
+ // Increment counters after successful verification
+ throttle_manager.increment_counters().await;
+
+ // Store the result regardless of how we got it
+ let storage = Arc::clone(&config).get_storage_adapter();
+ storage
+ .store(
+ &CheckEmailTask {
+ input: body.to_check_email_input(Arc::clone(&config)),
+ job_id: CheckEmailJobId::SingleShot,
+ webhook: None,
+ },
+ &result_ok,
+ storage.get_extra(),
)
- .into());
- }
+ .map_err(ReacherResponseError::from)
+ .await?;
- // If worker mode is disabled, we do a direct check, and skip rabbitmq.
- if !config.worker.enable {
- let input = body.to_check_email_input(Arc::clone(&config));
- let result = check_email(&input).await;
- let value = Ok(result);
-
- // Also store the result "manually", since we don't have a worker.
- let storage = config.get_storage_adapter();
- storage
- .store(
- &CheckEmailTask {
- input: input.clone(),
- job_id: CheckEmailJobId::SingleShot,
- webhook: None,
- },
- &value,
- storage.get_extra(),
- )
- .map_err(ReacherResponseError::from)
- .await?;
-
- // If we're in the Commercial License Trial, we also store the
- // result by sending it to back to Reacher.
- send_to_reacher(Arc::clone(&config), &input.to_email, &value)
- .await
- .map_err(ReacherResponseError::from)?;
-
- let result_bz = serde_json::to_vec(&value).map_err(ReacherResponseError::from)?;
-
- return Ok(warp::reply::with_header(
- result_bz,
- "Content-Type",
- "application/json",
- ));
- }
+ // If we're in the Commercial License Trial, we also store the
+ // result by sending it to back to Reacher.
+ send_to_reacher(Arc::clone(&config), &body.to_email, &result_ok)
+ .await
+ .map_err(ReacherResponseError::from)?;
+
+ let result = result_ok.unwrap();
+ info!(target: LOG_TARGET, email=body.to_email, is_reachable=?result.is_reachable, "Done verification");
+ Ok(serde_json::to_vec(&result).map_err(ReacherResponseError::from)?)
+}
+
+async fn handle_with_worker(
+ config: Arc,
+ body: &CheckEmailRequest,
+) -> Result, warp::Rejection> {
let channel = config
.must_worker_config()
.map_err(ReacherResponseError::from)?
@@ -116,7 +111,7 @@ async fn http_handler(
publish_task(
channel.clone(),
CheckEmailTask {
- input: body.to_check_email_input(config),
+ input: body.to_check_email_input(config.clone()),
job_id: CheckEmailJobId::SingleShot,
webhook: None,
},
@@ -156,11 +151,7 @@ async fn http_handler(
match single_shot_response {
SingleShotReply::Ok(body) => {
- return Ok(warp::reply::with_header(
- body,
- "Content-Type",
- "application/json",
- ));
+ return Ok(body);
}
SingleShotReply::Err((e, code)) => {
let status_code =
@@ -189,6 +180,46 @@ async fn http_handler(
.into())
}
+/// The main endpoint handler that implements the logic of this route.
+async fn http_handler(
+ config: Arc,
+ body: CheckEmailRequest,
+) -> Result {
+ // The to_email field must be present
+ if body.to_email.is_empty() {
+ return Err(ReacherResponseError::new(
+ http::StatusCode::BAD_REQUEST,
+ "to_email field is required.",
+ )
+ .into());
+ }
+
+ // Check throttle regardless of worker mode
+ let throttle_manager = config.get_throttle_manager();
+ if let Some(throttle_result) = throttle_manager.check_throttle().await {
+ return Err(ReacherResponseError::new(
+ http::StatusCode::TOO_MANY_REQUESTS,
+ format!(
+ "Rate limit {} exceeded, please wait {:?}",
+ throttle_result.limit_type, throttle_result.delay
+ ),
+ )
+ .into());
+ }
+
+ let result_bz = if !config.worker.enable {
+ handle_without_worker(Arc::clone(&config), &body, &throttle_manager).await?
+ } else {
+ handle_with_worker(Arc::clone(&config), &body).await?
+ };
+
+ Ok(warp::reply::with_header(
+ result_bz,
+ "Content-Type",
+ "application/json",
+ ))
+}
+
/// Create the `POST /v1/check_email` endpoint.
pub fn v1_check_email(
config: Arc,
diff --git a/backend/src/lib.rs b/backend/src/lib.rs
index 71c140ccf..a44ccc9d8 100644
--- a/backend/src/lib.rs
+++ b/backend/src/lib.rs
@@ -17,6 +17,7 @@
pub mod config;
pub mod http;
mod storage;
+pub mod throttle;
pub mod worker;
const CARGO_PKG_VERSION: &str = env!("CARGO_PKG_VERSION");
diff --git a/backend/src/storage/commercial_license_trial.rs b/backend/src/storage/commercial_license_trial.rs
index 4c84f5e65..551c130ee 100644
--- a/backend/src/storage/commercial_license_trial.rs
+++ b/backend/src/storage/commercial_license_trial.rs
@@ -15,10 +15,12 @@
// along with this program. If not, see .
use crate::config::{BackendConfig, CommercialLicenseTrialConfig};
+use crate::http::ReacherResponseError;
use crate::worker::do_work::TaskError;
use check_if_email_exists::{CheckEmailOutput, LOG_TARGET};
use std::sync::Arc;
use tracing::debug;
+use warp::http::StatusCode;
/// If we're in the Commercial License Trial, we also store the
/// result by sending it to back to Reacher.
@@ -26,7 +28,7 @@ pub async fn send_to_reacher(
config: Arc,
email: &str,
worker_output: &Result,
-) -> Result<(), reqwest::Error> {
+) -> Result<(), ReacherResponseError> {
if let Some(CommercialLicenseTrialConfig { api_token, url }) = &config.commercial_license_trial
{
let res = reqwest::Client::new()
@@ -35,6 +37,19 @@ pub async fn send_to_reacher(
.json(worker_output)
.send()
.await?;
+
+ // Error if not 2xx status code
+ if !res.status().is_success() {
+ let status = StatusCode::from_u16(res.status().as_u16())?;
+ let body: serde_json::Value = res.json().await?;
+
+ // Extract error message from the "error" field, if it exists, or
+ // else just return the whole body.
+ let error_body = body.get("error").unwrap_or(&body).to_owned();
+
+ return Err(ReacherResponseError::new(status, error_body));
+ }
+
let res = res.text().await?;
debug!(target: LOG_TARGET, email=email, res=res, "Sent result to Reacher Commercial License Trial");
}
diff --git a/backend/src/throttle.rs b/backend/src/throttle.rs
new file mode 100644
index 000000000..90c6126ea
--- /dev/null
+++ b/backend/src/throttle.rs
@@ -0,0 +1,239 @@
+// Reacher - Email Verification
+// Copyright (C) 2018-2023 Reacher
+
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published
+// by the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+use crate::config::ThrottleConfig;
+use std::fmt;
+use std::sync::Arc;
+use std::time::{Duration, Instant};
+use tokio::sync::Mutex;
+
+/// Represents the type of throttle limit that was hit.
+/// - `PerSecond`: The per-second request limit was exceeded
+/// - `PerMinute`: The per-minute request limit was exceeded
+/// - `PerHour`: The per-hour request limit was exceeded
+/// - `PerDay`: The per-day request limit was exceeded
+#[derive(Debug, Clone, PartialEq)]
+pub enum ThrottleLimit {
+ PerSecond,
+ PerMinute,
+ PerHour,
+ PerDay,
+}
+
+impl fmt::Display for ThrottleLimit {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ match self {
+ Self::PerSecond => write!(f, "per second"),
+ Self::PerMinute => write!(f, "per minute"),
+ Self::PerHour => write!(f, "per hour"),
+ Self::PerDay => write!(f, "per day"),
+ }
+ }
+}
+
+/// Represents the result of a throttle check.
+/// - `delay`: How long to wait before making the next request
+/// - `limit_type`: Which rate limit was exceeded (second/minute/hour/day)
+#[derive(Debug, Clone, PartialEq)]
+pub struct ThrottleResult {
+ pub delay: Duration,
+ pub limit_type: ThrottleLimit,
+}
+
+#[derive(Debug, Clone)]
+struct Throttle {
+ requests_per_second: u32,
+ requests_per_minute: u32,
+ requests_per_hour: u32,
+ requests_per_day: u32,
+ last_reset_second: Instant,
+ last_reset_minute: Instant,
+ last_reset_hour: Instant,
+ last_reset_day: Instant,
+}
+
+impl Default for Throttle {
+ fn default() -> Self {
+ let now = Instant::now();
+ Self {
+ requests_per_second: 0,
+ requests_per_minute: 0,
+ requests_per_hour: 0,
+ requests_per_day: 0,
+ last_reset_second: now,
+ last_reset_minute: now,
+ last_reset_hour: now,
+ last_reset_day: now,
+ }
+ }
+}
+
+impl Throttle {
+ fn new() -> Self {
+ let now = Instant::now();
+ Throttle {
+ requests_per_second: 0,
+ requests_per_minute: 0,
+ requests_per_hour: 0,
+ requests_per_day: 0,
+ last_reset_second: now,
+ last_reset_minute: now,
+ last_reset_hour: now,
+ last_reset_day: now,
+ }
+ }
+
+ fn reset_if_needed(&mut self) {
+ let now = Instant::now();
+
+ // Reset per-second counter
+ if now.duration_since(self.last_reset_second) >= Duration::from_secs(1) {
+ self.requests_per_second = 0;
+ self.last_reset_second = now;
+ }
+
+ // Reset per-minute counter
+ if now.duration_since(self.last_reset_minute) >= Duration::from_secs(60) {
+ self.requests_per_minute = 0;
+ self.last_reset_minute = now;
+ }
+
+ // Reset per-hour counter
+ if now.duration_since(self.last_reset_hour) >= Duration::from_secs(3600) {
+ self.requests_per_hour = 0;
+ self.last_reset_hour = now;
+ }
+
+ // Reset per-day counter
+ if now.duration_since(self.last_reset_day) >= Duration::from_secs(86400) {
+ self.requests_per_day = 0;
+ self.last_reset_day = now;
+ }
+ }
+
+ fn increment_counters(&mut self) {
+ self.requests_per_second += 1;
+ self.requests_per_minute += 1;
+ self.requests_per_hour += 1;
+ self.requests_per_day += 1;
+ }
+
+ fn should_throttle(&self, config: &ThrottleConfig) -> Option {
+ let now = Instant::now();
+
+ if let Some(max_per_second) = config.max_requests_per_second {
+ if self.requests_per_second >= max_per_second {
+ return Some(ThrottleResult {
+ delay: Duration::from_secs(1) - now.duration_since(self.last_reset_second),
+ limit_type: ThrottleLimit::PerSecond,
+ });
+ }
+ }
+
+ if let Some(max_per_minute) = config.max_requests_per_minute {
+ if self.requests_per_minute >= max_per_minute {
+ return Some(ThrottleResult {
+ delay: Duration::from_secs(60) - now.duration_since(self.last_reset_minute),
+ limit_type: ThrottleLimit::PerMinute,
+ });
+ }
+ }
+
+ if let Some(max_per_hour) = config.max_requests_per_hour {
+ if self.requests_per_hour >= max_per_hour {
+ return Some(ThrottleResult {
+ delay: Duration::from_secs(3600) - now.duration_since(self.last_reset_hour),
+ limit_type: ThrottleLimit::PerHour,
+ });
+ }
+ }
+
+ if let Some(max_per_day) = config.max_requests_per_day {
+ if self.requests_per_day >= max_per_day {
+ return Some(ThrottleResult {
+ delay: Duration::from_secs(86400) - now.duration_since(self.last_reset_day),
+ limit_type: ThrottleLimit::PerDay,
+ });
+ }
+ }
+
+ None
+ }
+}
+
+#[derive(Debug, Default)]
+pub struct ThrottleManager {
+ inner: Arc>,
+ config: ThrottleConfig,
+}
+
+impl ThrottleManager {
+ pub fn new(config: ThrottleConfig) -> Self {
+ Self {
+ inner: Arc::new(Mutex::new(Throttle::new())),
+ config,
+ }
+ }
+
+ pub async fn check_throttle(&self) -> Option {
+ let mut throttle = self.inner.lock().await;
+ throttle.reset_if_needed();
+ throttle.should_throttle(&self.config)
+ }
+
+ pub async fn increment_counters(&self) {
+ self.inner.lock().await.increment_counters();
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use tokio::time::{sleep, Duration};
+
+ #[tokio::test]
+ async fn test_throttle_limits() {
+ // Create config with low limits for testing
+ let config = ThrottleConfig {
+ max_requests_per_second: Some(2),
+ max_requests_per_minute: Some(5),
+ max_requests_per_hour: Some(10),
+ max_requests_per_day: Some(20),
+ };
+
+ let manager = ThrottleManager::new(config);
+
+ // Should allow initial requests
+ assert_eq!(manager.check_throttle().await, None);
+ manager.increment_counters().await;
+ assert_eq!(manager.check_throttle().await, None);
+ manager.increment_counters().await;
+
+ // Should throttle after hitting per-second limit
+ let throttle_result = manager.check_throttle().await;
+ assert!(throttle_result.is_some());
+ assert_eq!(
+ throttle_result.unwrap().limit_type,
+ ThrottleLimit::PerSecond
+ );
+
+ // Wait 1 second for per-second counter to reset
+ sleep(Duration::from_secs(1)).await;
+
+ // Should allow more requests
+ assert_eq!(manager.check_throttle().await, None);
+ }
+}
diff --git a/backend/src/worker/consume.rs b/backend/src/worker/consume.rs
index 9104f33c3..b4ea712b2 100644
--- a/backend/src/worker/consume.rs
+++ b/backend/src/worker/consume.rs
@@ -16,7 +16,7 @@
use super::do_work::{do_check_email_work, CheckEmailTask, TaskError};
use super::single_shot::send_single_shot_reply;
-use crate::config::{BackendConfig, RabbitMQConfig, ThrottleConfig};
+use crate::config::{BackendConfig, RabbitMQConfig};
use crate::worker::do_work::CheckEmailJobId;
use anyhow::Context;
use check_if_email_exists::LOG_TARGET;
@@ -24,9 +24,7 @@ use futures::stream::StreamExt;
use lapin::{options::*, types::FieldTable, Channel, Connection, ConnectionProperties};
use sentry_anyhow::capture_anyhow;
use std::sync::Arc;
-use std::time::{Duration, Instant};
-use tokio::sync::Mutex;
-use tracing::{debug, error, info};
+use tracing::{debug, error, info, trace};
/// Our RabbitMQ only has one queue: "check_email".
pub const CHECK_EMAIL_QUEUE: &str = "check_email";
@@ -96,12 +94,9 @@ async fn consume_check_email(config: Arc) -> Result<(), anyhow::E
let config_clone = Arc::clone(&config);
let worker_config = config_clone.must_worker_config()?;
let channel = worker_config.channel;
-
- let throttle = Arc::new(Mutex::new(Throttle::new()));
+ let throttle = config.get_throttle_manager();
tokio::spawn(async move {
- let worker_config = config_clone.must_worker_config()?;
-
let mut consumer = channel
.basic_consume(
CHECK_EMAIL_QUEUE,
@@ -117,16 +112,11 @@ async fn consume_check_email(config: Arc) -> Result<(), anyhow::E
let payload = serde_json::from_slice::(&delivery.data)?;
debug!(target: LOG_TARGET, email=?payload.input.to_email, "Consuming message");
- // Reset throttle counters if needed
- throttle.lock().await.reset_if_needed();
-
// Check if we should throttle before fetching the next message
- if let Some(wait_duration) = throttle
- .lock()
- .await
- .should_throttle(&worker_config.throttle)
- {
- info!(target: LOG_TARGET, wait=?wait_duration, email=?payload.input.to_email, "Too many requests, throttling");
+ if let Some(throttle_result) = throttle.check_throttle().await {
+ // This line below will log every time the worker fetches from
+ // RabbitMQ. It's noisy
+ trace!(target: LOG_TARGET, wait=?throttle_result.delay, email=?payload.input.to_email, "Too many requests {}, throttling", throttle_result.limit_type);
// For single-shot tasks, we return an error early, so that the user knows they need to retry.
match payload.job_id {
@@ -139,7 +129,7 @@ async fn consume_check_email(config: Arc) -> Result<(), anyhow::E
send_single_shot_reply(
Arc::clone(&channel),
&delivery,
- &Err(TaskError::Throttle(wait_duration)),
+ &Err(TaskError::Throttle(throttle_result)),
)
.await?;
}
@@ -170,7 +160,7 @@ async fn consume_check_email(config: Arc) -> Result<(), anyhow::E
});
// Increment throttle counters once we spawn the task
- throttle.lock().await.increment_counters();
+ throttle.increment_counters().await;
}
Ok::<(), anyhow::Error>(())
@@ -178,96 +168,3 @@ async fn consume_check_email(config: Arc) -> Result<(), anyhow::E
Ok(())
}
-
-#[derive(Clone)]
-struct Throttle {
- requests_per_second: u32,
- requests_per_minute: u32,
- requests_per_hour: u32,
- requests_per_day: u32,
- last_reset_second: Instant,
- last_reset_minute: Instant,
- last_reset_hour: Instant,
- last_reset_day: Instant,
-}
-
-impl Throttle {
- fn new() -> Self {
- let now = Instant::now();
- Throttle {
- requests_per_second: 0,
- requests_per_minute: 0,
- requests_per_hour: 0,
- requests_per_day: 0,
- last_reset_second: now,
- last_reset_minute: now,
- last_reset_hour: now,
- last_reset_day: now,
- }
- }
-
- fn reset_if_needed(&mut self) {
- let now = Instant::now();
-
- // Reset per-second counter
- if now.duration_since(self.last_reset_second) >= Duration::from_secs(1) {
- self.requests_per_second = 0;
- self.last_reset_second = now;
- }
-
- // Reset per-minute counter
- if now.duration_since(self.last_reset_minute) >= Duration::from_secs(60) {
- self.requests_per_minute = 0;
- self.last_reset_minute = now;
- }
-
- // Reset per-hour counter
- if now.duration_since(self.last_reset_hour) >= Duration::from_secs(3600) {
- self.requests_per_hour = 0;
- self.last_reset_hour = now;
- }
-
- // Reset per-day counter
- if now.duration_since(self.last_reset_day) >= Duration::from_secs(86400) {
- self.requests_per_day = 0;
- self.last_reset_day = now;
- }
- }
-
- fn increment_counters(&mut self) {
- self.requests_per_second += 1;
- self.requests_per_minute += 1;
- self.requests_per_hour += 1;
- self.requests_per_day += 1;
- }
-
- fn should_throttle(&self, config: &ThrottleConfig) -> Option {
- let now = Instant::now();
-
- if let Some(max_per_second) = config.max_requests_per_second {
- if self.requests_per_second >= max_per_second {
- return Some(Duration::from_secs(1) - now.duration_since(self.last_reset_second));
- }
- }
-
- if let Some(max_per_minute) = config.max_requests_per_minute {
- if self.requests_per_minute >= max_per_minute {
- return Some(Duration::from_secs(60) - now.duration_since(self.last_reset_minute));
- }
- }
-
- if let Some(max_per_hour) = config.max_requests_per_hour {
- if self.requests_per_hour >= max_per_hour {
- return Some(Duration::from_secs(3600) - now.duration_since(self.last_reset_hour));
- }
- }
-
- if let Some(max_per_day) = config.max_requests_per_day {
- if self.requests_per_day >= max_per_day {
- return Some(Duration::from_secs(86400) - now.duration_since(self.last_reset_day));
- }
- }
-
- None
- }
-}
diff --git a/backend/src/worker/do_work.rs b/backend/src/worker/do_work.rs
index 214ca110a..bfe2c928c 100644
--- a/backend/src/worker/do_work.rs
+++ b/backend/src/worker/do_work.rs
@@ -16,11 +16,11 @@
use crate::config::BackendConfig;
use crate::storage::commercial_license_trial::send_to_reacher;
+use crate::throttle::ThrottleResult;
use crate::worker::single_shot::send_single_shot_reply;
use check_if_email_exists::{
check_email, CheckEmailInput, CheckEmailOutput, Reachable, LOG_TARGET,
};
-use core::time;
use lapin::message::Delivery;
use lapin::{options::*, Channel};
use serde::{Deserialize, Serialize};
@@ -53,7 +53,7 @@ pub enum TaskError {
/// verification, as for bulk verification tasks the task will simply stay
/// in the queue until one worker is ready to process it.
#[error("Worker at full capacity, wait {0:?}")]
- Throttle(time::Duration),
+ Throttle(ThrottleResult),
#[error("Lapin error: {0}")]
Lapin(lapin::Error),
#[error("Reqwest error during webhook: {0}")]
diff --git a/backend/src/worker/single_shot.rs b/backend/src/worker/single_shot.rs
index 1b10c5cf2..a5804f0e5 100644
--- a/backend/src/worker/single_shot.rs
+++ b/backend/src/worker/single_shot.rs
@@ -47,7 +47,7 @@ impl TryFrom<&Result> for SingleShotReply {
match result {
Ok(output) => Ok(Self::Ok(serde_json::to_vec(output)?)),
Err(TaskError::Throttle(e)) => Ok(Self::Err((
- TaskError::Throttle(*e).to_string(),
+ TaskError::Throttle(e.clone()).to_string(),
StatusCode::TOO_MANY_REQUESTS.as_u16(),
))),
Err(e) => Ok(Self::Err((
diff --git a/backend/tests/check_email.rs b/backend/tests/check_email.rs
index 326133e89..93c62f42e 100644
--- a/backend/tests/check_email.rs
+++ b/backend/tests/check_email.rs
@@ -14,7 +14,7 @@
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see .
-#[cfg(not(feature = "worker"))]
+#[cfg(test)]
mod tests {
use std::sync::Arc;
@@ -27,7 +27,7 @@ mod tests {
const FOO_BAR_BAZ_RESPONSE: &str = r#"{"input":"foo@bar.baz","is_reachable":"invalid","misc":{"is_disposable":false,"is_role_account":false,"gravatar_url":null,"haveibeenpwned":null},"mx":{"accepts_mail":false,"records":[]},"smtp":{"can_connect_smtp":false,"has_full_inbox":false,"is_catch_all":false,"is_deliverable":false,"is_disabled":false},"syntax":{"address":"foo@bar.baz","domain":"bar.baz","is_valid_syntax":true,"username":"foo","normalized_email":"foo@bar.baz","suggestion":null}"#;
fn create_backend_config(header_secret: &str) -> Arc {
- let mut config = BackendConfig::default();
+ let mut config = BackendConfig::empty();
config.header_secret = Some(header_secret.to_string());
Arc::new(config)
}