Skip to content

Commit

Permalink
Merge pull request #37 from tmccombs/timeout-errors
Browse files Browse the repository at this point in the history
feat!: Add a new error type for handshake timeouts
  • Loading branch information
tmccombs authored Oct 17, 2023
2 parents 8eed9e9 + 23ca7ff commit c6a68b4
Show file tree
Hide file tree
Showing 13 changed files with 226 additions and 60 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ hyper-h2 = ["hyper", "hyper/http2"]
[dependencies]
futures-util = "0.3.8"
hyper = { version = "0.14.1", features = ["server", "tcp"], optional = true }
pin-project-lite = "0.2.8"
pin-project-lite = "0.2.13"
thiserror = "1.0.30"
tokio = { version = "1.0", features = ["time"] }
tokio-native-tls = { version = "0.3.0", optional = true }
Expand Down
6 changes: 3 additions & 3 deletions examples/echo-threads.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ mod tls_config;
use tls_config::tls_acceptor;

#[inline]
async fn handle_stream(stream: TlsStream<TcpStream>) {
async fn handle_stream(stream: TlsStream<TcpStream>, _remote_addr: SocketAddr) {
let (mut reader, mut writer) = split(stream);
match copy(&mut reader, &mut writer).await {
Ok(cnt) => eprintln!("Processed {} bytes", cnt),
Expand All @@ -32,8 +32,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
TlsListener::new(SpawningHandshakes(tls_acceptor()), listener)
.for_each_concurrent(None, |s| async {
match s {
Ok(stream) => {
handle_stream(stream).await;
Ok((stream, remote_addr)) => {
handle_stream(stream, remote_addr).await;
}
Err(e) => {
eprintln!("Error: {:?}", e);
Expand Down
10 changes: 7 additions & 3 deletions examples/echo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ mod tls_config;
use tls_config::tls_acceptor;

#[inline]
async fn handle_stream(stream: TlsStream<TcpStream>) {
async fn handle_stream(stream: TlsStream<TcpStream>, _remote_addr: SocketAddr) {
let (mut reader, mut writer) = split(stream);
match copy(&mut reader, &mut writer).await {
Ok(cnt) => eprintln!("Processed {} bytes", cnt),
Expand All @@ -41,10 +41,14 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
TlsListener::new(tls_acceptor(), listener)
.for_each_concurrent(None, |s| async {
match s {
Ok(stream) => {
handle_stream(stream).await;
Ok((stream, remote_addr)) => {
handle_stream(stream, remote_addr).await;
}
Err(e) => {
if let Some(remote_addr) = e.peer_addr() {
eprint!("[client {remote_addr}] ");
}

eprintln!("Error accepting connection: {:?}", e);
}
}
Expand Down
8 changes: 6 additions & 2 deletions examples/http-change-certificate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,22 @@ async fn main() {
tokio::select! {
conn = listener.accept() => {
match conn.expect("Tls listener stream should be infinite") {
Ok(conn) => {
Ok((conn, remote_addr)) => {
let http = http.clone();
let tx = tx.clone();
let counter = counter.clone();
tokio::spawn(async move {
let svc = service_fn(move |request| handle_request(tx.clone(), counter.clone(), request));
if let Err(err) = http.serve_connection(conn, svc).await {
eprintln!("Application error: {}", err);
eprintln!("Application error (client address: {remote_addr}): {err}");
}
});
},
Err(e) => {
if let Some(remote_addr) = e.peer_addr() {
eprint!("[client {remote_addr}] ");
}

eprintln!("Bad connection: {}", e);
}
}
Expand Down
8 changes: 6 additions & 2 deletions examples/http-low-level.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,19 @@ async fn main() {
listener
.for_each(|r| async {
match r {
Ok(conn) => {
Ok((conn, remote_addr)) => {
let http = http.clone();
tokio::spawn(async move {
if let Err(err) = http.serve_connection(conn, svc).await {
eprintln!("Application error: {}", err);
eprintln!("[client {remote_addr}] Application error: {}", err);
}
});
}
Err(err) => {
if let Some(remote_addr) = err.peer_addr() {
eprint!("[client {remote_addr}] ");
}

eprintln!("Error accepting connection: {}", err);
}
}
Expand Down
18 changes: 10 additions & 8 deletions examples/http-stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,16 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
});

// This uses a filter to handle errors with connecting
let incoming = TlsListener::new(tls_acceptor(), AddrIncoming::bind(&addr)?).filter(|conn| {
if let Err(err) = conn {
eprintln!("Error: {:?}", err);
ready(false)
} else {
ready(true)
}
});
let incoming = TlsListener::new(tls_acceptor(), AddrIncoming::bind(&addr)?)
.connections()
.filter(|conn| {
if let Err(err) = conn {
eprintln!("Error: {:?}", err);
ready(false)
} else {
ready(true)
}
});

let server = Server::builder(accept::from_stream(incoming)).serve(new_svc);
server.await?;
Expand Down
3 changes: 3 additions & 0 deletions examples/tls_config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ mod config {

const CERT: &[u8] = include_bytes!("local.cert");
const PKEY: &[u8] = include_bytes!("local.key");
#[allow(dead_code)]
const CERT2: &[u8] = include_bytes!("local2.cert");
#[allow(dead_code)]
const PKEY2: &[u8] = include_bytes!("local2.key");

pub type Acceptor = tokio_rustls::TlsAcceptor;
Expand All @@ -27,6 +29,7 @@ mod config {
tls_acceptor_impl(PKEY, CERT)
}

#[allow(dead_code)]
pub fn tls_acceptor2() -> Acceptor {
tls_acceptor_impl(PKEY2, CERT2)
}
Expand Down
27 changes: 21 additions & 6 deletions src/hyper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,16 @@ use std::ops::{Deref, DerefMut};
impl AsyncAccept for AddrIncoming {
type Connection = AddrStream;
type Error = std::io::Error;
type Address = std::net::SocketAddr;

fn poll_accept(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Connection, Self::Error>>> {
<AddrIncoming as HyperAccept>::poll_accept(self, cx)
) -> Poll<Option<Result<(Self::Connection, Self::Address), Self::Error>>> {
<AddrIncoming as HyperAccept>::poll_accept(self, cx).map_ok(|conn| {
let peer_addr = conn.remote_addr();
(conn, peer_addr)
})
}
}

Expand All @@ -22,6 +26,11 @@ pin_project! {
/// Unfortunately, it isn't possible to use a blanket impl, due to coherence rules.
/// At least until [RFC 1210](https://rust-lang.github.io/rfcs/1210-impl-specialization.html)
/// (specialization) is stabilized.
///
/// Note that, because `hyper::server::accept::Accept` does not expose the
/// remote address, the implementation of `AsyncAccept` for `WrappedAccept`
/// doesn't expose it either. That is, [`AsyncAccept::Address`] is `()` in
/// this case.
//#[cfg_attr(docsrs, doc(cfg(any(feature = "hyper-h1", feature = "hyper-h2"))))]
pub struct WrappedAccept<A> {
// sadly, pin-project-lite doesn't suport tuple structs :(
Expand All @@ -43,15 +52,20 @@ pub fn wrap<A: HyperAccept>(acceptor: A) -> WrappedAccept<A> {
impl<A: HyperAccept> AsyncAccept for WrappedAccept<A>
where
A::Conn: AsyncRead + AsyncWrite,
A::Error: std::error::Error,
{
type Connection = A::Conn;
type Error = A::Error;
type Address = ();

fn poll_accept(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Connection, Self::Error>>> {
self.project().inner.poll_accept(cx)
) -> Poll<Option<Result<(Self::Connection, ()), Self::Error>>> {
self.project()
.inner
.poll_accept(cx)
.map_ok(|conn| (conn, ()))
}
}

Expand All @@ -78,6 +92,7 @@ impl<A: HyperAccept> WrappedAccept<A> {
impl<A: HyperAccept, T> TlsListener<WrappedAccept<A>, T>
where
A::Conn: AsyncWrite + AsyncRead,
A::Error: std::error::Error,
T: AsyncTls<A::Conn>,
{
/// Create a `TlsListener` from a hyper [`Accept`](::hyper::server::accept::Accept) and tls
Expand All @@ -95,12 +110,12 @@ where
T: AsyncTls<A::Connection>,
{
type Conn = T::Stream;
type Error = Error<A::Error, T::Error>;
type Error = Error<A::Error, T::Error, A::Address>;

fn poll_accept(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
self.poll_next(cx)
self.poll_next(cx).map_ok(|(conn, _)| conn)
}
}
Loading

0 comments on commit c6a68b4

Please sign in to comment.