diff --git a/crates/tx5-connection/src/conn.rs b/crates/tx5-connection/src/conn.rs index d612b2b6..657e14b3 100644 --- a/crates/tx5-connection/src/conn.rs +++ b/crates/tx5-connection/src/conn.rs @@ -4,27 +4,35 @@ pub(crate) enum ConnCmd { SigRecv(tx5_signal::SignalMessage), } +/// Receive messages from a tx5 connection. +pub struct ConnRecv(tokio::sync::mpsc::Receiver>); + +impl ConnRecv { + /// Receive up to 16KiB of message data. + pub async fn recv(&mut self) -> Option> { + self.0.recv().await + } +} + /// A tx5 connection. -pub struct Tx5Connection { +pub struct Conn { ready: Arc, pub_key: PubKey, client: Weak, - pub(crate) cmd_send: tokio::sync::mpsc::Sender, - msg_recv: tokio::sync::Mutex>>, conn_task: tokio::task::JoinHandle<()>, } -impl Drop for Tx5Connection { +impl Drop for Conn { fn drop(&mut self) { self.conn_task.abort(); } } -impl Tx5Connection { +impl Conn { pub(crate) fn priv_new( pub_key: PubKey, client: Weak, - ) -> Arc { + ) -> (Arc, ConnRecv, tokio::sync::mpsc::Sender) { // zero len semaphore.. we actually just wait for the close let ready = Arc::new(tokio::sync::Semaphore::new(0)); @@ -117,14 +125,16 @@ impl Tx5Connection { } }); - Arc::new(Self { - ready, - pub_key, - client, + ( + Arc::new(Self { + ready, + pub_key, + client, + conn_task, + }), + ConnRecv(msg_recv), cmd_send, - msg_recv: tokio::sync::Mutex::new(msg_recv), - conn_task, - }) + ) } /// Wait until this connection is ready to send / receive data. @@ -146,9 +156,4 @@ impl Tx5Connection { Err(Error::other("closed")) } } - - /// Receive up to 16KiB of message data. - pub async fn recv(&self) -> Option> { - self.msg_recv.lock().await.recv().await - } } diff --git a/crates/tx5-connection/src/framed.rs b/crates/tx5-connection/src/framed.rs index d0c38b7f..5b863e7f 100644 --- a/crates/tx5-connection/src/framed.rs +++ b/crates/tx5-connection/src/framed.rs @@ -7,32 +7,43 @@ enum Cmd { got_permit: tokio::sync::oneshot::Sender<()>, }, RemotePermit(tokio::sync::OwnedSemaphorePermit, u32), + Close, +} + +/// Receive a framed message on the connection. +pub struct FramedConnRecv(tokio::sync::mpsc::Receiver>); + +impl FramedConnRecv { + /// Receive a framed message on the connection. + pub async fn recv(&mut self) -> Option> { + self.0.recv().await + } } /// A framed wrapper that can send and receive larger messages than /// the base connection. -pub struct Tx5ConnFramed { +pub struct FramedConn { pub_key: PubKey, - conn: tokio::sync::Mutex>, + conn: tokio::sync::Mutex>, cmd_send: tokio::sync::mpsc::Sender, recv_task: tokio::task::JoinHandle<()>, cmd_task: tokio::task::JoinHandle<()>, - msg_recv: tokio::sync::Mutex>>, } -impl Drop for Tx5ConnFramed { +impl Drop for FramedConn { fn drop(&mut self) { self.recv_task.abort(); self.cmd_task.abort(); } } -impl Tx5ConnFramed { +impl FramedConn { /// Construct a new framed wrapper around the base connection. pub async fn new( - conn: Arc, + conn: Arc, + mut conn_recv: ConnRecv, recv_limit: Arc, - ) -> Result { + ) -> Result<(Self, FramedConnRecv)> { conn.ready().await; let (a, b, c, d) = crate::proto::PROTO_VER_2.encode()?; @@ -42,17 +53,14 @@ impl Tx5ConnFramed { let (msg_send, msg_recv) = tokio::sync::mpsc::channel(32); let cmd_send2 = cmd_send.clone(); - let weak_conn = Arc::downgrade(&conn); let recv_task = tokio::task::spawn(async move { - while let Some(conn) = weak_conn.upgrade() { - if let Some(msg) = conn.recv().await { - if cmd_send2.send(Cmd::Recv(msg)).await.is_err() { - break; - } - } else { + while let Some(msg) = conn_recv.recv().await { + if cmd_send2.send(Cmd::Recv(msg)).await.is_err() { break; } } + + let _ = cmd_send2.send(Cmd::Close).await; }); let cmd_send2 = cmd_send.clone(); @@ -125,20 +133,23 @@ impl Tx5ConnFramed { break; } } + Cmd::Close => break, } } }); let pub_key = conn.pub_key().clone(); - Ok(Self { - pub_key, - conn: tokio::sync::Mutex::new(conn), - cmd_send, - recv_task, - cmd_task, - msg_recv: tokio::sync::Mutex::new(msg_recv), - }) + Ok(( + Self { + pub_key, + conn: tokio::sync::Mutex::new(conn), + cmd_send, + recv_task, + cmd_task, + }, + FramedConnRecv(msg_recv), + )) } /// The pub key of the remote peer this is connected to. @@ -146,11 +157,6 @@ impl Tx5ConnFramed { &self.pub_key } - /// Receive a message on the connection. - pub async fn recv(&self) -> Option> { - self.msg_recv.lock().await.recv().await - } - /// Send a message on the connection. pub async fn send(&self, msg: Vec) -> Result<()> { let conn = self.conn.lock().await; diff --git a/crates/tx5-connection/src/hub.rs b/crates/tx5-connection/src/hub.rs index 0e9a310b..15ad9160 100644 --- a/crates/tx5-connection/src/hub.rs +++ b/crates/tx5-connection/src/hub.rs @@ -1,38 +1,45 @@ pub use super::*; -type HubMap = HashMap>; +type HubMap = HashMap, tokio::sync::mpsc::Sender)>; async fn hub_map_assert( pub_key: PubKey, map: &mut HubMap, client: &Arc, -) -> Result<(bool, Arc)> { +) -> Result<( + Option, + Arc, + tokio::sync::mpsc::Sender, +)> { let mut found_during_prune = None; map.retain(|_, c| { - if let Some(c) = c.upgrade() { - found_during_prune = Some(c.clone()); + if let Some(f) = c.0.upgrade() { + if f.pub_key() == &pub_key { + found_during_prune = Some((f, c.1.clone())); + } true } else { false } }); - if let Some(found) = found_during_prune { - return Ok((false, found)); + if let Some((conn, cmd_send)) = found_during_prune { + return Ok((None, conn, cmd_send)); } client.assert(&pub_key).await?; // we're connected to the peer, create a connection - let conn = Tx5Connection::priv_new(pub_key.clone(), Arc::downgrade(client)); + let (conn, recv, cmd_send) = + Conn::priv_new(pub_key.clone(), Arc::downgrade(client)); let weak_conn = Arc::downgrade(&conn); - map.insert(pub_key, weak_conn); + map.insert(pub_key, (weak_conn, cmd_send.clone())); - Ok((true, conn)) + Ok((Some(recv), conn, cmd_send)) } enum HubCmd { @@ -42,20 +49,30 @@ enum HubCmd { }, Connect { pub_key: PubKey, - resp: tokio::sync::oneshot::Sender>>, + resp: + tokio::sync::oneshot::Sender, Arc)>>, }, + Close, +} + +/// A stream of incoming p2p connections. +pub struct HubRecv(tokio::sync::mpsc::Receiver<(Arc, ConnRecv)>); + +impl HubRecv { + /// Receive an incoming p2p connection. + pub async fn accept(&mut self) -> Option<(Arc, ConnRecv)> { + self.0.recv().await + } } /// A signal server connection from which we can establish tx5 connections. -pub struct Tx5ConnectionHub { +pub struct Hub { client: Arc, cmd_send: tokio::sync::mpsc::Sender, - conn_recv: - tokio::sync::Mutex>>, task_list: Vec>, } -impl Drop for Tx5ConnectionHub { +impl Drop for Hub { fn drop(&mut self) { for task in self.task_list.iter() { task.abort(); @@ -63,11 +80,16 @@ impl Drop for Tx5ConnectionHub { } } -impl Tx5ConnectionHub { - /// Create a new Tx5ConnectionHub based off a connected tx5 signal client. +impl Hub { + /// Create a new Hub based off a connected tx5 signal client. /// Note, if this is not a "listener" client, /// you do not need to ever call accept. - pub fn new(client: tx5_signal::SignalConnection) -> Self { + pub async fn new( + url: &str, + config: Arc, + ) -> Result<(Self, HubRecv)> { + let (client, mut recv) = + tx5_signal::SignalConnection::connect(url, config).await?; let client = Arc::new(client); let mut task_list = Vec::new(); @@ -75,21 +97,18 @@ impl Tx5ConnectionHub { let (cmd_send, mut cmd_recv) = tokio::sync::mpsc::channel(32); let cmd_send2 = cmd_send.clone(); - let weak_client = Arc::downgrade(&client); task_list.push(tokio::task::spawn(async move { - while let Some(client) = weak_client.upgrade() { - if let Some((pub_key, msg)) = client.recv_message().await { - if cmd_send2 - .send(HubCmd::CliRecv { pub_key, msg }) - .await - .is_err() - { - break; - } - } else { + while let Some((pub_key, msg)) = recv.recv_message().await { + if cmd_send2 + .send(HubCmd::CliRecv { pub_key, msg }) + .await + .is_err() + { break; } } + + let _ = cmd_send2.send(HubCmd::Close).await; })); let (conn_send, conn_recv) = tokio::sync::mpsc::channel(32); @@ -100,18 +119,23 @@ impl Tx5ConnectionHub { match cmd { HubCmd::CliRecv { pub_key, msg } => { if let Some(client) = weak_client.upgrade() { - let (did_create, conn) = match hub_map_assert( + let (recv, conn, cmd_send) = match hub_map_assert( pub_key, &mut map, &client, ) .await { - Err(_) => continue, - Ok(conn) => conn, + Err(err) => { + tracing::debug!( + ?err, + "failed to accept incoming connection" + ); + continue; + } + Ok(r) => r, }; - let _ = - conn.cmd_send.send(ConnCmd::SigRecv(msg)).await; - if did_create { - let _ = conn_send.send(conn).await; + let _ = cmd_send.send(ConnCmd::SigRecv(msg)).await; + if let Some(recv) = recv { + let _ = conn_send.send((conn, recv)).await; } } else { break; @@ -122,22 +146,29 @@ impl Tx5ConnectionHub { let _ = resp.send( hub_map_assert(pub_key, &mut map, &client) .await - .map(|(_, conn)| conn), + .map(|(recv, conn, _)| (recv, conn)), ); } else { break; } } + HubCmd::Close => break, } } + + if let Some(client) = weak_client.upgrade() { + client.close().await; + } })); - Self { - client, - cmd_send, - conn_recv: tokio::sync::Mutex::new(conn_recv), - task_list, - } + Ok(( + Self { + client, + cmd_send, + task_list, + }, + HubRecv(conn_recv), + )) } /// Get the pub_key used by this hub. @@ -146,19 +177,20 @@ impl Tx5ConnectionHub { } /// Establish a connection to a remote peer. - /// Note, if there is already an open connection, this Arc will point - /// to that same connection instance. - pub async fn connect(&self, pub_key: PubKey) -> Result> { + pub async fn connect( + &self, + pub_key: PubKey, + ) -> Result<(Arc, ConnRecv)> { let (s, r) = tokio::sync::oneshot::channel(); self.cmd_send .send(HubCmd::Connect { pub_key, resp: s }) .await .map_err(|_| Error::other("closed"))?; - r.await.map_err(|_| Error::other("closed"))? - } - - /// Accept an incoming tx5 connection. - pub async fn accept(&self) -> Option> { - self.conn_recv.lock().await.recv().await + let (recv, conn) = r.await.map_err(|_| Error::other("closed"))??; + if let Some(recv) = recv { + Ok((conn, recv)) + } else { + Err(Error::other("already connected")) + } } } diff --git a/crates/tx5-connection/src/test.rs b/crates/tx5-connection/src/test.rs index 7b84d2ba..ad08657d 100644 --- a/crates/tx5-connection/src/test.rs +++ b/crates/tx5-connection/src/test.rs @@ -1,5 +1,16 @@ use super::*; +fn init_tracing() { + let subscriber = tracing_subscriber::FmtSubscriber::builder() + .with_env_filter( + tracing_subscriber::filter::EnvFilter::from_default_env(), + ) + .with_file(true) + .with_line_number(true) + .finish(); + let _ = tracing::subscriber::set_global_default(subscriber); +} + pub struct TestSrv { server: sbd_server::SbdServer, } @@ -16,9 +27,9 @@ impl TestSrv { Self { server } } - pub async fn hub(&self) -> Tx5ConnectionHub { + pub async fn hub(&self) -> (Hub, HubRecv) { for addr in self.server.bind_addrs() { - if let Ok(sig) = tx5_signal::SignalConnection::connect( + if let Ok(r) = Hub::new( &format!("ws://{addr}"), Arc::new(tx5_signal::SignalConfig { listener: true, @@ -28,7 +39,7 @@ impl TestSrv { ) .await { - return Tx5ConnectionHub::new(sig); + return r; } } @@ -38,18 +49,20 @@ impl TestSrv { #[tokio::test(flavor = "multi_thread")] async fn sanity() { + init_tracing(); + let srv = TestSrv::new().await; - let hub1 = srv.hub().await; + let (hub1, _hubr1) = srv.hub().await; let pk1 = hub1.pub_key().clone(); - let hub2 = srv.hub().await; + let (hub2, mut hubr2) = srv.hub().await; let pk2 = hub2.pub_key().clone(); println!("connect"); - let c1 = hub1.connect(pk2).await.unwrap(); + let (c1, mut r1) = hub1.connect(pk2).await.unwrap(); println!("accept"); - let c2 = hub2.accept().await.unwrap(); + let (c2, mut r2) = hubr2.accept().await.unwrap(); assert_eq!(&pk1, c2.pub_key()); @@ -58,43 +71,134 @@ async fn sanity() { println!("ready"); c1.send(b"hello".to_vec()).await.unwrap(); - assert_eq!(b"hello", c2.recv().await.unwrap().as_slice()); + assert_eq!(b"hello", r2.recv().await.unwrap().as_slice()); c2.send(b"world".to_vec()).await.unwrap(); - assert_eq!(b"world", c1.recv().await.unwrap().as_slice()); + assert_eq!(b"world", r1.recv().await.unwrap().as_slice()); } #[tokio::test(flavor = "multi_thread")] async fn framed_sanity() { + init_tracing(); + + let srv = TestSrv::new().await; + + let (hub1, _hubr1) = srv.hub().await; + let pk1 = hub1.pub_key().clone(); + + let (hub2, mut hubr2) = srv.hub().await; + let pk2 = hub2.pub_key().clone(); + + let ((c1, mut r1), (c2, mut r2)) = tokio::join!( + async { + let (c1, r1) = hub1.connect(pk2).await.unwrap(); + let limit = + Arc::new(tokio::sync::Semaphore::new(512 * 1024 * 1024)); + let f = FramedConn::new(c1, r1, limit).await.unwrap(); + f + }, + async { + let (c2, r2) = hubr2.accept().await.unwrap(); + assert_eq!(&pk1, c2.pub_key()); + let limit = + Arc::new(tokio::sync::Semaphore::new(512 * 1024 * 1024)); + let f = FramedConn::new(c2, r2, limit).await.unwrap(); + f + }, + ); + + c1.send(b"hello".to_vec()).await.unwrap(); + assert_eq!(b"hello", r2.recv().await.unwrap().as_slice()); + + c2.send(b"world".to_vec()).await.unwrap(); + assert_eq!(b"world", r1.recv().await.unwrap().as_slice()); +} + +#[tokio::test(flavor = "multi_thread")] +async fn base_end_when_disconnected() { + init_tracing(); + + let srv = TestSrv::new().await; + + let (hub1, mut hubr1) = srv.hub().await; + let pk1 = hub1.pub_key().clone(); + + let (hub2, mut hubr2) = srv.hub().await; + let pk2 = hub2.pub_key().clone(); + + println!("connect"); + let (c1, mut r1) = hub1.connect(pk2.clone()).await.unwrap(); + println!("accept"); + let (c2, mut r2) = hubr2.accept().await.unwrap(); + + println!("await ready"); + tokio::join!(c1.ready(), c2.ready()); + println!("ready"); + + assert_eq!(&pk1, c2.pub_key()); + + c1.send(b"hello".to_vec()).await.unwrap(); + assert_eq!(b"hello", r2.recv().await.unwrap().as_slice()); + + c2.send(b"world".to_vec()).await.unwrap(); + assert_eq!(b"world", r1.recv().await.unwrap().as_slice()); + + drop(srv); + + assert!(r1.recv().await.is_none()); + assert!(r2.recv().await.is_none()); + assert!(hubr1.accept().await.is_none()); + assert!(hubr2.accept().await.is_none()); + assert!(c1.send(b"hello".to_vec()).await.is_err()); + assert!(c2.send(b"hello".to_vec()).await.is_err()); + assert!(hub1.connect(pk2).await.is_err()); + assert!(hub2.connect(pk1).await.is_err()); +} + +#[tokio::test(flavor = "multi_thread")] +async fn framed_end_when_disconnected() { + init_tracing(); + let srv = TestSrv::new().await; - let hub1 = srv.hub().await; + let (hub1, mut hubr1) = srv.hub().await; let pk1 = hub1.pub_key().clone(); - let hub2 = srv.hub().await; + let (hub2, mut hubr2) = srv.hub().await; let pk2 = hub2.pub_key().clone(); - let (c1, c2) = tokio::join!( + let ((c1, mut r1), (c2, mut r2)) = tokio::join!( async { - let c1 = hub1.connect(pk2).await.unwrap(); + let (c1, r2) = hub1.connect(pk2.clone()).await.unwrap(); let limit = Arc::new(tokio::sync::Semaphore::new(512 * 1024 * 1024)); - let c1 = Tx5ConnFramed::new(c1, limit).await.unwrap(); - c1 + let f = FramedConn::new(c1, r2, limit).await.unwrap(); + f }, async { - let c2 = hub2.accept().await.unwrap(); + let (c2, r2) = hubr2.accept().await.unwrap(); assert_eq!(&pk1, c2.pub_key()); let limit = Arc::new(tokio::sync::Semaphore::new(512 * 1024 * 1024)); - let c2 = Tx5ConnFramed::new(c2, limit).await.unwrap(); - c2 + let f = FramedConn::new(c2, r2, limit).await.unwrap(); + f }, ); c1.send(b"hello".to_vec()).await.unwrap(); - assert_eq!(b"hello", c2.recv().await.unwrap().as_slice()); + assert_eq!(b"hello", r2.recv().await.unwrap().as_slice()); c2.send(b"world".to_vec()).await.unwrap(); - assert_eq!(b"world", c1.recv().await.unwrap().as_slice()); + assert_eq!(b"world", r1.recv().await.unwrap().as_slice()); + + drop(srv); + + assert!(r1.recv().await.is_none()); + assert!(r2.recv().await.is_none()); + assert!(hubr1.accept().await.is_none()); + assert!(hubr2.accept().await.is_none()); + assert!(c1.send(b"hello".to_vec()).await.is_err()); + assert!(c2.send(b"hello".to_vec()).await.is_err()); + assert!(hub1.connect(pk2).await.is_err()); + assert!(hub2.connect(pk1).await.is_err()); } diff --git a/crates/tx5-signal/src/conn.rs b/crates/tx5-signal/src/conn.rs index ab5443d1..432dde81 100644 --- a/crates/tx5-signal/src/conn.rs +++ b/crates/tx5-signal/src/conn.rs @@ -3,36 +3,36 @@ use crate::*; /// Tx5 signal connection configuration. pub type SignalConfig = sbd_e2e_crypto_client::Config; -/// A client connection to a tx5 signal server. -pub struct SignalConnection { - client: Arc, +/// Receive messages from the signal server. +pub struct MsgRecv { + client: Weak, + recv: sbd_e2e_crypto_client::MsgRecv, } -impl std::ops::Deref for SignalConnection { - type Target = sbd_e2e_crypto_client::SbdClientCrypto; +impl std::ops::Deref for MsgRecv { + type Target = sbd_e2e_crypto_client::MsgRecv; fn deref(&self) -> &Self::Target { - &self.client + &self.recv } } -impl SignalConnection { - /// Establish a new client connection to a tx5 signal server. - pub async fn connect(url: &str, config: Arc) -> Result { - let client = - sbd_e2e_crypto_client::SbdClientCrypto::new(url, config).await?; - let client = Arc::new(client); - - Ok(Self { client }) +impl std::ops::DerefMut for MsgRecv { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.recv } +} - /// Receive a message from a remote peer. - pub async fn recv_message(&self) -> Option<(PubKey, SignalMessage)> { +impl MsgRecv { + /// Receive messages from the signal server. + pub async fn recv_message(&mut self) -> Option<(PubKey, SignalMessage)> { loop { - let (pub_key, msg) = self.client.recv().await?; + let (pub_key, msg) = self.recv.recv().await?; match SignalMessage::parse(msg) { Err(_) => { - self.client.close_peer(&pub_key).await; + if let Some(client) = self.client.upgrade() { + client.close_peer(&pub_key).await; + } continue; } Ok(SignalMessage::Unknown) => continue, @@ -40,6 +40,40 @@ impl SignalConnection { } } } +} + +/// A client connection to a tx5 signal server. +pub struct SignalConnection { + client: Arc, +} + +impl std::ops::Deref for SignalConnection { + type Target = sbd_e2e_crypto_client::SbdClientCrypto; + + fn deref(&self) -> &Self::Target { + &self.client + } +} + +impl SignalConnection { + /// Establish a new client connection to a tx5 signal server. + pub async fn connect( + url: &str, + config: Arc, + ) -> Result<(Self, MsgRecv)> { + let (client, recv) = + sbd_e2e_crypto_client::SbdClientCrypto::new(url, config).await?; + let client = Arc::new(client); + let weak_client = Arc::downgrade(&client); + + Ok(( + Self { client }, + MsgRecv { + client: weak_client, + recv, + }, + )) + } /// Send a handshake request to a peer. Returns the nonce sent. pub async fn send_handshake_req( diff --git a/crates/tx5-signal/src/lib.rs b/crates/tx5-signal/src/lib.rs index e5623b96..b645e6eb 100644 --- a/crates/tx5-signal/src/lib.rs +++ b/crates/tx5-signal/src/lib.rs @@ -7,7 +7,7 @@ //! This is a thin wrapper around an SBD e2e crypto client. use std::io::{Error, Result}; -use std::sync::Arc; +use std::sync::{Arc, Weak}; pub use sbd_e2e_crypto_client::PubKey; diff --git a/crates/tx5/src/ep.rs b/crates/tx5/src/ep.rs index ba39c544..ab919ee4 100644 --- a/crates/tx5/src/ep.rs +++ b/crates/tx5/src/ep.rs @@ -39,25 +39,26 @@ pub enum EndpointEvent { impl std::fmt::Debug for EndpointEvent { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - Self::ListeningAddressOpen { local_url } => { - f.debug_struct("ListeningAddressOpen").field("peer_url", local_url).finish() - } - Self::ListeningAddressClosed { local_url } => { - f.debug_struct("ListeningAddressClosed") - .field("peer_url", local_url) - .finish() - } - Self::Connected { peer_url } => { - f.debug_struct("Connected").field("peer_url", peer_url).finish() - } - Self::Disconnected { peer_url } => { - f.debug_struct("Disconnected") - .field("peer_url", peer_url) - .finish() - } - Self::Message { peer_url, .. } => { - f.debug_struct("Message").field("peer_url", peer_url).finish() - } + Self::ListeningAddressOpen { local_url } => f + .debug_struct("ListeningAddressOpen") + .field("peer_url", local_url) + .finish(), + Self::ListeningAddressClosed { local_url } => f + .debug_struct("ListeningAddressClosed") + .field("peer_url", local_url) + .finish(), + Self::Connected { peer_url } => f + .debug_struct("Connected") + .field("peer_url", peer_url) + .finish(), + Self::Disconnected { peer_url } => f + .debug_struct("Disconnected") + .field("peer_url", peer_url) + .finish(), + Self::Message { peer_url, .. } => f + .debug_struct("Message") + .field("peer_url", peer_url) + .finish(), } } } @@ -73,7 +74,21 @@ pub(crate) struct EpInner { impl EpInner { pub fn drop_sig(&mut self, sig: Arc) { - self.sig_map.retain(|_, s| !Arc::ptr_eq(s, &sig)) + let sig_url = sig.sig_url.clone(); + let listener = sig.listener; + + let should_remove = match self.sig_map.get(&sig_url) { + Some(s) => Arc::ptr_eq(s, &sig), + None => false, + }; + + if should_remove { + self.sig_map.remove(&sig_url); + + if listener { + self.assert_sig(sig_url, listener); + } + } } pub fn assert_sig(&mut self, sig_url: SigUrl, listener: bool) -> Arc { @@ -95,7 +110,7 @@ impl EpInner { self.peer_map.retain(|_, p| !Arc::ptr_eq(p, &peer)) } - pub fn assert_peer(&mut self, peer_url: PeerUrl) -> Arc { + pub fn connect_peer(&mut self, peer_url: PeerUrl) -> Arc { if let Some(peer) = self.peer_map.get(&peer_url) { return peer.clone(); } @@ -114,10 +129,11 @@ impl EpInner { .clone() } - pub fn insert_peer( + pub fn accept_peer( &mut self, peer_url: PeerUrl, - peer: Arc, + conn: Arc, + conn_recv: tx5_connection::ConnRecv, ) { self.peer_map.entry(peer_url.clone()).or_insert_with(|| { Peer::new_accept( @@ -125,7 +141,8 @@ impl EpInner { self.recv_limit.clone(), self.this.clone(), peer_url, - peer, + conn, + conn_recv, self.evt_send.clone(), ) }); @@ -184,7 +201,7 @@ impl Endpoint { /// the data is handed off to our networking backend. pub async fn send(&self, peer_url: PeerUrl, data: Vec) -> Result<()> { tokio::time::timeout(self.config.timeout, async { - let peer = self.inner.lock().unwrap().assert_peer(peer_url); + let peer = self.inner.lock().unwrap().connect_peer(peer_url); peer.ready().await; peer.send(data).await }) diff --git a/crates/tx5/src/peer.rs b/crates/tx5/src/peer.rs index dea597ce..c448a2f8 100644 --- a/crates/tx5/src/peer.rs +++ b/crates/tx5/src/peer.rs @@ -3,7 +3,7 @@ use crate::*; use tx5_connection::*; enum MaybeReady { - Ready(Arc), + Ready(Arc), Wait(Arc), } @@ -49,7 +49,8 @@ impl Peer { recv_limit: Arc, ep: Weak>, peer_url: PeerUrl, - conn: Arc, + conn: Arc, + conn_recv: ConnRecv, evt_send: tokio::sync::mpsc::Sender, ) -> Arc { Arc::new_cyclic(|this| { @@ -61,7 +62,7 @@ impl Peer { recv_limit, ep, this.clone(), - conn, + Some((conn, conn_recv)), peer_url, evt_send, ready.clone(), @@ -98,36 +99,31 @@ async fn connect( evt_send: tokio::sync::mpsc::Sender, ready: Arc>, ) { - let conn = connect_loop(config.clone(), ep.clone(), peer_url.clone()).await; + tracing::trace!(?peer_url, "peer try connect"); - task( - config, recv_limit, ep, this, conn, peer_url, evt_send, ready, - ) - .await; -} - -async fn connect_loop( - config: Arc, - ep: Weak>, - peer_url: PeerUrl, -) -> Arc { - let mut wait = config.backoff_start; - - loop { - if let Some(ep) = ep.upgrade() { + let conn = if let Some(ep) = ep.upgrade() { + match tokio::time::timeout(config.timeout, async { let sig = ep.lock().unwrap().assert_sig(peer_url.to_sig(), false); sig.ready().await; - if let Ok(conn) = sig.connect(peer_url.pub_key().clone()).await { - return conn; + sig.connect(peer_url.pub_key().clone()).await + }) + .await + .map_err(Error::other) + { + Ok(Ok(conn)) => Some(conn), + Err(err) | Ok(Err(err)) => { + tracing::debug!(?err, "peer connect error"); + None } } + } else { + None + }; - wait *= 2; - if wait > config.backoff_max { - wait = config.backoff_max; - } - tokio::time::sleep(wait).await; - } + task( + config, recv_limit, ep, this, conn, peer_url, evt_send, ready, + ) + .await; } struct DropPeer { @@ -151,28 +147,32 @@ async fn task( recv_limit: Arc, ep: Weak>, this: Weak, - conn: Arc, + conn: Option<(Arc, ConnRecv)>, peer_url: PeerUrl, evt_send: tokio::sync::mpsc::Sender, ready: Arc>, ) { let _drop = DropPeer { ep, peer: this }; - conn.ready().await; - - let conn = match Tx5ConnFramed::new(conn, recv_limit).await { - Ok(conn) => Arc::new(conn), - Err(_) => return, + let (conn, conn_recv) = match conn { + None => return, + Some(conn) => conn, }; - let weak_conn = Arc::downgrade(&conn); + conn.ready().await; + + let (conn, mut conn_recv) = + match FramedConn::new(conn, conn_recv, recv_limit).await { + Ok(conn) => conn, + Err(_) => return, + }; { let mut lock = ready.lock().unwrap(); if let MaybeReady::Wait(w) = &*lock { w.close(); } - *lock = MaybeReady::Ready(conn); + *lock = MaybeReady::Ready(Arc::new(conn)); } drop(ready); @@ -183,23 +183,25 @@ async fn task( }) .await; - while let Some(conn) = weak_conn.upgrade() { - if let Some(msg) = conn.recv().await { - let _ = evt_send - .send(EndpointEvent::Message { - peer_url: peer_url.clone(), - message: msg, - }) - .await; - } else { - break; - } + tracing::info!(?peer_url, "peer connected"); + + while let Some(msg) = conn_recv.recv().await { + let _ = evt_send + .send(EndpointEvent::Message { + peer_url: peer_url.clone(), + message: msg, + }) + .await; } // wait at the end to account for a delay before the next try tokio::time::sleep(config.backoff_start).await; let _ = evt_send - .send(EndpointEvent::Disconnected { peer_url }) + .send(EndpointEvent::Disconnected { + peer_url: peer_url.clone(), + }) .await; + + tracing::debug!(?peer_url, "peer closed"); } diff --git a/crates/tx5/src/sig.rs b/crates/tx5/src/sig.rs index bc114498..8c4f563b 100644 --- a/crates/tx5/src/sig.rs +++ b/crates/tx5/src/sig.rs @@ -4,11 +4,13 @@ use tx5_connection::tx5_signal::*; use tx5_connection::*; enum MaybeReady { - Ready(Arc), + Ready(Arc), Wait(Arc), } pub(crate) struct Sig { + pub(crate) listener: bool, + pub(crate) sig_url: SigUrl, ready: Arc>, task: tokio::task::JoinHandle<()>, } @@ -35,13 +37,18 @@ impl Sig { ep, this.clone(), config, - sig_url, + sig_url.clone(), listener, evt_send, ready.clone(), )); - Self { ready, task } + Self { + listener, + sig_url, + ready, + task, + } }) } @@ -54,7 +61,10 @@ impl Sig { let _ = w.acquire().await; } - pub async fn connect(&self, pub_key: PubKey) -> Result> { + pub async fn connect( + &self, + pub_key: PubKey, + ) -> Result<(Arc, ConnRecv)> { let hub = match &*self.ready.lock().unwrap() { MaybeReady::Ready(h) => h.clone(), _ => return Err(Error::other("not ready")), @@ -67,7 +77,9 @@ async fn connect_loop( config: Arc, sig_url: SigUrl, listener: bool, -) -> Tx5ConnectionHub { +) -> (Hub, HubRecv) { + tracing::trace!(?config, ?sig_url, ?listener, "signal try connect"); + let mut wait = config.backoff_start; let signal_config = Arc::new(SignalConfig { @@ -77,10 +89,17 @@ async fn connect_loop( }); loop { - if let Ok(sig) = - SignalConnection::connect(&sig_url, signal_config.clone()).await + match tokio::time::timeout( + config.timeout, + Hub::new(&sig_url, signal_config.clone()), + ) + .await + .map_err(Error::other) { - return Tx5ConnectionHub::new(sig); + Ok(Ok(r)) => return r, + Err(err) | Ok(Err(err)) => { + tracing::debug!(?err, "signal connect error") + } } wait *= 2; @@ -120,12 +139,12 @@ async fn task( sig: this, }; - let hub = connect_loop(config.clone(), sig_url.clone(), listener).await; + let (hub, mut hub_recv) = + connect_loop(config.clone(), sig_url.clone(), listener).await; let local_url = sig_url.to_peer(hub.pub_key().clone()); let hub = Arc::new(hub); - let weak_hub = Arc::downgrade(&hub); { let mut lock = ready.lock().unwrap(); @@ -143,21 +162,25 @@ async fn task( }) .await; - while let Some(hub) = weak_hub.upgrade() { - if let Some(conn) = hub.accept().await { - if let Some(ep) = ep.upgrade() { - let peer_url = sig_url.to_peer(conn.pub_key().clone()); - ep.lock().unwrap().insert_peer(peer_url, conn); - } - } else { - break; + tracing::info!(?local_url, "signal connected"); + + while let Some((conn, conn_recv)) = hub_recv.accept().await { + if let Some(ep) = ep.upgrade() { + let peer_url = sig_url.to_peer(conn.pub_key().clone()); + ep.lock().unwrap().accept_peer(peer_url, conn, conn_recv); } } + tracing::trace!(?local_url, "signal closing"); + // wait at the end to account for a delay before the next try tokio::time::sleep(config.backoff_start).await; let _ = evt_send - .send(EndpointEvent::ListeningAddressClosed { local_url }) + .send(EndpointEvent::ListeningAddressClosed { + local_url: local_url.clone(), + }) .await; + + tracing::debug!(?local_url, "signal closed"); } diff --git a/crates/tx5/src/test.rs b/crates/tx5/src/test.rs index 2964268b..96003a6e 100644 --- a/crates/tx5/src/test.rs +++ b/crates/tx5/src/test.rs @@ -1,5 +1,84 @@ use crate::*; +struct TestEp { + ep: Arc, + task: tokio::task::JoinHandle<()>, + recv: tokio::sync::mpsc::UnboundedReceiver<(PeerUrl, Vec)>, + peer_url: Arc>, +} + +impl std::ops::Deref for TestEp { + type Target = Endpoint; + + fn deref(&self) -> &Self::Target { + &self.ep + } +} + +impl Drop for TestEp { + fn drop(&mut self) { + self.task.abort(); + } +} + +impl TestEp { + pub async fn new(ep: Endpoint) -> Self { + let ep = Arc::new(ep); + let (send, recv) = tokio::sync::mpsc::unbounded_channel(); + + let peer_url = Arc::new(Mutex::new( + PeerUrl::parse( + "ws://bad/AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA", + ) + .unwrap(), + )); + let (s, r) = tokio::sync::oneshot::channel(); + let mut s = Some(s); + + let weak = Arc::downgrade(&ep); + let peer_url2 = peer_url.clone(); + let task = tokio::task::spawn(async move { + while let Some(ep) = weak.upgrade() { + if let Some(evt) = ep.recv().await { + match evt { + EndpointEvent::ListeningAddressOpen { local_url } => { + *peer_url2.lock().unwrap() = local_url; + if let Some(s) = s.take() { + let _ = s.send(()); + } + } + EndpointEvent::Message { peer_url, message } => { + if send.send((peer_url, message)).is_err() { + break; + } + } + _ => (), + } + } else { + break; + } + } + }); + + r.await.unwrap(); + + Self { + ep, + task, + recv, + peer_url, + } + } + + fn peer_url(&self) -> PeerUrl { + self.peer_url.lock().unwrap().clone() + } + + async fn recv(&mut self) -> Option<(PeerUrl, Vec)> { + self.recv.recv().await + } +} + struct Test { sig_srv_hnd: Option, sig_port: Option, @@ -29,24 +108,13 @@ impl Test { this } - pub async fn ep( - &self, - config: Arc, - ) -> (PeerUrl, Endpoint) { + pub async fn ep(&self, config: Arc) -> TestEp { let sig_url = self.sig_url.clone().unwrap(); let ep = Endpoint::new(config); ep.listen(sig_url); - loop { - match ep.recv().await { - None => panic!(), - Some(EndpointEvent::ListeningAddressOpen { local_url }) => { - return (local_url, ep); - } - _ => (), - } - } + TestEp::new(ep).await } pub fn drop_sig(&mut self) { @@ -60,10 +128,7 @@ impl Test { let port = self.sig_port.unwrap_or(0); - let bind = vec![ - format!("127.0.0.1:{port}"), - format!("[::1]:{port}"), - ]; + let bind = vec![format!("127.0.0.1:{port}"), format!("[::1]:{port}")]; let config = Arc::new(sbd_server::Config { bind, @@ -84,7 +149,7 @@ impl Test { } } - tracing::info!(%sig_url); + eprintln!("sig_url: {sig_url}"); self.sig_url = Some(sig_url); self.sig_srv_hnd = Some(server); @@ -92,68 +157,53 @@ impl Test { } #[tokio::test(flavor = "multi_thread")] -async fn ep3_sanity() { - let config = Arc::new(Config::default()); +async fn ep_sanity() { + let config = Arc::new(Config { + signal_allow_plain_text: true, + ..Default::default() + }); let test = Test::new().await; - let (_cli_url1, ep1) = test.ep(config.clone()).await; - let (cli_url2, ep2) = test.ep(config).await; + let ep1 = test.ep(config.clone()).await; + let mut ep2 = test.ep(config).await; - ep1.send(cli_url2, b"hello".to_vec()).await.unwrap(); + ep1.send(ep2.peer_url(), b"hello".to_vec()).await.unwrap(); - let res = ep2.recv().await.unwrap(); - match res { - EndpointEvent::Connected { .. } => (), - _ => panic!(), - } - - let res = ep2.recv().await.unwrap(); - match res { - EndpointEvent::Message { message, .. } => { - assert_eq!(&b"hello"[..], &message); - } - oth => panic!("{oth:?}"), - } + let (from, msg) = ep2.recv().await.unwrap(); + assert_eq!(ep1.peer_url(), from); + assert_eq!(&b"hello"[..], &msg); //let stats = ep1.get_stats().await; //println!("STATS: {}", serde_json::to_string_pretty(&stats).unwrap()); } -/* #[tokio::test(flavor = "multi_thread")] -async fn ep3_sig_down() { +async fn ep_sig_down() { eprintln!("-- STARTUP --"); const TIMEOUT: std::time::Duration = std::time::Duration::from_secs(5); - let mut config = Config3::default(); - config.timeout = TIMEOUT * 2; - config.backoff_start = std::time::Duration::from_millis(200); - config.backoff_max = std::time::Duration::from_millis(200); + let config = Config { + signal_allow_plain_text: true, + timeout: TIMEOUT * 2, + backoff_start: std::time::Duration::from_millis(2000), + backoff_max: std::time::Duration::from_millis(2000), + ..Default::default() + }; let config = Arc::new(config); let mut test = Test::new().await; - let (_cli_url1, ep1, _ep1_recv) = test.ep(config.clone()).await; - let (cli_url2, ep2, mut ep2_recv) = test.ep(config.clone()).await; + let ep1 = test.ep(config.clone()).await; + let mut ep2 = test.ep(config.clone()).await; eprintln!("-- Establish Connection --"); - ep1.send(cli_url2.clone(), b"hello").await.unwrap(); - - let res = ep2_recv.recv().await.unwrap(); - match res { - Ep3Event::Connected { .. } => (), - _ => panic!(), - } + ep1.send(ep2.peer_url(), b"hello".to_vec()).await.unwrap(); - let res = ep2_recv.recv().await.unwrap(); - match res { - Ep3Event::Message { message, .. } => { - assert_eq!(&b"hello"[..], &message); - } - _ => panic!(), - } + let (from, msg) = ep2.recv().await.unwrap(); + assert_eq!(ep1.peer_url(), from); + assert_eq!(&b"hello"[..], &msg); eprintln!("-- Drop Sig --"); @@ -161,23 +211,11 @@ async fn ep3_sig_down() { tokio::time::sleep(TIMEOUT).await; - // need to trigger another signal message so we know the connection is down - let (cli_url3, _ep3, _ep3_recv) = test.ep(config).await; - - let (a, b) = tokio::join!( - ep1.send(cli_url3.clone(), b"hello",), - ep2.send(cli_url3, b"hello"), - ); - - a.unwrap_err(); - b.unwrap_err(); - - tokio::time::sleep(TIMEOUT).await; - - // now a send to cli_url2 should *also* fail eprintln!("-- Send Should Fail --"); - ep1.send(cli_url2.clone(), b"hello").await.unwrap_err(); + ep1.send(ep2.peer_url(), b"hello".to_vec()) + .await + .unwrap_err(); eprintln!("-- Restart Sig --"); @@ -187,82 +225,49 @@ async fn ep3_sig_down() { eprintln!("-- Send Should Succeed --"); - ep1.send(cli_url2.clone(), b"hello").await.unwrap(); - - let res = ep2_recv.recv().await.unwrap(); - match res { - Ep3Event::Disconnected { .. } => (), - oth => panic!("{oth:?}"), - } - - let res = ep2_recv.recv().await.unwrap(); - match res { - Ep3Event::Connected { .. } => (), - oth => panic!("{oth:?}"), - } + ep1.send(ep2.peer_url(), b"hello".to_vec()).await.unwrap(); - let res = ep2_recv.recv().await.unwrap(); - match res { - Ep3Event::Message { message, .. } => { - assert_eq!(&b"hello"[..], &message); - } - oth => panic!("{oth:?}"), - } + let (from, msg) = ep2.recv().await.unwrap(); + assert_eq!(ep1.peer_url(), from); + assert_eq!(&b"hello"[..], &msg); eprintln!("-- Done --"); } #[tokio::test(flavor = "multi_thread")] -async fn ep3_drop() { - let config = Arc::new(Config3::default()); +async fn ep_drop() { + let config = Arc::new(Config { + signal_allow_plain_text: true, + ..Default::default() + }); let test = Test::new().await; - let (_cli_url1, ep1, _ep1_recv) = test.ep(config.clone()).await; - let (cli_url2, ep2, mut ep2_recv) = test.ep(config.clone()).await; + let ep1 = test.ep(config.clone()).await; + let mut ep2 = test.ep(config.clone()).await; - ep1.send(cli_url2, b"hello").await.unwrap(); + ep1.send(ep2.peer_url(), b"hello".to_vec()).await.unwrap(); - let res = ep2_recv.recv().await.unwrap(); - match res { - Ep3Event::Connected { .. } => (), - _ => panic!(), - } - - let res = ep2_recv.recv().await.unwrap(); - match res { - Ep3Event::Message { message, .. } => { - assert_eq!(&b"hello"[..], &message); - } - _ => panic!(), - } + let (from, msg) = ep2.recv().await.unwrap(); + assert_eq!(ep1.peer_url(), from); + assert_eq!(&b"hello"[..], &msg); drop(ep2); - drop(ep2_recv); - let (cli_url3, _ep3, mut ep3_recv) = test.ep(config).await; + let mut ep3 = test.ep(config).await; - ep1.send(cli_url3, b"world").await.unwrap(); + ep1.send(ep3.peer_url(), b"world".to_vec()).await.unwrap(); - let res = ep3_recv.recv().await.unwrap(); - match res { - Ep3Event::Connected { .. } => (), - _ => panic!(), - } - - let res = ep3_recv.recv().await.unwrap(); - match res { - Ep3Event::Message { message, .. } => { - assert_eq!(&b"world"[..], &message); - } - _ => panic!(), - } + let (from, msg) = ep3.recv().await.unwrap(); + assert_eq!(ep1.peer_url(), from); + assert_eq!(&b"world"[..], &msg); } +/* /// Test negotiation (polite / impolite node logic) by setting up a lot /// of nodes and having them all try to make connections to each other /// at the same time and see if we get all the messages. #[tokio::test(flavor = "multi_thread")] -async fn ep3_negotiation() { +async fn ep_negotiation() { const NODE_COUNT: usize = 9; let mut url_list = Vec::new(); @@ -314,7 +319,7 @@ async fn ep3_negotiation() { } #[tokio::test(flavor = "multi_thread")] -async fn ep3_messages_contiguous() { +async fn ep_messages_contiguous() { let config = Arc::new(Config3::default()); let test = Test::new().await; @@ -435,7 +440,7 @@ async fn ep3_messages_contiguous() { } #[tokio::test(flavor = "multi_thread")] -async fn ep3_preflight_happy() { +async fn ep_preflight_happy() { use rand::Rng; let did_send = Arc::new(std::sync::atomic::AtomicBool::new(false)); @@ -493,7 +498,7 @@ async fn ep3_preflight_happy() { } #[tokio::test(flavor = "multi_thread")] -async fn ep3_close_connection() { +async fn ep_close_connection() { let config = Arc::new(Config3::default()); let test = Test::new().await; @@ -526,7 +531,7 @@ async fn ep3_close_connection() { } #[tokio::test(flavor = "multi_thread")] -async fn ep3_ban_after_connected_outgoing_side() { +async fn ep_ban_after_connected_outgoing_side() { let config = Arc::new(Config3::default()); let test = Test::new().await; @@ -559,7 +564,7 @@ async fn ep3_ban_after_connected_outgoing_side() { } #[tokio::test(flavor = "multi_thread")] -async fn ep3_recon_after_ban() { +async fn ep_recon_after_ban() { let config = Arc::new(Config3::default()); let test = Test::new().await; @@ -612,7 +617,7 @@ async fn ep3_recon_after_ban() { } #[tokio::test(flavor = "multi_thread")] -async fn ep3_broadcast_happy() { +async fn ep_broadcast_happy() { let config = Arc::new(Config3::default()); let test = Test::new().await;