diff --git a/axum/src/extract/connect_info.rs b/axum/src/extract/connect_info.rs index 54a8d77582..2c7866f9f6 100644 --- a/axum/src/extract/connect_info.rs +++ b/axum/src/extract/connect_info.rs @@ -92,6 +92,17 @@ const _: () = { *stream.remote_addr() } } + + impl<'a, L, F> Connected>> for L::Addr + where + L: serve::Listener, + L::Addr: Clone + Sync + 'static, + F: FnMut(&mut L::Io) + Send + 'static, + { + fn connect_info(stream: serve::IncomingStream<'a, serve::TapIo>) -> Self { + stream.remote_addr().clone() + } + } }; impl Connected for SocketAddr { diff --git a/axum/src/serve.rs b/axum/src/serve.rs index aa205c8d91..a1358b1502 100644 --- a/axum/src/serve.rs +++ b/axum/src/serve.rs @@ -434,6 +434,7 @@ mod tests { extract::connect_info::Connected, handler::{Handler, HandlerWithoutStateExt}, routing::get, + serve::ListenerExt, Router, }; @@ -452,14 +453,29 @@ mod tests { let addr = "0.0.0.0:0"; + let tcp_nodelay_listener = || async { + TcpListener::bind(addr).await.unwrap().tap_io(|tcp_stream| { + if let Err(err) = tcp_stream.set_nodelay(true) { + eprintln!("failed to set TCP_NODELAY on incoming connection: {err:#}"); + } + }) + }; + // router serve(TcpListener::bind(addr).await.unwrap(), router.clone()); + serve(tcp_nodelay_listener().await, router.clone()) + .await + .unwrap(); serve(UnixListener::bind("").unwrap(), router.clone()); serve( TcpListener::bind(addr).await.unwrap(), router.clone().into_make_service(), ); + serve( + tcp_nodelay_listener().await, + router.clone().into_make_service(), + ); serve( UnixListener::bind("").unwrap(), router.clone().into_make_service(), @@ -471,6 +487,12 @@ mod tests { .clone() .into_make_service_with_connect_info::(), ); + serve( + tcp_nodelay_listener().await, + router + .clone() + .into_make_service_with_connect_info::(), + ); serve( UnixListener::bind("").unwrap(), router.into_make_service_with_connect_info::(), @@ -478,12 +500,17 @@ mod tests { // method router serve(TcpListener::bind(addr).await.unwrap(), get(handler)); + serve(tcp_nodelay_listener().await, get(handler)); serve(UnixListener::bind("").unwrap(), get(handler)); serve( TcpListener::bind(addr).await.unwrap(), get(handler).into_make_service(), ); + serve( + tcp_nodelay_listener().await, + get(handler).into_make_service(), + ); serve( UnixListener::bind("").unwrap(), get(handler).into_make_service(), @@ -493,6 +520,10 @@ mod tests { TcpListener::bind(addr).await.unwrap(), get(handler).into_make_service_with_connect_info::(), ); + serve( + tcp_nodelay_listener().await, + get(handler).into_make_service_with_connect_info::(), + ); serve( UnixListener::bind("").unwrap(), get(handler).into_make_service_with_connect_info::(), @@ -503,24 +534,31 @@ mod tests { TcpListener::bind(addr).await.unwrap(), handler.into_service(), ); + serve(tcp_nodelay_listener().await, handler.into_service()); serve(UnixListener::bind("").unwrap(), handler.into_service()); serve( TcpListener::bind(addr).await.unwrap(), handler.with_state(()), ); + serve(tcp_nodelay_listener().await, handler.with_state(())); serve(UnixListener::bind("").unwrap(), handler.with_state(())); serve( TcpListener::bind(addr).await.unwrap(), handler.into_make_service(), ); + serve(tcp_nodelay_listener().await, handler.into_make_service()); serve(UnixListener::bind("").unwrap(), handler.into_make_service()); serve( TcpListener::bind(addr).await.unwrap(), handler.into_make_service_with_connect_info::(), ); + serve( + tcp_nodelay_listener().await, + handler.into_make_service_with_connect_info::(), + ); serve( UnixListener::bind("").unwrap(), handler.into_make_service_with_connect_info::(), diff --git a/axum/src/serve/listener.rs b/axum/src/serve/listener.rs index 8458b9572d..91effae5bd 100644 --- a/axum/src/serve/listener.rs +++ b/axum/src/serve/listener.rs @@ -102,7 +102,7 @@ impl ListenerExt for L {} /// Return type of [`ListenerExt::tap_io`]. /// /// See that method for details. -pub struct TapIo { +pub struct TapIo { listener: L, tap_fn: F, }