From 45efb09afdea44706b774fcc324d0b6baabf40d8 Mon Sep 17 00:00:00 2001 From: Sam Clark <3758302+goatgoose@users.noreply.github.com> Date: Mon, 18 Nov 2024 12:24:18 -0500 Subject: [PATCH] test(s2n-tls-hyper): Add localhost http tests (#4838) --- bindings/rust/s2n-tls-hyper/Cargo.toml | 1 + .../rust/s2n-tls-hyper/tests/common/echo.rs | 85 +++++++++++ .../rust/s2n-tls-hyper/tests/common/mod.rs | 27 ++++ bindings/rust/s2n-tls-hyper/tests/http.rs | 140 ++++++++++++++++++ 4 files changed, 253 insertions(+) create mode 100644 bindings/rust/s2n-tls-hyper/tests/common/echo.rs create mode 100644 bindings/rust/s2n-tls-hyper/tests/common/mod.rs create mode 100644 bindings/rust/s2n-tls-hyper/tests/http.rs diff --git a/bindings/rust/s2n-tls-hyper/Cargo.toml b/bindings/rust/s2n-tls-hyper/Cargo.toml index a83970e12d3..8cdf850b3bd 100644 --- a/bindings/rust/s2n-tls-hyper/Cargo.toml +++ b/bindings/rust/s2n-tls-hyper/Cargo.toml @@ -23,4 +23,5 @@ http = { version= "1" } [dev-dependencies] tokio = { version = "1", features = ["macros", "test-util"] } http-body-util = "0.1" +hyper-util = { version = "0.1", features = ["server"] } bytes = "1" diff --git a/bindings/rust/s2n-tls-hyper/tests/common/echo.rs b/bindings/rust/s2n-tls-hyper/tests/common/echo.rs new file mode 100644 index 00000000000..044a99d775d --- /dev/null +++ b/bindings/rust/s2n-tls-hyper/tests/common/echo.rs @@ -0,0 +1,85 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use bytes::Bytes; +use http::{Request, Response}; +use http_body_util::{combinators::BoxBody, BodyExt}; +use hyper::service::service_fn; +use hyper_util::rt::{TokioExecutor, TokioIo}; +use s2n_tls::connection::Builder; +use s2n_tls_tokio::TlsAcceptor; +use std::{error::Error, future::Future}; +use tokio::net::TcpListener; + +async fn echo( + req: Request, +) -> Result>, hyper::Error> { + Ok(Response::new(req.into_body().boxed())) +} + +async fn serve_echo( + tcp_listener: TcpListener, + builder: B, +) -> Result<(), Box> +where + B: Builder, + ::Output: Unpin + Send + Sync + 'static, +{ + let (tcp_stream, _) = tcp_listener.accept().await?; + let acceptor = TlsAcceptor::new(builder); + let tls_stream = acceptor.accept(tcp_stream).await?; + let io = TokioIo::new(tls_stream); + + let server = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new()); + if let Err(err) = server.serve_connection(io, service_fn(echo)).await { + // The hyper client doesn't gracefully terminate by waiting for the server's shutdown. + // Instead, the client sends its shutdown and then immediately closes the socket. This can + // cause a NotConnected error to be emitted when the server attempts to send its shutdown. + // + // For now, NotConnected errors are ignored. After the hyper client can be configured to + // gracefully shutdown, this exception can be removed: + // https://github.com/aws/s2n-tls/issues/4855 + // + // Also, it's possible that a NotConnected error could occur during some operation other + // than a shutdown. Ideally, these NotConnected errors wouldn't be ignored. However, it's + // not currently possible to distinguish between shutdown vs non-shutdown errors: + // https://github.com/aws/s2n-tls/issues/4856 + if let Some(hyper_err) = err.downcast_ref::() { + if let Some(source) = hyper_err.source() { + if let Some(io_err) = source.downcast_ref::() { + if io_err.kind() == tokio::io::ErrorKind::NotConnected { + return Ok(()); + } + } + } + } + + return Err(err); + } + + Ok(()) +} + +pub async fn make_echo_request( + server_builder: B, + send_client_request: F, +) -> Result<(), Box> +where + B: Builder + Send + Sync + 'static, + ::Output: Unpin + Send + Sync + 'static, + F: FnOnce(u16) -> Fut, + Fut: Future>> + Send + 'static, +{ + let listener = TcpListener::bind("127.0.0.1:0").await?; + let addr = listener.local_addr()?; + + let mut tasks = tokio::task::JoinSet::new(); + tasks.spawn(serve_echo(listener, server_builder)); + tasks.spawn(send_client_request(addr.port())); + + while let Some(res) = tasks.join_next().await { + res.unwrap()?; + } + + Ok(()) +} diff --git a/bindings/rust/s2n-tls-hyper/tests/common/mod.rs b/bindings/rust/s2n-tls-hyper/tests/common/mod.rs new file mode 100644 index 00000000000..148462d2d12 --- /dev/null +++ b/bindings/rust/s2n-tls-hyper/tests/common/mod.rs @@ -0,0 +1,27 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use s2n_tls::{callbacks::VerifyHostNameCallback, config, error::Error, security::DEFAULT_TLS13}; + +pub mod echo; + +/// NOTE: this certificate and key are used for testing purposes only! +pub static CERT_PEM: &[u8] = + include_bytes!(concat!(env!("CARGO_MANIFEST_DIR"), "/../certs/cert.pem")); +pub static KEY_PEM: &[u8] = + include_bytes!(concat!(env!("CARGO_MANIFEST_DIR"), "/../certs/key.pem")); + +pub fn config() -> Result { + let mut builder = config::Config::builder(); + builder.set_security_policy(&DEFAULT_TLS13)?; + builder.trust_pem(CERT_PEM)?; + builder.load_pem(CERT_PEM, KEY_PEM)?; + Ok(builder) +} + +pub struct InsecureAcceptAllCertificatesHandler {} +impl VerifyHostNameCallback for InsecureAcceptAllCertificatesHandler { + fn verify_host_name(&self, _host_name: &str) -> bool { + true + } +} diff --git a/bindings/rust/s2n-tls-hyper/tests/http.rs b/bindings/rust/s2n-tls-hyper/tests/http.rs new file mode 100644 index 00000000000..0a5469b45de --- /dev/null +++ b/bindings/rust/s2n-tls-hyper/tests/http.rs @@ -0,0 +1,140 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use crate::common::InsecureAcceptAllCertificatesHandler; +use bytes::Bytes; +use http::{Method, Request, Uri}; +use http_body_util::{BodyExt, Empty, Full}; +use hyper_util::{client::legacy::Client, rt::TokioExecutor}; +use s2n_tls::{ + callbacks::{ClientHelloCallback, ConnectionFuture}, + connection::Connection, +}; +use s2n_tls_hyper::connector::HttpsConnector; +use std::{error::Error, pin::Pin, str::FromStr}; + +pub mod common; + +const TEST_DATA: &[u8] = "hello world".as_bytes(); + +// The maximum TLS record payload is 2^14 bytes. +// Send more to ensure multiple records. +const LARGE_TEST_DATA: &[u8] = &[5; (1 << 15)]; + +#[tokio::test] +async fn test_get_request() -> Result<(), Box> { + let config = common::config()?.build()?; + common::echo::make_echo_request(config.clone(), |port| async move { + let connector = HttpsConnector::new(config.clone()); + let client: Client<_, Empty> = + Client::builder(TokioExecutor::new()).build(connector); + + let uri = Uri::from_str(format!("https://localhost:{}", port).as_str())?; + let response = client.get(uri).await?; + assert_eq!(response.status(), 200); + + Ok(()) + }) + .await?; + + Ok(()) +} + +#[tokio::test] +async fn test_http_methods() -> Result<(), Box> { + let methods = [Method::GET, Method::POST, Method::PUT, Method::DELETE]; + for method in methods { + let config = common::config()?.build()?; + common::echo::make_echo_request(config.clone(), |port| async move { + let connector = HttpsConnector::new(config.clone()); + let client: Client<_, Full> = + Client::builder(TokioExecutor::new()).build(connector); + let request: Request> = Request::builder() + .method(method) + .uri(Uri::from_str( + format!("https://localhost:{}", port).as_str(), + )?) + .body(Full::from(TEST_DATA))?; + + let response = client.request(request).await?; + assert_eq!(response.status(), 200); + + let body = response.into_body().collect().await?.to_bytes(); + assert_eq!(body.to_vec().as_slice(), TEST_DATA); + + Ok(()) + }) + .await?; + } + + Ok(()) +} + +#[tokio::test] +async fn test_large_request() -> Result<(), Box> { + let config = common::config()?.build()?; + common::echo::make_echo_request(config.clone(), |port| async move { + let connector = HttpsConnector::new(config.clone()); + let client: Client<_, Full> = Client::builder(TokioExecutor::new()).build(connector); + let request: Request> = Request::builder() + .method(Method::POST) + .uri(Uri::from_str( + format!("https://localhost:{}", port).as_str(), + )?) + .body(Full::from(LARGE_TEST_DATA))?; + + let response = client.request(request).await?; + assert_eq!(response.status(), 200); + + let body = response.into_body().collect().await?.to_bytes(); + assert_eq!(body.to_vec().as_slice(), LARGE_TEST_DATA); + + Ok(()) + }) + .await?; + + Ok(()) +} + +#[tokio::test] +async fn test_sni() -> Result<(), Box> { + struct TestClientHelloHandler { + expected_server_name: &'static str, + } + impl ClientHelloCallback for TestClientHelloHandler { + fn on_client_hello( + &self, + connection: &mut Connection, + ) -> Result>>, s2n_tls::error::Error> { + let server_name = connection.server_name().unwrap(); + assert_eq!(server_name, self.expected_server_name); + Ok(None) + } + } + + for hostname in ["localhost", "127.0.0.1"] { + let mut config = common::config()?; + config.set_client_hello_callback(TestClientHelloHandler { + // Ensure that the HttpsConnector correctly sets the SNI according to the hostname in + // the URI. + expected_server_name: hostname, + })?; + config.set_verify_host_callback(InsecureAcceptAllCertificatesHandler {})?; + let config = config.build()?; + + common::echo::make_echo_request(config.clone(), |port| async move { + let connector = HttpsConnector::new(config.clone()); + let client: Client<_, Empty> = + Client::builder(TokioExecutor::new()).build(connector); + + let uri = Uri::from_str(format!("https://{}:{}", hostname, port).as_str())?; + let response = client.get(uri).await?; + assert_eq!(response.status(), 200); + + Ok(()) + }) + .await?; + } + + Ok(()) +}