From 23864cd14cf5d994b376463944b7c22dd3544345 Mon Sep 17 00:00:00 2001 From: Jonas Platte Date: Sun, 1 Dec 2024 12:46:27 +0100 Subject: [PATCH] Make ConnectInfo work with ListenerExt::tap_io --- axum/src/extract/connect_info.rs | 14 ++++++++++-- axum/src/serve.rs | 38 ++++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+), 2 deletions(-) diff --git a/axum/src/extract/connect_info.rs b/axum/src/extract/connect_info.rs index 54a8d77582c..a097fe38bf9 100644 --- a/axum/src/extract/connect_info.rs +++ b/axum/src/extract/connect_info.rs @@ -4,7 +4,7 @@ //! //! [`Router::into_make_service_with_connect_info`]: crate::routing::Router::into_make_service_with_connect_info -use crate::extension::AddExtension; +use crate::{extension::AddExtension, serve}; use super::{Extension, FromRequestParts}; use http::request::Parts; @@ -84,7 +84,6 @@ pub trait Connected: Clone + Send + Sync + 'static { #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] const _: () = { - use crate::serve; use tokio::net::TcpListener; impl Connected> for SocketAddr { @@ -100,6 +99,17 @@ impl Connected for SocketAddr { } } +impl<'a, L, F> Connected>> for L::Addr +where + L: serve::Listener, + L::Addr: Clone + Sync + 'static, + F: FnMut(&mut L::Io) + Send + 'static, +{ + fn connect_info(stream: serve::IncomingStream<'a, serve::TapIo>) -> Self { + stream.remote_addr().clone() + } +} + impl Service for IntoMakeServiceWithConnectInfo where S: Clone, diff --git a/axum/src/serve.rs b/axum/src/serve.rs index aa205c8d915..a1358b15022 100644 --- a/axum/src/serve.rs +++ b/axum/src/serve.rs @@ -434,6 +434,7 @@ mod tests { extract::connect_info::Connected, handler::{Handler, HandlerWithoutStateExt}, routing::get, + serve::ListenerExt, Router, }; @@ -452,14 +453,29 @@ mod tests { let addr = "0.0.0.0:0"; + let tcp_nodelay_listener = || async { + TcpListener::bind(addr).await.unwrap().tap_io(|tcp_stream| { + if let Err(err) = tcp_stream.set_nodelay(true) { + eprintln!("failed to set TCP_NODELAY on incoming connection: {err:#}"); + } + }) + }; + // router serve(TcpListener::bind(addr).await.unwrap(), router.clone()); + serve(tcp_nodelay_listener().await, router.clone()) + .await + .unwrap(); serve(UnixListener::bind("").unwrap(), router.clone()); serve( TcpListener::bind(addr).await.unwrap(), router.clone().into_make_service(), ); + serve( + tcp_nodelay_listener().await, + router.clone().into_make_service(), + ); serve( UnixListener::bind("").unwrap(), router.clone().into_make_service(), @@ -471,6 +487,12 @@ mod tests { .clone() .into_make_service_with_connect_info::(), ); + serve( + tcp_nodelay_listener().await, + router + .clone() + .into_make_service_with_connect_info::(), + ); serve( UnixListener::bind("").unwrap(), router.into_make_service_with_connect_info::(), @@ -478,12 +500,17 @@ mod tests { // method router serve(TcpListener::bind(addr).await.unwrap(), get(handler)); + serve(tcp_nodelay_listener().await, get(handler)); serve(UnixListener::bind("").unwrap(), get(handler)); serve( TcpListener::bind(addr).await.unwrap(), get(handler).into_make_service(), ); + serve( + tcp_nodelay_listener().await, + get(handler).into_make_service(), + ); serve( UnixListener::bind("").unwrap(), get(handler).into_make_service(), @@ -493,6 +520,10 @@ mod tests { TcpListener::bind(addr).await.unwrap(), get(handler).into_make_service_with_connect_info::(), ); + serve( + tcp_nodelay_listener().await, + get(handler).into_make_service_with_connect_info::(), + ); serve( UnixListener::bind("").unwrap(), get(handler).into_make_service_with_connect_info::(), @@ -503,24 +534,31 @@ mod tests { TcpListener::bind(addr).await.unwrap(), handler.into_service(), ); + serve(tcp_nodelay_listener().await, handler.into_service()); serve(UnixListener::bind("").unwrap(), handler.into_service()); serve( TcpListener::bind(addr).await.unwrap(), handler.with_state(()), ); + serve(tcp_nodelay_listener().await, handler.with_state(())); serve(UnixListener::bind("").unwrap(), handler.with_state(())); serve( TcpListener::bind(addr).await.unwrap(), handler.into_make_service(), ); + serve(tcp_nodelay_listener().await, handler.into_make_service()); serve(UnixListener::bind("").unwrap(), handler.into_make_service()); serve( TcpListener::bind(addr).await.unwrap(), handler.into_make_service_with_connect_info::(), ); + serve( + tcp_nodelay_listener().await, + handler.into_make_service_with_connect_info::(), + ); serve( UnixListener::bind("").unwrap(), handler.into_make_service_with_connect_info::(),