From 8c030a3aee6ece3bb4ccb9cc2546ee247083faa8 Mon Sep 17 00:00:00 2001 From: Jonas Platte Date: Sat, 30 Nov 2024 15:30:46 +0100 Subject: [PATCH 1/5] Make `serve` generic over the listener and IO types Co-authored-by: David Pedersen --- axum/CHANGELOG.md | 2 + .../into_make_service_with_connect_info.md | 5 +- axum/src/extract/connect_info.rs | 15 +- axum/src/handler/service.rs | 7 +- axum/src/routing/method_routing.rs | 9 +- axum/src/routing/mod.rs | 9 +- axum/src/serve.rs | 335 +++++++++++++----- examples/unix-domain-socket/src/main.rs | 56 +-- 8 files changed, 296 insertions(+), 142 deletions(-) diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md index ce98bca697..51654fa420 100644 --- a/axum/CHANGELOG.md +++ b/axum/CHANGELOG.md @@ -23,6 +23,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **added:** Extend `FailedToDeserializePathParams::kind` enum with (`ErrorKind::DeserializeError`) This new variant captures both `key`, `value`, and `message` from named path parameters parse errors, instead of only deserialization error message in `ErrorKind::Message`. ([#2720]) +- **breaking:** Make `serve` generic over the listener and IO types ([#2941]) [#2897]: https://github.com/tokio-rs/axum/pull/2897 [#2903]: https://github.com/tokio-rs/axum/pull/2903 @@ -34,6 +35,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 [#2992]: https://github.com/tokio-rs/axum/pull/2992 [#2720]: https://github.com/tokio-rs/axum/pull/2720 [#3039]: https://github.com/tokio-rs/axum/pull/3039 +[#2941]: https://github.com/tokio-rs/axum/pull/2941 # 0.8.0 diff --git a/axum/src/docs/routing/into_make_service_with_connect_info.md b/axum/src/docs/routing/into_make_service_with_connect_info.md index 26d0602f31..088f21f9d4 100644 --- a/axum/src/docs/routing/into_make_service_with_connect_info.md +++ b/axum/src/docs/routing/into_make_service_with_connect_info.md @@ -35,6 +35,7 @@ use axum::{ serve::IncomingStream, Router, }; +use tokio::net::TcpListener; let app = Router::new().route("/", get(handler)); @@ -49,8 +50,8 @@ struct MyConnectInfo { // ... } -impl Connected> for MyConnectInfo { - fn connect_info(target: IncomingStream<'_>) -> Self { +impl Connected> for MyConnectInfo { + fn connect_info(target: IncomingStream<'_, TcpListener>) -> Self { MyConnectInfo { // ... } diff --git a/axum/src/extract/connect_info.rs b/axum/src/extract/connect_info.rs index 3d8f9a0163..54a8d77582 100644 --- a/axum/src/extract/connect_info.rs +++ b/axum/src/extract/connect_info.rs @@ -79,16 +79,17 @@ where /// [`Router::into_make_service_with_connect_info`]: crate::routing::Router::into_make_service_with_connect_info pub trait Connected: Clone + Send + Sync + 'static { /// Create type holding information about the connection. - fn connect_info(target: T) -> Self; + fn connect_info(stream: T) -> Self; } #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] const _: () = { - use crate::serve::IncomingStream; + use crate::serve; + use tokio::net::TcpListener; - impl Connected> for SocketAddr { - fn connect_info(target: IncomingStream<'_>) -> Self { - target.remote_addr() + impl Connected> for SocketAddr { + fn connect_info(stream: serve::IncomingStream<'_, TcpListener>) -> Self { + *stream.remote_addr() } } }; @@ -261,8 +262,8 @@ mod tests { value: &'static str, } - impl Connected> for MyConnectInfo { - fn connect_info(_target: IncomingStream<'_>) -> Self { + impl Connected> for MyConnectInfo { + fn connect_info(_target: IncomingStream<'_, TcpListener>) -> Self { Self { value: "it worked!", } diff --git a/axum/src/handler/service.rs b/axum/src/handler/service.rs index e6b8df9316..2090051978 100644 --- a/axum/src/handler/service.rs +++ b/axum/src/handler/service.rs @@ -180,12 +180,13 @@ where // for `axum::serve(listener, handler)` #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] const _: () = { - use crate::serve::IncomingStream; + use crate::serve; - impl Service> for HandlerService + impl Service> for HandlerService where H: Clone, S: Clone, + L: serve::Listener, { type Response = Self; type Error = Infallible; @@ -195,7 +196,7 @@ const _: () = { Poll::Ready(Ok(())) } - fn call(&mut self, _req: IncomingStream<'_>) -> Self::Future { + fn call(&mut self, _req: serve::IncomingStream<'_, L>) -> Self::Future { std::future::ready(Ok(self.clone())) } } diff --git a/axum/src/routing/method_routing.rs b/axum/src/routing/method_routing.rs index cfc47e1f7f..2c03e7d5d0 100644 --- a/axum/src/routing/method_routing.rs +++ b/axum/src/routing/method_routing.rs @@ -1313,9 +1313,12 @@ where // for `axum::serve(listener, router)` #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] const _: () = { - use crate::serve::IncomingStream; + use crate::serve; - impl Service> for MethodRouter<()> { + impl Service> for MethodRouter<()> + where + L: serve::Listener, + { type Response = Self; type Error = Infallible; type Future = std::future::Ready>; @@ -1324,7 +1327,7 @@ const _: () = { Poll::Ready(Ok(())) } - fn call(&mut self, _req: IncomingStream<'_>) -> Self::Future { + fn call(&mut self, _req: serve::IncomingStream<'_, L>) -> Self::Future { std::future::ready(Ok(self.clone().with_state(()))) } } diff --git a/axum/src/routing/mod.rs b/axum/src/routing/mod.rs index d1e84d6aa9..6404ce7db3 100644 --- a/axum/src/routing/mod.rs +++ b/axum/src/routing/mod.rs @@ -518,9 +518,12 @@ impl Router { // for `axum::serve(listener, router)` #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] const _: () = { - use crate::serve::IncomingStream; + use crate::serve; - impl Service> for Router<()> { + impl Service> for Router<()> + where + L: serve::Listener, + { type Response = Self; type Error = Infallible; type Future = std::future::Ready>; @@ -529,7 +532,7 @@ const _: () = { Poll::Ready(Ok(())) } - fn call(&mut self, _req: IncomingStream<'_>) -> Self::Future { + fn call(&mut self, _req: serve::IncomingStream<'_, L>) -> Self::Future { // call `Router::with_state` such that everything is turned into `Route` eagerly // rather than doing that per request std::future::ready(Ok(self.clone().with_state(()))) diff --git a/axum/src/serve.rs b/axum/src/serve.rs index 87b103ee90..30715455af 100644 --- a/axum/src/serve.rs +++ b/axum/src/serve.rs @@ -1,12 +1,12 @@ //! Serve services. use std::{ + any::TypeId, convert::Infallible, fmt::Debug, future::{poll_fn, Future, IntoFuture}, io, marker::PhantomData, - net::SocketAddr, sync::Arc, time::Duration, }; @@ -18,12 +18,59 @@ use hyper_util::rt::{TokioExecutor, TokioIo}; #[cfg(any(feature = "http1", feature = "http2"))] use hyper_util::{server::conn::auto::Builder, service::TowerToHyperService}; use tokio::{ + io::{AsyncRead, AsyncWrite}, net::{TcpListener, TcpStream}, sync::watch, }; use tower::ServiceExt as _; use tower_service::Service; +/// Types that can listen for connections. +pub trait Listener: Send + 'static { + /// The listener's IO type. + type Io: AsyncRead + AsyncWrite + Unpin + Send + 'static; + + /// The listener's address type. + type Addr: Send; + + /// Accept a new incoming connection to this listener + fn accept(&mut self) -> impl Future> + Send; + + /// Returns the local address that this listener is bound to. + fn local_addr(&self) -> io::Result; +} + +impl Listener for TcpListener { + type Io = TcpStream; + type Addr = std::net::SocketAddr; + + #[inline] + async fn accept(&mut self) -> io::Result<(Self::Io, Self::Addr)> { + Self::accept(self).await + } + + #[inline] + fn local_addr(&self) -> io::Result { + Self::local_addr(self) + } +} + +#[cfg(unix)] +impl Listener for tokio::net::UnixListener { + type Io = tokio::net::UnixStream; + type Addr = tokio::net::unix::SocketAddr; + + #[inline] + async fn accept(&mut self) -> io::Result<(Self::Io, Self::Addr)> { + Self::accept(self).await + } + + #[inline] + fn local_addr(&self) -> io::Result { + Self::local_addr(self) + } +} + /// Serve the service with the supplied listener. /// /// This method of running a service is intentionally simple and doesn't support any configuration. @@ -89,14 +136,15 @@ use tower_service::Service; /// [`HandlerWithoutStateExt::into_make_service_with_connect_info`]: crate::handler::HandlerWithoutStateExt::into_make_service_with_connect_info /// [`HandlerService::into_make_service_with_connect_info`]: crate::handler::HandlerService::into_make_service_with_connect_info #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] -pub fn serve(tcp_listener: TcpListener, make_service: M) -> Serve +pub fn serve(listener: L, make_service: M) -> Serve where - M: for<'a> Service, Error = Infallible, Response = S>, + L: Listener, + M: for<'a> Service, Error = Infallible, Response = S>, S: Service + Clone + Send + 'static, S::Future: Send, { Serve { - tcp_listener, + listener, make_service, tcp_nodelay: None, _marker: PhantomData, @@ -106,15 +154,18 @@ where /// Future returned by [`serve`]. #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] #[must_use = "futures must be awaited or polled"] -pub struct Serve { - tcp_listener: TcpListener, +pub struct Serve { + listener: L, make_service: M, tcp_nodelay: Option, _marker: PhantomData, } #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] -impl Serve { +impl Serve +where + L: Listener, +{ /// Prepares a server to handle graceful shutdown when the provided future completes. /// /// # Example @@ -136,12 +187,12 @@ impl Serve { /// // ... /// } /// ``` - pub fn with_graceful_shutdown(self, signal: F) -> WithGracefulShutdown + pub fn with_graceful_shutdown(self, signal: F) -> WithGracefulShutdown where F: Future + Send + 'static, { WithGracefulShutdown { - tcp_listener: self.tcp_listener, + listener: self.listener, make_service: self.make_service, signal, tcp_nodelay: self.tcp_nodelay, @@ -149,6 +200,14 @@ impl Serve { } } + /// Returns the local address this server is bound to. + pub fn local_addr(&self) -> io::Result { + self.listener.local_addr() + } +} + +#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] +impl Serve { /// Instructs the server to set the value of the `TCP_NODELAY` option on every accepted connection. /// /// See also [`TcpStream::set_nodelay`]. @@ -173,39 +232,41 @@ impl Serve { ..self } } - - /// Returns the local address this server is bound to. - pub fn local_addr(&self) -> io::Result { - self.tcp_listener.local_addr() - } } #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] -impl Debug for Serve +impl Debug for Serve where + L: Debug + 'static, M: Debug, { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let Self { - tcp_listener, + listener, make_service, tcp_nodelay, _marker: _, } = self; - f.debug_struct("Serve") - .field("tcp_listener", tcp_listener) - .field("make_service", make_service) - .field("tcp_nodelay", tcp_nodelay) - .finish() + let mut s = f.debug_struct("Serve"); + s.field("listener", listener) + .field("make_service", make_service); + + if TypeId::of::() == TypeId::of::() { + s.field("tcp_nodelay", tcp_nodelay); + } + + s.finish() } } #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] -impl IntoFuture for Serve +impl IntoFuture for Serve where - M: for<'a> Service, Error = Infallible, Response = S> + Send + 'static, - for<'a> >>::Future: Send, + L: Listener, + L::Addr: Debug, + M: for<'a> Service, Error = Infallible, Response = S> + Send + 'static, + for<'a> >>::Future: Send, S: Service + Clone + Send + 'static, S::Future: Send, { @@ -221,15 +282,27 @@ where /// Serve future with graceful shutdown enabled. #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] #[must_use = "futures must be awaited or polled"] -pub struct WithGracefulShutdown { - tcp_listener: TcpListener, +pub struct WithGracefulShutdown { + listener: L, make_service: M, signal: F, tcp_nodelay: Option, _marker: PhantomData, } -impl WithGracefulShutdown { +#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] +impl WithGracefulShutdown +where + L: Listener, +{ + /// Returns the local address this server is bound to. + pub fn local_addr(&self) -> io::Result { + self.listener.local_addr() + } +} + +#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] +impl WithGracefulShutdown { /// Instructs the server to set the value of the `TCP_NODELAY` option on every accepted connection. /// /// See also [`TcpStream::set_nodelay`]. @@ -259,43 +332,45 @@ impl WithGracefulShutdown { ..self } } - - /// Returns the local address this server is bound to. - pub fn local_addr(&self) -> io::Result { - self.tcp_listener.local_addr() - } } #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] -impl Debug for WithGracefulShutdown +impl Debug for WithGracefulShutdown where + L: Debug + 'static, M: Debug, S: Debug, F: Debug, { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let Self { - tcp_listener, + listener, make_service, signal, tcp_nodelay, _marker: _, } = self; - f.debug_struct("WithGracefulShutdown") - .field("tcp_listener", tcp_listener) + let mut s = f.debug_struct("WithGracefulShutdown"); + s.field("listener", listener) .field("make_service", make_service) - .field("signal", signal) - .field("tcp_nodelay", tcp_nodelay) - .finish() + .field("signal", signal); + + if TypeId::of::() == TypeId::of::() { + s.field("tcp_nodelay", tcp_nodelay); + } + + s.finish() } } #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] -impl IntoFuture for WithGracefulShutdown +impl IntoFuture for WithGracefulShutdown where - M: for<'a> Service, Error = Infallible, Response = S> + Send + 'static, - for<'a> >>::Future: Send, + L: Listener, + L::Addr: Debug, + M: for<'a> Service, Error = Infallible, Response = S> + Send + 'static, + for<'a> >>::Future: Send, S: Service + Clone + Send + 'static, S::Future: Send, F: Future + Send + 'static, @@ -305,7 +380,7 @@ where fn into_future(self) -> Self::IntoFuture { let Self { - tcp_listener, + mut listener, mut make_service, signal, tcp_nodelay, @@ -324,8 +399,8 @@ where let (close_tx, close_rx) = watch::channel(()); loop { - let (tcp_stream, remote_addr) = tokio::select! { - conn = tcp_accept(&tcp_listener) => { + let (io, remote_addr) = tokio::select! { + conn = accept(&mut listener) => { match conn { Some(conn) => conn, None => continue, @@ -338,14 +413,16 @@ where }; if let Some(nodelay) = tcp_nodelay { + let tcp_stream: &tokio::net::TcpStream = ::downcast_ref(&io) + .expect("internal error: tcp_nodelay used with the wrong type of listener"); if let Err(err) = tcp_stream.set_nodelay(nodelay) { trace!("failed to set TCP_NODELAY on incoming connection: {err:#}"); } } - let tcp_stream = TokioIo::new(tcp_stream); + let io = TokioIo::new(io); - trace!("connection {remote_addr} accepted"); + trace!("connection {remote_addr:?} accepted"); poll_fn(|cx| make_service.poll_ready(cx)) .await @@ -353,7 +430,7 @@ where let tower_service = make_service .call(IncomingStream { - tcp_stream: &tcp_stream, + io: &io, remote_addr, }) .await @@ -372,7 +449,7 @@ where // CONNECT protocol needed for HTTP/2 websockets #[cfg(feature = "http2")] builder.http2().enable_connect_protocol(); - let conn = builder.serve_connection_with_upgrades(tcp_stream, hyper_service); + let conn = builder.serve_connection_with_upgrades(io, hyper_service); pin_mut!(conn); let signal_closed = signal_tx.closed().fuse(); @@ -393,14 +470,12 @@ where } } - trace!("connection {remote_addr} closed"); - drop(close_rx); }); } drop(close_rx); - drop(tcp_listener); + drop(listener); trace!( "waiting for {} task(s) to finish", @@ -422,7 +497,10 @@ fn is_connection_error(e: &io::Error) -> bool { ) } -async fn tcp_accept(listener: &TcpListener) -> Option<(TcpStream, SocketAddr)> { +async fn accept(listener: &mut L) -> Option<(L::Io, L::Addr)> +where + L: Listener, +{ match listener.accept().await { Ok(conn) => Some(conn), Err(e) => { @@ -448,6 +526,35 @@ async fn tcp_accept(listener: &TcpListener) -> Option<(TcpStream, SocketAddr)> { } } +/// An incoming stream. +/// +/// Used with [`serve`] and [`IntoMakeServiceWithConnectInfo`]. +/// +/// [`IntoMakeServiceWithConnectInfo`]: crate::extract::connect_info::IntoMakeServiceWithConnectInfo +#[derive(Debug)] +pub struct IncomingStream<'a, L> +where + L: Listener, +{ + io: &'a TokioIo, + remote_addr: L::Addr, +} + +impl IncomingStream<'_, L> +where + L: Listener, +{ + /// Get a reference to the inner IO type. + pub fn io(&self) -> &L::Io { + self.io.inner() + } + + /// Returns the remote address that this stream is bound to. + pub fn remote_addr(&self) -> &L::Addr { + &self.remote_addr + } +} + mod private { use std::{ future::Future, @@ -474,33 +581,15 @@ mod private { } } -/// An incoming stream. -/// -/// Used with [`serve`] and [`IntoMakeServiceWithConnectInfo`]. -/// -/// [`IntoMakeServiceWithConnectInfo`]: crate::extract::connect_info::IntoMakeServiceWithConnectInfo -#[derive(Debug)] -pub struct IncomingStream<'a> { - tcp_stream: &'a TokioIo, - remote_addr: SocketAddr, -} - -impl IncomingStream<'_> { - /// Returns the local address that this stream is bound to. - pub fn local_addr(&self) -> std::io::Result { - self.tcp_stream.inner().local_addr() - } - - /// Returns the remote address that this stream is bound to. - pub fn remote_addr(&self) -> SocketAddr { - self.remote_addr - } -} - #[cfg(test)] mod tests { + use http::StatusCode; + use tokio::net::UnixListener; + use super::*; use crate::{ + body::to_bytes, + extract::connect_info::Connected, handler::{Handler, HandlerWithoutStateExt}, routing::get, Router, @@ -512,30 +601,63 @@ mod tests { #[allow(dead_code, unused_must_use)] async fn if_it_compiles_it_works() { + #[derive(Clone, Debug)] + struct UdsConnectInfo; + + impl Connected> for UdsConnectInfo { + fn connect_info(_stream: IncomingStream<'_, UnixListener>) -> Self { + Self + } + } + let router: Router = Router::new(); let addr = "0.0.0.0:0"; // router serve(TcpListener::bind(addr).await.unwrap(), router.clone()); + serve(UnixListener::bind("").unwrap(), router.clone()); + serve( TcpListener::bind(addr).await.unwrap(), router.clone().into_make_service(), ); + serve( + UnixListener::bind("").unwrap(), + router.clone().into_make_service(), + ); + serve( TcpListener::bind(addr).await.unwrap(), - router.into_make_service_with_connect_info::(), + router + .clone() + .into_make_service_with_connect_info::(), + ); + serve( + UnixListener::bind("").unwrap(), + router.into_make_service_with_connect_info::(), ); // method router serve(TcpListener::bind(addr).await.unwrap(), get(handler)); + serve(UnixListener::bind("").unwrap(), get(handler)); + serve( TcpListener::bind(addr).await.unwrap(), get(handler).into_make_service(), ); + serve( + UnixListener::bind("").unwrap(), + get(handler).into_make_service(), + ); + serve( TcpListener::bind(addr).await.unwrap(), - get(handler).into_make_service_with_connect_info::(), + get(handler).into_make_service_with_connect_info::(), + ); + serve( + UnixListener::bind("").unwrap(), + get(handler).into_make_service_with_connect_info::(), ); // handler @@ -543,17 +665,27 @@ mod tests { TcpListener::bind(addr).await.unwrap(), handler.into_service(), ); + serve(UnixListener::bind("").unwrap(), handler.into_service()); + serve( TcpListener::bind(addr).await.unwrap(), handler.with_state(()), ); + serve(UnixListener::bind("").unwrap(), handler.with_state(())); + serve( TcpListener::bind(addr).await.unwrap(), handler.into_make_service(), ); + serve(UnixListener::bind("").unwrap(), handler.into_make_service()); + serve( TcpListener::bind(addr).await.unwrap(), - handler.into_make_service_with_connect_info::(), + handler.into_make_service_with_connect_info::(), + ); + serve( + UnixListener::bind("").unwrap(), + handler.into_make_service_with_connect_info::(), ); // nodelay @@ -613,4 +745,49 @@ mod tests { // Call Serve::into_future outside of a tokio context. This used to panic. _ = serve(listener, router).into_future(); } + + #[crate::test] + async fn serving_on_custom_io_type() { + struct ReadyListener(Option); + + impl Listener for ReadyListener + where + T: AsyncRead + AsyncWrite + Unpin + Send + 'static, + { + type Io = T; + type Addr = (); + + async fn accept(&mut self) -> io::Result<(Self::Io, Self::Addr)> { + match self.0.take() { + Some(server) => Ok((server, ())), + None => std::future::pending().await, + } + } + + fn local_addr(&self) -> io::Result { + Ok(()) + } + } + + let (client, server) = tokio::io::duplex(1024); + let listener = ReadyListener(Some(server)); + + let app = Router::new().route("/", get(|| async { "Hello, World!" })); + + tokio::spawn(serve(listener, app).into_future()); + + let stream = TokioIo::new(client); + let (mut sender, conn) = hyper::client::conn::http1::handshake(stream).await.unwrap(); + tokio::spawn(conn); + + let request = Request::builder().body(Body::empty()).unwrap(); + + let response = sender.send_request(request).await.unwrap(); + assert_eq!(response.status(), StatusCode::OK); + + let body = Body::new(response.into_body()); + let body = to_bytes(body, usize::MAX).await.unwrap(); + let body = String::from_utf8(body.to_vec()).unwrap(); + assert_eq!(body, "Hello, World!"); + } } diff --git a/examples/unix-domain-socket/src/main.rs b/examples/unix-domain-socket/src/main.rs index 697f31a557..07b38d9191 100644 --- a/examples/unix-domain-socket/src/main.rs +++ b/examples/unix-domain-socket/src/main.rs @@ -21,17 +21,13 @@ mod unix { extract::connect_info::{self, ConnectInfo}, http::{Method, Request, StatusCode}, routing::get, + serve::IncomingStream, Router, }; use http_body_util::BodyExt; - use hyper::body::Incoming; - use hyper_util::{ - rt::{TokioExecutor, TokioIo}, - server, - }; - use std::{convert::Infallible, path::PathBuf, sync::Arc}; + use hyper_util::rt::TokioIo; + use std::{path::PathBuf, sync::Arc}; use tokio::net::{unix::UCred, UnixListener, UnixStream}; - use tower::Service; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; pub async fn server() { @@ -52,33 +48,11 @@ mod unix { let uds = UnixListener::bind(path.clone()).unwrap(); tokio::spawn(async move { - let app = Router::new().route("/", get(handler)); - - let mut make_service = app.into_make_service_with_connect_info::(); - - // See https://github.com/tokio-rs/axum/blob/main/examples/serve-with-hyper/src/main.rs for - // more details about this setup - loop { - let (socket, _remote_addr) = uds.accept().await.unwrap(); - - let tower_service = unwrap_infallible(make_service.call(&socket).await); - - tokio::spawn(async move { - let socket = TokioIo::new(socket); - - let hyper_service = - hyper::service::service_fn(move |request: Request| { - tower_service.clone().call(request) - }); + let app = Router::new() + .route("/", get(handler)) + .into_make_service_with_connect_info::(); - if let Err(err) = server::conn::auto::Builder::new(TokioExecutor::new()) - .serve_connection_with_upgrades(socket, hyper_service) - .await - { - eprintln!("failed to serve connection: {err:#}"); - } - }); - } + axum::serve(uds, app).await.unwrap(); }); let stream = TokioIo::new(UnixStream::connect(path).await.unwrap()); @@ -117,22 +91,14 @@ mod unix { peer_cred: UCred, } - impl connect_info::Connected<&UnixStream> for UdsConnectInfo { - fn connect_info(target: &UnixStream) -> Self { - let peer_addr = target.peer_addr().unwrap(); - let peer_cred = target.peer_cred().unwrap(); - + impl connect_info::Connected> for UdsConnectInfo { + fn connect_info(stream: IncomingStream<'_, UnixListener>) -> Self { + let peer_addr = stream.io().peer_addr().unwrap(); + let peer_cred = stream.io().peer_cred().unwrap(); Self { peer_addr: Arc::new(peer_addr), peer_cred, } } } - - fn unwrap_infallible(result: Result) -> T { - match result { - Ok(value) => value, - Err(err) => match err {}, - } - } } From 0b34a1434abb5c408442e7899ad5454a7c5a1586 Mon Sep 17 00:00:00 2001 From: Jonas Platte Date: Sat, 30 Nov 2024 15:31:28 +0100 Subject: [PATCH 2/5] Handle accept errors inside Listener::accept --- axum/src/serve.rs | 81 +++++++++++++++++++++++------------------------ 1 file changed, 39 insertions(+), 42 deletions(-) diff --git a/axum/src/serve.rs b/axum/src/serve.rs index 30715455af..d137a6c5b7 100644 --- a/axum/src/serve.rs +++ b/axum/src/serve.rs @@ -33,8 +33,11 @@ pub trait Listener: Send + 'static { /// The listener's address type. type Addr: Send; - /// Accept a new incoming connection to this listener - fn accept(&mut self) -> impl Future> + Send; + /// Accept a new incoming connection to this listener. + /// + /// If the underlying accept call can return an error, this function must + /// take care of logging and retrying. + fn accept(&mut self) -> impl Future + Send; /// Returns the local address that this listener is bound to. fn local_addr(&self) -> io::Result; @@ -44,9 +47,13 @@ impl Listener for TcpListener { type Io = TcpStream; type Addr = std::net::SocketAddr; - #[inline] - async fn accept(&mut self) -> io::Result<(Self::Io, Self::Addr)> { - Self::accept(self).await + async fn accept(&mut self) -> (Self::Io, Self::Addr) { + loop { + match Self::accept(self).await { + Ok(tup) => return tup, + Err(e) => handle_accept_error(e).await, + } + } } #[inline] @@ -60,9 +67,13 @@ impl Listener for tokio::net::UnixListener { type Io = tokio::net::UnixStream; type Addr = tokio::net::unix::SocketAddr; - #[inline] - async fn accept(&mut self) -> io::Result<(Self::Io, Self::Addr)> { - Self::accept(self).await + async fn accept(&mut self) -> (Self::Io, Self::Addr) { + loop { + match Self::accept(self).await { + Ok(tup) => return tup, + Err(e) => handle_accept_error(e).await, + } + } } #[inline] @@ -400,12 +411,7 @@ where loop { let (io, remote_addr) = tokio::select! { - conn = accept(&mut listener) => { - match conn { - Some(conn) => conn, - None => continue, - } - } + conn = listener.accept() => conn, _ = signal_tx.closed() => { trace!("signal received, not accepting new connections"); break; @@ -497,33 +503,24 @@ fn is_connection_error(e: &io::Error) -> bool { ) } -async fn accept(listener: &mut L) -> Option<(L::Io, L::Addr)> -where - L: Listener, -{ - match listener.accept().await { - Ok(conn) => Some(conn), - Err(e) => { - if is_connection_error(&e) { - return None; - } - - // [From `hyper::Server` in 0.14](https://github.com/hyperium/hyper/blob/v0.14.27/src/server/tcp.rs#L186) - // - // > A possible scenario is that the process has hit the max open files - // > allowed, and so trying to accept a new connection will fail with - // > `EMFILE`. In some cases, it's preferable to just wait for some time, if - // > the application will likely close some files (or connections), and try - // > to accept the connection again. If this option is `true`, the error - // > will be logged at the `error` level, since it is still a big deal, - // > and then the listener will sleep for 1 second. - // - // hyper allowed customizing this but axum does not. - error!("accept error: {e}"); - tokio::time::sleep(Duration::from_secs(1)).await; - None - } +async fn handle_accept_error(e: io::Error) { + if is_connection_error(&e) { + return; } + + // [From `hyper::Server` in 0.14](https://github.com/hyperium/hyper/blob/v0.14.27/src/server/tcp.rs#L186) + // + // > A possible scenario is that the process has hit the max open files + // > allowed, and so trying to accept a new connection will fail with + // > `EMFILE`. In some cases, it's preferable to just wait for some time, if + // > the application will likely close some files (or connections), and try + // > to accept the connection again. If this option is `true`, the error + // > will be logged at the `error` level, since it is still a big deal, + // > and then the listener will sleep for 1 second. + // + // hyper allowed customizing this but axum does not. + error!("accept error: {e}"); + tokio::time::sleep(Duration::from_secs(1)).await; } /// An incoming stream. @@ -757,9 +754,9 @@ mod tests { type Io = T; type Addr = (); - async fn accept(&mut self) -> io::Result<(Self::Io, Self::Addr)> { + async fn accept(&mut self) -> (Self::Io, Self::Addr) { match self.0.take() { - Some(server) => Ok((server, ())), + Some(server) => (server, ()), None => std::future::pending().await, } } From c0002434855cc5ec9ae1228cb46a25ecf5420824 Mon Sep 17 00:00:00 2001 From: Jonas Platte Date: Sat, 30 Nov 2024 15:58:35 +0100 Subject: [PATCH 3/5] Remove tcp_nodelay from serve --- axum/src/serve.rs | 100 ++-------------------------------------------- 1 file changed, 4 insertions(+), 96 deletions(-) diff --git a/axum/src/serve.rs b/axum/src/serve.rs index d137a6c5b7..c87e2dc1d3 100644 --- a/axum/src/serve.rs +++ b/axum/src/serve.rs @@ -206,7 +206,6 @@ where listener: self.listener, make_service: self.make_service, signal, - tcp_nodelay: self.tcp_nodelay, _marker: PhantomData, } } @@ -217,34 +216,6 @@ where } } -#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] -impl Serve { - /// Instructs the server to set the value of the `TCP_NODELAY` option on every accepted connection. - /// - /// See also [`TcpStream::set_nodelay`]. - /// - /// # Example - /// ``` - /// use axum::{Router, routing::get}; - /// - /// # async { - /// let router = Router::new().route("/", get(|| async { "Hello, World!" })); - /// - /// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap(); - /// axum::serve(listener, router) - /// .tcp_nodelay(true) - /// .await - /// .unwrap(); - /// # }; - /// ``` - pub fn tcp_nodelay(self, nodelay: bool) -> Self { - Self { - tcp_nodelay: Some(nodelay), - ..self - } - } -} - #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] impl Debug for Serve where @@ -297,7 +268,6 @@ pub struct WithGracefulShutdown { listener: L, make_service: M, signal: F, - tcp_nodelay: Option, _marker: PhantomData, } @@ -312,39 +282,6 @@ where } } -#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] -impl WithGracefulShutdown { - /// Instructs the server to set the value of the `TCP_NODELAY` option on every accepted connection. - /// - /// See also [`TcpStream::set_nodelay`]. - /// - /// # Example - /// ``` - /// use axum::{Router, routing::get}; - /// - /// # async { - /// let router = Router::new().route("/", get(|| async { "Hello, World!" })); - /// - /// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap(); - /// axum::serve(listener, router) - /// .with_graceful_shutdown(shutdown_signal()) - /// .tcp_nodelay(true) - /// .await - /// .unwrap(); - /// # }; - /// - /// async fn shutdown_signal() { - /// // ... - /// } - /// ``` - pub fn tcp_nodelay(self, nodelay: bool) -> Self { - Self { - tcp_nodelay: Some(nodelay), - ..self - } - } -} - #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] impl Debug for WithGracefulShutdown where @@ -358,20 +295,14 @@ where listener, make_service, signal, - tcp_nodelay, _marker: _, } = self; - let mut s = f.debug_struct("WithGracefulShutdown"); - s.field("listener", listener) + f.debug_struct("WithGracefulShutdown") + .field("listener", listener) .field("make_service", make_service) - .field("signal", signal); - - if TypeId::of::() == TypeId::of::() { - s.field("tcp_nodelay", tcp_nodelay); - } - - s.finish() + .field("signal", signal) + .finish() } } @@ -394,7 +325,6 @@ where mut listener, mut make_service, signal, - tcp_nodelay, _marker: _, } = self; @@ -418,14 +348,6 @@ where } }; - if let Some(nodelay) = tcp_nodelay { - let tcp_stream: &tokio::net::TcpStream = ::downcast_ref(&io) - .expect("internal error: tcp_nodelay used with the wrong type of listener"); - if let Err(err) = tcp_stream.set_nodelay(nodelay) { - trace!("failed to set TCP_NODELAY on incoming connection: {err:#}"); - } - } - let io = TokioIo::new(io); trace!("connection {remote_addr:?} accepted"); @@ -684,20 +606,6 @@ mod tests { UnixListener::bind("").unwrap(), handler.into_make_service_with_connect_info::(), ); - - // nodelay - serve( - TcpListener::bind(addr).await.unwrap(), - handler.into_service(), - ) - .tcp_nodelay(true); - - serve( - TcpListener::bind(addr).await.unwrap(), - handler.into_service(), - ) - .with_graceful_shutdown(async { /*...*/ }) - .tcp_nodelay(true); } async fn handler() {} From e8644d85adda4e926044b00b49a6c747cbd07679 Mon Sep 17 00:00:00 2001 From: Jonas Platte Date: Sat, 30 Nov 2024 16:12:45 +0100 Subject: [PATCH 4/5] Move serve::Listener into its own module --- axum/src/serve.rs | 113 ++++++------------------------------- axum/src/serve/listener.rs | 92 ++++++++++++++++++++++++++++++ 2 files changed, 108 insertions(+), 97 deletions(-) create mode 100644 axum/src/serve/listener.rs diff --git a/axum/src/serve.rs b/axum/src/serve.rs index c87e2dc1d3..9ebf4048f4 100644 --- a/axum/src/serve.rs +++ b/axum/src/serve.rs @@ -8,7 +8,6 @@ use std::{ io, marker::PhantomData, sync::Arc, - time::Duration, }; use axum_core::{body::Body, extract::Request, response::Response}; @@ -17,70 +16,13 @@ use hyper::body::Incoming; use hyper_util::rt::{TokioExecutor, TokioIo}; #[cfg(any(feature = "http1", feature = "http2"))] use hyper_util::{server::conn::auto::Builder, service::TowerToHyperService}; -use tokio::{ - io::{AsyncRead, AsyncWrite}, - net::{TcpListener, TcpStream}, - sync::watch, -}; +use tokio::{net::TcpListener, sync::watch}; use tower::ServiceExt as _; use tower_service::Service; -/// Types that can listen for connections. -pub trait Listener: Send + 'static { - /// The listener's IO type. - type Io: AsyncRead + AsyncWrite + Unpin + Send + 'static; - - /// The listener's address type. - type Addr: Send; - - /// Accept a new incoming connection to this listener. - /// - /// If the underlying accept call can return an error, this function must - /// take care of logging and retrying. - fn accept(&mut self) -> impl Future + Send; - - /// Returns the local address that this listener is bound to. - fn local_addr(&self) -> io::Result; -} - -impl Listener for TcpListener { - type Io = TcpStream; - type Addr = std::net::SocketAddr; - - async fn accept(&mut self) -> (Self::Io, Self::Addr) { - loop { - match Self::accept(self).await { - Ok(tup) => return tup, - Err(e) => handle_accept_error(e).await, - } - } - } - - #[inline] - fn local_addr(&self) -> io::Result { - Self::local_addr(self) - } -} +mod listener; -#[cfg(unix)] -impl Listener for tokio::net::UnixListener { - type Io = tokio::net::UnixStream; - type Addr = tokio::net::unix::SocketAddr; - - async fn accept(&mut self) -> (Self::Io, Self::Addr) { - loop { - match Self::accept(self).await { - Ok(tup) => return tup, - Err(e) => handle_accept_error(e).await, - } - } - } - - #[inline] - fn local_addr(&self) -> io::Result { - Self::local_addr(self) - } -} +pub use self::listener::Listener; /// Serve the service with the supplied listener. /// @@ -416,35 +358,6 @@ where } } -fn is_connection_error(e: &io::Error) -> bool { - matches!( - e.kind(), - io::ErrorKind::ConnectionRefused - | io::ErrorKind::ConnectionAborted - | io::ErrorKind::ConnectionReset - ) -} - -async fn handle_accept_error(e: io::Error) { - if is_connection_error(&e) { - return; - } - - // [From `hyper::Server` in 0.14](https://github.com/hyperium/hyper/blob/v0.14.27/src/server/tcp.rs#L186) - // - // > A possible scenario is that the process has hit the max open files - // > allowed, and so trying to accept a new connection will fail with - // > `EMFILE`. In some cases, it's preferable to just wait for some time, if - // > the application will likely close some files (or connections), and try - // > to accept the connection again. If this option is `true`, the error - // > will be logged at the `error` level, since it is still a big deal, - // > and then the listener will sleep for 1 second. - // - // hyper allowed customizing this but axum does not. - error!("accept error: {e}"); - tokio::time::sleep(Duration::from_secs(1)).await; -} - /// An incoming stream. /// /// Used with [`serve`] and [`IntoMakeServiceWithConnectInfo`]. @@ -502,10 +415,20 @@ mod private { #[cfg(test)] mod tests { + use std::{ + future::{pending, IntoFuture as _}, + net::{IpAddr, Ipv4Addr}, + }; + + use axum_core::{body::Body, extract::Request}; use http::StatusCode; - use tokio::net::UnixListener; + use hyper_util::rt::TokioIo; + use tokio::{ + io::{self, AsyncRead, AsyncWrite}, + net::{TcpListener, UnixListener}, + }; - use super::*; + use super::{serve, IncomingStream, Listener}; use crate::{ body::to_bytes, extract::connect_info::Connected, @@ -513,10 +436,6 @@ mod tests { routing::get, Router, }; - use std::{ - future::pending, - net::{IpAddr, Ipv4Addr}, - }; #[allow(dead_code, unused_must_use)] async fn if_it_compiles_it_works() { @@ -674,7 +593,7 @@ mod tests { } } - let (client, server) = tokio::io::duplex(1024); + let (client, server) = io::duplex(1024); let listener = ReadyListener(Some(server)); let app = Router::new().route("/", get(|| async { "Hello, World!" })); diff --git a/axum/src/serve/listener.rs b/axum/src/serve/listener.rs new file mode 100644 index 0000000000..0f8754e908 --- /dev/null +++ b/axum/src/serve/listener.rs @@ -0,0 +1,92 @@ +use std::{future::Future, time::Duration}; + +use tokio::{ + io::{self, AsyncRead, AsyncWrite}, + net::{TcpListener, TcpStream}, +}; + +/// Types that can listen for connections. +pub trait Listener: Send + 'static { + /// The listener's IO type. + type Io: AsyncRead + AsyncWrite + Unpin + Send + 'static; + + /// The listener's address type. + type Addr: Send; + + /// Accept a new incoming connection to this listener. + /// + /// If the underlying accept call can return an error, this function must + /// take care of logging and retrying. + fn accept(&mut self) -> impl Future + Send; + + /// Returns the local address that this listener is bound to. + fn local_addr(&self) -> io::Result; +} + +impl Listener for TcpListener { + type Io = TcpStream; + type Addr = std::net::SocketAddr; + + async fn accept(&mut self) -> (Self::Io, Self::Addr) { + loop { + match Self::accept(self).await { + Ok(tup) => return tup, + Err(e) => handle_accept_error(e).await, + } + } + } + + #[inline] + fn local_addr(&self) -> io::Result { + Self::local_addr(self) + } +} + +#[cfg(unix)] +impl Listener for tokio::net::UnixListener { + type Io = tokio::net::UnixStream; + type Addr = tokio::net::unix::SocketAddr; + + async fn accept(&mut self) -> (Self::Io, Self::Addr) { + loop { + match Self::accept(self).await { + Ok(tup) => return tup, + Err(e) => handle_accept_error(e).await, + } + } + } + + #[inline] + fn local_addr(&self) -> io::Result { + Self::local_addr(self) + } +} + +async fn handle_accept_error(e: io::Error) { + if is_connection_error(&e) { + return; + } + + // [From `hyper::Server` in 0.14](https://github.com/hyperium/hyper/blob/v0.14.27/src/server/tcp.rs#L186) + // + // > A possible scenario is that the process has hit the max open files + // > allowed, and so trying to accept a new connection will fail with + // > `EMFILE`. In some cases, it's preferable to just wait for some time, if + // > the application will likely close some files (or connections), and try + // > to accept the connection again. If this option is `true`, the error + // > will be logged at the `error` level, since it is still a big deal, + // > and then the listener will sleep for 1 second. + // + // hyper allowed customizing this but axum does not. + error!("accept error: {e}"); + tokio::time::sleep(Duration::from_secs(1)).await; +} + +fn is_connection_error(e: &io::Error) -> bool { + matches!( + e.kind(), + io::ErrorKind::ConnectionRefused + | io::ErrorKind::ConnectionAborted + | io::ErrorKind::ConnectionReset + ) +} From c6df5f434c8af852864b6f056b44941357b09315 Mon Sep 17 00:00:00 2001 From: Jonas Platte Date: Sat, 30 Nov 2024 16:29:01 +0100 Subject: [PATCH 5/5] Add serve::ListenerExt with tap_io method --- axum/src/serve.rs | 2 +- axum/src/serve/listener.rs | 77 +++++++++++++++++++++++++++++++++++++- 2 files changed, 77 insertions(+), 2 deletions(-) diff --git a/axum/src/serve.rs b/axum/src/serve.rs index 9ebf4048f4..aa205c8d91 100644 --- a/axum/src/serve.rs +++ b/axum/src/serve.rs @@ -22,7 +22,7 @@ use tower_service::Service; mod listener; -pub use self::listener::Listener; +pub use self::listener::{Listener, ListenerExt, TapIo}; /// Serve the service with the supplied listener. /// diff --git a/axum/src/serve/listener.rs b/axum/src/serve/listener.rs index 0f8754e908..8458b9572d 100644 --- a/axum/src/serve/listener.rs +++ b/axum/src/serve/listener.rs @@ -1,4 +1,4 @@ -use std::{future::Future, time::Duration}; +use std::{fmt, future::Future, time::Duration}; use tokio::{ io::{self, AsyncRead, AsyncWrite}, @@ -62,6 +62,81 @@ impl Listener for tokio::net::UnixListener { } } +/// Extensions to [`Listener`]. +pub trait ListenerExt: Listener + Sized { + /// Run a mutable closure on every accepted `Io`. + /// + /// # Example + /// + /// ``` + /// use axum::{Router, routing::get, serve::ListenerExt}; + /// use tracing::trace; + /// + /// # async { + /// let router = Router::new().route("/", get(|| async { "Hello, World!" })); + /// + /// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000") + /// .await + /// .unwrap() + /// .tap_io(|tcp_stream| { + /// if let Err(err) = tcp_stream.set_nodelay(true) { + /// trace!("failed to set TCP_NODELAY on incoming connection: {err:#}"); + /// } + /// }); + /// axum::serve(listener, router).await.unwrap(); + /// # }; + /// ``` + fn tap_io(self, tap_fn: F) -> TapIo + where + F: FnMut(&mut Self::Io) + Send + 'static, + { + TapIo { + listener: self, + tap_fn, + } + } +} + +impl ListenerExt for L {} + +/// Return type of [`ListenerExt::tap_io`]. +/// +/// See that method for details. +pub struct TapIo { + listener: L, + tap_fn: F, +} + +impl fmt::Debug for TapIo +where + L: Listener + fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("TapIo") + .field("listener", &self.listener) + .finish_non_exhaustive() + } +} + +impl Listener for TapIo +where + L: Listener, + F: FnMut(&mut L::Io) + Send + 'static, +{ + type Io = L::Io; + type Addr = L::Addr; + + async fn accept(&mut self) -> (Self::Io, Self::Addr) { + let (mut io, addr) = self.listener.accept().await; + (self.tap_fn)(&mut io); + (io, addr) + } + + fn local_addr(&self) -> io::Result { + self.listener.local_addr() + } +} + async fn handle_accept_error(e: io::Error) { if is_connection_error(&e) { return;