Skip to content

Commit

Permalink
test fix
Browse files Browse the repository at this point in the history
  • Loading branch information
neonphog committed May 6, 2024
1 parent 30da062 commit 9560947
Show file tree
Hide file tree
Showing 7 changed files with 159 additions and 52 deletions.
116 changes: 74 additions & 42 deletions crates/tx5-connection/src/conn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,26 @@ pub struct Conn {
pub_key: PubKey,
client: Weak<tx5_signal::SignalConnection>,
conn_task: tokio::task::JoinHandle<()>,
keepalive_task: tokio::task::JoinHandle<()>,
}

impl Drop for Conn {
fn drop(&mut self) {
self.conn_task.abort();
self.keepalive_task.abort();
}
}

impl Conn {
#[cfg(test)]
pub(crate) fn test_kill_keepalive_task(&self) {
self.keepalive_task.abort();
}

pub(crate) fn priv_new(
pub_key: PubKey,
client: Weak<tx5_signal::SignalConnection>,
config: Arc<tx5_signal::SignalConfig>,
) -> (Arc<Self>, ConnRecv, Arc<tokio::sync::mpsc::Sender<ConnCmd>>) {
// zero len semaphore.. we actually just wait for the close
let ready = Arc::new(tokio::sync::Semaphore::new(0));
Expand All @@ -42,6 +50,23 @@ impl Conn {
let (cmd_send, mut cmd_recv) = tokio::sync::mpsc::channel(32);
let cmd_send = Arc::new(cmd_send);

let keepalive_dur = config.max_idle / 2;
let client2 = client.clone();
let pub_key2 = pub_key.clone();
let keepalive_task = tokio::task::spawn(async move {
loop {
tokio::time::sleep(keepalive_dur).await;

if let Some(client) = client2.upgrade() {
if client.send_keepalive(&pub_key2).await.is_err() {
break;
}
} else {
break;
}
}
});

let ready2 = ready.clone();
let client2 = client.clone();
let pub_key2 = pub_key.clone();
Expand All @@ -51,56 +76,51 @@ impl Conn {
None => return,
};

match tokio::time::timeout(
std::time::Duration::from_secs(10),
async {
let nonce = client.send_handshake_req(&pub_key2).await?;

let mut got_peer_res = false;
let mut sent_our_res = false;

while let Some(cmd) = cmd_recv.recv().await {
match cmd {
ConnCmd::SigRecv(sig) => {
use tx5_signal::SignalMessage::*;
match sig {
HandshakeReq(oth_nonce) => {
client
.send_handshake_res(
&pub_key2, oth_nonce,
)
.await?;
sent_our_res = true;
}
HandshakeRes(res_nonce) => {
if res_nonce != nonce {
return Err(Error::other(
"nonce mismatch",
));
}
got_peer_res = true;
}
_ => {
match tokio::time::timeout(config.max_idle, async {
let nonce = client.send_handshake_req(&pub_key2).await?;

let mut got_peer_res = false;
let mut sent_our_res = false;

while let Some(cmd) = cmd_recv.recv().await {
match cmd {
ConnCmd::SigRecv(sig) => {
use tx5_signal::SignalMessage::*;
match sig {
HandshakeReq(oth_nonce) => {
client
.send_handshake_res(
&pub_key2, oth_nonce,
)
.await?;
sent_our_res = true;
}
HandshakeRes(res_nonce) => {
if res_nonce != nonce {
return Err(Error::other(
"invalid message during handshake",
"nonce mismatch",
));
}
got_peer_res = true;
}
_ => {
return Err(Error::other(
"invalid message during handshake",
));
}
}
ConnCmd::Close => {
return Err(Error::other(
"close during handshake",
))
}
}
if got_peer_res && sent_our_res {
break;
ConnCmd::Close => {
return Err(Error::other("close during handshake"))
}
}
if got_peer_res && sent_our_res {
break;
}
}

Result::Ok(())
},
)
Result::Ok(())
})
.await
{
Err(_) | Ok(Err(_)) => {
Expand All @@ -115,12 +135,16 @@ impl Conn {
// closing the semaphore causes all the acquire awaits to end
ready2.close();

while let Some(cmd) = cmd_recv.recv().await {
while let Ok(Some(cmd)) =
tokio::time::timeout(config.max_idle, cmd_recv.recv()).await
{
match cmd {
ConnCmd::SigRecv(sig) => {
use tx5_signal::SignalMessage::*;
#[allow(clippy::single_match)] // placeholder
match sig {
// invalid
HandshakeReq(_) | HandshakeRes(_) => break,
Message(msg) => {
if msg_send.send(msg).await.is_err() {
break;
Expand All @@ -132,6 +156,13 @@ impl Conn {
ConnCmd::Close => break,
}
}

// explicitly close the peer
if let Some(client) = client2.upgrade() {
client.close_peer(&pub_key2).await;
};

// the receiver side is closed because msg_send is dropped.
});

(
Expand All @@ -140,6 +171,7 @@ impl Conn {
pub_key,
client,
conn_task,
keepalive_task,
}),
ConnRecv(msg_recv),
cmd_send,
Expand Down
15 changes: 9 additions & 6 deletions crates/tx5-connection/src/hub.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ async fn hub_map_assert(
pub_key: PubKey,
map: &mut HubMap,
client: &Arc<tx5_signal::SignalConnection>,
config: &Arc<tx5_signal::SignalConfig>,
) -> Result<(
Option<ConnRecv>,
Arc<Conn>,
Expand Down Expand Up @@ -56,7 +57,7 @@ async fn hub_map_assert(
// we're connected to the peer, create a connection

let (conn, recv, cmd_send) =
Conn::priv_new(pub_key.clone(), Arc::downgrade(client));
Conn::priv_new(pub_key.clone(), Arc::downgrade(client), config.clone());

let weak_conn = Arc::downgrade(&conn);

Expand Down Expand Up @@ -112,7 +113,7 @@ impl Hub {
config: Arc<tx5_signal::SignalConfig>,
) -> Result<(Self, HubRecv)> {
let (client, mut recv) =
tx5_signal::SignalConnection::connect(url, config).await?;
tx5_signal::SignalConnection::connect(url, config.clone()).await?;
let client = Arc::new(client);

tracing::debug!(%url, pub_key = ?client.pub_key(), "hub connected");
Expand Down Expand Up @@ -147,7 +148,7 @@ impl Hub {
HubCmd::CliRecv { pub_key, msg } => {
if let Some(client) = weak_client.upgrade() {
let (recv, conn, cmd_send) = match hub_map_assert(
pub_key, &mut map, &client,
pub_key, &mut map, &client, &config,
)
.await
{
Expand All @@ -171,9 +172,11 @@ impl Hub {
HubCmd::Connect { pub_key, resp } => {
if let Some(client) = weak_client.upgrade() {
let _ = resp.send(
hub_map_assert(pub_key, &mut map, &client)
.await
.map(|(recv, conn, _)| (recv, conn)),
hub_map_assert(
pub_key, &mut map, &client, &config,
)
.await
.map(|(recv, conn, _)| (recv, conn)),
);
} else {
break;
Expand Down
46 changes: 46 additions & 0 deletions crates/tx5-connection/src/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ impl TestSrv {
Arc::new(tx5_signal::SignalConfig {
listener: true,
allow_plain_text: true,
max_idle: std::time::Duration::from_secs(1),
..Default::default()
}),
)
Expand All @@ -47,6 +48,51 @@ impl TestSrv {
}
}

#[tokio::test(flavor = "multi_thread")]
async fn base_timeout() {
init_tracing();

let srv = TestSrv::new().await;

let (hub1, _hubr1) = srv.hub().await;
let pk1 = hub1.pub_key().clone();

let (hub2, mut hubr2) = srv.hub().await;
let pk2 = hub2.pub_key().clone();

println!("connect");
let (c1, mut r1) = hub1.connect(pk2).await.unwrap();
c1.test_kill_keepalive_task();
println!("accept");
let (c2, mut r2) = hubr2.accept().await.unwrap();
c2.test_kill_keepalive_task();

assert_eq!(&pk1, c2.pub_key());

println!("await ready");
tokio::join!(c1.ready(), c2.ready());
println!("ready");

c1.send(b"hello".to_vec()).await.unwrap();
assert_eq!(b"hello", r2.recv().await.unwrap().as_slice());

c2.send(b"world".to_vec()).await.unwrap();
assert_eq!(b"world", r1.recv().await.unwrap().as_slice());

match tokio::time::timeout(std::time::Duration::from_secs(3), async {
tokio::join!(r1.recv(), r2.recv())
})
.await
{
Err(_) => panic!("recv failed to time out"),
Ok((None, None)) => (), // correct, they both exited
_ => panic!("unexpected success"),
}

assert!(c1.send(b"foo".to_vec()).await.is_err());
assert!(c2.send(b"bar".to_vec()).await.is_err());
}

#[tokio::test(flavor = "multi_thread")]
async fn sanity() {
init_tracing();
Expand Down
7 changes: 7 additions & 0 deletions crates/tx5-signal/src/conn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,4 +146,11 @@ impl SignalConnection {
self.client.send(pub_key, &msg).await?;
Ok(())
}

/// Keepalive.
pub async fn send_keepalive(&self, pub_key: &PubKey) -> Result<()> {
let msg = SignalMessage::keepalive();
self.client.send(pub_key, &msg).await?;
Ok(())
}
}
11 changes: 11 additions & 0 deletions crates/tx5-signal/src/wire.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ const F_OFFR: &[u8] = b"offr";
const F_ANSW: &[u8] = b"answ";
const F_ICEM: &[u8] = b"icem";
const F_FMSG: &[u8] = b"fmsg";
const F_KEEP: &[u8] = b"keep";

/// Parsed signal message.
pub enum SignalMessage {
Expand All @@ -31,6 +32,9 @@ pub enum SignalMessage {
/// Pre-webrtc and webrtc failure fallback communication message.
Message(Vec<u8>),

/// Keepalive
Keepalive,

/// Message type not understood by this client.
Unknown,
}
Expand All @@ -45,6 +49,7 @@ impl std::fmt::Debug for SignalMessage {
Self::Answer(_) => f.write_str("Answer"),
Self::Ice(_) => f.write_str("Ice"),
Self::Message(_) => f.write_str("Message"),
Self::Keepalive => f.write_str("Keepalive"),
Self::Unknown => f.write_str("Unknown"),
}
}
Expand Down Expand Up @@ -109,6 +114,11 @@ impl SignalMessage {
Ok(msg)
}

/// Keepalive.
pub(crate) fn keepalive() -> Vec<u8> {
F_KEEP.to_vec()
}

/// Parse a raw received buffer into a signal message.
pub(crate) fn parse(mut b: Vec<u8>) -> Result<Self> {
if b.len() < 4 {
Expand Down Expand Up @@ -148,6 +158,7 @@ impl SignalMessage {
let _ = b.drain(..4);
Ok(SignalMessage::Message(b))
}
F_KEEP => Ok(SignalMessage::Keepalive),
_ => Ok(SignalMessage::Unknown),
}
}
Expand Down
1 change: 1 addition & 0 deletions crates/tx5/src/sig.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ async fn connect_loop(
let signal_config = Arc::new(SignalConfig {
listener,
allow_plain_text: config.signal_allow_plain_text,
max_idle: config.timeout,
..Default::default()
});

Expand Down
15 changes: 11 additions & 4 deletions crates/tx5/src/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -248,9 +248,15 @@ async fn ep_sig_down() {

ep1.send(ep2.peer_url(), b"hello".to_vec()).await.unwrap();

let (from, msg) = ep2.recv().await.unwrap();
assert_eq!(ep1.peer_url(), from);
assert_eq!(&b"hello"[..], &msg);
loop {
let (from, msg) = ep2.recv().await.unwrap();
if &msg[..3] == b"<<<" {
continue;
}
assert_eq!(ep1.peer_url(), from);
assert_eq!(&b"hello"[..], &msg);
break;
}

eprintln!("-- Done --");
}
Expand Down Expand Up @@ -518,6 +524,7 @@ async fn ep_preflight_happy() {
async fn ep_close_connection() {
let config = Arc::new(Config {
signal_allow_plain_text: true,
timeout: std::time::Duration::from_secs(2),
..Default::default()
});
let test = Test::new().await;
Expand All @@ -534,7 +541,7 @@ async fn ep_close_connection() {
ep1.close(&ep2.peer_url());

let (url, message) = ep2_recv.recv().await.unwrap();
assert_eq!(ep2.peer_url(), url);
assert_eq!(ep1.peer_url(), url);
assert_eq!(&b"<<<test-disconnect>>>"[..], &message);
}

Expand Down

0 comments on commit 9560947

Please sign in to comment.