From a58d250b963cc928fa643e14c0edf38ddcbb80bd Mon Sep 17 00:00:00 2001 From: Danny Browning Date: Fri, 12 Apr 2019 20:34:52 -0600 Subject: [PATCH] TLS Peer Dependency (#4) * Native TLS Shadowing Addresses #1 and #2. Provides interfaces without needing to use native-tls library. --- .travis.yml | 2 +- Cargo.toml | 8 +- README.md | 5 +- examples/download-rust-lang.rs | 3 +- src/acceptor.rs | 133 +++++++++++++++++++++++++++ src/connector.rs | 163 +++++++++++++++++++++++++++++++++ src/errors.rs | 18 ++++ src/lib.rs | 153 +++---------------------------- src/pending.rs | 86 +++++++++++++++++ tests/bad.rs | 72 +++++++-------- tests/google.rs | 48 +++++----- tests/smoke.rs | 25 +++-- 12 files changed, 492 insertions(+), 224 deletions(-) create mode 100644 src/acceptor.rs create mode 100644 src/connector.rs create mode 100644 src/errors.rs create mode 100644 src/pending.rs diff --git a/.travis.yml b/.travis.yml index a64e4fa..df3f83a 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,7 +1,7 @@ --- language: rust rust: - - nightly + - nightly-2019-04-07 sudo: false cache: - apt diff --git a/Cargo.toml b/Cargo.toml index fc75d35..ba10289 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,7 +7,7 @@ edition = "2018" # - Update CHANGELOG.md. # - Update doc URL. # - Create "v0.1.x" git tag. -version = "0.3.0-alpha.2" +version = "0.3.0-alpha.3" license = "MIT" readme = "README.md" description = """ @@ -15,10 +15,12 @@ TLS support for AsyncRead/AsyncWrite using native-tls """ authors = ["Danny Browning ", "Carl Lerche "] categories = ["asynchronous", "network-programming"] -documentation = "https://docs.rs/tls-async/0.3.0-alpha.2/tls_async/" +documentation = "https://docs.rs/tls-async/0.3.0-alpha.3/tls_async/" repository = "https://github.com/dbcfd/tls-async" [dependencies] +failure = "0.1" +failure_derive = "0.1" log = "0.4.1" native-tls = "0.2" @@ -29,7 +31,7 @@ features = ["compat", "io-compat", "std"] [dev-dependencies] cfg-if = "0.1" -romio = "0.3.0-alpha.2" +romio = "0.3.0-alpha.3" tokio = "0.1" [dev-dependencies.env_logger] diff --git a/README.md b/README.md index 97dcba3..ece97ee 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ This is an experimental fork of [tokio-tls](https://github.com/tokio-rs/tokio/tr An implementation of TLS/SSL streams for [Futures 0.3](https://github.com/rust-lang-nursery/futures-rs) built on top of the [`native-tls` crate] -[Documentation](https://docs.rs/tls-async/0.3.0-alpha.1/) +[Documentation](https://docs.rs/tls-async/0.3.0-alpha.3/) [`native-tls` crate]: https://github.com/sfackler/rust-native-tls @@ -29,8 +29,7 @@ First, add this to your `Cargo.toml`: ```toml [dependencies] -native-tls = "0.2" -tls-async = "0.3.0-alpha.1" +tls-async = "0.3.0-alpha.3" ``` Next, add this to your crate: diff --git a/examples/download-rust-lang.rs b/examples/download-rust-lang.rs index 82cbf33..3892c45 100644 --- a/examples/download-rust-lang.rs +++ b/examples/download-rust-lang.rs @@ -3,8 +3,8 @@ use std::net::ToSocketAddrs; use futures::{FutureExt, TryFutureExt}; use futures::io::{AsyncReadExt, AsyncWriteExt}; -use native_tls::TlsConnector; use romio::TcpStream; +use tls_async::TlsConnector; use tokio::runtime::Runtime; fn main() { @@ -19,7 +19,6 @@ fn main() { let socket = await!(TcpStream::connect(&addr)).expect("Could not connect"); let cx = TlsConnector::builder().build().expect("Could not build"); - let cx = tls_async::TlsConnector::from(cx); let mut socket = await!(cx.connect("www.rust-lang.org", socket)).expect("Could not form tls connection"); let _ = await!(socket.write_all(b"\ diff --git a/src/acceptor.rs b/src/acceptor.rs new file mode 100644 index 0000000..04b367f --- /dev/null +++ b/src/acceptor.rs @@ -0,0 +1,133 @@ +use crate::errors::Error; +use crate::pending::PendingTlsStream; +use crate::{Identity, Protocol}; + +use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite}; + +/// A builder for `TlsAcceptor`s. +pub struct TlsAcceptorBuilder { + inner: native_tls::TlsAcceptorBuilder, +} + +impl TlsAcceptorBuilder { + /// Sets the minimum supported protocol version. + /// + /// A value of `None` enables support for the oldest protocols supported by the implementation. + /// + /// Defaults to `Some(Protocol::Tlsv10)`. + pub fn min_protocol_version(&mut self, protocol: Option) -> &mut TlsAcceptorBuilder { + self.inner.min_protocol_version(protocol); + self + } + + /// Sets the maximum supported protocol version. + /// + /// A value of `None` enables support for the newest protocols supported by the implementation. + /// + /// Defaults to `None`. + pub fn max_protocol_version(&mut self, protocol: Option) -> &mut TlsAcceptorBuilder { + self.inner.max_protocol_version(protocol); + self + } + + /// Creates a new `TlsAcceptor`. + pub fn build(&self) -> Result { + let acceptor = self.inner.build().map_err(Error::Acceptor)?; + Ok(TlsAcceptor { + inner: acceptor + }) + } +} + +/// A builder for server-side TLS connections. +/// +/// # Examples +/// +/// ```rust,no_run +/// #![feature(async_await, await_macro, futures_api)] +/// use futures::StreamExt; +/// use futures::io::AsyncRead; +/// use tls_async::{Identity, TlsAcceptor, TlsStream}; +/// use std::fs::File; +/// use std::io::{Read}; +/// use romio::{TcpListener, TcpStream}; +/// use std::sync::Arc; +/// use std::thread; +/// +/// let mut file = File::open("identity.pfx").unwrap(); +/// let mut identity = vec![]; +/// file.read_to_end(&mut identity).unwrap(); +/// let identity = Identity::from_pkcs12(&identity, "hunter2").unwrap(); +/// +/// let mut listener = TcpListener::bind(&"0.0.0.0:8443".parse().unwrap()).unwrap(); +/// let acceptor = TlsAcceptor::new(identity).unwrap(); +/// let acceptor = Arc::new(acceptor); +/// +/// fn handle_client(stream: S) { +/// // ... +/// } +/// +/// let mut incoming = listener.incoming(); +/// # futures::executor::block_on(async { +/// for stream in await!(incoming.next()) { +/// match stream { +/// Ok(stream) => { +/// let acceptor = acceptor.clone(); +/// let stream = await!(acceptor.accept(stream)).unwrap(); +/// handle_client(stream); +/// } +/// Err(e) => { /* connection failed */ } +/// } +/// } +/// # }) +/// ``` +#[derive(Clone)] +pub struct TlsAcceptor { + inner: native_tls::TlsAcceptor, +} + +impl TlsAcceptor { + /// Creates a acceptor with default settings. + /// + /// The identity acts as the server's private key/certificate chain. + pub fn new(identity: Identity) -> Result { + let native_acceptor = native_tls::TlsAcceptor::new(identity).map_err(Error::Acceptor)?; + Ok(TlsAcceptor { + inner: native_acceptor, + }) + } + + /// Returns a new builder for a `TlsAcceptor`. + /// + /// The identity acts as the server's private key/certificate chain. + pub fn builder(identity: Identity) -> TlsAcceptorBuilder { + let builder = native_tls::TlsAcceptor::builder(identity); + TlsAcceptorBuilder { + inner: builder, + } + } + + /// Accepts a new client connection with the provided stream. + /// + /// This function will internally call `TlsAcceptor::accept` to connect + /// the stream and returns a future representing the resolution of the + /// connection operation. The returned future will resolve to either + /// `TlsStream` or `Error` depending if it's successful or not. + /// + /// This is typically used after a new socket has been accepted from a + /// `TcpListener`. That socket is then passed to this function to perform + /// the server half of accepting a client connection. + pub fn accept(&self, stream: S) -> PendingTlsStream + where S: AsyncRead + AsyncWrite, + { + PendingTlsStream::new(self.inner.accept(stream.compat())) + } +} + +impl From for TlsAcceptor { + fn from(inner: native_tls::TlsAcceptor) -> Self { + Self { + inner, + } + } +} \ No newline at end of file diff --git a/src/connector.rs b/src/connector.rs new file mode 100644 index 0000000..f69efc0 --- /dev/null +++ b/src/connector.rs @@ -0,0 +1,163 @@ +use crate::errors::Error; +use crate::pending::PendingTlsStream; +use crate::{Certificate, Identity, Protocol}; + +use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite}; + +/// A builder for `TlsConnector`s. +pub struct TlsConnectorBuilder { + inner: native_tls::TlsConnectorBuilder, +} + +impl TlsConnectorBuilder { + /// Sets the identity to be used for client certificate authentication. + pub fn identity(&mut self, identity: Identity) -> &mut TlsConnectorBuilder { + self.inner.identity(identity); + self + } + + /// Sets the minimum supported protocol version. + /// + /// A value of `None` enables support for the oldest protocols supported by the implementation. + /// + /// Defaults to `Some(Protocol::Tlsv10)`. + pub fn min_protocol_version(&mut self, protocol: Option) -> &mut TlsConnectorBuilder { + self.inner.min_protocol_version(protocol); + self + } + + /// Sets the maximum supported protocol version. + /// + /// A value of `None` enables support for the newest protocols supported by the implementation. + /// + /// Defaults to `None`. + pub fn max_protocol_version(&mut self, protocol: Option) -> &mut TlsConnectorBuilder { + self.inner.max_protocol_version(protocol); + self + } + + /// Adds a certificate to the set of roots that the connector will trust. + /// + /// The connector will use the system's trust root by default. This method can be used to add + /// to that set when communicating with servers not trusted by the system. + /// + /// Defaults to an empty set. + pub fn add_root_certificate(&mut self, cert: Certificate) -> &mut TlsConnectorBuilder { + self.inner.add_root_certificate(cert); + self + } + + /// Controls the use of certificate validation. + /// + /// Defaults to `false`. + /// + /// # Warning + /// + /// You should think very carefully before using this method. If invalid certificates are trusted, *any* + /// certificate for *any* site will be trusted for use. This includes expired certificates. This introduces + /// significant vulnerabilities, and should only be used as a last resort. + pub fn danger_accept_invalid_certs( + &mut self, + accept_invalid_certs: bool, + ) -> &mut TlsConnectorBuilder { + self.inner.danger_accept_invalid_certs(accept_invalid_certs); + self + } + + /// Controls the use of Server Name Indication (SNI). + /// + /// Defaults to `true`. + pub fn use_sni(&mut self, use_sni: bool) -> &mut TlsConnectorBuilder { + self.inner.use_sni(use_sni); + self + } + + /// Controls the use of hostname verification. + /// + /// Defaults to `false`. + /// + /// # Warning + /// + /// You should think very carefully before using this method. If invalid hostnames are trusted, *any* valid + /// certificate for *any* site will be trusted for use. This introduces significant vulnerabilities, and should + /// only be used as a last resort. + pub fn danger_accept_invalid_hostnames( + &mut self, + accept_invalid_hostnames: bool, + ) -> &mut TlsConnectorBuilder { + self.inner.danger_accept_invalid_hostnames(accept_invalid_hostnames); + self + } + + /// Creates a new `TlsConnector`. + pub fn build(&self) -> Result { + let connector = self.inner.build().map_err(Error::Connector)?; + Ok(TlsConnector { + inner: connector + }) + } +} + +/// +/// # Examples +/// +/// ```rust,no_run +/// #![feature(async_await, await_macro, futures_api)] +/// use futures::io::{AsyncReadExt, AsyncWriteExt}; +/// use tls_async::TlsConnector; +/// use std::io::{Read, Write}; +/// use std::net::ToSocketAddrs; +/// use romio::TcpStream; +/// +/// # futures::executor::block_on(async { +/// let connector = TlsConnector::new().unwrap(); +/// +/// let addr = "google.com:443".to_socket_addrs().unwrap().next().unwrap(); +/// let stream = await!(TcpStream::connect(&addr)).unwrap(); +/// let mut stream = await!(connector.connect("google.com", stream)).unwrap(); +/// +/// await!(stream.write_all(b"GET / HTTP/1.0\r\n\r\n")).unwrap(); +/// let mut res = vec![]; +/// await!(stream.read_to_end(&mut res)).unwrap(); +/// println!("{}", String::from_utf8_lossy(&res)); +/// # }) +/// ``` +#[derive(Clone)] +pub struct TlsConnector { + inner: native_tls::TlsConnector, +} + +impl TlsConnector { + /// Returns a new connector with default settings. + pub fn new() -> Result { + let native_connector = native_tls::TlsConnector::new().map_err(Error::Connector)?; + Ok( TlsConnector { + inner: native_connector, + }) + } + + /// Returns a new builder for a `TlsConnector`. + pub fn builder() -> TlsConnectorBuilder { + TlsConnectorBuilder { + inner: native_tls::TlsConnector::builder(), + } + } + + /// Connects the provided stream with this connector, assuming the provided + /// domain. + /// + /// This function will internally call `TlsConnector::connect` to connect + /// the stream and returns a future representing the resolution of the + /// connection operation. The returned future will resolve to either + /// `TlsStream` or `Error` depending if it's successful or not. + /// + /// This is typically used for clients who have already established, for + /// example, a TCP connection to a remote server. That stream is then + /// provided here to perform the client half of a connection to a + /// TLS-powered server. + pub fn connect<'a, S>(&'a self, domain: &'a str, stream: S) -> PendingTlsStream + where S: AsyncRead + AsyncWrite, + { + PendingTlsStream::new(self.inner.connect(domain, stream.compat())) + } +} \ No newline at end of file diff --git a/src/errors.rs b/src/errors.rs new file mode 100644 index 0000000..f54583a --- /dev/null +++ b/src/errors.rs @@ -0,0 +1,18 @@ +use failure::Fail; + +#[derive(Debug, Fail)] +pub enum Error { + #[fail(display="NativeTls Acceptor Error")] + Acceptor(#[cause] native_tls::Error), + #[fail(display="NativeTls Connector Error")] + Connector(#[cause] native_tls::Error), + #[fail(display="Error during handshake")] + Handshake(#[cause] native_tls::Error), + #[fail(display="NativeTls Error")] + Native(#[cause] native_tls::Error), + #[fail(display="Cannot repeat handshake")] + RepeatedHandshake, +} + +unsafe impl Sync for Error {} +unsafe impl Send for Error {} \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index 8dff503..fd1eedc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,3 @@ -#![feature(async_await, await_macro, futures_api)] //! Async TLS streams //! //! This library is an implementation of TLS streams using the most appropriate @@ -15,19 +14,23 @@ //! functionality provided by the `native-tls` crate, on which this crate is //! built. Configuration of TLS parameters is still primarily done through the //! `native-tls` crate. +#![feature(async_await, await_macro, futures_api)] +mod acceptor; +mod connector; +mod errors; +mod pending; + +pub use acceptor::TlsAcceptor as TlsAcceptor; +pub use connector::TlsConnector as TlsConnector; +pub use errors::Error as Error; use std::io::{self, Read, Write}; -use std::pin::Pin; use std::task::Waker; -use futures::Future; use futures::compat::Compat; -use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite}; +use futures::io::{AsyncRead, AsyncWrite}; use futures::Poll; -use log::debug; -use native_tls::{Error, HandshakeError, MidHandshakeTlsStream, TlsStream as NativeTlsStream}; - -pub type NativeWrapperStream = NativeTlsStream>; +pub use native_tls::{Certificate as Certificate, Identity as Identity, Protocol as Protocol}; /// A wrapper around an underlying raw stream which implements the TLS or SSL /// protocol. @@ -38,19 +41,19 @@ pub type NativeWrapperStream = NativeTlsStream>; /// to a `TlsStream` are encrypted when passing through to `S`. #[derive(Debug)] pub struct TlsStream { - inner: NativeWrapperStream, + inner: native_tls::TlsStream>, } impl TlsStream { /// Get access to the internal `native_tls::TlsStream` stream which also /// transitively allows access to `S`. - pub fn get_ref(&self) -> &NativeWrapperStream { + pub fn get_ref(&self) -> &native_tls::TlsStream> { &self.inner } /// Get mutable access to the internal `native_tls::TlsStream` stream which /// also transitively allows mutable access to `S`. - pub fn get_mut(&mut self) -> &mut NativeWrapperStream { + pub fn get_mut(&mut self) -> &mut native_tls::TlsStream> { &mut self.inner } } @@ -99,130 +102,4 @@ impl AsyncWrite for TlsStream { Err(e) => Poll::Ready(Err(e)) } } -} - -/// A wrapper around a `native_tls::TlsConnector`, providing an async `connect` -/// method. -#[derive(Clone)] -pub struct TlsConnector { - inner: native_tls::TlsConnector, -} - -impl TlsConnector { - /// Connects the provided stream with this connector, assuming the provided - /// domain. - /// - /// This function will internally call `TlsConnector::connect` to connect - /// the stream and returns a future representing the resolution of the - /// connection operation. The returned future will resolve to either - /// `TlsStream` or `Error` depending if it's successful or not. - /// - /// This is typically used for clients who have already established, for - /// example, a TCP connection to a remote server. That stream is then - /// provided here to perform the client half of a connection to a - /// TLS-powered server. - pub fn connect<'a, S>(&'a self, domain: &'a str, stream: S) -> PendingTlsStream - where S: AsyncRead + AsyncWrite, - { - let connect_result = self.inner.connect(domain, stream.compat()).map(|i| { - Handshake::Completed(i) - }); - PendingTlsStream { - inner: Some(connect_result) - } - } -} - -impl From for TlsConnector { - fn from(inner: native_tls::TlsConnector) -> Self { - Self { - inner, - } - } -} - -/// A wrapper around a `native_tls::TlsAcceptor`, providing an async `accept` -/// method. -#[derive(Clone)] -pub struct TlsAcceptor { - inner: native_tls::TlsAcceptor, -} - -impl TlsAcceptor { - /// Accepts a new client connection with the provided stream. - /// - /// This function will internally call `TlsAcceptor::accept` to connect - /// the stream and returns a future representing the resolution of the - /// connection operation. The returned future will resolve to either - /// `TlsStream` or `Error` depending if it's successful or not. - /// - /// This is typically used after a new socket has been accepted from a - /// `TcpListener`. That socket is then passed to this function to perform - /// the server half of accepting a client connection. - pub fn accept(&self, stream: S) -> PendingTlsStream - where S: AsyncRead + AsyncWrite, - { - let accept_result = self.inner.accept(stream.compat()).map(|i| { - Handshake::Completed(i) - }); - PendingTlsStream { - inner: Some(accept_result) - } - } -} - -impl From for TlsAcceptor { - fn from(inner: native_tls::TlsAcceptor) -> Self { - Self { - inner, - } - } -} - -pub enum Handshake { - Completed(NativeWrapperStream), - Midhandshake(MidHandshakeTlsStream>) -} - -pub struct PendingTlsStream { - inner: Option, HandshakeError>>> -} - -impl Future for PendingTlsStream { - type Output = Result, Error>; - - fn poll(mut self: Pin<&mut Self>, _lw: &Waker) -> Poll { - let this: &mut Self = &mut *self; - let inner = std::mem::replace(&mut this.inner, None); - - match inner.expect("Cannot poll handshake twice") { - Ok(Handshake::Completed(native_stream)) => { - debug!("Connection was completed"); - Poll::Ready(Ok(TlsStream { inner: native_stream })) - } - Ok(Handshake::Midhandshake(midhandshake_stream)) => { - debug!("Connection was interrupted mid handshake, attempting handshake"); - match midhandshake_stream.handshake() { - Ok(native_stream) => { - debug!("Handshake completed, connection established"); - Poll::Ready(Ok(TlsStream { inner: native_stream })) - }, - Err(HandshakeError::Failure(e)) => Poll::Ready(Err(e)), - Err(HandshakeError::WouldBlock(midhandshake_stream)) => { - debug!("Handshake interrupted, {:?}", midhandshake_stream); - std::mem::replace(&mut this.inner, Some(Ok(Handshake::Midhandshake(midhandshake_stream)))); - Poll::Pending - } - } - } - Err(HandshakeError::Failure(e)) => Poll::Ready(Err(e)), - Err(HandshakeError::WouldBlock(midhandshake_stream)) => { - debug!("Handshake interrupted, {:?}", midhandshake_stream); - std::mem::replace(&mut this.inner, Some(Ok(Handshake::Midhandshake(midhandshake_stream)))); - Poll::Pending - } - } - } -} - - +} \ No newline at end of file diff --git a/src/pending.rs b/src/pending.rs new file mode 100644 index 0000000..07cd238 --- /dev/null +++ b/src/pending.rs @@ -0,0 +1,86 @@ +use crate::errors::Error; +use crate::TlsStream; + +use std::pin::Pin; +use std::task::Waker; + +use futures::Future; +use futures::compat::Compat; +use futures::io::{AsyncRead, AsyncWrite}; +use futures::Poll; +use log::debug; +use native_tls::{HandshakeError, MidHandshakeTlsStream}; +pub use native_tls::{TlsConnector as NativeTlsConnector, TlsStream as NativeTlsStream}; + +enum Handshake { + Error(Error), + Midhandshake(MidHandshakeTlsStream>), + Completed(NativeTlsStream>), +} + +impl Handshake { + pub fn was_pending(&self) -> bool { + if let Handshake::Midhandshake(_) = self { + true + } else { + false + } + } +} + +type NativeHandshake = Result>, HandshakeError>>; + +impl From> for Handshake { + fn from(v: NativeHandshake) -> Self { + match v { + Ok(native_stream) => Handshake::Completed(native_stream), + Err(HandshakeError::Failure(e)) => Handshake::Error(Error::Handshake(e)), + Err(HandshakeError::WouldBlock(midhandshake_stream)) => Handshake::Midhandshake(midhandshake_stream), + } + } +} + +pub struct PendingTlsStream { + inner: Handshake, +} + +impl PendingTlsStream { + pub fn new(inner: NativeHandshake) -> Self { + PendingTlsStream { + inner: Handshake::from(inner) + } + } + fn inner<'a>(self: Pin<&'a mut Self>) -> &'a mut Handshake { + unsafe { + &mut Pin::get_unchecked_mut(self).inner + } + } +} + +impl Future for PendingTlsStream { + type Output = Result, Error>; + + fn poll(mut self: Pin<&mut Self>, _lw: &Waker) -> Poll { + loop { + let handshake = std::mem::replace(self.as_mut().inner(), Handshake::Error(Error::RepeatedHandshake)); + match handshake { + Handshake::Error(e) => return Poll::Ready(Err(e)), + Handshake::Midhandshake(midhandshake_stream) => { + debug!("Connection was interrupted mid handshake, attempting handshake"); + let res = Handshake::from(midhandshake_stream.handshake()); + let was_pending = res.was_pending(); + *self.as_mut().inner() = res; + if was_pending { + return Poll::Pending; + } + } + Handshake::Completed(native_stream) => { + debug!("Connection was completed"); + return Poll::Ready(Ok(TlsStream { inner: native_stream })) + } + } + } + } +} + + diff --git a/tests/bad.rs b/tests/bad.rs index 3f2f4a3..785e3b1 100644 --- a/tests/bad.rs +++ b/tests/bad.rs @@ -1,11 +1,20 @@ #![feature(async_await, await_macro, futures_api)] -use std::io::Error; use std::net::ToSocketAddrs; use cfg_if::cfg_if; use futures::{FutureExt, TryFutureExt}; -use native_tls::TlsConnector; use romio::TcpStream; +use tls_async::{Error, TlsConnector}; + +fn check_cause(err: Error, s: &str) { + match err { + Error::Handshake(e) => { + let err = e.to_string(); + assert!(e.to_string().contains(s), "Error {} did not contain {}", err, s); + } + _ => panic!("Error {:?} was not a handshake error") + } +} macro_rules! t { ($e:expr) => (match $e { @@ -16,25 +25,20 @@ macro_rules! t { cfg_if! { if #[cfg(feature = "force-rustls")] { - fn verify_failed(err: &Error, s: &str) { - let err = err.to_string(); - assert!(err.contains(s), "bad error: {}", err); + fn assert_expired_error(err: Error) { + check_cause(err, "CertExpired"); } - fn assert_expired_error(err: &Error) { - verify_failed(err, "CertExpired"); + fn assert_wrong_host(err: Error) { + check_cause(err, "CertNotValidForName"); } - fn assert_wrong_host(err: &Error) { - verify_failed(err, "CertNotValidForName"); + fn assert_self_signed(err: Error) { + check_cause(err, "UnknownIssuer"); } - fn assert_self_signed(err: &Error) { - verify_failed(err, "UnknownIssuer"); - } - - fn assert_untrusted_root(err: &Error) { - verify_failed(err, "UnknownIssuer"); + fn assert_untrusted_root(err: Error) { + check_cause(err, "UnknownIssuer"); } } else if #[cfg(any(feature = "force-openssl", all(not(target_os = "macos"), @@ -42,8 +46,8 @@ cfg_if! { not(target_os = "ios"))))] { use openssl; - fn verify_failed(err: &Error) { - assert!(format!("{}", err).contains("certificate verify failed")) + fn verify_failed(err: Error) { + check_cause(err, "certificate verify failed") ; } use self::verify_failed as assert_expired_error; @@ -52,8 +56,8 @@ cfg_if! { use self::verify_failed as assert_untrusted_root; } else if #[cfg(any(target_os = "macos", target_os = "ios"))] { - fn assert_invalid_cert_chain(err: &Error) { - assert!(format!("{}", err).contains("was not trusted.")) + fn assert_invalid_cert_chain(err: Error) { + check_cause(err, "was not trusted."); } use self::assert_invalid_cert_chain as assert_expired_error; @@ -61,29 +65,22 @@ cfg_if! { use self::assert_invalid_cert_chain as assert_self_signed; use self::assert_invalid_cert_chain as assert_untrusted_root; } else { - fn assert_expired_error(err: &Error) { - let s = err.to_string(); - assert!(s.contains("system clock"), "error = {:?}", s); + fn assert_expired_error(err: Error) { + check_cause(err, "system clock"); } - fn assert_wrong_host(err: &Error) { - let s = err.to_string(); - assert!(s.contains("CN name"), "error = {:?}", s); + fn assert_wrong_host(err: Error) { + check_cause(err, "CN name"); } - fn assert_self_signed(err: &Error) { - let s = err.to_string(); - assert!(s.contains("root certificate which is not trusted"), "error = {:?}", s); + fn assert_self_signed(err: Error) { + check_cause(err, "root certificate which is not trusted"); } use self::assert_self_signed as assert_untrusted_root; } } -fn native2io(e: native_tls::Error) -> Error { - Error::new(std::io::ErrorKind::Other, e) -} - async fn get_host(host: String) -> Result<(), Error> { drop(env_logger::try_init()); @@ -93,8 +90,7 @@ async fn get_host(host: String) -> Result<(), Error> { let socket = t!(await!(TcpStream::connect(&addr))); let builder = TlsConnector::builder(); let cx = t!(builder.build()); - let cx = tls_async::TlsConnector::from(cx); - await!(cx.connect(&host, socket)).map_err(native2io)?; + await!(cx.connect(&host, socket))?; Ok(()) } @@ -107,7 +103,7 @@ fn expired() { let res = rt.block_on(fut_res.boxed().compat()); assert!(res.is_err()); - assert_expired_error(&res.err().unwrap()); + assert_expired_error(res.err().unwrap()); } // TODO: the OSX builders on Travis apparently fail this tests spuriously? @@ -122,7 +118,7 @@ fn wrong_host() { let res = rt.block_on(fut_res.boxed().compat()); assert!(res.is_err()); - assert_wrong_host(&res.err().unwrap()); + assert_wrong_host(res.err().unwrap()); } #[test] @@ -134,7 +130,7 @@ fn self_signed() { let res = rt.block_on(fut_res.boxed().compat()); assert!(res.is_err()); - assert_self_signed(&res.err().unwrap()); + assert_self_signed(res.err().unwrap()); } #[test] @@ -146,5 +142,5 @@ fn untrusted_root() { let res = rt.block_on(fut_res.boxed().compat()); assert!(res.is_err()); - assert_untrusted_root(&res.err().unwrap()); + assert_untrusted_root(res.err().unwrap()); } diff --git a/tests/google.rs b/tests/google.rs index a597294..7387d7a 100644 --- a/tests/google.rs +++ b/tests/google.rs @@ -1,13 +1,21 @@ #![feature(async_await, await_macro, futures_api)] - -use std::io; use std::net::ToSocketAddrs; use cfg_if::cfg_if; use futures::{FutureExt, TryFutureExt}; use futures::io::{AsyncReadExt, AsyncWriteExt}; -use native_tls::TlsConnector; use romio::TcpStream; +use tls_async::{Error, TlsConnector}; + +fn check_cause(err: Error, s: &str) { + match err { + Error::Handshake(e) => { + let err = e.to_string(); + assert!(e.to_string().contains(s), "Error {} did not contain {}", err, s); + } + _ => panic!("Error {:?} was not a handshake error") + } +} macro_rules! t { ($e:expr) => (match $e { @@ -18,9 +26,8 @@ macro_rules! t { cfg_if! { if #[cfg(feature = "force-rustls")] { - fn assert_bad_hostname_error(err: &io::Error) { - let err = err.to_string(); - assert!(err.contains("CertNotValidForName"), "bad error: {}", err); + fn assert_bad_hostname_error(err: Error) { + check_cause(err, "CertNotValidForName"); } } else if #[cfg(any(feature = "force-openssl", all(not(target_os = "macos"), @@ -28,30 +35,21 @@ cfg_if! { not(target_os = "ios"))))] { extern crate openssl; - fn assert_bad_hostname_error(err: &io::Error) { - let err = err.get_ref().unwrap(); - let err = err.downcast_ref::().unwrap(); - assert!(format!("{}", err).contains("certificate verify failed")); + fn assert_bad_hostname_error(err: Error) { + check_cause(err, "certificate verify failed"); } } else if #[cfg(any(target_os = "macos", target_os = "ios"))] { - fn assert_bad_hostname_error(err: &io::Error) { - let err = err.get_ref().unwrap(); - let err = err.downcast_ref::().unwrap(); - assert!(format!("{}", err).contains("was not trusted.")); + fn assert_bad_hostname_error(err: Error) { + check_cause(err, "was not trusted."); } } else { - fn assert_bad_hostname_error(err: &io::Error) { - let err = err.get_ref().unwrap(); - let err = err.downcast_ref::().unwrap(); - assert!(format!("{}", err).contains("CN name")); + fn assert_bad_hostname_error(err: Error) { + let err = err.compat().to_string(); + check_cause(err, "CN name"); } } } -fn native2io(e: native_tls::Error) -> io::Error { - io::Error::new(io::ErrorKind::Other, e) -} - #[test] fn fetch_google() { drop(env_logger::try_init()); @@ -67,8 +65,7 @@ fn fetch_google() { // Send off the request by first negotiating an SSL handshake, then writing // of our request, then flushing, then finally read off the response. let builder = TlsConnector::builder(); - let cx = t!(builder.build()); - let connector = tls_async::TlsConnector::from(cx); + let connector = t!(builder.build()); println!("Attempting tls connection"); @@ -105,7 +102,6 @@ fn wrong_hostname_error() { let socket = t!(await!(TcpStream::connect(&addr))); let builder = TlsConnector::builder(); let connector = t!(builder.build()); - let connector = tls_async::TlsConnector::from(connector); await!(connector.connect("rust-lang.org", socket)) }; @@ -113,5 +109,5 @@ fn wrong_hostname_error() { let res = rt.block_on(fut_result.fuse().boxed().compat()); assert!(res.is_err()); - assert_bad_hostname_error(&native2io(res.err().unwrap())); + assert_bad_hostname_error(res.err().unwrap()); } diff --git a/tests/smoke.rs b/tests/smoke.rs index ff12d6a..850096f 100644 --- a/tests/smoke.rs +++ b/tests/smoke.rs @@ -2,11 +2,10 @@ use std::io::Write; use std::process::Command; -use tls_async::{TlsAcceptor as TlsAsyncAcceptor, TlsConnector as TlsAsyncConnector}; +use tls_async::{Identity, TlsAcceptor, TlsConnector}; use cfg_if::cfg_if; use futures::io::{AsyncReadExt, AsyncWriteExt}; use futures::{FutureExt, StreamExt, TryFutureExt}; -use native_tls::{TlsConnector as NativeTlsConnector, TlsAcceptor as NativeTlsAcceptor, Identity}; use romio::{TcpStream, TcpListener}; macro_rules! t { @@ -207,15 +206,15 @@ cfg_if! { use std::env; use std::sync::{Once, ONCE_INIT}; - fn contexts() -> (TlsAsyncAcceptor, TlsAsyncConnector) { + fn contexts() -> (TlsAcceptor, TlsConnector) { let keys = openssl_keys(); let pkcs12 = t!(Identity::from_pkcs12(&keys.pkcs12_der, "foobar")); - let srv = NativeTlsAcceptor::builder(pkcs12); + let srv = TlsAcceptor::builder(pkcs12); - let cert = t!(native_tls::Certificate::from_der(&keys.cert_der)); + let cert = t!(tls_async::Certificate::from_der(&keys.cert_der)); - let mut client = NativeTlsConnector::builder(); + let mut client = TlsConnector::builder(); t!(client.add_root_certificate(cert).build()); (t!(srv.build()).into(), t!(client.build()).into()) @@ -226,14 +225,14 @@ cfg_if! { use std::env; use std::fs::File; - fn contexts() -> (TlsAsyncAcceptor, TlsAsyncConnector) { + fn contexts() -> (TlsAcceptor, TlsConnector) { let keys = openssl_keys(); let pkcs12 = t!(Identity::from_pkcs12(&keys.pkcs12_der, "foobar")); - let srv = NativeTlsAcceptor::builder(pkcs12); + let srv = TlsAcceptor::builder(pkcs12); - let cert = native_tls::Certificate::from_der(&keys.cert_der).unwrap(); - let mut client = NativeTlsConnector::builder(); + let cert = tls_async::Certificate::from_der(&keys.cert_der).unwrap(); + let mut client = TlsConnector::builder(); client.add_root_certificate(cert); (t!(srv.build()).into(), t!(client.build()).into()) @@ -262,15 +261,15 @@ cfg_if! { const FRIENDLY_NAME: &'static str = "tls-async localhost testing cert"; - fn contexts() -> (TlsAsyncAcceptor, TlsAsyncConnector) { + fn contexts() -> (TlsAcceptor, TlsConnector) { let cert = localhost_cert(); let mut store = t!(Memory::new()).into_store(); t!(store.add_cert(&cert, CertAdd::Always)); let pkcs12_der = t!(store.export_pkcs12("foobar")); let pkcs12 = t!(Identity::from_pkcs12(&pkcs12_der, "foobar")); - let srv = NativeTlsAcceptor::builder(pkcs12); - let client = NativeTlsConnector::builder(); + let srv = TlsAcceptor::builder(pkcs12); + let client = TlsConnector::builder(); (t!(srv.build()).into(), t!(client.build()).into()) }