Skip to content

Commit

Permalink
address unbounded channel in signal client (#70)
Browse files Browse the repository at this point in the history
  • Loading branch information
neonphog authored Nov 20, 2023
1 parent 40285ab commit c7c61df
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 118 deletions.
3 changes: 1 addition & 2 deletions crates/tx5-signal/examples/connect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,8 @@ async fn main() {

tracing::info!(%sig_url);

let cli = tx5_signal::Cli::builder()
let (cli, _rcv) = tx5_signal::Cli::builder()
.with_url(sig_url)
.with_recv_cb(|_msg| {})
.build()
.await
.expect("expect can build tx5_signal::Cli");
Expand Down
106 changes: 45 additions & 61 deletions crates/tx5-signal/src/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,9 @@ pub enum SignalMsg {
},
}

type RecvCb = Box<dyn FnMut(SignalMsg) + 'static + Send>;

/// Builder for constructing a Cli instance.
pub struct CliBuilder {
recv_cb: RecvCb,
msg_limit: usize,
lair_client: Option<LairClient>,
lair_tag: Option<Arc<str>>,
url: Option<url::Url>,
Expand All @@ -65,7 +63,12 @@ pub struct CliBuilder {
impl Default for CliBuilder {
fn default() -> Self {
Self {
recv_cb: Box::new(|_| {}),
// if *every* message were 512 bytes (they are most often *far*
// smaller than this), this would represent
// 512 * 1024 = ~524 KiB of data, and 1024 should be plenty
// to address any concurrency concerns. This shouldn't ever
// need to be configurable, but we can easily make it so if needed.
msg_limit: 1024,
lair_client: None,
lair_tag: None,
url: None,
Expand All @@ -74,23 +77,6 @@ impl Default for CliBuilder {
}

impl CliBuilder {
/// Set the receiver callback.
pub fn set_recv_cb<Cb>(&mut self, cb: Cb)
where
Cb: FnMut(SignalMsg) + 'static + Send,
{
self.recv_cb = Box::new(cb);
}

/// Apply the receiver callback.
pub fn with_recv_cb<Cb>(mut self, cb: Cb) -> Self
where
Cb: FnMut(SignalMsg) + 'static + Send,
{
self.set_recv_cb(cb);
self
}

/// Set the LairClient.
pub fn set_lair_client(&mut self, lair_client: LairClient) {
self.lair_client = Some(lair_client);
Expand Down Expand Up @@ -125,7 +111,9 @@ impl CliBuilder {
}

/// Build the Srv instance.
pub async fn build(self) -> Result<Cli> {
pub async fn build(
self,
) -> Result<(Cli, tokio::sync::mpsc::Receiver<SignalMsg>)> {
Cli::priv_build(self).await
}
}
Expand Down Expand Up @@ -448,14 +436,18 @@ impl Cli {
}
}

async fn priv_build(builder: CliBuilder) -> Result<Self> {
async fn priv_build(
builder: CliBuilder,
) -> Result<(Self, tokio::sync::mpsc::Receiver<SignalMsg>)> {
let CliBuilder {
recv_cb,
msg_limit,
lair_client,
lair_tag,
url,
} = builder;

let (msg_send, msg_recv) = tokio::sync::mpsc::channel(msg_limit);

let mut lair_keystore = None;

let lair_tag = match lair_tag {
Expand Down Expand Up @@ -572,7 +564,7 @@ impl Cli {
con_url_versioned,
endpoint,
ice.clone(),
recv_cb,
msg_send,
x25519_pub,
lair_client.clone(),
write_send.clone(),
Expand Down Expand Up @@ -601,16 +593,19 @@ impl Cli {

let _ = init_recv.await;

Ok(Self {
addr: url,
hnd,
ice,
write_send,
seq: Seq::new(),
_lair_keystore: lair_keystore,
lair_client,
x25519_pub,
})
Ok((
Self {
addr: url,
hnd,
ice,
write_send,
seq: Seq::new(),
_lair_keystore: lair_keystore,
lair_client,
x25519_pub,
},
msg_recv,
))
}
}

Expand All @@ -621,15 +616,14 @@ async fn con_task(
con_url: String,
endpoint: String,
ice: Arc<Mutex<Arc<serde_json::Value>>>,
recv_cb: RecvCb,
msg_send: tokio::sync::mpsc::Sender<SignalMsg>,
x25519_pub: Id,
lair_client: LairClient,
write_send: WriteSend,
write_recv: WriteRecv,
init: tokio::sync::oneshot::Sender<()>,
) {
let mut init = Some(init);
let mut recv_cb = Some(recv_cb);
let mut write_recv = Some(write_recv);
loop {
if let Some(socket) = con_open_connection(
Expand All @@ -648,16 +642,15 @@ async fn con_task(
let _ = init.send(());
}

let (a_recv_cb, a_write_recv) = con_manage_connection(
let a_write_recv = con_manage_connection(
socket,
recv_cb.take().unwrap(),
msg_send.clone(),
x25519_pub,
&lair_client,
write_send.clone(),
write_recv.take().unwrap(),
)
.await;
recv_cb = Some(a_recv_cb);
write_recv = Some(a_write_recv);
}

Expand Down Expand Up @@ -854,16 +847,14 @@ async fn con_open_connection(

async fn con_manage_connection(
socket: Socket,
recv_cb: RecvCb,
msg_send: tokio::sync::mpsc::Sender<SignalMsg>,
x25519_pub: Id,
lair_client: &LairClient,
write_send: WriteSend,
write_recv: WriteRecv,
) -> (RecvCb, WriteRecv) {
let recv_cb = Arc::new(tokio::sync::Mutex::new(recv_cb));
) -> WriteRecv {
let write_recv = Arc::new(tokio::sync::Mutex::new(write_recv));

let recv_cb2 = recv_cb.clone();
let write_recv2 = write_recv.clone();

macro_rules! dbg_err {
Expand All @@ -883,7 +874,6 @@ async fn con_manage_connection(

tokio::select! {
_ = async move {
let mut recv_cb = recv_cb2.lock().await;
while let Some(msg) = read.next().await {
let msg = dbg_err!(msg);
if let Message::Pong(_) = &msg {
Expand All @@ -903,11 +893,11 @@ async fn con_manage_connection(
let msg = msg.into_data();
match dbg_err!(wire::Wire::decode(&msg)) {
wire::Wire::DemoV1 { rem_pub } => {
recv_cb(SignalMsg::Demo { rem_pub });
let _ = msg_send.send(SignalMsg::Demo { rem_pub }).await;
}
wire::Wire::FwdV1 { rem_pub, nonce, cipher } => {
if let Err(err) = decode_fwd(
&mut recv_cb,
&msg_send,
&mut seq_track,
&x25519_pub,
lair_client,
Expand Down Expand Up @@ -941,20 +931,14 @@ async fn con_manage_connection(
} => (),
};

(
Arc::try_unwrap(recv_cb)
.map_err(|_| ())
.unwrap()
.into_inner(),
Arc::try_unwrap(write_recv)
.map_err(|_| ())
.unwrap()
.into_inner(),
)
Arc::try_unwrap(write_recv)
.map_err(|_| ())
.unwrap()
.into_inner()
}

async fn decode_fwd(
recv_cb: &mut RecvCb,
msg_send: &tokio::sync::mpsc::Sender<SignalMsg>,
seq_track: &mut SeqTrack,
x25519_pub: &Id,
lair_client: &LairClient,
Expand Down Expand Up @@ -982,13 +966,13 @@ async fn decode_fwd(

match msg {
wire::FwdInnerV1::Offer { offer, .. } => {
recv_cb(SignalMsg::Offer { rem_pub, offer });
let _ = msg_send.send(SignalMsg::Offer { rem_pub, offer }).await;
}
wire::FwdInnerV1::Answer { answer, .. } => {
recv_cb(SignalMsg::Answer { rem_pub, answer });
let _ = msg_send.send(SignalMsg::Answer { rem_pub, answer }).await;
}
wire::FwdInnerV1::Ice { ice, .. } => {
recv_cb(SignalMsg::Ice { rem_pub, ice });
let _ = msg_send.send(SignalMsg::Ice { rem_pub, ice }).await;
}
}

Expand Down
71 changes: 26 additions & 45 deletions crates/tx5-signal/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,9 @@ struct Test {
}

impl Test {
pub async fn new<Cb>(port: u16, recv_cb: Cb) -> Self
where
Cb: FnMut(SignalMsg) + 'static + Send,
{
pub async fn new(
port: u16,
) -> (Self, tokio::sync::mpsc::Receiver<SignalMsg>) {
let passphrase = sodoken::BufRead::new_no_lock(b"test-passphrase");
let keystore_config = PwHashLimits::Minimum
.with_exec(|| LairServerConfigInner::new("/", passphrase.clone()))
Expand All @@ -49,10 +48,9 @@ impl Test {
.await
.unwrap();

let cli = cli::Cli::builder()
let (cli, msg_recv) = cli::Cli::builder()
.with_lair_client(lair_client)
.with_lair_tag(tag)
.with_recv_cb(recv_cb)
.with_url(
url::Url::parse(&format!("ws://localhost:{}/tx5-ws", port))
.unwrap(),
Expand All @@ -61,10 +59,13 @@ impl Test {
.await
.unwrap();

Self {
_keystore: keystore,
cli,
}
(
Self {
_keystore: keystore,
cli,
},
msg_recv,
)
}
}

Expand Down Expand Up @@ -124,29 +125,12 @@ async fn sanity() {
}

async fn sanity_inner(srv_port: u16) {
#[derive(Debug)]
enum In {
Cli1(SignalMsg),
Cli2(SignalMsg),
}

let (in_send, mut in_recv) = tokio::sync::mpsc::unbounded_channel();

let cli1 = {
let in_send = in_send.clone();
Test::new(srv_port, move |msg| {
in_send.send(In::Cli1(msg)).unwrap();
})
.await
};
let (cli1, mut rcv1) = Test::new(srv_port).await;

let cli1_pk = *cli1.cli.local_id();
tracing::info!(%cli1_pk);

let cli2 = Test::new(srv_port, move |msg| {
in_send.send(In::Cli2(msg)).unwrap();
})
.await;
let (cli2, mut rcv2) = Test::new(srv_port).await;

let cli2_pk = *cli2.cli.local_id();
tracing::info!(%cli2_pk);
Expand All @@ -156,38 +140,35 @@ async fn sanity_inner(srv_port: u16) {
.await
.unwrap();

let msg = in_recv.recv().await;
let msg = rcv2.recv().await;
tracing::info!(?msg);
assert!(matches!(msg, Some(In::Cli2(SignalMsg::Offer { .. }))));
assert!(matches!(msg, Some(SignalMsg::Offer { .. })));

cli2.cli
.answer(cli1_pk, serde_json::json!({ "type": "answer" }))
.await
.unwrap();

let msg = in_recv.recv().await;
let msg = rcv1.recv().await;
tracing::info!(?msg);
assert!(matches!(msg, Some(In::Cli1(SignalMsg::Answer { .. }))));
assert!(matches!(msg, Some(SignalMsg::Answer { .. })));

cli1.cli
.ice(cli2_pk, serde_json::json!({ "type": "ice" }))
.await
.unwrap();

let msg = in_recv.recv().await;
let msg = rcv2.recv().await;
tracing::info!(?msg);
assert!(matches!(msg, Some(In::Cli2(SignalMsg::Ice { .. }))));
assert!(matches!(msg, Some(SignalMsg::Ice { .. })));

cli1.cli.demo();

for _ in 0..2 {
let msg = in_recv.recv().await;
tracing::info!(?msg);
let inner = match msg {
Some(In::Cli1(m)) => m,
Some(In::Cli2(m)) => m,
_ => panic!("unexpected eos"),
};
assert!(matches!(inner, SignalMsg::Demo { .. }));
}
let msg = rcv1.recv().await;
tracing::info!(?msg);
assert!(matches!(msg, Some(SignalMsg::Demo { .. })));

let msg = rcv2.recv().await;
tracing::info!(?msg);
assert!(matches!(msg, Some(SignalMsg::Demo { .. })));
}
3 changes: 1 addition & 2 deletions crates/tx5/examples/turn_doctor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,9 +175,8 @@ async fn main() {
let sig_url = url::Url::parse(sig_url).unwrap();
println!("SIG_URL: {sig_url}");

let sig_cli = tx5_signal::Cli::builder()
let (sig_cli, _sig_rcv) = tx5_signal::Cli::builder()
.with_url(sig_url)
.with_recv_cb(|_msg| {})
.build()
.await
.expect("expect can build tx5_signal::Cli");
Expand Down
Loading

0 comments on commit c7c61df

Please sign in to comment.