Skip to content

Commit

Permalink
feat: ensure rustls errors propagate (#14)
Browse files Browse the repository at this point in the history
  • Loading branch information
mmastrac authored Nov 15, 2023
1 parent 60595ca commit 71d5fe9
Show file tree
Hide file tree
Showing 7 changed files with 168 additions and 117 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

25 changes: 25 additions & 0 deletions src/adapter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,31 @@ use std::io::Read;
use std::io::Write;
use tokio::net::TcpStream;

/// Convert a [`rustls::Error`] to an [`io::Error`]
pub fn rustls_to_io_error(error: rustls::Error) -> io::Error {
io::Error::new(ErrorKind::InvalidData, error)
}

/// Clones an [`io::Result`], assuming the inner error, if any, is a [`rustls::Error`].
pub fn clone_result<T: Clone>(result: &io::Result<T>) -> io::Result<T> {
match result {
Ok(t) => Ok(t.clone()),
Err(e) => Err(clone_error(e)),
}
}

/// Clones an [`io::Error`], assuming the inner error, if any, is a [`rustls::Error`].
pub fn clone_error(e: &io::Error) -> io::Error {
let kind = e.kind();
match e.get_ref() {
Some(e) => match e.downcast_ref::<rustls::Error>() {
Some(e) => io::Error::new(kind, e.clone()),
None => kind.into(),
},
None => kind.into(),
}
}

#[inline(always)]
fn trace_error(error: io::Error) -> io::Error {
#[cfg(all(debug_assertions, feature = "trace"))]
Expand Down
42 changes: 21 additions & 21 deletions src/connection_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use tokio::io::ReadBuf;
use tokio::net::TcpStream;

use crate::adapter::read_tls;
use crate::adapter::rustls_to_io_error;
use crate::adapter::write_tls;
use crate::trace;

Expand Down Expand Up @@ -118,15 +119,6 @@ impl ConnectionStream {
.unwrap_or_default()
}

#[cfg(test)]
pub fn tls_bytes_to_write(&self) -> usize {
self
.last_iostate
.as_ref()
.map(|iostate| iostate.tls_bytes_to_write())
.unwrap_or_default()
}

fn poll_read_only(&mut self, mut cx: PollContext) -> StreamProgress {
if self.rd_error.is_some() || self.rd_proto_error.is_some() {
StreamProgress::Error
Expand Down Expand Up @@ -231,9 +223,8 @@ impl ConnectionStream {
Err(err) if err.kind() == ErrorKind::WouldBlock => {
trace!("r*={err:?}");
// No data to read, but we need to make sure we don't have an error state here.
if self.rd_proto_error.is_some() {
// TODO: Should we expose the underlying TLS error?
Err(ErrorKind::InvalidData.into())
if let Some(err) = &self.rd_proto_error {
Err(rustls_to_io_error(err.clone()))
} else if let Some(err) = self.rd_error {
// We have a connection error
Err(err.into())
Expand Down Expand Up @@ -536,6 +527,7 @@ mod tests {
use crate::tests::server_name;
use crate::tests::tcp_pair;
use crate::tests::TestResult;
use crate::TestOptions;
use futures::future::poll_fn;
use futures::task::noop_waker_ref;
use rustls::ClientConnection;
Expand Down Expand Up @@ -596,8 +588,16 @@ mod tests {
ClientConnection::new(client_config().into(), server_name())
.unwrap()
.into();
let server = spawn(handshake_task(server.into(), tls_server));
let client = spawn(handshake_task(client.into(), tls_client));
let server = spawn(handshake_task(
server.into(),
tls_server,
TestOptions::default(),
));
let client = spawn(handshake_task(
client.into(),
tls_client,
TestOptions::default(),
));
let (tcp_client, tls_client) = client.await.unwrap().unwrap().reclaim();
let (tcp_server, tls_server) = server.await.unwrap().unwrap().reclaim();
assert!(!tls_client.is_handshaking());
Expand Down Expand Up @@ -764,18 +764,18 @@ mod tests {
expect_write_1(&mut server).await;
assert_ne!(server.plaintext_bytes_to_read(), 0);

client
.into_inner()
.0
.write_all(b"THIS IS NOT A VALID TLS PACKET")
.await?;
let mut tcp = client.into_inner().0;
tcp.write_all(b"THIS IS NOT A VALID TLS PACKET").await?;

// One byte will read fine
expect_read_1(&mut server).await;

// The next byte will not
expect_read_1_err(&mut server, ErrorKind::InvalidData).await;

// Hold the TCP connection until here otherwise Windows may barf
drop(tcp);

Ok(())
}

Expand All @@ -791,7 +791,7 @@ mod tests {
total += n.unwrap();
}

server.tls.writer().write(b"final")?;
_ = server.tls.writer().write(b"final")?;
let iostate = server.tls.process_new_packets().unwrap();

assert!(iostate.tls_bytes_to_write() > 0);
Expand Down Expand Up @@ -829,7 +829,7 @@ mod tests {
total += n.unwrap();
}

server.tls.writer().write(b"final")?;
_ = server.tls.writer().write(b"final")?;
let iostate = server.tls.process_new_packets().unwrap();

assert!(iostate.tls_bytes_to_write() > 0);
Expand Down
83 changes: 61 additions & 22 deletions src/handshake.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,40 +7,40 @@ use std::sync::Arc;
use tokio::net::TcpStream;

use crate::adapter::read_tls;
use crate::adapter::rustls_to_io_error;
use crate::adapter::write_tls;
use crate::trace;
use crate::TestOptions;

fn try_read<'a, 'b>(
tcp: &'a TcpStream,
tls: &'b mut Connection,
) -> io::Result<()> {
#[inline(always)]
fn trace_result<T>(result: io::Result<T>) -> io::Result<T> {
#[cfg(feature = "trace")]
if let Err(err) = &result {
trace!("result = {err:?}");
}
result
}

fn try_read(tcp: &TcpStream, tls: &mut Connection) -> io::Result<()> {
match read_tls(tcp, tls) {
Ok(0) => {
// EOF during handshake
return Err(ErrorKind::UnexpectedEof.into());
}
Ok(_) => {
tls
.process_new_packets()
.map_err(|_| io::Error::from(ErrorKind::InvalidData))?;
tls.process_new_packets().map_err(rustls_to_io_error)?;
}
Err(err) if err.kind() == ErrorKind::WouldBlock => {
// Spurious wakeup
}
err @ Err(_) => {
// If we failed to read, try a last-gasp write to send a reason to the other side. This behaves in the
// same way that the rustls Connection::complete_io() method would.
_ = try_write(tcp, tls);
err?;
}
}
Ok(())
}

fn try_write<'a, 'b>(
tcp: &'a TcpStream,
tls: &'b mut Connection,
) -> io::Result<()> {
fn try_write(tcp: &TcpStream, tls: &mut Connection) -> io::Result<()> {
match write_tls(tcp, tls) {
Ok(_) => {}
Err(err) if err.kind() == ErrorKind::WouldBlock => {
Expand Down Expand Up @@ -81,15 +81,27 @@ impl HandshakeResult {
}

/// Performs a handshake and returns the [`TcpStream`]/[`Connection`] pair if successful.
#[cfg(test)]
pub(crate) async fn handshake_task(
tcp: Arc<TcpStream>,
tls: Connection,
test_options: TestOptions,
) -> io::Result<HandshakeResult> {
handshake_task_internal(tcp, tls, TestOptions::default()).await
let res = handshake_task_internal(tcp, tls, test_options).await;
// Ensure consistency in handshake errors
match res {
#[cfg(windows)]
Err(err) if err.kind() == ErrorKind::ConnectionAborted => {
Err(ErrorKind::UnexpectedEof.into())
}
#[cfg(target_os = "macos")]
Err(err) if err.kind() == ErrorKind::ConnectionReset => {
Err(ErrorKind::UnexpectedEof.into())
}
r => r,
}
}

pub(crate) async fn handshake_task_internal(
async fn handshake_task_internal(
tcp: Arc<TcpStream>,
mut tls: Connection,
test_options: TestOptions,
Expand All @@ -107,7 +119,7 @@ pub(crate) async fn handshake_task_internal(
break;
}
if tls.wants_write() {
tcp.writable().await?;
trace_result(tcp.writable().await)?;
#[cfg(test)]
if test_options.slow_handshake_write {
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
Expand Down Expand Up @@ -153,12 +165,31 @@ pub(crate) async fn handshake_task_internal(
// this loop while we flush writes. Note that these signals changed subtly between rustls 0.20 and
// rustls 0.21 (in the former we didn't need the `tls.wants_read()` test).
if tls.is_handshaking() && tls.wants_read() {
tcp.readable().await?;
trace_result(tcp.readable().await)?;
#[cfg(test)]
if test_options.slow_handshake_read {
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
}
try_read(&tcp, &mut tls)?;
match try_read(&tcp, &mut tls) {
Ok(_) => {}
Err(err) => {
trace!("read error {err:?}, starting last gasp write");
// If we failed to read, try a last-gasp write to send a reason to the other side. This behaves in the
// same way that the rustls Connection::complete_io() method would.
while tls.wants_write() {
trace_result(tcp.writable().await)?;
match try_write(&tcp, &mut tls) {
Err(err) if err.kind() == ErrorKind::WouldBlock => {
// Spurious wakeup
continue;
}
Err(_) => break,
Ok(_) => {}
}
}
return Err(err);
}
}
}
}
Ok(HandshakeResult(tcp, tls))
Expand All @@ -185,8 +216,16 @@ mod tests {
ClientConnection::new(client_config().into(), server_name())
.unwrap()
.into();
let server = spawn(handshake_task(server.into(), tls_server));
let client = spawn(handshake_task(client.into(), tls_client));
let server = spawn(handshake_task(
server.into(),
tls_server,
TestOptions::default(),
));
let client = spawn(handshake_task(
client.into(),
tls_client,
TestOptions::default(),
));
let (tcp_client, tls_client) = client.await.unwrap().unwrap().reclaim();
let (tcp_server, tls_server) = server.await.unwrap().unwrap().reclaim();
assert!(!tls_client.is_handshaking());
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ struct TestOptions {

macro_rules! trace {
($($args:expr),+) => {
if false && cfg!(feature="trace")
if cfg!(feature="trace")
{
println!($($args),+);
}
Expand Down
Loading

0 comments on commit 71d5fe9

Please sign in to comment.