From d36e4bb19ad32c19e8ac7af50261bcaec5ed9f41 Mon Sep 17 00:00:00 2001 From: ThetaSinner Date: Tue, 26 Mar 2024 17:33:55 +0000 Subject: [PATCH] Allow connections to be closed from the public API (#84) * Allow connections to be closed from the public API * Remove debugging trace * Review comments and clippy --- crates/tx5-core/src/evt.rs | 4 +++- crates/tx5-go-pion/src/evt.rs | 2 +- crates/tx5-go-pion/src/lib.rs | 4 ++-- crates/tx5-go-pion/src/peer_con.rs | 25 ++++++++++++++++++++- crates/tx5/src/back_buf.rs | 4 ++-- crates/tx5/src/ep3.rs | 36 +++++++++++++++++++++++++++++- crates/tx5/src/ep3/peer.rs | 11 +++++++-- crates/tx5/src/ep3/sig.rs | 2 +- crates/tx5/src/ep3/test.rs | 33 +++++++++++++++++++++++++++ flake.lock | 18 +++++++-------- 10 files changed, 119 insertions(+), 20 deletions(-) diff --git a/crates/tx5-core/src/evt.rs b/crates/tx5-core/src/evt.rs index f115f801..1090ebaa 100644 --- a/crates/tx5-core/src/evt.rs +++ b/crates/tx5-core/src/evt.rs @@ -2,7 +2,9 @@ use crate::{Error, Result}; use std::sync::Arc; /// Permit for sending on the channel. -pub struct EventPermit(Option); +pub struct EventPermit( + #[allow(dead_code)] Option, +); impl std::fmt::Debug for EventPermit { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { diff --git a/crates/tx5-go-pion/src/evt.rs b/crates/tx5-go-pion/src/evt.rs index 64ed195f..1daf83af 100644 --- a/crates/tx5-go-pion/src/evt.rs +++ b/crates/tx5-go-pion/src/evt.rs @@ -5,7 +5,7 @@ use tx5_go_pion_sys::Event as SysEvent; use tx5_go_pion_sys::API; /// PeerConnectionState events. -#[derive(Debug)] +#[derive(Debug, Clone, Copy)] pub enum PeerConnectionState { /// New = 0x01, diff --git a/crates/tx5-go-pion/src/lib.rs b/crates/tx5-go-pion/src/lib.rs index 454cec8d..359cd255 100644 --- a/crates/tx5-go-pion/src/lib.rs +++ b/crates/tx5-go-pion/src/lib.rs @@ -310,9 +310,9 @@ mod tests { println!("close data 2"); data2.close(Error::id("").into()); println!("close peer 1"); - peer1.close(Error::id("").into()); + peer1.close(Error::id("")); println!("close peer 2"); - peer2.close(Error::id("").into()); + peer2.close(Error::id("")); println!("close turn"); turn.stop().await.unwrap(); diff --git a/crates/tx5-go-pion/src/peer_con.rs b/crates/tx5-go-pion/src/peer_con.rs index b9015d1d..9d2e463d 100644 --- a/crates/tx5-go-pion/src/peer_con.rs +++ b/crates/tx5-go-pion/src/peer_con.rs @@ -95,6 +95,7 @@ impl From<&AnswerConfig> for GoBufRef<'static> { pub(crate) struct PeerConCore { peer_con_id: usize, + con_state: PeerConnectionState, evt_send: tokio::sync::mpsc::UnboundedSender, drop_err: Error, } @@ -118,6 +119,7 @@ impl PeerConCore { ) -> Self { Self { peer_con_id, + con_state: PeerConnectionState::New, evt_send, drop_err: Error::id("PeerConnectionDropped").into(), } @@ -200,8 +202,29 @@ impl PeerConnection { .await? } + /// Set the connection state. This should only be set based on connection state events + /// coming from the underlying webrtc library. + pub fn set_con_state(&self, con_state: PeerConnectionState) { + let mut lock = self.0.lock().unwrap(); + if let Ok(core) = &mut *lock { + core.con_state = con_state; + } else { + tracing::warn!( + ?con_state, + "Unable to set peer connection state: {:?}", + self.get_peer_con_id() + ); + } + } + + /// Get the connection state. + pub fn get_con_state(&self) -> Result { + peer_con_strong_core!(self.0, core, { Ok(core.con_state) }) + } + /// Close this connection. - pub fn close(&self, err: Error) { + pub fn close>(&self, err: E) { + let err = err.into(); let mut tmp = Err(err.clone()); { diff --git a/crates/tx5/src/back_buf.rs b/crates/tx5/src/back_buf.rs index 886a0203..5d355c2b 100644 --- a/crates/tx5/src/back_buf.rs +++ b/crates/tx5/src/back_buf.rs @@ -64,10 +64,10 @@ impl std::io::Read for BackBuf { /// Conversion type facilitating Into<&mut BackBuf>. pub(crate) enum BackBufRef<'lt> { /// An owned BackBuf. - Owned(Result), + Owned(#[allow(dead_code)] Result), /// A borrowed BackBuf. - Borrowed(Result<&'lt mut BackBuf>), + Borrowed(#[allow(dead_code)] Result<&'lt mut BackBuf>), } impl From for BackBufRef<'static> { diff --git a/crates/tx5/src/ep3.rs b/crates/tx5/src/ep3.rs index 9ec6367c..f5a017b6 100644 --- a/crates/tx5/src/ep3.rs +++ b/crates/tx5/src/ep3.rs @@ -424,13 +424,47 @@ impl Ep3 { if let Ok(sig) = fut.await { // see if we are still banning this id. if ep.ban_map.lock().unwrap().is_banned(rem_id) { - sig.ban(rem_id); + sig.close(rem_id); } } }); } } + /// Request that the peer connection identified by the given `peer_url` is closed. + pub fn close(&self, peer_url: PeerUrl) -> Result<()> { + if !peer_url.is_client() { + return Err(Error::str("Expected PeerUrl, got SigUrl")); + } + + let peer_id = peer_url.id().unwrap(); + + let sig_url = peer_url.to_server(); + match self._sig_map.lock().unwrap().get(&sig_url) { + Some((_, fut)) => { + let fut = fut.clone(); + tokio::task::spawn(async move { + match fut.await { + Ok(sig) => sig.close(peer_id), + Err(e) => { + tracing::debug!( + ?e, + "Unable to close peer connection", + ); + } + } + }); + } + None => { + return Err(Error::str( + "No connections held for this signal server", + )); + } + } + + Ok(()) + } + /// Send data to a remote on this tx5 endpoint. /// The future returned from this method will resolve when /// the data is handed off to our networking backend. diff --git a/crates/tx5/src/ep3/peer.rs b/crates/tx5/src/ep3/peer.rs index a08b3aac..a407a836 100644 --- a/crates/tx5/src/ep3/peer.rs +++ b/crates/tx5/src/ep3/peer.rs @@ -58,7 +58,6 @@ pub(crate) struct Peer { cmd_task: tokio::task::JoinHandle<()>, recv_task: tokio::task::JoinHandle<()>, data_task: tokio::task::JoinHandle<()>, - #[allow(dead_code)] peer: Arc, data_chan: Arc, send_limit: Arc, @@ -70,6 +69,7 @@ pub(crate) struct Peer { impl Drop for Peer { fn drop(&mut self) { + self.peer.close("Close"); let evt_send = self.sig.evt_send.clone(); let msg = Ep3Event::Disconnected { peer_url: self.peer_url.clone(), @@ -355,6 +355,7 @@ impl Peer { }; let recv_task = { + let weak_peer = Arc::downgrade(&peer); let sig = sig.clone(); tokio::task::spawn(async move { while let Some(evt) = peer_recv.recv().await { @@ -364,7 +365,13 @@ impl Peer { tracing::warn!(?err); break; } - Evt::State(_state) => (), + Evt::State(state) => { + if let Some(peer) = weak_peer.upgrade() { + peer.set_con_state(state); + } else { + break; + } + } Evt::ICECandidate(mut ice) => { let ice = match ice.as_json() { Err(err) => { diff --git a/crates/tx5/src/ep3/sig.rs b/crates/tx5/src/ep3/sig.rs index 9978a3d6..880c82b8 100644 --- a/crates/tx5/src/ep3/sig.rs +++ b/crates/tx5/src/ep3/sig.rs @@ -311,7 +311,7 @@ impl Sig { } } - pub fn ban(&self, id: Id) { + pub fn close(&self, id: Id) { let r = self.peer_map.lock().unwrap().get(&id).cloned(); if let Some((uniq, _, _, _)) = r { close_peer(&self.sig.weak_peer_map, id, uniq); diff --git a/crates/tx5/src/ep3/test.rs b/crates/tx5/src/ep3/test.rs index b8736d3a..03375558 100644 --- a/crates/tx5/src/ep3/test.rs +++ b/crates/tx5/src/ep3/test.rs @@ -474,6 +474,39 @@ async fn ep3_preflight_happy() { assert_eq!(true, did_valid.load(std::sync::atomic::Ordering::SeqCst)); } +#[tokio::test(flavor = "multi_thread")] +async fn ep3_close_connection() { + let config = Arc::new(Config3::default()); + let test = Test::new().await; + + let (_cli_url1, ep1, _ep1_recv) = test.ep(config.clone()).await; + let (cli_url2, _ep2, mut ep2_recv) = test.ep(config).await; + + ep1.send(cli_url2.clone(), b"hello").await.unwrap(); + + let res = ep2_recv.recv().await.unwrap(); + match res { + Ep3Event::Connected { .. } => (), + _ => panic!(), + } + + let res = ep2_recv.recv().await.unwrap(); + match res { + Ep3Event::Message { message, .. } => { + assert_eq!(&b"hello"[..], &message); + } + _ => panic!(), + } + + ep1.close(cli_url2).unwrap(); + + let res = ep2_recv.recv().await.unwrap(); + match res { + Ep3Event::Disconnected { .. } => (), + e => panic!("Actually got event {:?}", e), + } +} + #[tokio::test(flavor = "multi_thread")] async fn ep3_ban_after_connected_outgoing_side() { let config = Arc::new(Config3::default()); diff --git a/flake.lock b/flake.lock index 09c2d980..fac48001 100644 --- a/flake.lock +++ b/flake.lock @@ -188,11 +188,11 @@ "versions": "versions" }, "locked": { - "lastModified": 1708595708, - "narHash": "sha256-coOhtMii+epTQobSAj1qGfVYbN9Rs0oB+Rj6ZePqKIU=", + "lastModified": 1711434707, + "narHash": "sha256-amuVlpnud2qmdZutl0Y9GEOVxLJBChk2Q70Cwe6riD4=", "owner": "holochain", "repo": "holochain", - "rev": "e2fd7138bfeb1185a245421eefb2f83d237eccef", + "rev": "7999dc663d7849b33454b922c9ba91b1341471aa", "type": "github" }, "original": { @@ -252,11 +252,11 @@ }, "nixpkgs": { "locked": { - "lastModified": 1708475490, - "narHash": "sha256-g1v0TsWBQPX97ziznfJdWhgMyMGtoBFs102xSYO4syU=", + "lastModified": 1711163522, + "narHash": "sha256-YN/Ciidm+A0fmJPWlHBGvVkcarYWSC+s3NTPk/P+q3c=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "0e74ca98a74bc7270d28838369593635a5db3260", + "rev": "44d0940ea560dee511026a53f0e2e2cde489b4d4", "type": "github" }, "original": { @@ -329,11 +329,11 @@ ] }, "locked": { - "lastModified": 1708567842, - "narHash": "sha256-tJmra4795ji+hWZTq9UfbHISu+0/V8kdfAj2VYFk6xc=", + "lastModified": 1711419061, + "narHash": "sha256-+5M/czgYGqs/jKmi8bvYC+JUYboUKNTfkRiesXopeXQ=", "owner": "oxalica", "repo": "rust-overlay", - "rev": "0b5394f1da0e50715d36a22d4912cb3b02e6b72a", + "rev": "4c11d2f698ff1149f76b69e72852d5d75f492d0c", "type": "github" }, "original": {