Skip to content

Commit

Permalink
Fix test flakiness from TCP port contention (#388)
Browse files Browse the repository at this point in the history
Eliminate the lock contention on test service TCP sockets leading to random,
frequent test failure.
  • Loading branch information
DanGould authored Dec 31, 2024
2 parents e54e51f + 3c20429 commit d940ed2
Show file tree
Hide file tree
Showing 7 changed files with 178 additions and 97 deletions.
4 changes: 2 additions & 2 deletions Cargo-minimal.lock
Original file line number Diff line number Diff line change
Expand Up @@ -1450,9 +1450,9 @@ dependencies = [

[[package]]
name = "ohttp-relay"
version = "0.0.8"
version = "0.0.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7850c40a0aebcba289d3252c0a45f93cba6ad4b0c46b88a5fc51dba6ddce8632"
checksum = "4f8e8aef13b8327b680aaaca807aa11ba5979fc5858203e7b77c68128ede61a2"
dependencies = [
"futures",
"http",
Expand Down
4 changes: 2 additions & 2 deletions Cargo-recent.lock
Original file line number Diff line number Diff line change
Expand Up @@ -1450,9 +1450,9 @@ dependencies = [

[[package]]
name = "ohttp-relay"
version = "0.0.8"
version = "0.0.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7850c40a0aebcba289d3252c0a45f93cba6ad4b0c46b88a5fc51dba6ddce8632"
checksum = "4f8e8aef13b8327b680aaaca807aa11ba5979fc5858203e7b77c68128ede61a2"
dependencies = [
"futures",
"http",
Expand Down
2 changes: 1 addition & 1 deletion payjoin-cli/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ url = { version = "2.3.1", features = ["serde"] }
[dev-dependencies]
bitcoind = { version = "0.36.0", features = ["0_21_2"] }
http = "1"
ohttp-relay = "0.0.8"
ohttp-relay = { version = "0.0.9", features = ["_test-util"] }
once_cell = "1"
payjoin-directory = { path = "../payjoin-directory", features = ["_danger-local-https"] }
testcontainers = "0.15.0"
Expand Down
47 changes: 30 additions & 17 deletions payjoin-cli/tests/e2e.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,11 @@ mod e2e {
payjoin_sent.unwrap().unwrap_or(Some(false)).unwrap(),
"Payjoin send was not detected"
);

fn find_free_port() -> u16 {
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
listener.local_addr().unwrap().port()
}
}

#[cfg(feature = "v2")]
Expand All @@ -170,6 +175,7 @@ mod e2e {
use url::Url;

type Error = Box<dyn std::error::Error + 'static>;
type BoxSendSyncError = Box<dyn std::error::Error + Send + Sync>;
type Result<T> = std::result::Result<T, Error>;

static INIT_TRACING: OnceCell<()> = OnceCell::new();
Expand All @@ -178,18 +184,26 @@ mod e2e {

init_tracing();
let (cert, key) = local_cert_key();
let ohttp_relay_port = find_free_port();
let ohttp_relay = Url::parse(&format!("http://localhost:{}", ohttp_relay_port)).unwrap();
let directory_port = find_free_port();
let directory = Url::parse(&format!("https://localhost:{}", directory_port)).unwrap();
let docker: Cli = Cli::default();
let db = docker.run(Redis);
let db_host = format!("127.0.0.1:{}", db.get_host_port_ipv4(6379));
let (port, directory_handle) =
init_directory(db_host, (cert.clone(), key)).await.expect("Failed to init directory");
let directory = Url::parse(&format!("https://localhost:{}", port)).unwrap();

let gateway_origin = http::Uri::from_str(directory.as_str()).unwrap();
let (ohttp_relay_port, ohttp_relay_handle) =
ohttp_relay::listen_tcp_on_free_port(gateway_origin)
.await
.expect("Failed to init ohttp relay");
let ohttp_relay = Url::parse(&format!("http://localhost:{}", ohttp_relay_port)).unwrap();

let temp_dir = env::temp_dir();
let receiver_db_path = temp_dir.join("receiver_db");
let sender_db_path = temp_dir.join("sender_db");
let result: Result<()> = tokio::select! {
res = ohttp_relay::listen_tcp(ohttp_relay_port, gateway_origin) => Err(format!("Ohttp relay is long running: {:?}", res).into()),
res = init_directory(directory_port, (cert.clone(), key)) => Err(format!("Directory server is long running: {:?}", res).into()),
res = ohttp_relay_handle => Err(format!("Ohttp relay is long running: {:?}", res).into()),
res = directory_handle => Err(format!("Directory server is long running: {:?}", res).into()),
res = send_receive_cli_async(ohttp_relay, directory, cert, receiver_db_path.clone(), sender_db_path.clone()) => res.map_err(|e| format!("send_receive failed: {:?}", e).into()),
};

Expand Down Expand Up @@ -479,13 +493,17 @@ mod e2e {
Err("Timeout waiting for service to be ready".into())
}

async fn init_directory(port: u16, local_cert_key: (Vec<u8>, Vec<u8>)) -> Result<()> {
let docker: Cli = Cli::default();
async fn init_directory(
db_host: String,
local_cert_key: (Vec<u8>, Vec<u8>),
) -> std::result::Result<
(u16, tokio::task::JoinHandle<std::result::Result<(), BoxSendSyncError>>),
BoxSendSyncError,
> {
println!("Database running on {}", db_host);
let timeout = Duration::from_secs(2);
let db = docker.run(Redis);
let db_host = format!("127.0.0.1:{}", db.get_host_port_ipv4(6379));
println!("Database running on {}", db.get_host_port_ipv4(6379));
payjoin_directory::listen_tcp_with_tls(port, db_host, timeout, local_cert_key).await
payjoin_directory::listen_tcp_with_tls_on_free_port(db_host, timeout, local_cert_key)
.await
}

// generates or gets a DER encoded localhost cert and key.
Expand Down Expand Up @@ -524,11 +542,6 @@ mod e2e {
}
}

fn find_free_port() -> u16 {
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
listener.local_addr().unwrap().port()
}

async fn cleanup_temp_file(path: &std::path::Path) {
if let Err(e) = fs::remove_dir_all(path).await {
eprintln!("Failed to remove {:?}: {}", path, e);
Expand Down
101 changes: 66 additions & 35 deletions payjoin-directory/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,67 @@ const ID_LENGTH: usize = 13;
mod db;
use crate::db::DbPool;

#[cfg(feature = "_danger-local-https")]
type BoxError = Box<dyn std::error::Error + Send + Sync>;

#[cfg(feature = "_danger-local-https")]
pub async fn listen_tcp_with_tls_on_free_port(
db_host: String,
timeout: Duration,
cert_key: (Vec<u8>, Vec<u8>),
) -> Result<(u16, tokio::task::JoinHandle<Result<(), BoxError>>), BoxError> {
let listener = tokio::net::TcpListener::bind("[::]:0").await?;
let port = listener.local_addr()?.port();
println!("Directory server binding to port {}", listener.local_addr()?);
let handle = listen_tcp_with_tls_on_listener(listener, db_host, timeout, cert_key).await?;
Ok((port, handle))
}

// Helper function to avoid code duplication
#[cfg(feature = "_danger-local-https")]
async fn listen_tcp_with_tls_on_listener(
listener: tokio::net::TcpListener,
db_host: String,
timeout: Duration,
tls_config: (Vec<u8>, Vec<u8>),
) -> Result<tokio::task::JoinHandle<Result<(), BoxError>>, BoxError> {
let pool = DbPool::new(timeout, db_host).await?;
let ohttp = Arc::new(Mutex::new(init_ohttp()?));
let tls_acceptor = init_tls_acceptor(tls_config)?;
// Spawn the connection handling loop in a separate task
let handle = tokio::spawn(async move {
while let Ok((stream, _)) = listener.accept().await {
let pool = pool.clone();
let ohttp = ohttp.clone();
let tls_acceptor = tls_acceptor.clone();
tokio::spawn(async move {
let tls_stream = match tls_acceptor.accept(stream).await {
Ok(tls_stream) => tls_stream,
Err(e) => {
error!("TLS accept error: {}", e);
return;
}
};
if let Err(err) = http1::Builder::new()
.serve_connection(
TokioIo::new(tls_stream),
service_fn(move |req| {
serve_payjoin_directory(req, pool.clone(), ohttp.clone())
}),
)
.with_upgrades()
.await
{
error!("Error serving connection: {:?}", err);
}
});
}
Ok(())
});
Ok(handle)
}

// Modify existing listen_tcp_with_tls to use the new helper
pub async fn listen_tcp(
port: u16,
db_host: String,
Expand Down Expand Up @@ -73,41 +134,11 @@ pub async fn listen_tcp_with_tls(
port: u16,
db_host: String,
timeout: Duration,
tls_config: (Vec<u8>, Vec<u8>),
) -> Result<(), Box<dyn std::error::Error>> {
let pool = DbPool::new(timeout, db_host).await?;
let ohttp = Arc::new(Mutex::new(init_ohttp()?));
let bind_addr = SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), port);
let tls_acceptor = init_tls_acceptor(tls_config)?;
let listener = TcpListener::bind(bind_addr).await?;
while let Ok((stream, _)) = listener.accept().await {
let pool = pool.clone();
let ohttp = ohttp.clone();
let tls_acceptor = tls_acceptor.clone();
tokio::spawn(async move {
let tls_stream = match tls_acceptor.accept(stream).await {
Ok(tls_stream) => tls_stream,
Err(e) => {
error!("TLS accept error: {}", e);
return;
}
};
if let Err(err) = http1::Builder::new()
.serve_connection(
TokioIo::new(tls_stream),
service_fn(move |req| {
serve_payjoin_directory(req, pool.clone(), ohttp.clone())
}),
)
.with_upgrades()
.await
{
error!("Error serving connection: {:?}", err);
}
});
}

Ok(())
cert_key: (Vec<u8>, Vec<u8>),
) -> Result<tokio::task::JoinHandle<Result<(), BoxError>>, BoxError> {
let addr = format!("0.0.0.0:{}", port);
let listener = tokio::net::TcpListener::bind(&addr).await?;
listen_tcp_with_tls_on_listener(listener, db_host, timeout, cert_key).await
}

#[cfg(feature = "_danger-local-https")]
Expand Down
2 changes: 1 addition & 1 deletion payjoin/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ serde_json = "1.0.108"
bitcoind = { version = "0.36.0", features = ["0_21_2"] }
http = "1"
payjoin-directory = { path = "../payjoin-directory", features = ["_danger-local-https"] }
ohttp-relay = "0.0.8"
ohttp-relay = { version = "0.0.9", features = ["_test-util"] }
once_cell = "1"
rcgen = { version = "0.11" }
reqwest = { version = "0.12", default-features = false, features = ["rustls-tls"] }
Expand Down
Loading

0 comments on commit d940ed2

Please sign in to comment.