diff --git a/src/servers/http/server.rs b/src/servers/http/server.rs index 7c8148f2..a68f7d16 100644 --- a/src/servers/http/server.rs +++ b/src/servers/http/server.rs @@ -12,6 +12,7 @@ use tokio::sync::oneshot::{Receiver, Sender}; use super::v1::routes::router; use crate::bootstrap::jobs::Started; use crate::core::Tracker; +use crate::servers::custom_axum_server::{self, TimeoutAcceptor}; use crate::servers::registar::{ServiceHealthCheckJob, ServiceRegistration, ServiceRegistrationForm}; use crate::servers::signals::{graceful_shutdown, Halted}; @@ -60,13 +61,15 @@ impl Launcher { let running = Box::pin(async { match tls { - Some(tls) => axum_server::from_tcp_rustls(socket, tls) + Some(tls) => custom_axum_server::from_tcp_rustls_with_timeouts(socket, tls) .handle(handle) + .acceptor(TimeoutAcceptor) .serve(app.into_make_service_with_connect_info::()) .await .expect("Axum server crashed."), - None => axum_server::from_tcp(socket) + None => custom_axum_server::from_tcp_with_timeouts(socket) .handle(handle) + .acceptor(TimeoutAcceptor) .serve(app.into_make_service_with_connect_info::()) .await .expect("Axum server crashed."), diff --git a/src/servers/http/v1/routes.rs b/src/servers/http/v1/routes.rs index 05cd3871..c54da51a 100644 --- a/src/servers/http/v1/routes.rs +++ b/src/servers/http/v1/routes.rs @@ -3,12 +3,15 @@ use std::net::SocketAddr; use std::sync::Arc; use std::time::Duration; +use axum::error_handling::HandleErrorLayer; use axum::http::HeaderName; use axum::response::Response; use axum::routing::get; -use axum::Router; +use axum::{BoxError, Router}; use axum_client_ip::SecureClientIpSource; -use hyper::Request; +use hyper::{Request, StatusCode}; +use tower::timeout::TimeoutLayer; +use tower::ServiceBuilder; use tower_http::compression::CompressionLayer; use tower_http::propagate_header::PropagateHeaderLayer; use tower_http::request_id::{MakeRequestUuid, SetRequestIdLayer}; @@ -18,6 +21,8 @@ use tracing::{Level, Span}; use super::handlers::{announce, health_check, scrape}; use crate::core::Tracker; +const TIMEOUT: Duration = Duration::from_secs(5); + /// It adds the routes to the router. /// /// > **NOTICE**: it's added a layer to get the client IP from the connection @@ -69,4 +74,11 @@ pub fn router(tracker: Arc, server_socket_addr: SocketAddr) -> Router { }), ) .layer(SetRequestIdLayer::x_request_id(MakeRequestUuid)) + .layer( + ServiceBuilder::new() + // this middleware goes above `TimeoutLayer` because it will receive + // errors returned by `TimeoutLayer` + .layer(HandleErrorLayer::new(|_: BoxError| async { StatusCode::REQUEST_TIMEOUT })) + .layer(TimeoutLayer::new(TIMEOUT)), + ) }