diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md index 3dd43f1a33..7c02a3888f 100644 --- a/axum/CHANGELOG.md +++ b/axum/CHANGELOG.md @@ -8,8 +8,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 # Unreleased - **fixed:** Improve `debug_handler` on tuple response types ([#2201]) +- **breaking:** Make `serve` generic over the listener and IO types ([#2479]) [#2201]: https://github.com/tokio-rs/axum/pull/2201 +[#2479]: https://github.com/tokio-rs/axum/pull/2479 # 0.7.3 (29. December, 2023) diff --git a/axum/Cargo.toml b/axum/Cargo.toml index 1e188dce66..c450a03ea7 100644 --- a/axum/Cargo.toml +++ b/axum/Cargo.toml @@ -115,6 +115,7 @@ rustversion = "1.0.9" [dev-dependencies] anyhow = "1.0" axum-macros = { path = "../axum-macros", version = "0.4.0", features = ["__private"] } +hyper = { version = "1.1.0", features = ["client"] } quickcheck = "1.0" quickcheck_macros = "1.0" reqwest = { version = "0.11.14", default-features = false, features = ["json", "stream", "multipart"] } diff --git a/axum/src/docs/routing/into_make_service_with_connect_info.md b/axum/src/docs/routing/into_make_service_with_connect_info.md index 26d0602f31..088f21f9d4 100644 --- a/axum/src/docs/routing/into_make_service_with_connect_info.md +++ b/axum/src/docs/routing/into_make_service_with_connect_info.md @@ -35,6 +35,7 @@ use axum::{ serve::IncomingStream, Router, }; +use tokio::net::TcpListener; let app = Router::new().route("/", get(handler)); @@ -49,8 +50,8 @@ struct MyConnectInfo { // ... } -impl Connected> for MyConnectInfo { - fn connect_info(target: IncomingStream<'_>) -> Self { +impl Connected> for MyConnectInfo { + fn connect_info(target: IncomingStream<'_, TcpListener>) -> Self { MyConnectInfo { // ... } diff --git a/axum/src/extract/connect_info.rs b/axum/src/extract/connect_info.rs index 2390d886f8..48e9bfb4f5 100644 --- a/axum/src/extract/connect_info.rs +++ b/axum/src/extract/connect_info.rs @@ -80,16 +80,17 @@ where /// [`Router::into_make_service_with_connect_info`]: crate::routing::Router::into_make_service_with_connect_info pub trait Connected: Clone + Send + Sync + 'static { /// Create type holding information about the connection. - fn connect_info(target: T) -> Self; + fn connect_info(stream: T) -> Self; } #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] const _: () = { - use crate::serve::IncomingStream; + use crate::serve; + use tokio::net::TcpListener; - impl Connected> for SocketAddr { - fn connect_info(target: IncomingStream<'_>) -> Self { - target.remote_addr() + impl Connected> for SocketAddr { + fn connect_info(stream: serve::IncomingStream<'_, TcpListener>) -> Self { + *stream.remote_addr() } } }; @@ -264,8 +265,8 @@ mod tests { value: &'static str, } - impl Connected> for MyConnectInfo { - fn connect_info(_target: IncomingStream<'_>) -> Self { + impl Connected> for MyConnectInfo { + fn connect_info(_target: IncomingStream<'_, TcpListener>) -> Self { Self { value: "it worked!", } diff --git a/axum/src/handler/service.rs b/axum/src/handler/service.rs index e6b8df9316..2090051978 100644 --- a/axum/src/handler/service.rs +++ b/axum/src/handler/service.rs @@ -180,12 +180,13 @@ where // for `axum::serve(listener, handler)` #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] const _: () = { - use crate::serve::IncomingStream; + use crate::serve; - impl Service> for HandlerService + impl Service> for HandlerService where H: Clone, S: Clone, + L: serve::Listener, { type Response = Self; type Error = Infallible; @@ -195,7 +196,7 @@ const _: () = { Poll::Ready(Ok(())) } - fn call(&mut self, _req: IncomingStream<'_>) -> Self::Future { + fn call(&mut self, _req: serve::IncomingStream<'_, L>) -> Self::Future { std::future::ready(Ok(self.clone())) } } diff --git a/axum/src/routing/method_routing.rs b/axum/src/routing/method_routing.rs index 962a440111..a7b257c8ae 100644 --- a/axum/src/routing/method_routing.rs +++ b/axum/src/routing/method_routing.rs @@ -1227,9 +1227,12 @@ where // for `axum::serve(listener, router)` #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] const _: () = { - use crate::serve::IncomingStream; + use crate::serve; - impl Service> for MethodRouter<()> { + impl Service> for MethodRouter<()> + where + L: serve::Listener, + { type Response = Self; type Error = Infallible; type Future = std::future::Ready>; @@ -1238,7 +1241,7 @@ const _: () = { Poll::Ready(Ok(())) } - fn call(&mut self, _req: IncomingStream<'_>) -> Self::Future { + fn call(&mut self, _req: serve::IncomingStream<'_, L>) -> Self::Future { std::future::ready(Ok(self.clone())) } } diff --git a/axum/src/routing/mod.rs b/axum/src/routing/mod.rs index 13b5725549..ce7b531a56 100644 --- a/axum/src/routing/mod.rs +++ b/axum/src/routing/mod.rs @@ -430,9 +430,12 @@ impl Router { // for `axum::serve(listener, router)` #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] const _: () = { - use crate::serve::IncomingStream; + use crate::serve; - impl Service> for Router<()> { + impl Service> for Router<()> + where + L: serve::Listener, + { type Response = Self; type Error = Infallible; type Future = std::future::Ready>; @@ -441,7 +444,7 @@ const _: () = { Poll::Ready(Ok(())) } - fn call(&mut self, _req: IncomingStream<'_>) -> Self::Future { + fn call(&mut self, _req: serve::IncomingStream<'_, L>) -> Self::Future { std::future::ready(Ok(self.clone())) } } diff --git a/axum/src/serve.rs b/axum/src/serve.rs index b5c68571e0..452fc95de1 100644 --- a/axum/src/serve.rs +++ b/axum/src/serve.rs @@ -6,7 +6,6 @@ use std::{ future::{poll_fn, Future, IntoFuture}, io, marker::PhantomData, - net::SocketAddr, sync::Arc, time::Duration, }; @@ -19,12 +18,59 @@ use hyper_util::{ server::conn::auto::Builder, }; use tokio::{ + io::{AsyncRead, AsyncWrite}, net::{TcpListener, TcpStream}, sync::watch, }; use tower::util::ServiceExt; use tower_service::Service; +/// Types that can listen for connections. +pub trait Listener: Send + 'static { + /// The listener's IO type. + type Io: AsyncRead + AsyncWrite + Unpin + Send + 'static; + + /// The listener's address type. + type Addr: Send; + + /// Accept a new incoming connection to this listener + fn accept(&mut self) -> impl Future> + Send; + + /// Returns the local address that this listener is bound to. + fn local_addr(&self) -> io::Result; +} + +impl Listener for TcpListener { + type Io = TcpStream; + type Addr = std::net::SocketAddr; + + #[inline] + async fn accept(&mut self) -> io::Result<(Self::Io, Self::Addr)> { + Self::accept(self).await + } + + #[inline] + fn local_addr(&self) -> io::Result { + Self::local_addr(self) + } +} + +#[cfg(unix)] +impl Listener for tokio::net::UnixListener { + type Io = tokio::net::UnixStream; + type Addr = tokio::net::unix::SocketAddr; + + #[inline] + async fn accept(&mut self) -> io::Result<(Self::Io, Self::Addr)> { + Self::accept(self).await + } + + #[inline] + fn local_addr(&self) -> io::Result { + Self::local_addr(self) + } +} + /// Serve the service with the supplied listener. /// /// This method of running a service is intentionally simple and doesn't support any configuration. @@ -90,14 +136,15 @@ use tower_service::Service; /// [`HandlerWithoutStateExt::into_make_service_with_connect_info`]: crate::handler::HandlerWithoutStateExt::into_make_service_with_connect_info /// [`HandlerService::into_make_service_with_connect_info`]: crate::handler::HandlerService::into_make_service_with_connect_info #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] -pub fn serve(tcp_listener: TcpListener, make_service: M) -> Serve +pub fn serve(listener: L, make_service: M) -> Serve where - M: for<'a> Service, Error = Infallible, Response = S>, + L: Listener, + M: for<'a> Service, Error = Infallible, Response = S>, S: Service + Clone + Send + 'static, S::Future: Send, { Serve { - tcp_listener, + listener, make_service, _marker: PhantomData, } @@ -105,14 +152,14 @@ where /// Future returned by [`serve`]. #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] -pub struct Serve { - tcp_listener: TcpListener, +pub struct Serve { + listener: L, make_service: M, _marker: PhantomData, } #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] -impl Serve { +impl Serve { /// Prepares a server to handle graceful shutdown when the provided future completes. /// /// # Example @@ -134,12 +181,12 @@ impl Serve { /// // ... /// } /// ``` - pub fn with_graceful_shutdown(self, signal: F) -> WithGracefulShutdown + pub fn with_graceful_shutdown(self, signal: F) -> WithGracefulShutdown where F: Future + Send + 'static, { WithGracefulShutdown { - tcp_listener: self.tcp_listener, + listener: self.listener, make_service: self.make_service, signal, _marker: PhantomData, @@ -148,29 +195,31 @@ impl Serve { } #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] -impl Debug for Serve +impl Debug for Serve where + L: Debug, M: Debug, { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let Self { - tcp_listener, + listener, make_service, _marker: _, } = self; f.debug_struct("Serve") - .field("tcp_listener", tcp_listener) + .field("listener", listener) .field("make_service", make_service) .finish() } } #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] -impl IntoFuture for Serve +impl IntoFuture for Serve where - M: for<'a> Service, Error = Infallible, Response = S> + Send + 'static, - for<'a> >>::Future: Send, + L: Listener, + M: for<'a> Service, Error = Infallible, Response = S> + Send + 'static, + for<'a> >>::Future: Send, S: Service + Clone + Send + 'static, S::Future: Send, { @@ -179,12 +228,12 @@ where fn into_future(self) -> Self::IntoFuture { let Self { - tcp_listener, + listener, make_service, _marker: _, } = self; - serve(tcp_listener, make_service) + serve(listener, make_service) .with_graceful_shutdown(std::future::pending()) .into_future() } @@ -192,30 +241,31 @@ where /// Serve future with graceful shutdown enabled. #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] -pub struct WithGracefulShutdown { - tcp_listener: TcpListener, +pub struct WithGracefulShutdown { + listener: L, make_service: M, signal: F, _marker: PhantomData, } #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] -impl Debug for WithGracefulShutdown +impl Debug for WithGracefulShutdown where + L: Debug, M: Debug, S: Debug, F: Debug, { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let Self { - tcp_listener, + listener, make_service, signal, _marker: _, } = self; f.debug_struct("WithGracefulShutdown") - .field("tcp_listener", tcp_listener) + .field("listener", listener) .field("make_service", make_service) .field("signal", signal) .finish() @@ -223,10 +273,11 @@ where } #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] -impl IntoFuture for WithGracefulShutdown +impl IntoFuture for WithGracefulShutdown where - M: for<'a> Service, Error = Infallible, Response = S> + Send + 'static, - for<'a> >>::Future: Send, + L: Listener, + M: for<'a> Service, Error = Infallible, Response = S> + Send + 'static, + for<'a> >>::Future: Send, S: Service + Clone + Send + 'static, S::Future: Send, F: Future + Send + 'static, @@ -236,7 +287,7 @@ where fn into_future(self) -> Self::IntoFuture { let Self { - tcp_listener, + mut listener, mut make_service, signal, _marker: _, @@ -254,8 +305,8 @@ where private::ServeFuture(Box::pin(async move { loop { - let (tcp_stream, remote_addr) = tokio::select! { - conn = tcp_accept(&tcp_listener) => { + let (io, remote_addr) = tokio::select! { + conn = accept(&mut listener) => { match conn { Some(conn) => conn, None => continue, @@ -266,9 +317,7 @@ where break; } }; - let tcp_stream = TokioIo::new(tcp_stream); - - trace!("connection {remote_addr} accepted"); + let io = TokioIo::new(io); poll_fn(|cx| make_service.poll_ready(cx)) .await @@ -276,7 +325,7 @@ where let tower_service = make_service .call(IncomingStream { - tcp_stream: &tcp_stream, + io: &io, remote_addr, }) .await @@ -291,7 +340,7 @@ where tokio::spawn(async move { let builder = Builder::new(TokioExecutor::new()); - let conn = builder.serve_connection_with_upgrades(tcp_stream, hyper_service); + let conn = builder.serve_connection_with_upgrades(io, hyper_service); pin_mut!(conn); let signal_closed = signal_tx.closed().fuse(); @@ -312,14 +361,12 @@ where } } - trace!("connection {remote_addr} closed"); - drop(close_rx); }); } drop(close_rx); - drop(tcp_listener); + drop(listener); trace!( "waiting for {} task(s) to finish", @@ -341,7 +388,10 @@ fn is_connection_error(e: &io::Error) -> bool { ) } -async fn tcp_accept(listener: &TcpListener) -> Option<(TcpStream, SocketAddr)> { +async fn accept(listener: &mut L) -> Option<(L::Io, L::Addr)> +where + L: Listener, +{ match listener.accept().await { Ok(conn) => Some(conn), Err(e) => { @@ -367,6 +417,35 @@ async fn tcp_accept(listener: &TcpListener) -> Option<(TcpStream, SocketAddr)> { } } +/// An incoming stream. +/// +/// Used with [`serve`] and [`IntoMakeServiceWithConnectInfo`]. +/// +/// [`IntoMakeServiceWithConnectInfo`]: crate::extract::connect_info::IntoMakeServiceWithConnectInfo +#[derive(Debug)] +pub struct IncomingStream<'a, L> +where + L: Listener, +{ + io: &'a TokioIo, + remote_addr: L::Addr, +} + +impl IncomingStream<'_, L> +where + L: Listener, +{ + /// Get a reference to the inner IO type. + pub fn io(&self) -> &L::Io { + self.io.inner() + } + + /// Returns the remote address that this stream is bound to. + pub fn remote_addr(&self) -> &L::Addr { + &self.remote_addr + } +} + mod private { use std::{ future::Future, @@ -393,33 +472,15 @@ mod private { } } -/// An incoming stream. -/// -/// Used with [`serve`] and [`IntoMakeServiceWithConnectInfo`]. -/// -/// [`IntoMakeServiceWithConnectInfo`]: crate::extract::connect_info::IntoMakeServiceWithConnectInfo -#[derive(Debug)] -pub struct IncomingStream<'a> { - tcp_stream: &'a TokioIo, - remote_addr: SocketAddr, -} - -impl IncomingStream<'_> { - /// Returns the local address that this stream is bound to. - pub fn local_addr(&self) -> std::io::Result { - self.tcp_stream.inner().local_addr() - } - - /// Returns the remote address that this stream is bound to. - pub fn remote_addr(&self) -> SocketAddr { - self.remote_addr - } -} - #[cfg(test)] mod tests { + use http::StatusCode; + use tokio::net::UnixListener; + use super::*; use crate::{ + body::to_bytes, + extract::connect_info::Connected, handler::{Handler, HandlerWithoutStateExt}, routing::get, Router, @@ -427,30 +488,63 @@ mod tests { #[allow(dead_code, unused_must_use)] async fn if_it_compiles_it_works() { + #[derive(Clone, Debug)] + struct UdsConnectInfo; + + impl Connected> for UdsConnectInfo { + fn connect_info(_stream: IncomingStream<'_, UnixListener>) -> Self { + Self + } + } + let router: Router = Router::new(); let addr = "0.0.0.0:0"; // router serve(TcpListener::bind(addr).await.unwrap(), router.clone()); + serve(UnixListener::bind("").unwrap(), router.clone()); + serve( TcpListener::bind(addr).await.unwrap(), router.clone().into_make_service(), ); + serve( + UnixListener::bind("").unwrap(), + router.clone().into_make_service(), + ); + serve( TcpListener::bind(addr).await.unwrap(), - router.into_make_service_with_connect_info::(), + router + .clone() + .into_make_service_with_connect_info::(), + ); + serve( + UnixListener::bind("").unwrap(), + router.into_make_service_with_connect_info::(), ); // method router serve(TcpListener::bind(addr).await.unwrap(), get(handler)); + serve(UnixListener::bind("").unwrap(), get(handler)); + serve( TcpListener::bind(addr).await.unwrap(), get(handler).into_make_service(), ); + serve( + UnixListener::bind("").unwrap(), + get(handler).into_make_service(), + ); + serve( TcpListener::bind(addr).await.unwrap(), - get(handler).into_make_service_with_connect_info::(), + get(handler).into_make_service_with_connect_info::(), + ); + serve( + UnixListener::bind("").unwrap(), + get(handler).into_make_service_with_connect_info::(), ); // handler @@ -458,19 +552,74 @@ mod tests { TcpListener::bind(addr).await.unwrap(), handler.into_service(), ); + serve(UnixListener::bind("").unwrap(), handler.into_service()); + serve( TcpListener::bind(addr).await.unwrap(), handler.with_state(()), ); + serve(UnixListener::bind("").unwrap(), handler.with_state(())); + serve( TcpListener::bind(addr).await.unwrap(), handler.into_make_service(), ); + serve(UnixListener::bind("").unwrap(), handler.into_make_service()); + serve( TcpListener::bind(addr).await.unwrap(), - handler.into_make_service_with_connect_info::(), + handler.into_make_service_with_connect_info::(), + ); + serve( + UnixListener::bind("").unwrap(), + handler.into_make_service_with_connect_info::(), ); } async fn handler() {} + + #[crate::test] + async fn serving_on_custom_io_type() { + struct ReadyListener(Option); + + impl Listener for ReadyListener + where + T: AsyncRead + AsyncWrite + Unpin + Send + 'static, + { + type Io = T; + type Addr = (); + + async fn accept(&mut self) -> io::Result<(Self::Io, Self::Addr)> { + match self.0.take() { + Some(server) => Ok((server, ())), + None => std::future::pending().await, + } + } + + fn local_addr(&self) -> io::Result { + Ok(()) + } + } + + let (client, server) = tokio::io::duplex(1024); + let listener = ReadyListener(Some(server)); + + let app = Router::new().route("/", get(|| async { "Hello, World!" })); + + tokio::spawn(serve(listener, app).into_future()); + + let stream = TokioIo::new(client); + let (mut sender, conn) = hyper::client::conn::http1::handshake(stream).await.unwrap(); + tokio::spawn(conn); + + let request = Request::builder().body(Body::empty()).unwrap(); + + let response = sender.send_request(request).await.unwrap(); + assert_eq!(response.status(), StatusCode::OK); + + let body = Body::new(response.into_body()); + let body = to_bytes(body, usize::MAX).await.unwrap(); + let body = String::from_utf8(body.to_vec()).unwrap(); + assert_eq!(body, "Hello, World!"); + } } diff --git a/examples/unix-domain-socket/src/main.rs b/examples/unix-domain-socket/src/main.rs index d11792dd70..d96a268417 100644 --- a/examples/unix-domain-socket/src/main.rs +++ b/examples/unix-domain-socket/src/main.rs @@ -22,17 +22,13 @@ mod unix { extract::connect_info::{self, ConnectInfo}, http::{Method, Request, StatusCode}, routing::get, + serve::IncomingStream, Router, }; use http_body_util::BodyExt; - use hyper::body::Incoming; - use hyper_util::{ - rt::{TokioExecutor, TokioIo}, - server, - }; - use std::{convert::Infallible, path::PathBuf, sync::Arc}; + use hyper_util::rt::TokioIo; + use std::{path::PathBuf, sync::Arc}; use tokio::net::{unix::UCred, UnixListener, UnixStream}; - use tower::Service; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; pub async fn server() { @@ -53,33 +49,11 @@ mod unix { let uds = UnixListener::bind(path.clone()).unwrap(); tokio::spawn(async move { - let app = Router::new().route("/", get(handler)); - - let mut make_service = app.into_make_service_with_connect_info::(); - - // See https://github.com/tokio-rs/axum/blob/main/examples/serve-with-hyper/src/main.rs for - // more details about this setup - loop { - let (socket, _remote_addr) = uds.accept().await.unwrap(); - - let tower_service = unwrap_infallible(make_service.call(&socket).await); - - tokio::spawn(async move { - let socket = TokioIo::new(socket); - - let hyper_service = - hyper::service::service_fn(move |request: Request| { - tower_service.clone().call(request) - }); + let app = Router::new() + .route("/", get(handler)) + .into_make_service_with_connect_info::(); - if let Err(err) = server::conn::auto::Builder::new(TokioExecutor::new()) - .serve_connection_with_upgrades(socket, hyper_service) - .await - { - eprintln!("failed to serve connection: {err:#}"); - } - }); - } + axum::serve(uds, app).await.unwrap(); }); let stream = TokioIo::new(UnixStream::connect(path).await.unwrap()); @@ -118,22 +92,14 @@ mod unix { peer_cred: UCred, } - impl connect_info::Connected<&UnixStream> for UdsConnectInfo { - fn connect_info(target: &UnixStream) -> Self { - let peer_addr = target.peer_addr().unwrap(); - let peer_cred = target.peer_cred().unwrap(); - + impl connect_info::Connected> for UdsConnectInfo { + fn connect_info(stream: IncomingStream<'_, UnixListener>) -> Self { + let peer_addr = stream.io().peer_addr().unwrap(); + let peer_cred = stream.io().peer_cred().unwrap(); Self { peer_addr: Arc::new(peer_addr), peer_cred, } } } - - fn unwrap_infallible(result: Result) -> T { - match result { - Ok(value) => value, - Err(err) => match err {}, - } - } }