From 6bd78b6541074524030343150f11ddeec298cef6 Mon Sep 17 00:00:00 2001 From: Thayne McCombs Date: Thu, 21 Sep 2023 00:25:05 -0600 Subject: [PATCH 1/3] feat!: Add a new error type for handshake timeouts And make the errors enun non_exhaustive BREAKING CHANGE: Adds a new variant to the Error Enum BREAKING CHANGE: The Error enum is now non_exhaustive BREAKING CHANGE: Now returns an error if a handshake times out Fixes: #36 --- src/lib.rs | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index a9862c4..bd06b3b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -142,6 +142,7 @@ pub struct Builder { /// Wraps errors from either the listener or the TLS Acceptor #[derive(Debug, Error)] +#[non_exhaustive] pub enum Error { /// An error that arose from the listener ([AsyncAccept::Error]) #[error("{0}")] @@ -149,6 +150,12 @@ pub enum Error { /// An error that occurred during the TLS accept handshake #[error("{0}")] TlsAcceptError(#[source] TE), + // TODO: is there any way we could include thee original connection, or maybe some + // info about it here? + /// The TLS handshake timed out + #[error("Timeout during TLS handshake")] + #[non_exhaustive] + HandshakeTimeout {}, } impl TlsListener @@ -219,16 +226,12 @@ where } } - loop { - return match this.waiting.poll_next_unpin(cx) { - Poll::Ready(Some(Ok(conn))) => { - Poll::Ready(Some(conn.map_err(Error::TlsAcceptError))) - } - // The handshake timed out, try getting another connection from the - // queue - Poll::Ready(Some(Err(_))) => continue, - _ => Poll::Pending, - }; + match this.waiting.poll_next_unpin(cx) { + Poll::Ready(Some(Ok(conn))) => Poll::Ready(Some(conn.map_err(Error::TlsAcceptError))), + // The handshake timed out, try getting another connection from the + // queue + Poll::Ready(Some(Err(_))) => Poll::Ready(Some(Err(Error::HandshakeTimeout()))), + _ => Poll::Pending, } } } From e920fbbb6dd0d586802ca5b22312b46f36bedcdf Mon Sep 17 00:00:00 2001 From: ahcodedthat <83854662+ahcodedthat@users.noreply.github.com> Date: Fri, 22 Sep 2023 21:43:59 -0700 Subject: [PATCH 2/3] feat!: Yield remote address upon accepting a connection, and include it in errors. BREAKING CHANGE: The enum variant `Error::ListenerError` is now struct-like instead of tuple-like, and is `non_exhaustive` like the enum itself. BREAKING CHANGE: `Error` now has three type parameters, not two. BREAKING CHANGE: `TlsListener::accept` and `::next` yields a tuple of (connection, remote address), not just the connection. BREAKING CHANGE: `AsyncAccept` now has an associated type `Address`, which `poll_accept` must now return along with the accepted connection. --- examples/echo-threads.rs | 6 +- examples/echo.rs | 10 ++- examples/http-change-certificate.rs | 8 +- examples/http-low-level.rs | 8 +- examples/http-stream.rs | 20 ++--- src/hyper.rs | 25 +++++-- src/lib.rs | 110 ++++++++++++++++++++++++---- src/net.rs | 10 ++- tests/basic.rs | 12 ++- tests/helper/mocks.rs | 19 ++++- tests/helper/mod.rs | 2 +- 11 files changed, 180 insertions(+), 50 deletions(-) diff --git a/examples/echo-threads.rs b/examples/echo-threads.rs index d510351..b2cee66 100644 --- a/examples/echo-threads.rs +++ b/examples/echo-threads.rs @@ -13,7 +13,7 @@ mod tls_config; use tls_config::tls_acceptor; #[inline] -async fn handle_stream(stream: TlsStream) { +async fn handle_stream(stream: TlsStream, _remote_addr: SocketAddr) { let (mut reader, mut writer) = split(stream); match copy(&mut reader, &mut writer).await { Ok(cnt) => eprintln!("Processed {} bytes", cnt), @@ -32,8 +32,8 @@ async fn main() -> Result<(), Box> { TlsListener::new(SpawningHandshakes(tls_acceptor()), listener) .for_each_concurrent(None, |s| async { match s { - Ok(stream) => { - handle_stream(stream).await; + Ok((stream, remote_addr)) => { + handle_stream(stream, remote_addr).await; } Err(e) => { eprintln!("Error: {:?}", e); diff --git a/examples/echo.rs b/examples/echo.rs index 1811316..31dd86c 100644 --- a/examples/echo.rs +++ b/examples/echo.rs @@ -22,7 +22,7 @@ mod tls_config; use tls_config::tls_acceptor; #[inline] -async fn handle_stream(stream: TlsStream) { +async fn handle_stream(stream: TlsStream, _remote_addr: SocketAddr) { let (mut reader, mut writer) = split(stream); match copy(&mut reader, &mut writer).await { Ok(cnt) => eprintln!("Processed {} bytes", cnt), @@ -41,10 +41,14 @@ async fn main() -> Result<(), Box> { TlsListener::new(tls_acceptor(), listener) .for_each_concurrent(None, |s| async { match s { - Ok(stream) => { - handle_stream(stream).await; + Ok((stream, remote_addr)) => { + handle_stream(stream, remote_addr).await; } Err(e) => { + if let Some(remote_addr) = e.remote_addr() { + eprint!("[client {remote_addr}] "); + } + eprintln!("Error accepting connection: {:?}", e); } } diff --git a/examples/http-change-certificate.rs b/examples/http-change-certificate.rs index c543ee6..6ba6742 100644 --- a/examples/http-change-certificate.rs +++ b/examples/http-change-certificate.rs @@ -30,18 +30,22 @@ async fn main() { tokio::select! { conn = listener.accept() => { match conn.expect("Tls listener stream should be infinite") { - Ok(conn) => { + Ok((conn, remote_addr)) => { let http = http.clone(); let tx = tx.clone(); let counter = counter.clone(); tokio::spawn(async move { let svc = service_fn(move |request| handle_request(tx.clone(), counter.clone(), request)); if let Err(err) = http.serve_connection(conn, svc).await { - eprintln!("Application error: {}", err); + eprintln!("Application error (client address: {remote_addr}): {}", err); } }); }, Err(e) => { + if let Some(remote_addr) = e.remote_addr() { + eprint!("[client {remote_addr}] "); + } + eprintln!("Bad connection: {}", e); } } diff --git a/examples/http-low-level.rs b/examples/http-low-level.rs index 678486e..74b3509 100644 --- a/examples/http-low-level.rs +++ b/examples/http-low-level.rs @@ -27,15 +27,19 @@ async fn main() { listener .for_each(|r| async { match r { - Ok(conn) => { + Ok((conn, remote_addr)) => { let http = http.clone(); tokio::spawn(async move { if let Err(err) = http.serve_connection(conn, svc).await { - eprintln!("Application error: {}", err); + eprintln!("[client {remote_addr}] Application error: {}", err); } }); } Err(err) => { + if let Some(remote_addr) = err.remote_addr() { + eprint!("[client {remote_addr}] "); + } + eprintln!("Error accepting connection: {}", err); } } diff --git a/examples/http-stream.rs b/examples/http-stream.rs index bfe9b5c..7ad623b 100644 --- a/examples/http-stream.rs +++ b/examples/http-stream.rs @@ -1,4 +1,4 @@ -use futures_util::stream::StreamExt; +use futures_util::stream::{StreamExt, TryStreamExt}; use hyper::server::accept; use hyper::server::conn::AddrIncoming; use hyper::service::{make_service_fn, service_fn}; @@ -22,14 +22,16 @@ async fn main() -> Result<(), Box> { }); // This uses a filter to handle errors with connecting - let incoming = TlsListener::new(tls_acceptor(), AddrIncoming::bind(&addr)?).filter(|conn| { - if let Err(err) = conn { - eprintln!("Error: {:?}", err); - ready(false) - } else { - ready(true) - } - }); + let incoming = TlsListener::new(tls_acceptor(), AddrIncoming::bind(&addr)?) + .filter(|conn| { + if let Err(err) = conn { + eprintln!("Error: {:?}", err); + ready(false) + } else { + ready(true) + } + }) + .map_ok(|(conn, _remote_addr)| conn); let server = Server::builder(accept::from_stream(incoming)).serve(new_svc); server.await?; diff --git a/src/hyper.rs b/src/hyper.rs index a63fe74..002b88c 100644 --- a/src/hyper.rs +++ b/src/hyper.rs @@ -7,12 +7,16 @@ use std::ops::{Deref, DerefMut}; impl AsyncAccept for AddrIncoming { type Connection = AddrStream; type Error = std::io::Error; + type Address = std::net::SocketAddr; fn poll_accept( self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>> { - ::poll_accept(self, cx) + ) -> Poll>> { + ::poll_accept(self, cx).map_ok(|conn| { + let remote_addr = conn.remote_addr(); + (conn, remote_addr) + }) } } @@ -22,6 +26,11 @@ pin_project! { /// Unfortunately, it isn't possible to use a blanket impl, due to coherence rules. /// At least until [RFC 1210](https://rust-lang.github.io/rfcs/1210-impl-specialization.html) /// (specialization) is stabilized. + /// + /// Note that, because `hyper::server::accept::Accept` does not expose the + /// remote address, the implementation of `AsyncAccept` for `WrappedAccept` + /// doesn't expose it either. That is, [`AsyncAccept::Address`] is `()` in + /// this case. //#[cfg_attr(docsrs, doc(cfg(any(feature = "hyper-h1", feature = "hyper-h2"))))] pub struct WrappedAccept { // sadly, pin-project-lite doesn't suport tuple structs :( @@ -46,12 +55,16 @@ where { type Connection = A::Conn; type Error = A::Error; + type Address = (); fn poll_accept( self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>> { - self.project().inner.poll_accept(cx) + ) -> Poll>> { + self.project() + .inner + .poll_accept(cx) + .map_ok(|conn| (conn, ())) } } @@ -95,12 +108,12 @@ where T: AsyncTls, { type Conn = T::Stream; - type Error = Error; + type Error = Error; fn poll_accept( self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll>> { - self.poll_next(cx) + self.poll_next(cx).map_ok(|(conn, _)| conn) } } diff --git a/src/lib.rs b/src/lib.rs index bd06b3b..3110070 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -19,6 +19,7 @@ use futures_util::stream::{FuturesUnordered, Stream, StreamExt}; use pin_project_lite::pin_project; #[cfg(feature = "rt")] pub use spawning_handshake::SpawningHandshakes; +use std::fmt::Debug; use std::future::Future; use std::pin::Pin; use std::task::{Context, Poll}; @@ -67,6 +68,11 @@ pub trait AsyncTls: Clone { pub trait AsyncAccept { /// The type of the connection that is accepted. type Connection: AsyncRead + AsyncWrite; + /// The type of the remote address, such as [`std::net::SocketAddr`]. + /// + /// If no remote address can be determined (such as for mock connections), + /// `()` or a similar dummy type can be used. + type Address: Debug; /// The type of error that may be returned. type Error; @@ -74,7 +80,7 @@ pub trait AsyncAccept { fn poll_accept( self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>>; + ) -> Poll>>; /// Return a new `AsyncAccept` that stops accepting connections after /// `ender` completes. @@ -126,7 +132,7 @@ pin_project! { #[pin] listener: A, tls: T, - waiting: FuturesUnordered>, + waiting: FuturesUnordered, A::Address>>, max_handshakes: usize, timeout: Duration, } @@ -143,19 +149,35 @@ pub struct Builder { /// Wraps errors from either the listener or the TLS Acceptor #[derive(Debug, Error)] #[non_exhaustive] -pub enum Error { +// TODO: It would probably be more simple and more future-proof to use the +// `AsyncAccept` and `AsyncTls` implementations as the type parameters here, so +// that their associated types can be used in the fields +// (i.e. `error: A::Error, remote_addr: A::Address`), but that would require us +// to either hand-write `impl Debug` or use a proc-macro crate like +// `impl-tools` to derive `Debug` with custom bounds, +// due to https://github.com/rust-lang/rust/issues/26925 +pub enum Error { /// An error that arose from the listener ([AsyncAccept::Error]) #[error("{0}")] ListenerError(#[source] LE), /// An error that occurred during the TLS accept handshake - #[error("{0}")] - TlsAcceptError(#[source] TE), - // TODO: is there any way we could include thee original connection, or maybe some - // info about it here? + #[error("{error}")] + #[non_exhaustive] + TlsAcceptError { + /// The error that occurred. + #[source] + error: TE, + + /// The client's address and port. + remote_addr: A, + }, /// The TLS handshake timed out #[error("Timeout during TLS handshake")] #[non_exhaustive] - HandshakeTimeout {}, + HandshakeTimeout { + /// The client's address and port. + remote_addr: A, + }, } impl TlsListener @@ -207,7 +229,7 @@ where A::Error: std::error::Error, T: AsyncTls, { - type Item = Result>; + type Item = Result<(T::Stream, A::Address), Error>; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let mut this = self.project(); @@ -215,9 +237,11 @@ where while this.waiting.len() < *this.max_handshakes { match this.listener.as_mut().poll_accept(cx) { Poll::Pending => break, - Poll::Ready(Some(Ok(conn))) => { - this.waiting - .push(timeout(*this.timeout, this.tls.accept(conn))); + Poll::Ready(Some(Ok((conn, address)))) => { + this.waiting.push(FutureWithExtraData::new( + timeout(*this.timeout, this.tls.accept(conn)), + address, + )); } Poll::Ready(Some(Err(e))) => { return Poll::Ready(Some(Err(Error::ListenerError(e)))); @@ -227,10 +251,15 @@ where } match this.waiting.poll_next_unpin(cx) { - Poll::Ready(Some(Ok(conn))) => Poll::Ready(Some(conn.map_err(Error::TlsAcceptError))), + Poll::Ready(Some((Ok(result), remote_addr))) => Poll::Ready(Some(match result { + Ok(conn) => Ok((conn, remote_addr)), + Err(error) => Err(Error::TlsAcceptError { error, remote_addr }), + })), // The handshake timed out, try getting another connection from the // queue - Poll::Ready(Some(Err(_))) => Poll::Ready(Some(Err(Error::HandshakeTimeout()))), + Poll::Ready(Some((Err(_), remote_addr))) => { + Poll::Ready(Some(Err(Error::HandshakeTimeout { remote_addr }))) + } _ => Poll::Pending, } } @@ -337,6 +366,19 @@ impl Builder { } } +impl Error { + /// Returns the client's address and port, if known. + pub fn remote_addr(&self) -> Option<&A> { + match self { + Self::ListenerError(_) => None, + + Self::TlsAcceptError { remote_addr, .. } | Self::HandshakeTimeout { remote_addr } => { + Some(remote_addr) + } + } + } +} + /// Create a new Builder for a TlsListener /// /// `server_config` will be used to configure the TLS sessions. @@ -361,11 +403,12 @@ pin_project! { impl AsyncAccept for Until { type Connection = A::Connection; type Error = A::Error; + type Address = A::Address; fn poll_accept( self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll>> { let this = self.project(); match this.ender.poll(cx) { @@ -374,3 +417,40 @@ impl AsyncAccept for Until { } } } + +pin_project! { + struct FutureWithExtraData { + #[pin] + future: Fut, + extra: Option, + } +} + +impl FutureWithExtraData { + fn new(future: Fut, extra: X) -> Self { + Self { + future, + extra: Some(extra), + } + } +} + +impl Future for FutureWithExtraData +where + Fut: Future, +{ + type Output = (Fut::Output, X); + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + let extra = this.extra; + + this.future.poll(cx).map(|output| { + let extra = extra + .take() + .expect("this future has already been polled to completion"); + + (output, extra) + }) + } +} diff --git a/src/net.rs b/src/net.rs index ea712b3..8270c38 100644 --- a/src/net.rs +++ b/src/net.rs @@ -10,13 +10,14 @@ use tokio::net::{UnixListener, UnixStream}; impl AsyncAccept for TcpListener { type Connection = TcpStream; type Error = io::Error; + type Address = std::net::SocketAddr; fn poll_accept( self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll>> { match (*self).poll_accept(cx) { - Poll::Ready(Ok((stream, _))) => Poll::Ready(Some(Ok(stream))), + Poll::Ready(Ok((stream, remote_addr))) => Poll::Ready(Some(Ok((stream, remote_addr)))), Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))), Poll::Pending => Poll::Pending, } @@ -28,13 +29,14 @@ impl AsyncAccept for TcpListener { impl AsyncAccept for UnixListener { type Connection = UnixStream; type Error = io::Error; + type Address = tokio::net::unix::SocketAddr; fn poll_accept( self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll>> { match (*self).poll_accept(cx) { - Poll::Ready(Ok((stream, _))) => Poll::Ready(Some(Ok(stream))), + Poll::Ready(Ok((stream, remote_addr))) => Poll::Ready(Some(Ok((stream, remote_addr)))), Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))), Poll::Pending => Poll::Pending, } diff --git a/tests/basic.rs b/tests/basic.rs index e22b892..99133f6 100644 --- a/tests/basic.rs +++ b/tests/basic.rs @@ -19,6 +19,7 @@ async fn accept_connections() { spawn(listener.for_each_concurrent(None, |s| async { s.expect("unexpected error") + .0 .write_all(b"HELLO, WORLD!") .await .unwrap(); @@ -65,7 +66,13 @@ async fn tls_error() { spawn(async move { connect.send_data(b"foo").await }); let mut listener = TlsListener::new(ErrTls, accept); - assert_err!(listener.accept().await.unwrap(), TlsAcceptError(_)); + assert_err!( + listener.accept().await.unwrap(), + TlsAcceptError { + remote_addr: MockAddress(42), + .. + } + ); } #[tokio::test] @@ -77,7 +84,8 @@ async fn accept_ended() { }); let res = listener.accept().await; - if let Some(Ok(mut stream)) = res { + if let Some(Ok((mut stream, MockAddress(stream_id)))) = res { + assert_eq!(stream_id, 42); stream.write_all(b"ABC").await.unwrap(); } else { panic!("Failed to accept stream. Got {:?}", res); diff --git a/tests/helper/mocks.rs b/tests/helper/mocks.rs index 6bbecfc..189e405 100644 --- a/tests/helper/mocks.rs +++ b/tests/helper/mocks.rs @@ -2,6 +2,7 @@ use futures_util::future::{self, Ready}; use futures_util::ready; use std::io; use std::pin::Pin; +use std::sync::atomic::{self, AtomicU32}; use std::task::{Context, Poll}; use tokio::io::{ duplex, split, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, DuplexStream, ReadBuf, @@ -10,7 +11,7 @@ use tokio::sync::mpsc; use tls_listener::{AsyncAccept, AsyncTls}; -type ConnResult = io::Result; +type ConnResult = io::Result<(DuplexStream, MockAddress)>; pub struct MockAccept { chan: mpsc::Receiver, @@ -18,17 +19,28 @@ pub struct MockAccept { pub struct MockConnect { chan: mpsc::Sender, + counter: AtomicU32, } +#[derive(Clone, Copy, Debug)] +pub struct MockAddress(pub u32); + pub fn accepting() -> (MockConnect, MockAccept) { let (tx, rx) = mpsc::channel(32); - (MockConnect { chan: tx }, MockAccept { chan: rx }) + ( + MockConnect { + chan: tx, + counter: AtomicU32::new(42), + }, + MockAccept { chan: rx }, + ) } impl MockConnect { pub async fn connect(&self) -> DuplexStream { let (tx, rx) = duplex(1024); - self.chan.send(Ok(rx)).await.unwrap(); + let count = self.counter.fetch_add(1, atomic::Ordering::Relaxed); + self.chan.send(Ok((rx, MockAddress(count)))).await.unwrap(); tx } @@ -56,6 +68,7 @@ impl MockConnect { impl AsyncAccept for MockAccept { type Connection = DuplexStream; type Error = io::Error; + type Address = MockAddress; fn poll_accept(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { Pin::into_inner(self).chan.poll_recv(cx) diff --git a/tests/helper/mod.rs b/tests/helper/mod.rs index 4959562..ea3b4f2 100644 --- a/tests/helper/mod.rs +++ b/tests/helper/mod.rs @@ -18,7 +18,7 @@ pub fn setup_echo() -> (MockConnect, JoinHandle<()>) { let (connector, listener) = setup(); let handle = tokio::spawn(listener.for_each_concurrent(None, |s| async { - let (mut reader, mut writer) = split(s.expect("Unexpected error")); + let (mut reader, mut writer) = split(s.expect("Unexpected error").0); copy(&mut reader, &mut writer) .await .expect("Failed to copy"); From 23ca7ffccd3333615f28400300bd5e82835c2137 Mon Sep 17 00:00:00 2001 From: Thayne McCombs Date: Mon, 16 Oct 2023 01:20:49 -0600 Subject: [PATCH 3/3] feat!: More changes for including peer address in response This builds on the previous commit. In addition to some minor stylictic and naming changes (such as calling the address peer_addr instead of remote_addr to be more consistent with tokio and stdlib), the main change here is replacing the FutureWithExtraData with a more purpose-built Waiting struct encodes the state of a connection that is waitinf for a handshake to complete. BREAKING CHANGE: AsyncAccept::Error must implement std::error::Error BREAKING CHANGE: TlsAcceptError is now a struct form variant. Fixes: #36 --- Cargo.toml | 2 +- examples/echo.rs | 2 +- examples/http-change-certificate.rs | 4 +- examples/http-low-level.rs | 2 +- examples/http-stream.rs | 6 +- examples/tls_config/mod.rs | 3 + src/hyper.rs | 8 +- src/lib.rs | 169 +++++++++++++++------------- src/net.rs | 4 +- tests/basic.rs | 17 ++- tests/helper/mocks.rs | 6 +- 11 files changed, 128 insertions(+), 95 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index a54e089..0d546f8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,7 +21,7 @@ hyper-h2 = ["hyper", "hyper/http2"] [dependencies] futures-util = "0.3.8" hyper = { version = "0.14.1", features = ["server", "tcp"], optional = true } -pin-project-lite = "0.2.8" +pin-project-lite = "0.2.13" thiserror = "1.0.30" tokio = { version = "1.0", features = ["time"] } tokio-native-tls = { version = "0.3.0", optional = true } diff --git a/examples/echo.rs b/examples/echo.rs index 31dd86c..0921172 100644 --- a/examples/echo.rs +++ b/examples/echo.rs @@ -45,7 +45,7 @@ async fn main() -> Result<(), Box> { handle_stream(stream, remote_addr).await; } Err(e) => { - if let Some(remote_addr) = e.remote_addr() { + if let Some(remote_addr) = e.peer_addr() { eprint!("[client {remote_addr}] "); } diff --git a/examples/http-change-certificate.rs b/examples/http-change-certificate.rs index 6ba6742..f0294d1 100644 --- a/examples/http-change-certificate.rs +++ b/examples/http-change-certificate.rs @@ -37,12 +37,12 @@ async fn main() { tokio::spawn(async move { let svc = service_fn(move |request| handle_request(tx.clone(), counter.clone(), request)); if let Err(err) = http.serve_connection(conn, svc).await { - eprintln!("Application error (client address: {remote_addr}): {}", err); + eprintln!("Application error (client address: {remote_addr}): {err}"); } }); }, Err(e) => { - if let Some(remote_addr) = e.remote_addr() { + if let Some(remote_addr) = e.peer_addr() { eprint!("[client {remote_addr}] "); } diff --git a/examples/http-low-level.rs b/examples/http-low-level.rs index 74b3509..5aa228a 100644 --- a/examples/http-low-level.rs +++ b/examples/http-low-level.rs @@ -36,7 +36,7 @@ async fn main() { }); } Err(err) => { - if let Some(remote_addr) = err.remote_addr() { + if let Some(remote_addr) = err.peer_addr() { eprint!("[client {remote_addr}] "); } diff --git a/examples/http-stream.rs b/examples/http-stream.rs index 7ad623b..44b140f 100644 --- a/examples/http-stream.rs +++ b/examples/http-stream.rs @@ -1,4 +1,4 @@ -use futures_util::stream::{StreamExt, TryStreamExt}; +use futures_util::stream::StreamExt; use hyper::server::accept; use hyper::server::conn::AddrIncoming; use hyper::service::{make_service_fn, service_fn}; @@ -23,6 +23,7 @@ async fn main() -> Result<(), Box> { // This uses a filter to handle errors with connecting let incoming = TlsListener::new(tls_acceptor(), AddrIncoming::bind(&addr)?) + .connections() .filter(|conn| { if let Err(err) = conn { eprintln!("Error: {:?}", err); @@ -30,8 +31,7 @@ async fn main() -> Result<(), Box> { } else { ready(true) } - }) - .map_ok(|(conn, _remote_addr)| conn); + }); let server = Server::builder(accept::from_stream(incoming)).serve(new_svc); server.await?; diff --git a/examples/tls_config/mod.rs b/examples/tls_config/mod.rs index e8cfc17..114e8a3 100644 --- a/examples/tls_config/mod.rs +++ b/examples/tls_config/mod.rs @@ -5,7 +5,9 @@ mod config { const CERT: &[u8] = include_bytes!("local.cert"); const PKEY: &[u8] = include_bytes!("local.key"); + #[allow(dead_code)] const CERT2: &[u8] = include_bytes!("local2.cert"); + #[allow(dead_code)] const PKEY2: &[u8] = include_bytes!("local2.key"); pub type Acceptor = tokio_rustls::TlsAcceptor; @@ -27,6 +29,7 @@ mod config { tls_acceptor_impl(PKEY, CERT) } + #[allow(dead_code)] pub fn tls_acceptor2() -> Acceptor { tls_acceptor_impl(PKEY2, CERT2) } diff --git a/src/hyper.rs b/src/hyper.rs index 002b88c..a5df30d 100644 --- a/src/hyper.rs +++ b/src/hyper.rs @@ -14,8 +14,8 @@ impl AsyncAccept for AddrIncoming { cx: &mut Context<'_>, ) -> Poll>> { ::poll_accept(self, cx).map_ok(|conn| { - let remote_addr = conn.remote_addr(); - (conn, remote_addr) + let peer_addr = conn.remote_addr(); + (conn, peer_addr) }) } } @@ -52,6 +52,7 @@ pub fn wrap(acceptor: A) -> WrappedAccept { impl AsyncAccept for WrappedAccept where A::Conn: AsyncRead + AsyncWrite, + A::Error: std::error::Error, { type Connection = A::Conn; type Error = A::Error; @@ -60,7 +61,7 @@ where fn poll_accept( self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll>> { self.project() .inner .poll_accept(cx) @@ -91,6 +92,7 @@ impl WrappedAccept { impl TlsListener, T> where A::Conn: AsyncWrite + AsyncRead, + A::Error: std::error::Error, T: AsyncTls, { /// Create a `TlsListener` from a hyper [`Accept`](::hyper::server::accept::Accept) and tls diff --git a/src/lib.rs b/src/lib.rs index 3110070..919ee13 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -15,14 +15,14 @@ //! - `tokio-net`: Implementations for tokio socket types (default) //! - `rt`: Features that depend on the tokio runtime, such as [`SpawningHandshakes`] -use futures_util::stream::{FuturesUnordered, Stream, StreamExt}; +use futures_util::stream::{FuturesUnordered, Stream, StreamExt, TryStreamExt}; use pin_project_lite::pin_project; #[cfg(feature = "rt")] pub use spawning_handshake::SpawningHandshakes; use std::fmt::Debug; use std::future::Future; use std::pin::Pin; -use std::task::{Context, Poll}; +use std::task::{ready, Context, Poll}; use std::time::Duration; use thiserror::Error; use tokio::io::{AsyncRead, AsyncWrite}; @@ -74,9 +74,12 @@ pub trait AsyncAccept { /// `()` or a similar dummy type can be used. type Address: Debug; /// The type of error that may be returned. - type Error; + type Error: std::error::Error; /// Poll to accept the next connection. + /// + /// On success return the new connection, and the address of the peer. + #[allow(clippy::type_complexity)] fn poll_accept( self: Pin<&mut Self>, cx: &mut Context<'_>, @@ -132,7 +135,7 @@ pin_project! { #[pin] listener: A, tls: T, - waiting: FuturesUnordered, A::Address>>, + waiting: FuturesUnordered>, max_handshakes: usize, timeout: Duration, } @@ -149,14 +152,7 @@ pub struct Builder { /// Wraps errors from either the listener or the TLS Acceptor #[derive(Debug, Error)] #[non_exhaustive] -// TODO: It would probably be more simple and more future-proof to use the -// `AsyncAccept` and `AsyncTls` implementations as the type parameters here, so -// that their associated types can be used in the fields -// (i.e. `error: A::Error, remote_addr: A::Address`), but that would require us -// to either hand-write `impl Debug` or use a proc-macro crate like -// `impl-tools` to derive `Debug` with custom bounds, -// due to https://github.com/rust-lang/rust/issues/26925 -pub enum Error { +pub enum Error { /// An error that arose from the listener ([AsyncAccept::Error]) #[error("{0}")] ListenerError(#[source] LE), @@ -164,19 +160,19 @@ pub enum Error { #[error("{error}")] #[non_exhaustive] TlsAcceptError { - /// The error that occurred. + /// The original error that occurred #[source] error: TE, - /// The client's address and port. - remote_addr: A, + /// Address of the other side of the connection + peer_addr: Addr, }, /// The TLS handshake timed out #[error("Timeout during TLS handshake")] #[non_exhaustive] HandshakeTimeout { - /// The client's address and port. - remote_addr: A, + /// Address of the other side of the connection + peer_addr: Addr, }, } @@ -190,10 +186,17 @@ where } } +/// Convenience type alias to get the proper error type from the type of the [`AsyncAccept`] and +/// [`AsyncTls`] used. +type TlsListenerError = Error< + ::Error, + ::Connection>>::Error, + ::Address, +>; + impl TlsListener where A: AsyncAccept, - A::Error: std::error::Error, T: AsyncTls, { /// Accept the next connection @@ -221,15 +224,24 @@ where pub fn replace_acceptor_pin(self: Pin<&mut Self>, acceptor: T) { *self.project().tls = acceptor; } + + /// Convert into a Stream of connections. + /// + /// This drops the address of the connection, but provides a more convenient API + /// if the address isn't needed. + /// + /// The address will still be included in errors. + pub fn connections(self) -> impl Stream>> { + self.map_ok(|(conn, _addr)| conn) + } } impl Stream for TlsListener where A: AsyncAccept, - A::Error: std::error::Error, T: AsyncTls, { - type Item = Result<(T::Stream, A::Address), Error>; + type Item = Result<(T::Stream, A::Address), TlsListenerError>; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let mut this = self.project(); @@ -237,11 +249,11 @@ where while this.waiting.len() < *this.max_handshakes { match this.listener.as_mut().poll_accept(cx) { Poll::Pending => break, - Poll::Ready(Some(Ok((conn, address)))) => { - this.waiting.push(FutureWithExtraData::new( - timeout(*this.timeout, this.tls.accept(conn)), - address, - )); + Poll::Ready(Some(Ok((conn, addr)))) => { + this.waiting.push(Waiting { + inner: timeout(*this.timeout, this.tls.accept(conn)), + peer_addr: Some(addr), + }); } Poll::Ready(Some(Err(e))) => { return Poll::Ready(Some(Err(Error::ListenerError(e)))); @@ -251,16 +263,11 @@ where } match this.waiting.poll_next_unpin(cx) { - Poll::Ready(Some((Ok(result), remote_addr))) => Poll::Ready(Some(match result { - Ok(conn) => Ok((conn, remote_addr)), - Err(error) => Err(Error::TlsAcceptError { error, remote_addr }), - })), - // The handshake timed out, try getting another connection from the - // queue - Poll::Ready(Some((Err(_), remote_addr))) => { - Poll::Ready(Some(Err(Error::HandshakeTimeout { remote_addr }))) - } - _ => Poll::Pending, + // If we don't have anything waiting yet, + // then we are still pending, + Poll::Ready(None) => Poll::Pending, + // Otherwise the result is already what we want + result => result, } } } @@ -367,14 +374,17 @@ impl Builder { } impl Error { - /// Returns the client's address and port, if known. - pub fn remote_addr(&self) -> Option<&A> { + /// Get the peer address from the connection that caused the error, if applicable. + /// + /// This will only return Some for errors that occur after an initial connection + /// is established, such as TlsAcceptError and HandshakeTimeout. And only if + /// the [`AsyncAccept`] implementation implements [`peer_addr`](AsyncAccept::peer_addr) + pub fn peer_addr(&self) -> Option<&A> { match self { - Self::ListenerError(_) => None, - - Self::TlsAcceptError { remote_addr, .. } | Self::HandshakeTimeout { remote_addr } => { - Some(remote_addr) + Error::TlsAcceptError { peer_addr, .. } | Self::HandshakeTimeout { peer_addr, .. } => { + Some(peer_addr) } + _ => None, } } } @@ -390,6 +400,46 @@ pub fn builder(tls: T) -> Builder { } } +pin_project! { + struct Waiting + where + A: AsyncAccept, + T: AsyncTls + { + #[pin] + inner: Timeout, + peer_addr: Option, + } +} + +impl Future for Waiting +where + A: AsyncAccept, + T: AsyncTls, +{ + type Output = Result<(T::Stream, A::Address), TlsListenerError>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut this = self.project(); + let res = ready!(this.inner.as_mut().poll(cx)); + let addr = this + .peer_addr + .take() + .expect("this future has already been polled to completion"); + match res { + // We succesfully got a connection + Ok(Ok(conn)) => Poll::Ready(Ok((conn, addr))), + // The handshake failed + Ok(Err(e)) => Poll::Ready(Err(Error::TlsAcceptError { + error: e, + peer_addr: addr, + })), + // The handshake timed out + Err(_) => Poll::Ready(Err(Error::HandshakeTimeout { peer_addr: addr })), + } + } +} + pin_project! { /// See [`AsyncAccept::until`] pub struct Until { @@ -417,40 +467,3 @@ impl AsyncAccept for Until { } } } - -pin_project! { - struct FutureWithExtraData { - #[pin] - future: Fut, - extra: Option, - } -} - -impl FutureWithExtraData { - fn new(future: Fut, extra: X) -> Self { - Self { - future, - extra: Some(extra), - } - } -} - -impl Future for FutureWithExtraData -where - Fut: Future, -{ - type Output = (Fut::Output, X); - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.project(); - let extra = this.extra; - - this.future.poll(cx).map(|output| { - let extra = extra - .take() - .expect("this future has already been polled to completion"); - - (output, extra) - }) - } -} diff --git a/src/net.rs b/src/net.rs index 8270c38..17a41a7 100644 --- a/src/net.rs +++ b/src/net.rs @@ -17,7 +17,7 @@ impl AsyncAccept for TcpListener { cx: &mut Context<'_>, ) -> Poll>> { match (*self).poll_accept(cx) { - Poll::Ready(Ok((stream, remote_addr))) => Poll::Ready(Some(Ok((stream, remote_addr)))), + Poll::Ready(Ok(conn)) => Poll::Ready(Some(Ok(conn))), Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))), Poll::Pending => Poll::Pending, } @@ -36,7 +36,7 @@ impl AsyncAccept for UnixListener { cx: &mut Context<'_>, ) -> Poll>> { match (*self).poll_accept(cx) { - Poll::Ready(Ok((stream, remote_addr))) => Poll::Ready(Some(Ok((stream, remote_addr)))), + Poll::Ready(Ok(conn)) => Poll::Ready(Some(Ok(conn))), Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))), Poll::Pending => Poll::Pending, } diff --git a/tests/basic.rs b/tests/basic.rs index 99133f6..447daae 100644 --- a/tests/basic.rs +++ b/tests/basic.rs @@ -69,7 +69,7 @@ async fn tls_error() { assert_err!( listener.accept().await.unwrap(), TlsAcceptError { - remote_addr: MockAddress(42), + peer_addr: MockAddress(42), .. } ); @@ -123,3 +123,18 @@ async fn echo() { std::panic::resume_unwind(e.into_panic()); } } + +#[tokio::test] +async fn addr() { + let (connector, mut listener) = setup(); + + spawn(async move { + connector.send_data(b"hi").await.unwrap(); + connector.send_data(b"boo").await.unwrap(); + connector.send_data(b"test").await.unwrap(); + }); + + for i in 42..44 { + assert_eq!(listener.accept().await.unwrap().unwrap().1, MockAddress(i)); + } +} diff --git a/tests/helper/mocks.rs b/tests/helper/mocks.rs index 189e405..3adc6ab 100644 --- a/tests/helper/mocks.rs +++ b/tests/helper/mocks.rs @@ -2,7 +2,7 @@ use futures_util::future::{self, Ready}; use futures_util::ready; use std::io; use std::pin::Pin; -use std::sync::atomic::{self, AtomicU32}; +use std::sync::atomic::{AtomicU32, Ordering}; use std::task::{Context, Poll}; use tokio::io::{ duplex, split, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, DuplexStream, ReadBuf, @@ -22,7 +22,7 @@ pub struct MockConnect { counter: AtomicU32, } -#[derive(Clone, Copy, Debug)] +#[derive(Clone, Copy, Debug, PartialEq)] pub struct MockAddress(pub u32); pub fn accepting() -> (MockConnect, MockAccept) { @@ -39,7 +39,7 @@ pub fn accepting() -> (MockConnect, MockAccept) { impl MockConnect { pub async fn connect(&self) -> DuplexStream { let (tx, rx) = duplex(1024); - let count = self.counter.fetch_add(1, atomic::Ordering::Relaxed); + let count = self.counter.fetch_add(1, Ordering::Relaxed); self.chan.send(Ok((rx, MockAddress(count)))).await.unwrap(); tx }