Skip to content

Commit

Permalink
Add serve::ListenerExt with tap_io method
Browse files Browse the repository at this point in the history
  • Loading branch information
jplatte committed Nov 30, 2024
1 parent e8644d8 commit a5f7c1d
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 2 deletions.
2 changes: 1 addition & 1 deletion axum/src/serve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use tower_service::Service;

mod listener;

pub use self::listener::Listener;
pub use self::listener::{Listener, ListenerExt};

/// Serve the service with the supplied listener.
///
Expand Down
74 changes: 73 additions & 1 deletion axum/src/serve/listener.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{future::Future, time::Duration};
use std::{fmt, future::Future, time::Duration};

use tokio::{
io::{self, AsyncRead, AsyncWrite},
Expand Down Expand Up @@ -62,6 +62,78 @@ impl Listener for tokio::net::UnixListener {
}
}

/// Extensions to [`Listener`].
pub trait ListenerExt: Listener + Sized {
/// Run a mutable closure on every accepted `Io`.
///
/// # Example
///
/// ```
/// use axum::{Router, routing::get, serve::ListenerExt};
/// use tracing::trace;
///
/// # async {
/// let router = Router::new().route("/", get(|| async { "Hello, World!" }));
///
/// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000")
/// .await
/// .unwrap()
/// .tap_io(|tcp_stream| {
/// if let Err(err) = tcp_stream.set_nodelay(true) {
/// trace!("failed to set TCP_NODELAY on incoming connection: {err:#}");
/// }
/// });
/// axum::serve(listener, router).await.unwrap();
/// # };
/// ```
fn tap_io<F>(self, tap_fn: F) -> TapIo<Self, F>
where
F: FnMut(&mut Self::Io) + Send + 'static,
{
TapIo {
listener: self,
tap_fn,
}
}
}

impl<L: Listener> ListenerExt for L {}

pub struct TapIo<L: Listener, F> {
listener: L,
tap_fn: F,
}

impl<L, F> fmt::Debug for TapIo<L, F>
where
L: Listener + fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("TapIo")
.field("listener", &self.listener)
.finish_non_exhaustive()
}
}

impl<L, F> Listener for TapIo<L, F>
where
L: Listener,
F: FnMut(&mut L::Io) + Send + 'static,
{
type Io = L::Io;
type Addr = L::Addr;

async fn accept(&mut self) -> (Self::Io, Self::Addr) {
let (mut io, addr) = self.listener.accept().await;
(self.tap_fn)(&mut io);
(io, addr)
}

fn local_addr(&self) -> io::Result<Self::Addr> {
self.listener.local_addr()
}
}

async fn handle_accept_error(e: io::Error) {
if is_connection_error(&e) {
return;
Expand Down

0 comments on commit a5f7c1d

Please sign in to comment.