From 956094780486d5145bde6ab80887d4e235f6e35f Mon Sep 17 00:00:00 2001 From: neonphog Date: Mon, 6 May 2024 14:01:33 -0600 Subject: [PATCH] test fix --- crates/tx5-connection/src/conn.rs | 116 +++++++++++++++++++----------- crates/tx5-connection/src/hub.rs | 15 ++-- crates/tx5-connection/src/test.rs | 46 ++++++++++++ crates/tx5-signal/src/conn.rs | 7 ++ crates/tx5-signal/src/wire.rs | 11 +++ crates/tx5/src/sig.rs | 1 + crates/tx5/src/test.rs | 15 ++-- 7 files changed, 159 insertions(+), 52 deletions(-) diff --git a/crates/tx5-connection/src/conn.rs b/crates/tx5-connection/src/conn.rs index 2f40f999..8ef26e59 100644 --- a/crates/tx5-connection/src/conn.rs +++ b/crates/tx5-connection/src/conn.rs @@ -22,18 +22,26 @@ pub struct Conn { pub_key: PubKey, client: Weak, conn_task: tokio::task::JoinHandle<()>, + keepalive_task: tokio::task::JoinHandle<()>, } impl Drop for Conn { fn drop(&mut self) { self.conn_task.abort(); + self.keepalive_task.abort(); } } impl Conn { + #[cfg(test)] + pub(crate) fn test_kill_keepalive_task(&self) { + self.keepalive_task.abort(); + } + pub(crate) fn priv_new( pub_key: PubKey, client: Weak, + config: Arc, ) -> (Arc, ConnRecv, Arc>) { // zero len semaphore.. we actually just wait for the close let ready = Arc::new(tokio::sync::Semaphore::new(0)); @@ -42,6 +50,23 @@ impl Conn { let (cmd_send, mut cmd_recv) = tokio::sync::mpsc::channel(32); let cmd_send = Arc::new(cmd_send); + let keepalive_dur = config.max_idle / 2; + let client2 = client.clone(); + let pub_key2 = pub_key.clone(); + let keepalive_task = tokio::task::spawn(async move { + loop { + tokio::time::sleep(keepalive_dur).await; + + if let Some(client) = client2.upgrade() { + if client.send_keepalive(&pub_key2).await.is_err() { + break; + } + } else { + break; + } + } + }); + let ready2 = ready.clone(); let client2 = client.clone(); let pub_key2 = pub_key.clone(); @@ -51,56 +76,51 @@ impl Conn { None => return, }; - match tokio::time::timeout( - std::time::Duration::from_secs(10), - async { - let nonce = client.send_handshake_req(&pub_key2).await?; - - let mut got_peer_res = false; - let mut sent_our_res = false; - - while let Some(cmd) = cmd_recv.recv().await { - match cmd { - ConnCmd::SigRecv(sig) => { - use tx5_signal::SignalMessage::*; - match sig { - HandshakeReq(oth_nonce) => { - client - .send_handshake_res( - &pub_key2, oth_nonce, - ) - .await?; - sent_our_res = true; - } - HandshakeRes(res_nonce) => { - if res_nonce != nonce { - return Err(Error::other( - "nonce mismatch", - )); - } - got_peer_res = true; - } - _ => { + match tokio::time::timeout(config.max_idle, async { + let nonce = client.send_handshake_req(&pub_key2).await?; + + let mut got_peer_res = false; + let mut sent_our_res = false; + + while let Some(cmd) = cmd_recv.recv().await { + match cmd { + ConnCmd::SigRecv(sig) => { + use tx5_signal::SignalMessage::*; + match sig { + HandshakeReq(oth_nonce) => { + client + .send_handshake_res( + &pub_key2, oth_nonce, + ) + .await?; + sent_our_res = true; + } + HandshakeRes(res_nonce) => { + if res_nonce != nonce { return Err(Error::other( - "invalid message during handshake", + "nonce mismatch", )); } + got_peer_res = true; + } + _ => { + return Err(Error::other( + "invalid message during handshake", + )); } - } - ConnCmd::Close => { - return Err(Error::other( - "close during handshake", - )) } } - if got_peer_res && sent_our_res { - break; + ConnCmd::Close => { + return Err(Error::other("close during handshake")) } } + if got_peer_res && sent_our_res { + break; + } + } - Result::Ok(()) - }, - ) + Result::Ok(()) + }) .await { Err(_) | Ok(Err(_)) => { @@ -115,12 +135,16 @@ impl Conn { // closing the semaphore causes all the acquire awaits to end ready2.close(); - while let Some(cmd) = cmd_recv.recv().await { + while let Ok(Some(cmd)) = + tokio::time::timeout(config.max_idle, cmd_recv.recv()).await + { match cmd { ConnCmd::SigRecv(sig) => { use tx5_signal::SignalMessage::*; #[allow(clippy::single_match)] // placeholder match sig { + // invalid + HandshakeReq(_) | HandshakeRes(_) => break, Message(msg) => { if msg_send.send(msg).await.is_err() { break; @@ -132,6 +156,13 @@ impl Conn { ConnCmd::Close => break, } } + + // explicitly close the peer + if let Some(client) = client2.upgrade() { + client.close_peer(&pub_key2).await; + }; + + // the receiver side is closed because msg_send is dropped. }); ( @@ -140,6 +171,7 @@ impl Conn { pub_key, client, conn_task, + keepalive_task, }), ConnRecv(msg_recv), cmd_send, diff --git a/crates/tx5-connection/src/hub.rs b/crates/tx5-connection/src/hub.rs index 5491aab4..95ceb044 100644 --- a/crates/tx5-connection/src/hub.rs +++ b/crates/tx5-connection/src/hub.rs @@ -28,6 +28,7 @@ async fn hub_map_assert( pub_key: PubKey, map: &mut HubMap, client: &Arc, + config: &Arc, ) -> Result<( Option, Arc, @@ -56,7 +57,7 @@ async fn hub_map_assert( // we're connected to the peer, create a connection let (conn, recv, cmd_send) = - Conn::priv_new(pub_key.clone(), Arc::downgrade(client)); + Conn::priv_new(pub_key.clone(), Arc::downgrade(client), config.clone()); let weak_conn = Arc::downgrade(&conn); @@ -112,7 +113,7 @@ impl Hub { config: Arc, ) -> Result<(Self, HubRecv)> { let (client, mut recv) = - tx5_signal::SignalConnection::connect(url, config).await?; + tx5_signal::SignalConnection::connect(url, config.clone()).await?; let client = Arc::new(client); tracing::debug!(%url, pub_key = ?client.pub_key(), "hub connected"); @@ -147,7 +148,7 @@ impl Hub { HubCmd::CliRecv { pub_key, msg } => { if let Some(client) = weak_client.upgrade() { let (recv, conn, cmd_send) = match hub_map_assert( - pub_key, &mut map, &client, + pub_key, &mut map, &client, &config, ) .await { @@ -171,9 +172,11 @@ impl Hub { HubCmd::Connect { pub_key, resp } => { if let Some(client) = weak_client.upgrade() { let _ = resp.send( - hub_map_assert(pub_key, &mut map, &client) - .await - .map(|(recv, conn, _)| (recv, conn)), + hub_map_assert( + pub_key, &mut map, &client, &config, + ) + .await + .map(|(recv, conn, _)| (recv, conn)), ); } else { break; diff --git a/crates/tx5-connection/src/test.rs b/crates/tx5-connection/src/test.rs index b264e54b..f2a9ceec 100644 --- a/crates/tx5-connection/src/test.rs +++ b/crates/tx5-connection/src/test.rs @@ -34,6 +34,7 @@ impl TestSrv { Arc::new(tx5_signal::SignalConfig { listener: true, allow_plain_text: true, + max_idle: std::time::Duration::from_secs(1), ..Default::default() }), ) @@ -47,6 +48,51 @@ impl TestSrv { } } +#[tokio::test(flavor = "multi_thread")] +async fn base_timeout() { + init_tracing(); + + let srv = TestSrv::new().await; + + let (hub1, _hubr1) = srv.hub().await; + let pk1 = hub1.pub_key().clone(); + + let (hub2, mut hubr2) = srv.hub().await; + let pk2 = hub2.pub_key().clone(); + + println!("connect"); + let (c1, mut r1) = hub1.connect(pk2).await.unwrap(); + c1.test_kill_keepalive_task(); + println!("accept"); + let (c2, mut r2) = hubr2.accept().await.unwrap(); + c2.test_kill_keepalive_task(); + + assert_eq!(&pk1, c2.pub_key()); + + println!("await ready"); + tokio::join!(c1.ready(), c2.ready()); + println!("ready"); + + c1.send(b"hello".to_vec()).await.unwrap(); + assert_eq!(b"hello", r2.recv().await.unwrap().as_slice()); + + c2.send(b"world".to_vec()).await.unwrap(); + assert_eq!(b"world", r1.recv().await.unwrap().as_slice()); + + match tokio::time::timeout(std::time::Duration::from_secs(3), async { + tokio::join!(r1.recv(), r2.recv()) + }) + .await + { + Err(_) => panic!("recv failed to time out"), + Ok((None, None)) => (), // correct, they both exited + _ => panic!("unexpected success"), + } + + assert!(c1.send(b"foo".to_vec()).await.is_err()); + assert!(c2.send(b"bar".to_vec()).await.is_err()); +} + #[tokio::test(flavor = "multi_thread")] async fn sanity() { init_tracing(); diff --git a/crates/tx5-signal/src/conn.rs b/crates/tx5-signal/src/conn.rs index 432dde81..ed25fee2 100644 --- a/crates/tx5-signal/src/conn.rs +++ b/crates/tx5-signal/src/conn.rs @@ -146,4 +146,11 @@ impl SignalConnection { self.client.send(pub_key, &msg).await?; Ok(()) } + + /// Keepalive. + pub async fn send_keepalive(&self, pub_key: &PubKey) -> Result<()> { + let msg = SignalMessage::keepalive(); + self.client.send(pub_key, &msg).await?; + Ok(()) + } } diff --git a/crates/tx5-signal/src/wire.rs b/crates/tx5-signal/src/wire.rs index 2acf6bd2..90ca88b0 100644 --- a/crates/tx5-signal/src/wire.rs +++ b/crates/tx5-signal/src/wire.rs @@ -7,6 +7,7 @@ const F_OFFR: &[u8] = b"offr"; const F_ANSW: &[u8] = b"answ"; const F_ICEM: &[u8] = b"icem"; const F_FMSG: &[u8] = b"fmsg"; +const F_KEEP: &[u8] = b"keep"; /// Parsed signal message. pub enum SignalMessage { @@ -31,6 +32,9 @@ pub enum SignalMessage { /// Pre-webrtc and webrtc failure fallback communication message. Message(Vec), + /// Keepalive + Keepalive, + /// Message type not understood by this client. Unknown, } @@ -45,6 +49,7 @@ impl std::fmt::Debug for SignalMessage { Self::Answer(_) => f.write_str("Answer"), Self::Ice(_) => f.write_str("Ice"), Self::Message(_) => f.write_str("Message"), + Self::Keepalive => f.write_str("Keepalive"), Self::Unknown => f.write_str("Unknown"), } } @@ -109,6 +114,11 @@ impl SignalMessage { Ok(msg) } + /// Keepalive. + pub(crate) fn keepalive() -> Vec { + F_KEEP.to_vec() + } + /// Parse a raw received buffer into a signal message. pub(crate) fn parse(mut b: Vec) -> Result { if b.len() < 4 { @@ -148,6 +158,7 @@ impl SignalMessage { let _ = b.drain(..4); Ok(SignalMessage::Message(b)) } + F_KEEP => Ok(SignalMessage::Keepalive), _ => Ok(SignalMessage::Unknown), } } diff --git a/crates/tx5/src/sig.rs b/crates/tx5/src/sig.rs index 8c4f563b..83be6b77 100644 --- a/crates/tx5/src/sig.rs +++ b/crates/tx5/src/sig.rs @@ -85,6 +85,7 @@ async fn connect_loop( let signal_config = Arc::new(SignalConfig { listener, allow_plain_text: config.signal_allow_plain_text, + max_idle: config.timeout, ..Default::default() }); diff --git a/crates/tx5/src/test.rs b/crates/tx5/src/test.rs index eb27740c..9ba9455c 100644 --- a/crates/tx5/src/test.rs +++ b/crates/tx5/src/test.rs @@ -248,9 +248,15 @@ async fn ep_sig_down() { ep1.send(ep2.peer_url(), b"hello".to_vec()).await.unwrap(); - let (from, msg) = ep2.recv().await.unwrap(); - assert_eq!(ep1.peer_url(), from); - assert_eq!(&b"hello"[..], &msg); + loop { + let (from, msg) = ep2.recv().await.unwrap(); + if &msg[..3] == b"<<<" { + continue; + } + assert_eq!(ep1.peer_url(), from); + assert_eq!(&b"hello"[..], &msg); + break; + } eprintln!("-- Done --"); } @@ -518,6 +524,7 @@ async fn ep_preflight_happy() { async fn ep_close_connection() { let config = Arc::new(Config { signal_allow_plain_text: true, + timeout: std::time::Duration::from_secs(2), ..Default::default() }); let test = Test::new().await; @@ -534,7 +541,7 @@ async fn ep_close_connection() { ep1.close(&ep2.peer_url()); let (url, message) = ep2_recv.recv().await.unwrap(); - assert_eq!(ep2.peer_url(), url); + assert_eq!(ep1.peer_url(), url); assert_eq!(&b"<<>>"[..], &message); }