Skip to content

Commit

Permalink
feat: ensure rustls errors propagate
Browse files Browse the repository at this point in the history
  • Loading branch information
mmastrac committed Nov 15, 2023
1 parent 60595ca commit 2806f5c
Show file tree
Hide file tree
Showing 6 changed files with 121 additions and 83 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

25 changes: 25 additions & 0 deletions src/adapter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,31 @@ use std::io::Read;
use std::io::Write;
use tokio::net::TcpStream;

/// Convert a [`rustls::Error`] to an [`io::Error`]
pub fn rustls_to_io_error(error: rustls::Error) -> io::Error {
io::Error::new(ErrorKind::InvalidData, error)
}

/// Clones an [`io::Result`], assuming the inner error, if any, is a [`rustls::Error`].
pub fn clone_result<T: Clone>(result: &io::Result<T>) -> io::Result<T> {
match result {
Ok(t) => Ok(t.clone()),
Err(e) => Err(clone_error(e)),
}
}

/// Clones an [`io::Error`], assuming the inner error, if any, is a [`rustls::Error`].
pub fn clone_error(e: &io::Error) -> io::Error {
let kind = e.kind();
match e.get_ref() {
Some(e) => match e.downcast_ref::<rustls::Error>() {
Some(e) => io::Error::new(kind, e.clone()),
None => kind.into(),
},
None => kind.into(),
}
}

#[inline(always)]
fn trace_error(error: io::Error) -> io::Error {
#[cfg(all(debug_assertions, feature = "trace"))]
Expand Down
25 changes: 8 additions & 17 deletions src/connection_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use tokio::io::ReadBuf;
use tokio::net::TcpStream;

use crate::adapter::read_tls;
use crate::adapter::rustls_to_io_error;
use crate::adapter::write_tls;
use crate::trace;

Expand Down Expand Up @@ -118,15 +119,6 @@ impl ConnectionStream {
.unwrap_or_default()
}

#[cfg(test)]
pub fn tls_bytes_to_write(&self) -> usize {
self
.last_iostate
.as_ref()
.map(|iostate| iostate.tls_bytes_to_write())
.unwrap_or_default()
}

fn poll_read_only(&mut self, mut cx: PollContext) -> StreamProgress {
if self.rd_error.is_some() || self.rd_proto_error.is_some() {
StreamProgress::Error
Expand Down Expand Up @@ -231,9 +223,8 @@ impl ConnectionStream {
Err(err) if err.kind() == ErrorKind::WouldBlock => {
trace!("r*={err:?}");
// No data to read, but we need to make sure we don't have an error state here.
if self.rd_proto_error.is_some() {
// TODO: Should we expose the underlying TLS error?
Err(ErrorKind::InvalidData.into())
if let Some(err) = &self.rd_proto_error {
Err(rustls_to_io_error(err.clone()))
} else if let Some(err) = self.rd_error {
// We have a connection error
Err(err.into())
Expand Down Expand Up @@ -764,18 +755,18 @@ mod tests {
expect_write_1(&mut server).await;
assert_ne!(server.plaintext_bytes_to_read(), 0);

client
.into_inner()
.0
.write_all(b"THIS IS NOT A VALID TLS PACKET")
.await?;
let mut tcp = client.into_inner().0;
tcp.write_all(b"THIS IS NOT A VALID TLS PACKET").await?;

// One byte will read fine
expect_read_1(&mut server).await;

// The next byte will not
expect_read_1_err(&mut server, ErrorKind::InvalidData).await;

// Hold the TCP connection until here otherwise Windows may barf
drop(tcp);

Ok(())
}

Expand Down
43 changes: 34 additions & 9 deletions src/handshake.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,20 @@ use std::sync::Arc;
use tokio::net::TcpStream;

use crate::adapter::read_tls;
use crate::adapter::rustls_to_io_error;
use crate::adapter::write_tls;
use crate::trace;
use crate::TestOptions;

#[inline(always)]
fn trace_result<T>(result: io::Result<T>) -> io::Result<T> {
#[cfg(feature = "trace")]
if let Err(err) = &result {
trace!("result = {err:?}");
}
result
}

fn try_read<'a, 'b>(
tcp: &'a TcpStream,
tls: &'b mut Connection,
Expand All @@ -20,17 +31,12 @@ fn try_read<'a, 'b>(
return Err(ErrorKind::UnexpectedEof.into());
}
Ok(_) => {
tls
.process_new_packets()
.map_err(|_| io::Error::from(ErrorKind::InvalidData))?;
tls.process_new_packets().map_err(rustls_to_io_error)?;
}
Err(err) if err.kind() == ErrorKind::WouldBlock => {
// Spurious wakeup
}
err @ Err(_) => {
// If we failed to read, try a last-gasp write to send a reason to the other side. This behaves in the
// same way that the rustls Connection::complete_io() method would.
_ = try_write(tcp, tls);
err?;
}
}
Expand Down Expand Up @@ -107,7 +113,7 @@ pub(crate) async fn handshake_task_internal(
break;
}
if tls.wants_write() {
tcp.writable().await?;
trace_result(tcp.writable().await)?;
#[cfg(test)]
if test_options.slow_handshake_write {
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
Expand Down Expand Up @@ -153,12 +159,31 @@ pub(crate) async fn handshake_task_internal(
// this loop while we flush writes. Note that these signals changed subtly between rustls 0.20 and
// rustls 0.21 (in the former we didn't need the `tls.wants_read()` test).
if tls.is_handshaking() && tls.wants_read() {
tcp.readable().await?;
trace_result(tcp.readable().await)?;
#[cfg(test)]
if test_options.slow_handshake_read {
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
}
try_read(&tcp, &mut tls)?;
match try_read(&tcp, &mut tls) {
Ok(_) => {}
Err(err) => {
trace!("read error {err:?}, starting last gasp write");
// If we failed to read, try a last-gasp write to send a reason to the other side. This behaves in the
// same way that the rustls Connection::complete_io() method would.
while tls.wants_write() {
trace_result(tcp.writable().await)?;
match try_write(&tcp, &mut tls) {
Err(err) if err.kind() == ErrorKind::WouldBlock => {
// Spurious wakeup
continue;
}
Err(_) => break,
Ok(_) => {}
}
}
return Err(err);
}
}
}
}
Ok(HandshakeResult(tcp, tls))
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ struct TestOptions {

macro_rules! trace {
($($args:expr),+) => {
if false && cfg!(feature="trace")
if cfg!(feature="trace")
{
println!($($args),+);
}
Expand Down
Loading

0 comments on commit 2806f5c

Please sign in to comment.