From c7c61df55676e5dc17075c85b6d658113e8867d9 Mon Sep 17 00:00:00 2001 From: David Braden Date: Mon, 20 Nov 2023 09:20:12 -0700 Subject: [PATCH] address unbounded channel in signal client (#70) --- crates/tx5-signal/examples/connect.rs | 3 +- crates/tx5-signal/src/cli.rs | 106 +++++++++++--------------- crates/tx5-signal/src/tests.rs | 71 +++++++---------- crates/tx5/examples/turn_doctor.rs | 3 +- crates/tx5/src/endpoint.rs | 11 +-- 5 files changed, 76 insertions(+), 118 deletions(-) diff --git a/crates/tx5-signal/examples/connect.rs b/crates/tx5-signal/examples/connect.rs index 3779619d..13aacd4f 100644 --- a/crates/tx5-signal/examples/connect.rs +++ b/crates/tx5-signal/examples/connect.rs @@ -21,9 +21,8 @@ async fn main() { tracing::info!(%sig_url); - let cli = tx5_signal::Cli::builder() + let (cli, _rcv) = tx5_signal::Cli::builder() .with_url(sig_url) - .with_recv_cb(|_msg| {}) .build() .await .expect("expect can build tx5_signal::Cli"); diff --git a/crates/tx5-signal/src/cli.rs b/crates/tx5-signal/src/cli.rs index 1ff0c94b..417ea4c1 100644 --- a/crates/tx5-signal/src/cli.rs +++ b/crates/tx5-signal/src/cli.rs @@ -52,11 +52,9 @@ pub enum SignalMsg { }, } -type RecvCb = Box; - /// Builder for constructing a Cli instance. pub struct CliBuilder { - recv_cb: RecvCb, + msg_limit: usize, lair_client: Option, lair_tag: Option>, url: Option, @@ -65,7 +63,12 @@ pub struct CliBuilder { impl Default for CliBuilder { fn default() -> Self { Self { - recv_cb: Box::new(|_| {}), + // if *every* message were 512 bytes (they are most often *far* + // smaller than this), this would represent + // 512 * 1024 = ~524 KiB of data, and 1024 should be plenty + // to address any concurrency concerns. This shouldn't ever + // need to be configurable, but we can easily make it so if needed. + msg_limit: 1024, lair_client: None, lair_tag: None, url: None, @@ -74,23 +77,6 @@ impl Default for CliBuilder { } impl CliBuilder { - /// Set the receiver callback. - pub fn set_recv_cb(&mut self, cb: Cb) - where - Cb: FnMut(SignalMsg) + 'static + Send, - { - self.recv_cb = Box::new(cb); - } - - /// Apply the receiver callback. - pub fn with_recv_cb(mut self, cb: Cb) -> Self - where - Cb: FnMut(SignalMsg) + 'static + Send, - { - self.set_recv_cb(cb); - self - } - /// Set the LairClient. pub fn set_lair_client(&mut self, lair_client: LairClient) { self.lair_client = Some(lair_client); @@ -125,7 +111,9 @@ impl CliBuilder { } /// Build the Srv instance. - pub async fn build(self) -> Result { + pub async fn build( + self, + ) -> Result<(Cli, tokio::sync::mpsc::Receiver)> { Cli::priv_build(self).await } } @@ -448,14 +436,18 @@ impl Cli { } } - async fn priv_build(builder: CliBuilder) -> Result { + async fn priv_build( + builder: CliBuilder, + ) -> Result<(Self, tokio::sync::mpsc::Receiver)> { let CliBuilder { - recv_cb, + msg_limit, lair_client, lair_tag, url, } = builder; + let (msg_send, msg_recv) = tokio::sync::mpsc::channel(msg_limit); + let mut lair_keystore = None; let lair_tag = match lair_tag { @@ -572,7 +564,7 @@ impl Cli { con_url_versioned, endpoint, ice.clone(), - recv_cb, + msg_send, x25519_pub, lair_client.clone(), write_send.clone(), @@ -601,16 +593,19 @@ impl Cli { let _ = init_recv.await; - Ok(Self { - addr: url, - hnd, - ice, - write_send, - seq: Seq::new(), - _lair_keystore: lair_keystore, - lair_client, - x25519_pub, - }) + Ok(( + Self { + addr: url, + hnd, + ice, + write_send, + seq: Seq::new(), + _lair_keystore: lair_keystore, + lair_client, + x25519_pub, + }, + msg_recv, + )) } } @@ -621,7 +616,7 @@ async fn con_task( con_url: String, endpoint: String, ice: Arc>>, - recv_cb: RecvCb, + msg_send: tokio::sync::mpsc::Sender, x25519_pub: Id, lair_client: LairClient, write_send: WriteSend, @@ -629,7 +624,6 @@ async fn con_task( init: tokio::sync::oneshot::Sender<()>, ) { let mut init = Some(init); - let mut recv_cb = Some(recv_cb); let mut write_recv = Some(write_recv); loop { if let Some(socket) = con_open_connection( @@ -648,16 +642,15 @@ async fn con_task( let _ = init.send(()); } - let (a_recv_cb, a_write_recv) = con_manage_connection( + let a_write_recv = con_manage_connection( socket, - recv_cb.take().unwrap(), + msg_send.clone(), x25519_pub, &lair_client, write_send.clone(), write_recv.take().unwrap(), ) .await; - recv_cb = Some(a_recv_cb); write_recv = Some(a_write_recv); } @@ -854,16 +847,14 @@ async fn con_open_connection( async fn con_manage_connection( socket: Socket, - recv_cb: RecvCb, + msg_send: tokio::sync::mpsc::Sender, x25519_pub: Id, lair_client: &LairClient, write_send: WriteSend, write_recv: WriteRecv, -) -> (RecvCb, WriteRecv) { - let recv_cb = Arc::new(tokio::sync::Mutex::new(recv_cb)); +) -> WriteRecv { let write_recv = Arc::new(tokio::sync::Mutex::new(write_recv)); - let recv_cb2 = recv_cb.clone(); let write_recv2 = write_recv.clone(); macro_rules! dbg_err { @@ -883,7 +874,6 @@ async fn con_manage_connection( tokio::select! { _ = async move { - let mut recv_cb = recv_cb2.lock().await; while let Some(msg) = read.next().await { let msg = dbg_err!(msg); if let Message::Pong(_) = &msg { @@ -903,11 +893,11 @@ async fn con_manage_connection( let msg = msg.into_data(); match dbg_err!(wire::Wire::decode(&msg)) { wire::Wire::DemoV1 { rem_pub } => { - recv_cb(SignalMsg::Demo { rem_pub }); + let _ = msg_send.send(SignalMsg::Demo { rem_pub }).await; } wire::Wire::FwdV1 { rem_pub, nonce, cipher } => { if let Err(err) = decode_fwd( - &mut recv_cb, + &msg_send, &mut seq_track, &x25519_pub, lair_client, @@ -941,20 +931,14 @@ async fn con_manage_connection( } => (), }; - ( - Arc::try_unwrap(recv_cb) - .map_err(|_| ()) - .unwrap() - .into_inner(), - Arc::try_unwrap(write_recv) - .map_err(|_| ()) - .unwrap() - .into_inner(), - ) + Arc::try_unwrap(write_recv) + .map_err(|_| ()) + .unwrap() + .into_inner() } async fn decode_fwd( - recv_cb: &mut RecvCb, + msg_send: &tokio::sync::mpsc::Sender, seq_track: &mut SeqTrack, x25519_pub: &Id, lair_client: &LairClient, @@ -982,13 +966,13 @@ async fn decode_fwd( match msg { wire::FwdInnerV1::Offer { offer, .. } => { - recv_cb(SignalMsg::Offer { rem_pub, offer }); + let _ = msg_send.send(SignalMsg::Offer { rem_pub, offer }).await; } wire::FwdInnerV1::Answer { answer, .. } => { - recv_cb(SignalMsg::Answer { rem_pub, answer }); + let _ = msg_send.send(SignalMsg::Answer { rem_pub, answer }).await; } wire::FwdInnerV1::Ice { ice, .. } => { - recv_cb(SignalMsg::Ice { rem_pub, ice }); + let _ = msg_send.send(SignalMsg::Ice { rem_pub, ice }).await; } } diff --git a/crates/tx5-signal/src/tests.rs b/crates/tx5-signal/src/tests.rs index 239c4c32..bd5daf28 100644 --- a/crates/tx5-signal/src/tests.rs +++ b/crates/tx5-signal/src/tests.rs @@ -19,10 +19,9 @@ struct Test { } impl Test { - pub async fn new(port: u16, recv_cb: Cb) -> Self - where - Cb: FnMut(SignalMsg) + 'static + Send, - { + pub async fn new( + port: u16, + ) -> (Self, tokio::sync::mpsc::Receiver) { let passphrase = sodoken::BufRead::new_no_lock(b"test-passphrase"); let keystore_config = PwHashLimits::Minimum .with_exec(|| LairServerConfigInner::new("/", passphrase.clone())) @@ -49,10 +48,9 @@ impl Test { .await .unwrap(); - let cli = cli::Cli::builder() + let (cli, msg_recv) = cli::Cli::builder() .with_lair_client(lair_client) .with_lair_tag(tag) - .with_recv_cb(recv_cb) .with_url( url::Url::parse(&format!("ws://localhost:{}/tx5-ws", port)) .unwrap(), @@ -61,10 +59,13 @@ impl Test { .await .unwrap(); - Self { - _keystore: keystore, - cli, - } + ( + Self { + _keystore: keystore, + cli, + }, + msg_recv, + ) } } @@ -124,29 +125,12 @@ async fn sanity() { } async fn sanity_inner(srv_port: u16) { - #[derive(Debug)] - enum In { - Cli1(SignalMsg), - Cli2(SignalMsg), - } - - let (in_send, mut in_recv) = tokio::sync::mpsc::unbounded_channel(); - - let cli1 = { - let in_send = in_send.clone(); - Test::new(srv_port, move |msg| { - in_send.send(In::Cli1(msg)).unwrap(); - }) - .await - }; + let (cli1, mut rcv1) = Test::new(srv_port).await; let cli1_pk = *cli1.cli.local_id(); tracing::info!(%cli1_pk); - let cli2 = Test::new(srv_port, move |msg| { - in_send.send(In::Cli2(msg)).unwrap(); - }) - .await; + let (cli2, mut rcv2) = Test::new(srv_port).await; let cli2_pk = *cli2.cli.local_id(); tracing::info!(%cli2_pk); @@ -156,38 +140,35 @@ async fn sanity_inner(srv_port: u16) { .await .unwrap(); - let msg = in_recv.recv().await; + let msg = rcv2.recv().await; tracing::info!(?msg); - assert!(matches!(msg, Some(In::Cli2(SignalMsg::Offer { .. })))); + assert!(matches!(msg, Some(SignalMsg::Offer { .. }))); cli2.cli .answer(cli1_pk, serde_json::json!({ "type": "answer" })) .await .unwrap(); - let msg = in_recv.recv().await; + let msg = rcv1.recv().await; tracing::info!(?msg); - assert!(matches!(msg, Some(In::Cli1(SignalMsg::Answer { .. })))); + assert!(matches!(msg, Some(SignalMsg::Answer { .. }))); cli1.cli .ice(cli2_pk, serde_json::json!({ "type": "ice" })) .await .unwrap(); - let msg = in_recv.recv().await; + let msg = rcv2.recv().await; tracing::info!(?msg); - assert!(matches!(msg, Some(In::Cli2(SignalMsg::Ice { .. })))); + assert!(matches!(msg, Some(SignalMsg::Ice { .. }))); cli1.cli.demo(); - for _ in 0..2 { - let msg = in_recv.recv().await; - tracing::info!(?msg); - let inner = match msg { - Some(In::Cli1(m)) => m, - Some(In::Cli2(m)) => m, - _ => panic!("unexpected eos"), - }; - assert!(matches!(inner, SignalMsg::Demo { .. })); - } + let msg = rcv1.recv().await; + tracing::info!(?msg); + assert!(matches!(msg, Some(SignalMsg::Demo { .. }))); + + let msg = rcv2.recv().await; + tracing::info!(?msg); + assert!(matches!(msg, Some(SignalMsg::Demo { .. }))); } diff --git a/crates/tx5/examples/turn_doctor.rs b/crates/tx5/examples/turn_doctor.rs index 271a616c..20a6b89f 100644 --- a/crates/tx5/examples/turn_doctor.rs +++ b/crates/tx5/examples/turn_doctor.rs @@ -175,9 +175,8 @@ async fn main() { let sig_url = url::Url::parse(sig_url).unwrap(); println!("SIG_URL: {sig_url}"); - let sig_cli = tx5_signal::Cli::builder() + let (sig_cli, _sig_rcv) = tx5_signal::Cli::builder() .with_url(sig_url) - .with_recv_cb(|_msg| {}) .build() .await .expect("expect can build tx5_signal::Cli"); diff --git a/crates/tx5/src/endpoint.rs b/crates/tx5/src/endpoint.rs index 64944ca6..3569b265 100644 --- a/crates/tx5/src/endpoint.rs +++ b/crates/tx5/src/endpoint.rs @@ -226,22 +226,17 @@ async fn new_sig_task( ) { tracing::debug!(%sig_url, "spawning new signal task"); - let (sig_snd, mut sig_rcv) = tokio::sync::mpsc::unbounded_channel(); - - let (sig, cli_url) = match async { - let sig = tx5_signal::Cli::builder() + let (sig, mut sig_rcv, cli_url) = match async { + let (sig, sig_rcv) = tx5_signal::Cli::builder() .with_lair_client(config.lair_client().clone()) .with_lair_tag(config.lair_tag().clone()) .with_url(sig_url.to_string().parse().unwrap()) - .with_recv_cb(move |msg| { - let _ = sig_snd.send(msg); - }) .build() .await?; let cli_url = Tx5Url::new(sig.local_addr())?; - Result::Ok((sig, cli_url)) + Result::Ok((sig, sig_rcv, cli_url)) } .await {