diff --git a/axum/src/serve.rs b/axum/src/serve.rs index 9ebf4048f4c..0c7e790e7a8 100644 --- a/axum/src/serve.rs +++ b/axum/src/serve.rs @@ -22,7 +22,7 @@ use tower_service::Service; mod listener; -pub use self::listener::Listener; +pub use self::listener::{Listener, ListenerExt}; /// Serve the service with the supplied listener. /// diff --git a/axum/src/serve/listener.rs b/axum/src/serve/listener.rs index 0f8754e9082..2fa37d8c8ad 100644 --- a/axum/src/serve/listener.rs +++ b/axum/src/serve/listener.rs @@ -1,4 +1,4 @@ -use std::{future::Future, time::Duration}; +use std::{fmt, future::Future, time::Duration}; use tokio::{ io::{self, AsyncRead, AsyncWrite}, @@ -62,6 +62,78 @@ impl Listener for tokio::net::UnixListener { } } +/// Extensions to [`Listener`]. +pub trait ListenerExt: Listener + Sized { + /// Run a mutable closure on every accepted `Io`. + /// + /// # Example + /// + /// ``` + /// use axum::{Router, routing::get, serve::ListenerExt}; + /// use tracing::trace; + /// + /// # async { + /// let router = Router::new().route("/", get(|| async { "Hello, World!" })); + /// + /// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000") + /// .await + /// .unwrap() + /// .tap_io(|tcp_stream| { + /// if let Err(err) = tcp_stream.set_nodelay(true) { + /// trace!("failed to set TCP_NODELAY on incoming connection: {err:#}"); + /// } + /// }); + /// axum::serve(listener, router).await.unwrap(); + /// # }; + /// ``` + fn tap_io(self, tap_fn: F) -> TapIo + where + F: FnMut(&mut Self::Io) + Send + 'static, + { + TapIo { + listener: self, + tap_fn, + } + } +} + +impl ListenerExt for L {} + +pub struct TapIo { + listener: L, + tap_fn: F, +} + +impl fmt::Debug for TapIo +where + L: Listener + fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("TapIo") + .field("listener", &self.listener) + .finish_non_exhaustive() + } +} + +impl Listener for TapIo +where + L: Listener, + F: FnMut(&mut L::Io) + Send + 'static, +{ + type Io = L::Io; + type Addr = L::Addr; + + async fn accept(&mut self) -> (Self::Io, Self::Addr) { + let (mut io, addr) = self.listener.accept().await; + (self.tap_fn)(&mut io); + (io, addr) + } + + fn local_addr(&self) -> io::Result { + self.listener.local_addr() + } +} + async fn handle_accept_error(e: io::Error) { if is_connection_error(&e) { return;