diff --git a/.github/workflows/android.yml b/.github/workflows/android.yml index 2b7312b0..59346854 100644 --- a/.github/workflows/android.yml +++ b/.github/workflows/android.yml @@ -11,19 +11,26 @@ on: - main jobs: test: - runs-on: macos-latest + runs-on: ubuntu-latest strategy: + fail-fast: false matrix: #android-arch: ["arm64-v8a", "x86_64"] android-arch: ["x86_64"] - android-api-level: [29] + android-api-level: [29, 33] android-ndk-version: ["26.0.10792818"] steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 + + - name: Enable KVM + run: | + echo 'KERNEL=="kvm", GROUP="kvm", MODE="0666", OPTIONS+="static_node=kvm"' | sudo tee /etc/udev/rules.d/99-kvm4all.rules + sudo udevadm control --reload-rules + sudo udevadm trigger --name-match=kvm - name: Go Toolchain - uses: actions/setup-go@v3 + uses: actions/setup-go@v5 with: go-version: '=1.20.0' diff --git a/.github/workflows/static.yml b/.github/workflows/static.yml index f953618e..f2e0bb53 100644 --- a/.github/workflows/static.yml +++ b/.github/workflows/static.yml @@ -13,20 +13,21 @@ jobs: static-analysis: runs-on: ${{ matrix.os }} strategy: + fail-fast: false matrix: os: [ ubuntu-latest, ] toolchain: [ stable, - 1.71.1 + 1.75.0 ] steps: - name: Checkout - uses: actions/checkout@v2 + uses: actions/checkout@v4 - name: Go Toolchain - uses: actions/setup-go@v3 + uses: actions/setup-go@v5 with: go-version: '=1.20.0' diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 16ca668f..a65eb6ca 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -14,6 +14,7 @@ jobs: name: Test runs-on: ${{ matrix.os }} strategy: + fail-fast: false matrix: os: [ ubuntu-latest, @@ -25,10 +26,10 @@ jobs: ] steps: - name: Checkout - uses: actions/checkout@v2 + uses: actions/checkout@v4 - name: Go Toolchain - uses: actions/setup-go@v3 + uses: actions/setup-go@v5 with: go-version: '=1.20.0' @@ -45,5 +46,6 @@ jobs: - name: Cargo Test env: + RUST_LOG: error RUST_BACKTRACE: 1 run: cargo test -- --nocapture diff --git a/Cargo.lock b/Cargo.lock index 1aa0a8f9..8685cffc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -225,6 +225,18 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "app_dirs2" +version = "2.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a7e7b35733e3a8c1ccb90385088dd5b6eaa61325cb4d1ad56e683b5224ff352e" +dependencies = [ + "jni", + "ndk-context", + "winapi", + "xdg", +] + [[package]] name = "arc-swap" version = "1.6.0" @@ -379,6 +391,12 @@ dependencies = [ "serde", ] +[[package]] +name = "bit_field" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc827186963e592360843fb5ba4b973e145841266c1357f7180c43526f2e5b61" + [[package]] name = "bitflags" version = "1.3.2" @@ -491,6 +509,12 @@ dependencies = [ "subtle", ] +[[package]] +name = "cesu8" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d43a04d8753f35258c91f8ec639f792891f748a1edbd759cf1dcea3382ad83c" + [[package]] name = "cfg-if" version = "1.0.0" @@ -609,6 +633,16 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7" +[[package]] +name = "combine" +version = "4.6.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "35ed6e9d84f0b51a7f52daf1c7d71dd136fd7a3f41a8462b8cdb8c78d920fad4" +dependencies = [ + "bytes", + "memchr", +] + [[package]] name = "console" version = "0.15.7" @@ -1401,9 +1435,9 @@ checksum = "f93e7192158dbcda357bdec5fb5788eebf8bbac027f3f33e719d29135ae84156" [[package]] name = "hc_seed_bundle" -version = "0.2.0" +version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "397f372672597305f39af2b2abf389a077077489bc4f412e7ed574214c268208" +checksum = "6a160a11a63c03f180389bcbdb8f4823f84805bd07169160e4a0ede592512338" dependencies = [ "futures", "one_err", @@ -1818,6 +1852,28 @@ version = "1.0.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "af150ab688ff2122fcef229be89cb50dd66af9e01a4ff320cc137eecc9bacc38" +[[package]] +name = "jni" +version = "0.21.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a87aa2bb7d2af34197c04845522473242e1aa17c12f4935d5856491a7fb8c97" +dependencies = [ + "cesu8", + "cfg-if", + "combine", + "jni-sys", + "log", + "thiserror", + "walkdir", + "windows-sys 0.45.0", +] + +[[package]] +name = "jni-sys" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8eaf4bc02d17cbdd7ff4c7438cafcdf7fb9a4613313ad11b4f8fefe7d3fa0130" + [[package]] name = "jobserver" version = "0.1.27" @@ -1838,9 +1894,9 @@ dependencies = [ [[package]] name = "lair_keystore_api" -version = "0.4.0" +version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af33456610975b9138a8a8b2051fdeab69743f63d32adbb3f47d5fb8ce329663" +checksum = "4e485174546b1a95bbf801ce241e1039d182ec99c8ab4a5c0b90af0c8d8c651b" dependencies = [ "base64 0.13.1", "dunce", @@ -1855,8 +1911,7 @@ dependencies = [ "serde_yaml", "time", "tokio", - "toml 0.5.11", - "toml 0.7.8", + "toml", "tracing", "url", "winapi", @@ -1871,9 +1926,9 @@ checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" [[package]] name = "libc" -version = "0.2.149" +version = "0.2.152" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a08173bc88b7955d1b3145aa561539096c421ac8debde8cbc3612ec635fee29b" +checksum = "13e3bf6590cbc649f4d1a3eefc9d5d6eb746f5200ffb04e5e142700b8faa56e7" [[package]] name = "libflate" @@ -1911,9 +1966,9 @@ dependencies = [ [[package]] name = "libsodium-sys-stable" -version = "1.20.3" +version = "1.20.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cfc31f983531631496f4e621110cd81468ab78b65dee0046cfddea83caa2c327" +checksum = "d1d164bc6f9139c5f95efb4f0be931b2bd5a9edf7e4e3c945d26b95ab8fa669b" dependencies = [ "cc", "libc", @@ -2120,6 +2175,12 @@ dependencies = [ "tempfile", ] +[[package]] +name = "ndk-context" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "27b02d87554356db9e9a873add8782d4ea6e3e58ea071a9adb9a2e8ddb884a8b" + [[package]] name = "nix" version = "0.24.3" @@ -2221,9 +2282,9 @@ dependencies = [ [[package]] name = "once_cell" -version = "1.18.0" +version = "1.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd8b5dd2ae5ed71462c540258bedcb51965123ad7e7ccf4b9a8cafaa4a63576d" +checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" [[package]] name = "one_err" @@ -3371,9 +3432,9 @@ dependencies = [ [[package]] name = "sodoken" -version = "0.0.9" +version = "0.0.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ebd7d30290221181652f7a08112f5e7871e3deffde718dfa621025aa0e9c290" +checksum = "308b58141d2ddac517b5e606e6f94d72e7a64c1383c9439c4acc712267f2be52" dependencies = [ "libc", "libsodium-sys-stable", @@ -3619,9 +3680,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.33.0" +version = "1.35.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4f38200e3ef7995e5ef13baec2f432a6da0aa9ac495b2c0e8f3b7eec2c92d653" +checksum = "c89b4efa943be685f629b149f53829423f8f5531ea21249408e8e2f8671ec104" dependencies = [ "backtrace", "bytes", @@ -3638,9 +3699,9 @@ dependencies = [ [[package]] name = "tokio-macros" -version = "2.1.0" +version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "630bdcf245f78637c13ec01ffae6187cca34625e8c63150d424b59e55af2675e" +checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" dependencies = [ "proc-macro2", "quote", @@ -3731,15 +3792,6 @@ dependencies = [ "tracing", ] -[[package]] -name = "toml" -version = "0.5.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f4f7f0dd8d50a853a531c426359045b1998f04219d88799810762cd4ad314234" -dependencies = [ - "serde", -] - [[package]] name = "toml" version = "0.7.8" @@ -3965,10 +4017,11 @@ dependencies = [ [[package]] name = "tx5" -version = "0.0.6-alpha" +version = "0.0.7-alpha" dependencies = [ "backtrace", "better-panic", + "bit_field", "bytes", "criterion", "futures", @@ -3997,10 +4050,10 @@ dependencies = [ [[package]] name = "tx5-core" -version = "0.0.6-alpha" +version = "0.0.7-alpha" dependencies = [ + "app_dirs2", "base64 0.13.1", - "dirs", "once_cell", "rand", "serde", @@ -4014,7 +4067,7 @@ dependencies = [ [[package]] name = "tx5-demo" -version = "0.0.6-alpha" +version = "0.0.7-alpha" dependencies = [ "base64 0.13.1", "bytes", @@ -4036,7 +4089,7 @@ dependencies = [ [[package]] name = "tx5-go-pion" -version = "0.0.6-alpha" +version = "0.0.7-alpha" dependencies = [ "futures", "parking_lot", @@ -4051,7 +4104,7 @@ dependencies = [ [[package]] name = "tx5-go-pion-sys" -version = "0.0.6-alpha" +version = "0.0.7-alpha" dependencies = [ "Inflector", "base64 0.13.1", @@ -4069,7 +4122,7 @@ dependencies = [ [[package]] name = "tx5-go-pion-turn" -version = "0.0.6-alpha" +version = "0.0.7-alpha" dependencies = [ "base64 0.13.1", "dirs", @@ -4086,7 +4139,7 @@ dependencies = [ [[package]] name = "tx5-online" -version = "0.0.6-alpha" +version = "0.0.7-alpha" dependencies = [ "once_cell", "rand", @@ -4097,7 +4150,7 @@ dependencies = [ [[package]] name = "tx5-signal" -version = "0.0.6-alpha" +version = "0.0.7-alpha" dependencies = [ "futures", "lair_keystore_api", @@ -4127,7 +4180,7 @@ dependencies = [ [[package]] name = "tx5-signal-srv" -version = "0.0.6-alpha" +version = "0.0.7-alpha" dependencies = [ "clap", "dirs", @@ -4226,9 +4279,9 @@ checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" [[package]] name = "ureq" -version = "2.8.0" +version = "2.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f5ccd538d4a604753ebc2f17cd9946e89b77bf87f6a8e2309667c6f2e87855e3" +checksum = "f8cdd25c339e200129fe4de81451814e5228c9b771d57378817d6117cc2b3f97" dependencies = [ "base64 0.21.5", "log", @@ -4945,6 +4998,12 @@ dependencies = [ "libc", ] +[[package]] +name = "xdg" +version = "2.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "213b7324336b53d2414b2db8537e56544d981803139155afa84f76eeebb7a546" + [[package]] name = "yasna" version = "0.5.2" diff --git a/Cargo.toml b/Cargo.toml index fddd8b27..9417d0e3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,9 +13,11 @@ members = [ resolver = "2" [workspace.dependencies] +app_dirs2 = "2.5.5" backtrace = "0.3.69" base64 = "0.13.0" better-panic = "0.3.0" +bit_field = "0.10.2" bytes = "1.4.0" clap = { version = "4.4.6", features = [ "derive", "wrap_help" ] } criterion = { version = "0.5.1", features = [ "async_tokio" ] } @@ -27,7 +29,7 @@ Inflector = "0.11.4" influxive-otel-atomic-obs = "=0.0.2-alpha.1" influxive-child-svc = "=0.0.2-alpha.1" influxive = "=0.0.2-alpha.1" -lair_keystore_api = "0.4.0" +lair_keystore_api = "0.4.3" libc = "0.2.141" libloading = "0.8.0" once_cell = "1.17.1" @@ -47,22 +49,22 @@ serde = { version = "1.0.160", features = [ "derive", "rc" ] } serde_json = { version = "1.0.96", features = [ "preserve_order" ] } sha2 = "0.10.6" socket2 = { version = "0.5.2", features = [ "all" ] } -sodoken = "0.0.9" +sodoken = "0.0.10" tempfile = "3.8.0" -tokio = { version = "1.27.0" } +tokio = { version = "1.35.1" } tokio-rustls = "0.23.4" tokio-tungstenite = { version = "0.18.0", features = [ "rustls-tls-native-roots" ] } tracing = "0.1.37" tracing-appender = "0.2.2" tracing-subscriber = { version = "0.3.16", features = [ "env-filter" ] } trust-dns-resolver = "0.22.0" -tx5-core = { version = "0.0.6-alpha", default-features = false, path = "crates/tx5-core" } -tx5-go-pion-turn = { version = "0.0.6-alpha", path = "crates/tx5-go-pion-turn" } -tx5-go-pion-sys = { version = "0.0.6-alpha", path = "crates/tx5-go-pion-sys" } -tx5-go-pion = { version = "0.0.6-alpha", path = "crates/tx5-go-pion" } -tx5-signal-srv = { version = "0.0.6-alpha", path = "crates/tx5-signal-srv" } -tx5-signal = { version = "0.0.6-alpha", path = "crates/tx5-signal" } -tx5 = { version = "0.0.6-alpha", path = "crates/tx5" } +tx5-core = { version = "0.0.7-alpha", default-features = false, path = "crates/tx5-core" } +tx5-go-pion-turn = { version = "0.0.7-alpha", path = "crates/tx5-go-pion-turn" } +tx5-go-pion-sys = { version = "0.0.7-alpha", path = "crates/tx5-go-pion-sys" } +tx5-go-pion = { version = "0.0.7-alpha", path = "crates/tx5-go-pion" } +tx5-signal-srv = { version = "0.0.7-alpha", path = "crates/tx5-signal-srv" } +tx5-signal = { version = "0.0.7-alpha", path = "crates/tx5-signal" } +tx5 = { version = "0.0.7-alpha", path = "crates/tx5" } url = { version = "2.3.1", features = [ "serde" ] } warp = { version = "0.3.4", features = [ "websocket" ] } webpki-roots = { version = "0.23.0" } diff --git a/Makefile b/Makefile index 58a07730..459bfce6 100644 --- a/Makefile +++ b/Makefile @@ -75,7 +75,7 @@ bump: test: static tools cargo build --all-targets - RUST_BACKTRACE=1 cargo test -- --nocapture + RUST_BACKTRACE=1 RUST_LOG=error cargo test -- --nocapture static: docs tools cargo fmt -- --check diff --git a/android-run-tests.bash b/android-run-tests.bash index 8d64045d..241722f1 100755 --- a/android-run-tests.bash +++ b/android-run-tests.bash @@ -2,7 +2,16 @@ set -eEuxo pipefail +trap 'cleanup' ERR EXIT +cleanup() { + for i in $(cat output-test-executables); do + adb shell rm -f /data/local/tmp/$(basename $i) + done +} + for i in $(cat output-test-executables); do adb push $i /data/local/tmp/$(basename $i) - adb shell RUST_BACKTRACE=1 /data/local/tmp/$(basename $i) --test-threads 1 --nocapture + adb shell chmod 500 /data/local/tmp/$(basename $i) + adb shell RUST_LOG=error RUST_BACKTRACE=1 /data/local/tmp/$(basename $i) --test-threads 1 --nocapture + adb shell rm -f /data/local/tmp/$(basename $i) done diff --git a/crates/tx5-core/Cargo.toml b/crates/tx5-core/Cargo.toml index 7af3f328..d541ee78 100644 --- a/crates/tx5-core/Cargo.toml +++ b/crates/tx5-core/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tx5-core" -version = "0.0.6-alpha" +version = "0.0.7-alpha" edition = "2021" description = "Holochain WebRTC P2P Communication Ecosystem Core Types" license = "MIT OR Apache-2.0" @@ -16,11 +16,12 @@ once_cell = { workspace = true } rand = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } +tokio = { workspace = true, features = [ "sync" ] } tracing = { workspace = true } url = { workspace = true } # file_check deps -dirs = { workspace = true, optional = true } +app_dirs2 = { workspace = true, optional = true } sha2 = { workspace = true, optional = true } tempfile = { workspace = true, optional = true } @@ -32,4 +33,4 @@ default = [ "file_check" ] # A couple crates that depend on tx5-core need to be able to write/verify # files on system. Enable this `file_check` feature to provide that ability. -file_check = [ "dirs", "sha2", "tempfile" ] +file_check = [ "app_dirs2", "sha2", "tempfile" ] diff --git a/crates/tx5-core/src/error.rs b/crates/tx5-core/src/error.rs index 0c5b1ecb..f9de7c2e 100644 --- a/crates/tx5-core/src/error.rs +++ b/crates/tx5-core/src/error.rs @@ -8,16 +8,6 @@ pub struct Error { pub info: String, } -impl From<()> for Error { - #[inline] - fn from(_: ()) -> Self { - Self { - id: "Error".into(), - info: String::default(), - } - } -} - impl From for Error { #[inline] fn from(id: String) -> Self { diff --git a/crates/tx5-core/src/evt.rs b/crates/tx5-core/src/evt.rs new file mode 100644 index 00000000..f115f801 --- /dev/null +++ b/crates/tx5-core/src/evt.rs @@ -0,0 +1,107 @@ +use crate::{Error, Result}; +use std::sync::Arc; + +/// Permit for sending on the channel. +pub struct EventPermit(Option); + +impl std::fmt::Debug for EventPermit { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("EventPermit").finish() + } +} + +/// Sender side of an explicitly bounded channel that lets us send +/// bounded (backpressured) events, but unbounded error messages. +pub struct EventSend> { + limit: Arc, + send: tokio::sync::mpsc::UnboundedSender<(E, EventPermit)>, +} + +impl> Clone for EventSend { + fn clone(&self) -> Self { + Self { + limit: self.limit.clone(), + send: self.send.clone(), + } + } +} + +impl> EventSend { + /// Construct a new event channel with given bound. + pub fn new(limit: u32) -> (Self, EventRecv) { + let limit = Arc::new(tokio::sync::Semaphore::new(limit as usize)); + let (send, recv) = tokio::sync::mpsc::unbounded_channel(); + (EventSend { limit, send }, EventRecv(recv)) + } + + /// Try to get a send permit. + pub fn try_permit(&self) -> Option { + match self.limit.clone().try_acquire_owned() { + Ok(p) => Some(EventPermit(Some(p))), + _ => None, + } + } + + /// Send an event. + pub async fn send(&self, evt: E) -> Result<()> { + let permit = self + .limit + .clone() + .acquire_owned() + .await + .map_err(|_| Error::id("Closed"))?; + self.send + .send((evt, EventPermit(Some(permit)))) + .map_err(|_| Error::id("Closed")) + } + + /// Send an event with a previously acquired permit. + pub fn send_permit(&self, evt: E, permit: EventPermit) -> Result<()> { + self.send + .send((evt, permit)) + .map_err(|_| Error::id("Closed")) + } + + /// Send an error. + pub fn send_err(&self, err: impl Into) { + let _ = self.send.send((err.into().into(), EventPermit(None))); + } +} + +/// Receiver side of an explicitly bounded channel that lets us send +/// bounded (backpressured) events, but unbounded error messages. +pub struct EventRecv>( + tokio::sync::mpsc::UnboundedReceiver<(E, EventPermit)>, +); + +impl> std::fmt::Debug for EventRecv { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("EventRecv").finish() + } +} + +impl> EventRecv { + /// Receive incoming PeerConnection events. + pub async fn recv(&mut self) -> Option { + self.0.recv().await.map(|r| r.0) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test(flavor = "multi_thread")] + async fn event_limit() { + let (s, _r) = >::new(1); + + s.send(Error::id("yo").into()).await.unwrap(); + + assert!(tokio::time::timeout( + std::time::Duration::from_millis(10), + s.send(Error::id("yo").into()), + ) + .await + .is_err()); + } +} diff --git a/crates/tx5-core/src/file_check.rs b/crates/tx5-core/src/file_check.rs index bf3251aa..9df432cb 100644 --- a/crates/tx5-core/src/file_check.rs +++ b/crates/tx5-core/src/file_check.rs @@ -17,6 +17,19 @@ impl FileCheck { } } +fn get_user_cache_dir() -> Option { + match app_dirs2::app_root( + app_dirs2::AppDataType::UserCache, + &app_dirs2::AppInfo { + name: "host.holo.tx5", + author: "host.holo.tx5", + }, + ) { + Ok(dir) => Some(dir), + _ => None, + } +} + /// Write a file if needed, verify the file, and return a handle to that file. pub fn file_check( file_data: &[u8], @@ -26,72 +39,79 @@ pub fn file_check( ) -> Result { let file_name = format!("{file_name_prefix}-{file_hash}{file_name_ext}"); - let mut pref_path = - dirs::data_local_dir().expect("failed to get data_local_dir"); - pref_path.push(&file_name); + let pref_path = get_user_cache_dir().map(|mut d| { + d.push(&file_name); + d + }); - if let Ok(file) = validate(&pref_path, file_hash) { - return Ok(FileCheck { - path: pref_path, - _file: Some(file), - }); + if let Some(pref_path) = pref_path.as_ref() { + if let Ok(file) = validate(pref_path, file_hash) { + return Ok(FileCheck { + path: pref_path.clone(), + _file: Some(file), + }); + } } - let tmp = write(file_data)?; + let mut tmp = write(file_data)?; - // NOTE: This is NOT atomic, nor secure, but being able to validate the - // file hash post-op mitigates this a bit. And we can let the os - // clean up a dangling tmp file if it failed to unlink. - match tmp.persist_noclobber(&pref_path) { - Ok(mut file) => { - set_perms(&mut file)?; + if let Some(pref_path) = pref_path.as_ref() { + // NOTE: This is NOT atomic, nor secure, but being able to validate the + // file hash post-op mitigates this a bit. And we can let the os + // clean up a dangling tmp file if it failed to unlink. + match tmp.persist_noclobber(pref_path) { + Ok(mut file) => { + set_perms(&mut file)?; - drop(file); + drop(file); - let file = validate(&pref_path, file_hash)?; - - Ok(FileCheck { - path: pref_path, - _file: Some(file), - }) - } - Err(err) => { - let tempfile::PersistError { file: tmp, .. } = err; - - // First, check to see if a different process wrote correctly - if let Ok(file) = validate(&pref_path, file_hash) { - // we no longer need the tmp file, clean it up - let _ = tmp.close(); + let file = validate(pref_path, file_hash)?; return Ok(FileCheck { - path: pref_path, + path: pref_path.clone(), _file: Some(file), }); } + Err(err) => { + let tempfile::PersistError { file, .. } = err; + tmp = file; + } + } - // we're just going to use the tmp file, do what we need to - // do to make sure it isn't deleted when the handle drops. - - let path = tmp.path().to_owned(); - let tmp = tmp.into_temp_path(); - - // This seems wrong, but it is how tempfile internally goes - // about doing persist/keep, so we're using it already, - // and it's only once-ish per process... - std::mem::forget(tmp); - - let file = validate(&path, file_hash)?; + // before we go on to just using the tmp file, + // check to see if a different process wrote correctly + if let Ok(file) = validate(pref_path, file_hash) { + // we no longer need the tmp file, clean it up + let _ = tmp.close(); - Ok(FileCheck { - path, + return Ok(FileCheck { + path: pref_path.clone(), _file: Some(file), - }) + }); } } + + // we're just going to use the tmp file, do what we need to + // do to make sure it isn't deleted when the handle drops. + + let path = tmp.path().to_owned(); + let tmp = tmp.into_temp_path(); + + // This seems wrong, but it is how tempfile internally goes + // about doing persist/keep, so we're using it already, + // and it's only once-ish per process... + std::mem::forget(tmp); + + let file = validate(&path, file_hash)?; + + Ok(FileCheck { + path, + _file: Some(file), + }) } /// Validate a file. -fn validate(path: &std::path::Path, hash: &str) -> Result { +fn validate(path: &std::path::PathBuf, hash: &str) -> Result { use std::io::Read; let mut file = std::fs::OpenOptions::new().read(true).open(path)?; diff --git a/crates/tx5-core/src/lib.rs b/crates/tx5-core/src/lib.rs index 2620db67..e4664d09 100644 --- a/crates/tx5-core/src/lib.rs +++ b/crates/tx5-core/src/lib.rs @@ -26,6 +26,9 @@ pub use uniq::*; mod url; pub use crate::url::*; +mod evt; +pub use evt::*; + #[cfg(feature = "file_check")] pub mod file_check; diff --git a/crates/tx5-demo/Cargo.toml b/crates/tx5-demo/Cargo.toml index 6deb2482..5f193d25 100644 --- a/crates/tx5-demo/Cargo.toml +++ b/crates/tx5-demo/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tx5-demo" -version = "0.0.6-alpha" +version = "0.0.7-alpha" edition = "2021" description = "Demo crate showing off Tx5 WebRTC functionality" license = "MIT OR Apache-2.0" diff --git a/crates/tx5-demo/src/main.rs b/crates/tx5-demo/src/main.rs index 6bd02b32..327fd829 100644 --- a/crates/tx5-demo/src/main.rs +++ b/crates/tx5-demo/src/main.rs @@ -5,10 +5,10 @@ const DASH_TX5: &[u8] = include_bytes!("influxive-dashboards/tx5.json"); -use bytes::Buf; use clap::Parser; use std::collections::HashMap; -use tx5::{Ep, EpEvt, Result, Tx5Url}; +use std::sync::Arc; +use tx5::{Config3, Ep3, Ep3Event, Result, Tx5Url}; #[derive(Debug, Parser)] #[clap(name = "tx5-demo", version, about = "Holochain Tx5 WebRTC Demo Cli")] @@ -47,15 +47,12 @@ enum Message { } impl Message { - pub fn encode(&self) -> Result { - let b = serde_json::to_vec(self)?; - let mut o = bytes::BytesMut::with_capacity(b.len()); - o.extend_from_slice(&b); - Ok(o.freeze()) + pub fn encode(&self) -> Result> { + serde_json::to_vec(&self).map_err(tx5::Error::err) } - pub fn decode(data: bytes::Bytes) -> Result { - Ok(serde_json::from_slice(&data)?) + pub fn decode(data: &[u8]) -> Result { + Ok(serde_json::from_slice(data)?) } pub fn hello(known_peers: &HashMap) -> Self { @@ -177,13 +174,15 @@ impl Node { } } - pub fn send(&self, ep: &Ep, rem_url: &Tx5Url, data: bytes::Bytes) { + pub fn send(&self, ep: &Arc, rem_url: &Tx5Url, data: Vec) { let ep = ep.clone(); let rem_url = rem_url.clone(); tokio::task::spawn(async move { - let len = data.remaining(); + let len = data.len(); + let id = rem_url.id().unwrap(); - if let Err(err) = ep.send(rem_url, data).await { + + if let Err(err) = ep.send(rem_url, &data).await { d!( error, "SEND_ERROR", @@ -195,7 +194,7 @@ impl Node { }); } - pub fn broadcast_hello(&self, ep: &Ep) -> Result<()> { + pub fn broadcast_hello(&self, ep: &Arc) -> Result<()> { let hello = Message::hello(&self.known_peers).encode()?; for url in self.known_peers.keys() { if url == &self.this_url { @@ -213,7 +212,7 @@ impl Node { Ok(()) } - pub fn five_sec(&mut self, ep: &Ep) -> Result<()> { + pub fn five_sec(&mut self, ep: &Arc) -> Result<()> { { let this = self.known_peers.get_mut(&self.this_url).unwrap(); this.last_seen = std::time::Instant::now(); @@ -238,7 +237,7 @@ impl Node { Ok(()) } - pub fn thirty_sec(&mut self, ep: &Ep) -> Result<()> { + pub fn thirty_sec(&mut self, ep: &Arc) -> Result<()> { let mut v = Vec::new(); for peer in self.known_peers.keys() { @@ -317,7 +316,8 @@ async fn main_err() -> Result<()> { let sig_url = Tx5Url::new(sig_url)?; - let (ep, mut evt) = tx5::Ep::new().await?; + let (ep, mut evt) = tx5::Ep3::new(Arc::new(Config3::default())).await; + let ep = Arc::new(ep); let this_addr = ep.listen(sig_url.clone()).await?; let mut node = Node::new(this_addr.clone(), peer_urls); @@ -337,7 +337,7 @@ async fn main_err() -> Result<()> { enum Cmd { FiveSec, ThirtySec, - EpEvt(Result), + EpEvt(Ep3Event), } let (cmd_s, mut cmd_r) = tokio::sync::mpsc::unbounded_channel(); @@ -384,31 +384,28 @@ async fn main_err() -> Result<()> { d!(info, "FIVE_SEC", "{node:?}"); tracing::info!( "{}", - serde_json::to_string(&ep.get_stats().await.unwrap()) - .unwrap() + serde_json::to_string(&ep.get_stats().await).unwrap() ); } Cmd::ThirtySec => node.thirty_sec(&ep)?, - Cmd::EpEvt(Err(err)) => panic!("{err:?}"), - Cmd::EpEvt(Ok(EpEvt::Connected { rem_cli_url })) => { - node.add_known_peer(rem_cli_url); + Cmd::EpEvt(Ep3Event::Error(err)) => panic!("{err:?}"), + Cmd::EpEvt(Ep3Event::Connected { peer_url }) => { + node.add_known_peer(peer_url); } - Cmd::EpEvt(Ok(EpEvt::Disconnected { rem_cli_url })) => { - d!(info, "DISCONNECTED", "{:?}", rem_cli_url.id().unwrap()); + Cmd::EpEvt(Ep3Event::Disconnected { peer_url }) => { + d!(info, "DISCONNECTED", "{:?}", peer_url.id().unwrap()); } - Cmd::EpEvt(Ok(EpEvt::Data { - rem_cli_url, - mut data, - .. - })) => { - node.add_known_peer(rem_cli_url.clone()); - match Message::decode(data.copy_to_bytes(data.remaining())) { + Cmd::EpEvt(Ep3Event::Message { + peer_url, message, .. + }) => { + node.add_known_peer(peer_url.clone()); + match Message::decode(&message) { Err(err) => d!(error, "RECV_ERROR", "{err:?}"), Ok(Message::Hello { known_peers: kp }) => { for peer in kp { node.add_known_peer(Tx5Url::new(peer).unwrap()); } - node.recv_hello(rem_cli_url)?; + node.recv_hello(peer_url)?; } Ok(Message::Big(d)) => { d!( @@ -416,12 +413,11 @@ async fn main_err() -> Result<()> { "RECV_BIG", "len:{} {:?}", d.as_bytes().len(), - rem_cli_url.id().unwrap() + peer_url.id().unwrap() ); } } } - Cmd::EpEvt(_) => (), } } diff --git a/crates/tx5-go-pion-sys/Cargo.toml b/crates/tx5-go-pion-sys/Cargo.toml index 2e674ad8..d148f23b 100644 --- a/crates/tx5-go-pion-sys/Cargo.toml +++ b/crates/tx5-go-pion-sys/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tx5-go-pion-sys" -version = "0.0.6-alpha" +version = "0.0.7-alpha" edition = "2021" description = "Low level rust bindings to the go pion webrtc library" license = "MIT OR Apache-2.0" diff --git a/crates/tx5-go-pion-sys/go.mod b/crates/tx5-go-pion-sys/go.mod index 13709d11..65f49f55 100644 --- a/crates/tx5-go-pion-sys/go.mod +++ b/crates/tx5-go-pion-sys/go.mod @@ -4,30 +4,30 @@ go 1.20 require ( github.com/pion/logging v0.2.2 - github.com/pion/webrtc/v3 v3.2.21 + github.com/pion/webrtc/v3 v3.2.24 ) require ( github.com/davecgh/go-spew v1.1.1 // indirect - github.com/google/uuid v1.3.1 // indirect + github.com/google/uuid v1.6.0 // indirect github.com/pion/datachannel v1.5.5 // indirect - github.com/pion/dtls/v2 v2.2.7 // indirect - github.com/pion/ice/v2 v2.3.11 // indirect - github.com/pion/interceptor v0.1.19 // indirect + github.com/pion/dtls/v2 v2.2.9 // indirect + github.com/pion/ice/v2 v2.3.12 // indirect + github.com/pion/interceptor v0.1.25 // indirect github.com/pion/mdns v0.0.9 // indirect github.com/pion/randutil v0.1.0 // indirect - github.com/pion/rtcp v1.2.10 // indirect - github.com/pion/rtp v1.8.2 // indirect + github.com/pion/rtcp v1.2.13 // indirect + github.com/pion/rtp v1.8.3 // indirect github.com/pion/sctp v1.8.9 // indirect github.com/pion/sdp/v3 v3.0.6 // indirect - github.com/pion/srtp/v2 v2.0.17 // indirect + github.com/pion/srtp/v2 v2.0.18 // indirect github.com/pion/stun v0.6.1 // indirect github.com/pion/transport/v2 v2.2.4 // indirect github.com/pion/turn/v2 v2.1.4 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/stretchr/testify v1.8.4 // indirect - golang.org/x/crypto v0.13.0 // indirect - golang.org/x/net v0.15.0 // indirect - golang.org/x/sys v0.12.0 // indirect + golang.org/x/crypto v0.18.0 // indirect + golang.org/x/net v0.20.0 // indirect + golang.org/x/sys v0.16.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/crates/tx5-go-pion-sys/go.sum b/crates/tx5-go-pion-sys/go.sum index 604b4184..4cd4ecb9 100644 --- a/crates/tx5-go-pion-sys/go.sum +++ b/crates/tx5-go-pion-sys/go.sum @@ -17,8 +17,9 @@ github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMyw github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/uuid v1.3.1 h1:KjJaJ9iWZ3jOFZIf1Lqf4laDRCasjl0BCmnEGxkdLb4= github.com/google/uuid v1.3.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= @@ -36,13 +37,14 @@ github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1y github.com/onsi/gomega v1.17.0/go.mod h1:HnhC7FXeEQY45zxNK3PPoIUhzk/80Xly9PcubAlGdZY= github.com/pion/datachannel v1.5.5 h1:10ef4kwdjije+M9d7Xm9im2Y3O6A6ccQb0zcqZcJew8= github.com/pion/datachannel v1.5.5/go.mod h1:iMz+lECmfdCMqFRhXhcA/219B0SQlbpoR2V118yimL0= -github.com/pion/dtls/v2 v2.2.7 h1:cSUBsETxepsCSFSxC3mc/aDo14qQLMSL+O6IjG28yV8= github.com/pion/dtls/v2 v2.2.7/go.mod h1:8WiMkebSHFD0T+dIU+UeBaoV7kDhOW5oDCzZ7WZ/F9s= -github.com/pion/ice/v2 v2.3.11 h1:rZjVmUwyT55cmN8ySMpL7rsS8KYsJERsrxJLLxpKhdw= +github.com/pion/dtls/v2 v2.2.9 h1:K+D/aVf9/REahQvqk6G5JavdrD8W1PWDKC11UlwN7ts= +github.com/pion/dtls/v2 v2.2.9/go.mod h1:8WiMkebSHFD0T+dIU+UeBaoV7kDhOW5oDCzZ7WZ/F9s= github.com/pion/ice/v2 v2.3.11/go.mod h1:hPcLC3kxMa+JGRzMHqQzjoSj3xtE9F+eoncmXLlCL4E= -github.com/pion/interceptor v0.1.18/go.mod h1:tpvvF4cPM6NGxFA1DUMbhabzQBxdWMATDGEUYOR9x6I= -github.com/pion/interceptor v0.1.19 h1:tq0TGBzuZQqipyBhaC1mVUCfCh8XjDKUuibq9rIl5t4= -github.com/pion/interceptor v0.1.19/go.mod h1:VANhFxdJezB8mwToMMmrmyHyP9gym6xLqIUch31xryg= +github.com/pion/ice/v2 v2.3.12 h1:NWKW2b3+oSZS3klbQMIEWQ0i52Kuo0KBg505a5kQv4s= +github.com/pion/ice/v2 v2.3.12/go.mod h1:hPcLC3kxMa+JGRzMHqQzjoSj3xtE9F+eoncmXLlCL4E= +github.com/pion/interceptor v0.1.25 h1:pwY9r7P6ToQ3+IF0bajN0xmk/fNw/suTgaTdlwTDmhc= +github.com/pion/interceptor v0.1.25/go.mod h1:wkbPYAak5zKsfpVDYMtEfWEy8D4zL+rpxCxPImLOg3Y= github.com/pion/logging v0.2.2 h1:M9+AIj/+pxNsDfAT64+MAVgJO0rsyLnoJKCqf//DoeY= github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms= github.com/pion/mdns v0.0.8/go.mod h1:hYE72WX8WDveIhg7fmXgMKivD3Puklk0Ymzog0lSyaI= @@ -50,19 +52,21 @@ github.com/pion/mdns v0.0.9 h1:7Ue5KZsqq8EuqStnpPWV33vYYEH0+skdDN5L7EiEsI4= github.com/pion/mdns v0.0.9/go.mod h1:2JA5exfxwzXiCihmxpTKgFUpiQws2MnipoPK09vecIc= github.com/pion/randutil v0.1.0 h1:CFG1UdESneORglEsnimhUjf33Rwjubwj6xfiOXBa3mA= github.com/pion/randutil v0.1.0/go.mod h1:XcJrSMMbbMRhASFVOlj/5hQial/Y8oH/HVo7TBZq+j8= -github.com/pion/rtcp v1.2.10 h1:nkr3uj+8Sp97zyItdN60tE/S6vk4al5CPRR6Gejsdjc= github.com/pion/rtcp v1.2.10/go.mod h1:ztfEwXZNLGyF1oQDttz/ZKIBaeeg/oWbRYqzBM9TL1I= -github.com/pion/rtp v1.8.1/go.mod h1:pBGHaFt/yW7bf1jjWAoUjpSNoDnw98KTMg+jWWvziqU= -github.com/pion/rtp v1.8.2 h1:oKMM0K1/QYQ5b5qH+ikqDSZRipP5mIxPJcgcvw5sH0w= +github.com/pion/rtcp v1.2.12/go.mod h1:sn6qjxvnwyAkkPzPULIbVqSKI5Dv54Rv7VG0kNxh9L4= +github.com/pion/rtcp v1.2.13 h1:+EQijuisKwm/8VBs8nWllr0bIndR7Lf7cZG200mpbNo= +github.com/pion/rtcp v1.2.13/go.mod h1:sn6qjxvnwyAkkPzPULIbVqSKI5Dv54Rv7VG0kNxh9L4= github.com/pion/rtp v1.8.2/go.mod h1:pBGHaFt/yW7bf1jjWAoUjpSNoDnw98KTMg+jWWvziqU= +github.com/pion/rtp v1.8.3 h1:VEHxqzSVQxCkKDSHro5/4IUUG1ea+MFdqR2R3xSpNU8= +github.com/pion/rtp v1.8.3/go.mod h1:pBGHaFt/yW7bf1jjWAoUjpSNoDnw98KTMg+jWWvziqU= github.com/pion/sctp v1.8.5/go.mod h1:SUFFfDpViyKejTAdwD1d/HQsCu+V/40cCs2nZIvC3s0= github.com/pion/sctp v1.8.8/go.mod h1:igF9nZBrjh5AtmKc7U30jXltsFHicFCXSmWA2GWRaWs= github.com/pion/sctp v1.8.9 h1:TP5ZVxV5J7rz7uZmbyvnUvsn7EJ2x/5q9uhsTtXbI3g= github.com/pion/sctp v1.8.9/go.mod h1:cMLT45jqw3+jiJCrtHVwfQLnfR0MGZ4rgOJwUOIqLkI= github.com/pion/sdp/v3 v3.0.6 h1:WuDLhtuFUUVpTfus9ILC4HRyHsW6TdugjEX/QY9OiUw= github.com/pion/sdp/v3 v3.0.6/go.mod h1:iiFWFpQO8Fy3S5ldclBkpXqmWy02ns78NOKoLLL0YQw= -github.com/pion/srtp/v2 v2.0.17 h1:ECuOk+7uIpY6HUlTb0nXhfvu4REG2hjtC4ronYFCZE4= -github.com/pion/srtp/v2 v2.0.17/go.mod h1:y5WSHcJY4YfNB/5r7ca5YjHeIr1H3LM1rKArGGs8jMc= +github.com/pion/srtp/v2 v2.0.18 h1:vKpAXfawO9RtTRKZJbG4y0v1b11NZxQnxRl85kGuUlo= +github.com/pion/srtp/v2 v2.0.18/go.mod h1:0KJQjA99A6/a0DOVTu1PhDSw0CXF2jTkqOoMg3ODqdA= github.com/pion/stun v0.6.1 h1:8lp6YejULeHBF8NmV8e2787BogQhduZugh5PdhDyyN4= github.com/pion/stun v0.6.1/go.mod h1:/hO7APkX4hZKu/D0f2lHzNyvdkTGtIy3NDmLR7kSz/8= github.com/pion/transport v0.14.1 h1:XSM6olwW+o8J4SCmOBb/BpwZypkHeyM0PGFCxNQBr40= @@ -77,8 +81,8 @@ github.com/pion/transport/v3 v3.0.1/go.mod h1:UY7kiITrlMv7/IKgd5eTUcaahZx5oUN3l9 github.com/pion/turn/v2 v2.1.3/go.mod h1:huEpByKKHix2/b9kmTAM3YoX6MKP+/D//0ClgUYR2fY= github.com/pion/turn/v2 v2.1.4 h1:2xn8rduI5W6sCZQkEnIUDAkrBQNl2eYIBCHMZ3QMmP8= github.com/pion/turn/v2 v2.1.4/go.mod h1:huEpByKKHix2/b9kmTAM3YoX6MKP+/D//0ClgUYR2fY= -github.com/pion/webrtc/v3 v3.2.21 h1:c8fy5JcqJkAQBwwy3Sk9huQLTBUSqaggyRlv9Lnh2zY= -github.com/pion/webrtc/v3 v3.2.21/go.mod h1:vVURQTBOG5BpWKOJz3nlr23NfTDeyKVmubRNqzQp+Tg= +github.com/pion/webrtc/v3 v3.2.24 h1:MiFL5DMo2bDaaIFWr0DDpwiV/L4EGbLZb+xoRvfEo1Y= +github.com/pion/webrtc/v3 v3.2.24/go.mod h1:1CaT2fcZzZ6VZA+O1i9yK2DU4EOcXVvSbWG9pr5jefs= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/sclevine/agouti v3.0.0+incompatible/go.mod h1:b4WX9W9L1sfQKXeJf1mUTLZKJ48R1S7H23Ji7oFO5Bw= @@ -102,8 +106,9 @@ golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE golang.org/x/crypto v0.10.0/go.mod h1:o4eNf7Ede1fv+hwOwZsTHl9EsPFO6q6ZvYR8vYfY45I= golang.org/x/crypto v0.11.0/go.mod h1:xgJhtzW8F9jGdVFWZESrid1U1bjeNy4zgy5cRr/CIio= golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= -golang.org/x/crypto v0.13.0 h1:mvySKfSWJ+UKUii46M40LOvyWfN0s2U+46/jDd0e6Ck= golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc= +golang.org/x/crypto v0.18.0 h1:PGVlW0xEltQnzFZ55hkuX5+KLyrMYhHld1YHO4AKcdc= +golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= @@ -122,8 +127,9 @@ golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/net v0.11.0/go.mod h1:2L/ixqYpgIVXmeoSA/4Lu7BzTG4KIyPIryS4IsOd1oQ= golang.org/x/net v0.13.0/go.mod h1:zEVYFnQC7m/vmpQFELhcD1EWkZlX69l4oqgmer6hfKA= golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI= -golang.org/x/net v0.15.0 h1:ugBLEUaxABaB5AJqW9enI0ACdci2RUd4eP51NTBvuJ8= golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk= +golang.org/x/net v0.20.0 h1:aCL9BSgETF1k+blQaYUBx9hJ9LOGP3gAVemcZlf1Kpo= +golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -151,8 +157,9 @@ golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.9.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.16.0 h1:xWw16ngr6ZMtmxDyKyIgsE93KNKz5HKmMa3b8ALHidU= +golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.1.0/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= diff --git a/crates/tx5-go-pion-sys/vendor.zip b/crates/tx5-go-pion-sys/vendor.zip index e8c67d0d..a74edf64 100644 Binary files a/crates/tx5-go-pion-sys/vendor.zip and b/crates/tx5-go-pion-sys/vendor.zip differ diff --git a/crates/tx5-go-pion-turn/Cargo.toml b/crates/tx5-go-pion-turn/Cargo.toml index c20eba92..33d4062f 100644 --- a/crates/tx5-go-pion-turn/Cargo.toml +++ b/crates/tx5-go-pion-turn/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tx5-go-pion-turn" -version = "0.0.6-alpha" +version = "0.0.7-alpha" edition = "2021" description = "Rust process wrapper around tx5-go-pion-turn executable" license = "MIT OR Apache-2.0" diff --git a/crates/tx5-go-pion/Cargo.toml b/crates/tx5-go-pion/Cargo.toml index 55e27e3f..f70d7341 100644 --- a/crates/tx5-go-pion/Cargo.toml +++ b/crates/tx5-go-pion/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tx5-go-pion" -version = "0.0.6-alpha" +version = "0.0.7-alpha" edition = "2021" description = "Rust bindings to the go pion webrtc library" license = "MIT OR Apache-2.0" @@ -13,13 +13,12 @@ categories = ["network-programming"] [dependencies] futures = { workspace = true } parking_lot = { workspace = true } -tokio = { workspace = true, features = [ "rt" ] } +tokio = { workspace = true, features = [ "full" ] } tx5-go-pion-sys = { workspace = true } tracing = { workspace = true } url = { workspace = true } [dev-dependencies] tx5-core = { workspace = true } -tokio = { workspace = true, features = [ "full" ] } tracing-subscriber = { workspace = true } tx5-go-pion-turn = { workspace = true } diff --git a/crates/tx5-go-pion/examples/turn-stress.rs b/crates/tx5-go-pion/examples/turn-stress.rs index 4d77c7ee..e49064ca 100644 --- a/crates/tx5-go-pion/examples/turn-stress.rs +++ b/crates/tx5-go-pion/examples/turn-stress.rs @@ -68,7 +68,7 @@ async fn main() { let start = std::time::Instant::now(); - let (mut c1, mut evt1) = spawn_peer(config.clone()).await; + let (c1, mut evt1) = spawn_peer(config.clone()).await; tokio::task::spawn(async move { while let Some(evt) = evt1.recv().await { o2o_snd.send(Cmd::PeerEvt(evt)).unwrap(); @@ -82,14 +82,20 @@ async fn main() { let rcv_done1 = rcv_done.clone(); tokio::task::spawn(async move { - let seed = c1 + let (data_chan, data_recv) = c1 .create_data_channel(DataChannelConfig { label: Some("data".into()), }) .await .unwrap(); - tokio::task::spawn(spawn_chan(seed, start, chan_ready1, rcv_done1)); + tokio::task::spawn(spawn_chan( + data_chan, + data_recv, + start, + chan_ready1, + rcv_done1, + )); let mut offer = c1.create_offer(OfferConfig::default()).await.unwrap(); @@ -139,7 +145,7 @@ async fn main() { let mut ice_buf = Some(Vec::new()); - let (mut c2, mut evt2) = spawn_peer(config.clone()).await; + let (c2, mut evt2) = spawn_peer(config.clone()).await; tokio::task::spawn(async move { while let Some(evt) = evt2.recv().await { t2t_snd.send(Cmd::PeerEvt(evt)).unwrap(); @@ -159,9 +165,13 @@ async fn main() { t2o_snd.send(Cmd::Ice(ice)).unwrap(); } } - Cmd::PeerEvt(PeerConnectionEvent::DataChannel(seed)) => { + Cmd::PeerEvt(PeerConnectionEvent::DataChannel( + data_chan, + data_recv, + )) => { tokio::task::spawn(spawn_chan( - seed, + data_chan, + data_recv, start, chan_ready.clone(), rcv_done.clone(), @@ -201,49 +211,24 @@ async fn spawn_peer( PeerConnection, tokio::sync::mpsc::UnboundedReceiver, ) { - let (snd, rcv) = tokio::sync::mpsc::unbounded_channel(); - let con = PeerConnection::new(config, move |evt| { - let _ = snd.send(evt); - }) - .await - .unwrap(); + let (con, rcv) = PeerConnection::new(config).await.unwrap(); (con, rcv) } async fn spawn_chan( - seed: DataChannelSeed, + data_chan: DataChannel, + mut data_recv: tokio::sync::mpsc::UnboundedReceiver, start: std::time::Instant, chan_ready: Arc, rcv_done: Arc, ) { - let (s_o, r_o) = tokio::sync::oneshot::channel(); - let s_o = std::sync::Mutex::new(Some(s_o)); - let (s_d, r_d) = tokio::sync::oneshot::channel(); - let s_d = std::sync::Mutex::new(Some(s_d)); - let c = std::sync::atomic::AtomicUsize::new(1); - let mut chan = seed.handle(move |evt| match evt { - DataChannelEvent::Close | DataChannelEvent::BufferedAmountLow => (), - DataChannelEvent::Open => { - if let Some(s_o) = s_o.lock().unwrap().take() { - let _ = s_o.send(()); - } - } - DataChannelEvent::Message(mut buf) => { - assert_eq!(1024, buf.len().unwrap()); - std::io::Write::write_all(&mut std::io::stdout(), b".").unwrap(); - std::io::Write::flush(&mut std::io::stdout()).unwrap(); - let cnt = c.fetch_add(1, std::sync::atomic::Ordering::SeqCst); - if cnt == MSG_CNT { - if let Some(s_d) = s_d.lock().unwrap().take() { - let _ = s_d.send(()); - } - } + loop { + match data_recv.recv().await { + Some(DataChannelEvent::Open) => break, + Some(DataChannelEvent::BufferedAmountLow) => (), + oth => panic!("{oth:?}"), } - }); - - //if chan.ready_state().unwrap() < 2 { - r_o.await.unwrap(); - //} + } println!("chan ready"); @@ -253,11 +238,27 @@ async fn spawn_chan( for _ in 0..MSG_CNT { let buf = GoBuf::from_slice(ONE_KB).unwrap(); - chan.send(buf).await.unwrap(); + data_chan.send(buf).await.unwrap(); } - // we've received all our messages - r_d.await.unwrap(); + let mut cnt = 0; + + loop { + match data_recv.recv().await { + Some(DataChannelEvent::BufferedAmountLow) => (), + Some(DataChannelEvent::Message(mut buf)) => { + assert_eq!(1024, buf.len().unwrap()); + std::io::Write::write_all(&mut std::io::stdout(), b".") + .unwrap(); + std::io::Write::flush(&mut std::io::stdout()).unwrap(); + cnt += 1; + if cnt == MSG_CNT { + break; + } + } + oth => panic!("{oth:?}"), + } + } rcv_done.wait().await; diff --git a/crates/tx5-go-pion/src/data_chan.rs b/crates/tx5-go-pion/src/data_chan.rs index f8cd5a9e..3c0f5c4e 100644 --- a/crates/tx5-go-pion/src/data_chan.rs +++ b/crates/tx5-go-pion/src/data_chan.rs @@ -1,106 +1,189 @@ use crate::*; -use parking_lot::Mutex; -use std::sync::Arc; +use std::sync::{Arc, Mutex, Weak}; use tx5_go_pion_sys::API; -/// A precursor go pion webrtc DataChannel, awaiting an event handler. -#[derive(Debug)] -pub struct DataChannelSeed(pub(crate) usize, Arc>>); +pub(crate) struct DataChanCore { + data_chan_id: usize, + evt_send: tokio::sync::mpsc::UnboundedSender, + drop_err: Error, +} -impl Drop for DataChannelSeed { +impl Drop for DataChanCore { fn drop(&mut self) { - if self.0 != 0 { - unsafe { API.data_chan_free(self.0) } + let _ = self + .evt_send + .send(DataChannelEvent::Error(self.drop_err.clone())); + unregister_data_chan(self.data_chan_id); + unsafe { + API.data_chan_free(self.data_chan_id); } } } -impl DataChannelSeed { - pub(crate) fn new(data_chan_id: usize) -> Self { - let hold = Arc::new(Mutex::new(Vec::new())); - { - let hold = hold.clone(); - register_data_chan_evt_cb( - data_chan_id, - Arc::new(move |evt| { - hold.lock().push(evt); - }), - ); +impl DataChanCore { + pub fn new( + data_chan_id: usize, + evt_send: tokio::sync::mpsc::UnboundedSender, + ) -> Self { + Self { + data_chan_id, + evt_send, + drop_err: Error::id("DataChannelDropped").into(), } - Self(data_chan_id, hold) } - /// Construct a real DataChannel by providing an event handler. - pub fn handle(mut self, cb: Cb) -> DataChannel - where - Cb: Fn(DataChannelEvent) + 'static + Send + Sync, - { - let cb: DataChanEvtCb = Arc::new(cb); - let data_chan_id = self.0; - self.0 = 0; - replace_data_chan_evt_cb(data_chan_id, move || { - for evt in self.1.lock().drain(..) { - cb(evt); - } - cb - }); - DataChannel(data_chan_id) + pub fn close(&mut self, err: Error) { + // self.evt_send.send_err() is called in Drop impl + self.drop_err = err; + } +} + +#[derive(Clone)] +pub(crate) struct WeakDataChan( + pub(crate) Weak>>, +); + +macro_rules! data_chan_strong_core { + ($inner:expr, $ident:ident, $block:block) => { + match &mut *$inner.lock().unwrap() { + Ok($ident) => $block, + Err(err) => Result::Err(err.clone().into()), + } + }; +} + +macro_rules! data_chan_weak_core { + ($inner:expr, $ident:ident, $block:block) => { + match $inner.upgrade() { + Some(strong) => data_chan_strong_core!(strong, $ident, $block), + None => Result::Err(Error::id("DataChannelClosed")), + } + }; +} + +impl WeakDataChan { + pub fn send_evt(&self, evt: DataChannelEvent) -> Result<()> { + data_chan_weak_core!(self.0, core, { + core.evt_send + .send(evt) + .map_err(|_| Error::id("DataChannelClosed")) + }) } } /// A go pion webrtc DataChannel. -#[derive(Debug)] -pub struct DataChannel(pub(crate) usize); +pub struct DataChannel(Arc>>); -impl Drop for DataChannel { - fn drop(&mut self) { - unregister_data_chan_evt_cb(self.0); - unsafe { API.data_chan_free(self.0) } +impl std::fmt::Debug for DataChannel { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let res = match &*self.0.lock().unwrap() { + Ok(core) => format!("DataChannel(open, {})", core.data_chan_id), + Err(err) => format!("DataChannel(closed, {:?})", err), + }; + f.write_str(&res) } } impl DataChannel { + pub(crate) fn new( + data_chan_id: usize, + ) -> (Self, tokio::sync::mpsc::UnboundedReceiver) { + let (evt_send, evt_recv) = tokio::sync::mpsc::unbounded_channel(); + + let strong = Arc::new(Mutex::new(Ok(DataChanCore::new( + data_chan_id, + evt_send.clone(), + )))); + + let weak = WeakDataChan(Arc::downgrade(&strong)); + + register_data_chan(data_chan_id, weak); + + // we might have missed some state callbacks + if let Ok(ready_state) = + unsafe { API.data_chan_ready_state(data_chan_id) } + { + // this is a terrible suggestion clippy + #[allow(clippy::comparison_chain)] + if ready_state == 2 { + let _ = evt_send.send(DataChannelEvent::Open); + } else if ready_state > 2 { + let _ = evt_send.send(DataChannelEvent::Close); + } + } + + (Self(strong), evt_recv) + } + + /// Close this data channel. + pub fn close(&self, err: Error) { + let mut tmp = Err(err.clone()); + + { + let mut lock = self.0.lock().unwrap(); + let mut do_swap = false; + if let Ok(core) = &mut *lock { + core.close(err); + do_swap = true; + } + if do_swap { + std::mem::swap(&mut *lock, &mut tmp); + } + } + + // make sure the above lock is released before this is dropped + drop(tmp); + } + + fn get_data_chan_id(&self) -> Result { + data_chan_strong_core!(self.0, core, { Ok(core.data_chan_id) }) + } + /// Get the label of this DataChannel. #[inline] - pub fn label(&mut self) -> Result { - unsafe { Ok(GoBuf(API.data_chan_label(self.0)?)) } + pub fn label(&self) -> Result { + unsafe { Ok(GoBuf(API.data_chan_label(self.get_data_chan_id()?)?)) } } /// Get the ready state of this DataChannel. #[inline] - pub fn ready_state(&mut self) -> Result { - unsafe { API.data_chan_ready_state(self.0) } + pub fn ready_state(&self) -> Result { + unsafe { API.data_chan_ready_state(self.get_data_chan_id()?) } } /// Set the buffered amount low threshold. /// Returns the current BufferedAmount. #[inline] pub fn set_buffered_amount_low_threshold( - &mut self, + &self, threshold: usize, ) -> Result { unsafe { - API.data_chan_set_buffered_amount_low_threshold(self.0, threshold) + API.data_chan_set_buffered_amount_low_threshold( + self.get_data_chan_id()?, + threshold, + ) } } /// Returns the current BufferedAmount. #[inline] - pub fn buffered_amount(&mut self) -> Result { - unsafe { API.data_chan_buffered_amount(self.0) } + pub fn buffered_amount(&self) -> Result { + unsafe { API.data_chan_buffered_amount(self.get_data_chan_id()?) } } /// Send data to the remote peer on this DataChannel. /// Returns the current BufferedAmount. - pub async fn send<'a, B>(&mut self, data: B) -> Result + pub async fn send<'a, B>(&self, data: B) -> Result where B: Into>, { // TODO - use OnBufferedAmountLow signal to implement backpressure - let data_chan = self.0; + let data_chan_id = self.get_data_chan_id()?; + r2id!(data); tokio::task::spawn_blocking(move || unsafe { - API.data_chan_send(data_chan, data) + API.data_chan_send(data_chan_id, data) }) .await? } diff --git a/crates/tx5-go-pion/src/evt.rs b/crates/tx5-go-pion/src/evt.rs index b2229d4a..64ed195f 100644 --- a/crates/tx5-go-pion/src/evt.rs +++ b/crates/tx5-go-pion/src/evt.rs @@ -1,8 +1,6 @@ use crate::*; use once_cell::sync::Lazy; -use parking_lot::Mutex; use std::collections::HashMap; -use std::sync::Arc; use tx5_go_pion_sys::Event as SysEvent; use tx5_go_pion_sys::API; @@ -46,8 +44,8 @@ impl PeerConnectionState { /// Incoming events related to a PeerConnection. #[derive(Debug)] pub enum PeerConnectionEvent { - /// PeerConnection Error. - Error(std::io::Error), + /// PeerConnection error. + Error(Error), /// PeerConnectionState event. State(PeerConnectionState), @@ -56,150 +54,279 @@ pub enum PeerConnectionEvent { ICECandidate(GoBuf), /// Received an incoming data channel. - DataChannel(DataChannelSeed), + /// Warning: This returns an unbounded channel, + /// you should process this as quickly and synchronously as possible + /// to avoid a backlog filling up memory. + DataChannel( + DataChannel, + tokio::sync::mpsc::UnboundedReceiver, + ), +} + +impl From for PeerConnectionEvent { + fn from(err: Error) -> Self { + Self::Error(err) + } } /// Incoming events related to a DataChannel. #[derive(Debug)] pub enum DataChannelEvent { - /// DataChannel is ready to send / receive. + /// Data channel error. + Error(Error), + + /// Data channel has transitioned to "open". Open, - /// DataChannel is closed. + /// Data channel has transitioned to "closed". Close, - /// DataChannel incoming message. + /// Received incoming message on the data channel. Message(GoBuf), - /// DataChannel buffered amount is now low. + /// Data channel buffered amount is now low. BufferedAmountLow, } +impl From for DataChannelEvent { + fn from(err: Error) -> Self { + Self::Error(err) + } +} + #[inline] pub(crate) fn init_evt_manager() { // ensure initialization - MANAGER.is_locked(); + let _ = &*MANAGER; } -pub(crate) fn register_peer_con_evt_cb(id: usize, cb: PeerConEvtCb) { - MANAGER.lock().peer_con.insert(id, cb); +enum Cmd { + Evt(SysEvent), + PeerReg(usize, peer_con::WeakPeerCon), + PeerUnreg(usize), + DataReg(usize, data_chan::WeakDataChan), + DataUnreg(usize), } -pub(crate) fn unregister_peer_con_evt_cb(id: usize) { - MANAGER.lock().peer_con.remove(&id); +struct Manager { + rt_send: std::sync::mpsc::Sender, } -pub(crate) fn register_data_chan_evt_cb(id: usize, cb: DataChanEvtCb) { - MANAGER.lock().data_chan.insert(id, cb); -} +impl Manager { + pub fn new() -> Self { + struct D; -pub(crate) fn replace_data_chan_evt_cb(id: usize, f: F) -where - F: FnOnce() -> DataChanEvtCb, -{ - let mut lock = MANAGER.lock(); - let cb = f(); - lock.data_chan.insert(id, cb); -} + impl Drop for D { + fn drop(&mut self) { + tracing::error!("tx5-go-pion offload EVENT LOOP EXITED"); + } + } -pub(crate) fn unregister_data_chan_evt_cb(id: usize) { - MANAGER.lock().data_chan.remove(&id); -} + let drop_trace = D; -pub(crate) type PeerConEvtCb = - Arc; + let (rt_send, rt_recv) = std::sync::mpsc::channel(); -pub(crate) type DataChanEvtCb = - Arc; + // we need to offload the event handling to another thread + // because it's not safe to invoke other go apis within the + // same sync call: + // https://github.com/pion/webrtc/issues/2404 + std::thread::Builder::new() + .name("tx5-evt-offload".to_string()) + .spawn(move || { + let _drop_trace = drop_trace; -static MANAGER: Lazy> = Lazy::new(|| { - unsafe { - // TODO!!! MEORY LEAK - // we need else cases througout here, if there isn't a - // registered callback, we need to call _free functions - // on incoming events like DataChannel and OnMessage buffers. - API.on_event(|sys_evt| match sys_evt { - SysEvent::Error(_error) => (), - SysEvent::PeerConICECandidate { - peer_con_id, - candidate, - } => { - let maybe_cb = - MANAGER.lock().peer_con.get(&peer_con_id).cloned(); - if let Some(cb) = maybe_cb { - cb(PeerConnectionEvent::ICECandidate(GoBuf(candidate))); - } - } - SysEvent::PeerConStateChange { - peer_con_id, - peer_con_state, - } => { - let maybe_cb = - MANAGER.lock().peer_con.get(&peer_con_id).cloned(); - if let Some(cb) = maybe_cb { - cb(PeerConnectionEvent::State( - PeerConnectionState::from_raw(peer_con_state), - )); - } - } - SysEvent::PeerConDataChan { - peer_con_id, - data_chan_id, - } => { - let maybe_cb = - MANAGER.lock().peer_con.get(&peer_con_id).cloned(); - if let Some(cb) = maybe_cb { - cb(PeerConnectionEvent::DataChannel(DataChannelSeed::new( - data_chan_id, - ))); + let mut peer_map: HashMap = + HashMap::new(); + let mut data_map: HashMap = + HashMap::new(); + + trait Wk { + fn se(&self, evt: E) -> Result<()>; } - } - SysEvent::DataChanClose(data_chan_id) => { - let maybe_cb = - MANAGER.lock().data_chan.get(&data_chan_id).cloned(); - if let Some(cb) = maybe_cb { - cb(DataChannelEvent::Close); + + impl Wk for peer_con::WeakPeerCon { + fn se(&self, evt: PeerConnectionEvent) -> Result<()> { + self.send_evt(evt) + } } - } - SysEvent::DataChanOpen(data_chan_id) => { - let maybe_cb = - MANAGER.lock().data_chan.get(&data_chan_id).cloned(); - if let Some(cb) = maybe_cb { - cb(DataChannelEvent::Open); + + impl Wk for data_chan::WeakDataChan { + fn se(&self, evt: DataChannelEvent) -> Result<()> { + self.send_evt(evt) + } } - } - SysEvent::DataChanMessage { - data_chan_id, - buffer_id, - } => { - let maybe_cb = - MANAGER.lock().data_chan.get(&data_chan_id).cloned(); - if let Some(cb) = maybe_cb { - cb(DataChannelEvent::Message(GoBuf(buffer_id))); + + fn smap>( + map: &mut HashMap, + id: usize, + evt: E, + ) { + if let std::collections::hash_map::Entry::Occupied(o) = + map.entry(id) + { + if o.get().se(evt).is_err() { + o.remove(); + } + } } - } - SysEvent::DataChanBufferedAmountLow(data_chan_id) => { - let maybe_cb = - MANAGER.lock().data_chan.get(&data_chan_id).cloned(); - if let Some(cb) = maybe_cb { - cb(DataChannelEvent::BufferedAmountLow); + + while let Ok(cmd) = rt_recv.recv() { + match cmd { + Cmd::Evt(evt) => match evt { + SysEvent::Error(error) => { + tracing::error!( + ?error, + "tx5-go-pion error event" + ); + } + SysEvent::PeerConICECandidate { + peer_con_id, + candidate, + } => { + smap( + &mut peer_map, + peer_con_id, + PeerConnectionEvent::ICECandidate(GoBuf( + candidate, + )), + ); + } + SysEvent::PeerConStateChange { + peer_con_id, + peer_con_state, + } => { + smap( + &mut peer_map, + peer_con_id, + PeerConnectionEvent::State( + PeerConnectionState::from_raw( + peer_con_state, + ), + ), + ); + } + SysEvent::PeerConDataChan { + peer_con_id, + data_chan_id, + } => { + let (chan, recv) = + DataChannel::new(data_chan_id); + smap( + &mut peer_map, + peer_con_id, + PeerConnectionEvent::DataChannel( + chan, recv, + ), + ); + } + SysEvent::DataChanClose(data_chan_id) => { + smap( + &mut data_map, + data_chan_id, + DataChannelEvent::Close, + ); + } + SysEvent::DataChanOpen(data_chan_id) => { + smap( + &mut data_map, + data_chan_id, + DataChannelEvent::Open, + ); + } + SysEvent::DataChanMessage { + data_chan_id, + buffer_id, + } => { + let buf = GoBuf(buffer_id); + smap( + &mut data_map, + data_chan_id, + DataChannelEvent::Message(buf), + ); + } + SysEvent::DataChanBufferedAmountLow( + data_chan_id, + ) => { + smap( + &mut data_map, + data_chan_id, + DataChannelEvent::BufferedAmountLow, + ); + } + }, + Cmd::PeerReg(id, peer_con) => { + peer_map.insert(id, peer_con); + } + Cmd::PeerUnreg(id) => { + peer_map.remove(&id); + } + Cmd::DataReg(id, data_chan) => { + data_map.insert(id, data_chan); + } + Cmd::DataUnreg(id) => { + data_map.remove(&id); + } + } } - } - }); + }) + .expect("Failed to spawn offload thread"); + + Self { rt_send } } - Manager::new() -}); -struct Manager { - peer_con: HashMap, - data_chan: HashMap, -} + pub fn register_peer_con( + &self, + id: usize, + peer_con: peer_con::WeakPeerCon, + ) { + let _ = self.rt_send.send(Cmd::PeerReg(id, peer_con)); + } -impl Manager { - pub fn new() -> Mutex { - Mutex::new(Self { - peer_con: HashMap::new(), - data_chan: HashMap::new(), - }) + pub fn unregister_peer_con(&self, id: usize) { + let _ = self.rt_send.send(Cmd::PeerUnreg(id)); + } + + pub fn register_data_chan( + &self, + id: usize, + data_chan: data_chan::WeakDataChan, + ) { + let _ = self.rt_send.send(Cmd::DataReg(id, data_chan)); + } + + pub fn unregister_data_chan(&self, id: usize) { + let _ = self.rt_send.send(Cmd::DataUnreg(id)); + } + + pub fn handle_event(&self, evt: SysEvent) { + let _ = self.rt_send.send(Cmd::Evt(evt)); } } +pub(crate) fn register_peer_con(id: usize, peer_con: peer_con::WeakPeerCon) { + MANAGER.register_peer_con(id, peer_con); +} + +pub(crate) fn unregister_peer_con(id: usize) { + MANAGER.unregister_peer_con(id); +} + +pub(crate) fn register_data_chan( + id: usize, + data_chan: data_chan::WeakDataChan, +) { + MANAGER.register_data_chan(id, data_chan); +} + +pub(crate) fn unregister_data_chan(id: usize) { + MANAGER.unregister_data_chan(id); +} + +static MANAGER: Lazy = Lazy::new(|| { + unsafe { + API.on_event(move |evt| { + MANAGER.handle_event(evt); + }); + } + + Manager::new() +}); diff --git a/crates/tx5-go-pion/src/lib.rs b/crates/tx5-go-pion/src/lib.rs index 5de8ca7e..454cec8d 100644 --- a/crates/tx5-go-pion/src/lib.rs +++ b/crates/tx5-go-pion/src/lib.rs @@ -99,314 +99,224 @@ mod tests { serde_json::from_str(&format!("{{\"iceServers\":[{ice}]}}")) .unwrap(); - let ice1 = Arc::new(parking_lot::Mutex::new(Vec::new())); - let ice2 = Arc::new(parking_lot::Mutex::new(Vec::new())); + let (peer1, mut prcv1) = PeerConnection::new(&config).await.unwrap(); + let (peer2, mut prcv2) = PeerConnection::new(&config).await.unwrap(); - #[derive(Debug)] - enum Cmd { - Shutdown, - Stats(tokio::sync::oneshot::Sender), - ICE(GoBuf), - Offer(GoBuf), - Answer(GoBuf), - } - - #[derive(Debug)] - enum Res { - Chan1(DataChannelSeed), - Chan2(DataChannelSeed), - } - - let (cmd_send_1, mut cmd_recv_1) = - tokio::sync::mpsc::unbounded_channel(); - - let (cmd_send_2, mut cmd_recv_2) = - tokio::sync::mpsc::unbounded_channel(); - - let (res_send, mut res_recv) = tokio::sync::mpsc::unbounded_channel(); - - // -- spawn thread for peer connection 1 -- // - - let hnd1 = { - let config = config.clone(); - let res_send = res_send.clone(); - let cmd_send_2 = cmd_send_2.clone(); - let ice1 = ice1.clone(); - tokio::task::spawn(async move { - let mut peer1 = { - let cmd_send_2 = cmd_send_2.clone(); - PeerConnection::new(&config, move |evt| match evt { - PeerConnectionEvent::Error(err) => { - panic!("{:?}", err); - } - PeerConnectionEvent::State(state) => { - println!("peer1 state: {state:?}"); - } - PeerConnectionEvent::ICECandidate(mut candidate) => { - println!( - "peer1 in-ice: {}", - String::from_utf8_lossy( - &candidate.to_vec().unwrap() - ) - ); - ice1.lock().push(candidate.mut_clone()); - // ok if these are lost during test shutdown - let _ = cmd_send_2.send(Cmd::ICE(candidate)); - } - PeerConnectionEvent::DataChannel(chan) => { - println!("peer1 in-chan: {:?}", chan); - } - }) - .await - .unwrap() - }; - - println!("peer1 about to create data channel"); - let chan1 = peer1 - .create_data_channel(DataChannelConfig { - label: Some("data".into()), - }) - .await - .unwrap(); - res_send.send(Res::Chan1(chan1)).unwrap(); - println!("peer1 create data channel complete"); - - println!("peer1 about to create offer"); - let mut offer = - peer1.create_offer(OfferConfig::default()).await.unwrap(); - peer1.set_local_description(&mut offer).await.unwrap(); - cmd_send_2.send(Cmd::Offer(offer)).unwrap(); - println!("peer1 offer complete"); - - while let Some(cmd) = cmd_recv_1.recv().await { - match cmd { - Cmd::ICE(ice) => { - // ok if these are lost during test shutdown - let _ = peer1.add_ice_candidate(ice).await; - } - Cmd::Answer(mut answer) => { - println!( - "peer1 recv answer: {}", - String::from_utf8_lossy( - &answer.to_vec().unwrap() - ) - ); - peer1.set_remote_description(answer).await.unwrap(); - } - Cmd::Stats(rsp) => { - let _ = rsp.send(peer1.stats().await.unwrap()); - } - _ => break, - } - } + let (data1, mut drcv1) = peer1 + .create_data_channel(DataChannelConfig { + label: Some("data".into()), }) - }; - - // -- spawn thread for peer connection 2 -- // - - let hnd2 = { - let config = config.clone(); - let res_send = res_send.clone(); - let cmd_send_1 = cmd_send_1.clone(); - let ice2 = ice2.clone(); - tokio::task::spawn(async move { - let mut peer2 = { - let cmd_send_1 = cmd_send_1.clone(); - PeerConnection::new(&config, move |evt| match evt { - PeerConnectionEvent::Error(err) => { - panic!("{:?}", err); - } - PeerConnectionEvent::State(state) => { - println!("peer2 state: {state:?}"); - } - PeerConnectionEvent::ICECandidate(mut candidate) => { - println!( - "peer2 in-ice: {}", - String::from_utf8_lossy( - &candidate.to_vec().unwrap() - ) - ); - ice2.lock().push(candidate.mut_clone()); - // ok if these are lost during test shutdown - let _ = cmd_send_1.send(Cmd::ICE(candidate)); - } - PeerConnectionEvent::DataChannel(chan) => { - println!("peer2 in-chan: {:?}", chan); - res_send.send(Res::Chan2(chan)).unwrap(); - } - }) - .await - .unwrap() - }; - - while let Some(cmd) = cmd_recv_2.recv().await { - match cmd { - Cmd::ICE(ice) => { - // ok if these are lost during test shutdown - let _ = peer2.add_ice_candidate(ice).await; - } - Cmd::Offer(mut offer) => { - println!( - "peer2 recv offer: {}", - String::from_utf8_lossy( - &offer.to_vec().unwrap() - ) - ); - peer2.set_remote_description(offer).await.unwrap(); - - println!("peer2 about to create answer"); - let mut answer = peer2 - .create_answer(AnswerConfig::default()) - .await - .unwrap(); - peer2 - .set_local_description(&mut answer) - .await - .unwrap(); - cmd_send_1.send(Cmd::Answer(answer)).unwrap(); - println!("peer2 answer complete"); - } - Cmd::Stats(rsp) => { - let _ = rsp.send(peer2.stats().await.unwrap()); - } - _ => break, + .await + .unwrap(); + + let mut offer = + peer1.create_offer(OfferConfig::default()).await.unwrap(); + peer1 + .set_local_description(offer.try_clone().unwrap()) + .await + .unwrap(); + peer2.set_remote_description(offer).await.unwrap(); + let mut answer = + peer2.create_answer(AnswerConfig::default()).await.unwrap(); + peer2 + .set_local_description(answer.try_clone().unwrap()) + .await + .unwrap(); + peer1.set_remote_description(answer).await.unwrap(); + + let (data2, mut drcv2) = loop { + if let Some(evt) = prcv2.recv().await { + match evt { + PeerConnectionEvent::Error(err) => panic!("{err:?}"), + PeerConnectionEvent::State(_) => (), + PeerConnectionEvent::ICECandidate(ice) => { + peer1.add_ice_candidate(ice).await.unwrap(); + } + PeerConnectionEvent::DataChannel(data2, drcv2) => { + break (data2, drcv2); } } - }) + } else { + panic!("receiver ended"); + } }; - // -- retrieve our data channels -- // + #[derive(Debug)] + enum FinishState { + Start, + Msg1, + Msg2, + Done, + } - let mut chan1 = None; - let mut chan2 = None; + impl FinishState { + fn is_done(&self) -> bool { + matches!(self, Self::Done) + } - for _ in 0..2 { - match res_recv.recv().await.unwrap() { - Res::Chan1(chan) => chan1 = Some(chan), - Res::Chan2(chan) => chan2 = Some(chan), + fn msg1(&self) -> Self { + match self { + Self::Start => Self::Msg1, + Self::Msg1 => Self::Msg1, + Self::Msg2 => Self::Done, + oth => panic!("expected not Done, got: {oth:?}"), + } + } + + fn msg2(&self) -> Self { + match self { + Self::Start => Self::Msg2, + Self::Msg1 => Self::Done, + oth => panic!("expected Start or Msg1, got: {oth:?}"), + } } } - let (s_open, r_open) = std::sync::mpsc::sync_channel(32); - let (s_data, r_data) = std::sync::mpsc::sync_channel(32); + let mut state = FinishState::Start; - println!("got data channels"); + loop { + tokio::select! { + evt = prcv1.recv() => match evt { + Some(PeerConnectionEvent::State(_)) => (), + Some(PeerConnectionEvent::ICECandidate(ice)) => { + peer2.add_ice_candidate(ice).await.unwrap(); + } + oth => panic!("unexpected: {oth:?}"), + }, + evt = prcv2.recv() => match evt { + Some(PeerConnectionEvent::State(_)) => (), + Some(PeerConnectionEvent::ICECandidate(ice)) => { + peer1.add_ice_candidate(ice).await.unwrap(); + } + oth => panic!("unexpected: {oth:?}"), + }, + evt = drcv1.recv() => match evt { + Some(DataChannelEvent::BufferedAmountLow) => (), + Some(DataChannelEvent::Open) => { + assert_eq!( + "data", + &String::from_utf8_lossy( + &data2.label().unwrap().to_vec().unwrap()), + ); + println!( + "data1 pre-send buffered amount: {}", + data1.set_buffered_amount_low_threshold(5).unwrap(), + ); + println!( + "data1 post-send buffered amount: {}", + data1.send(GoBuf::from_slice(b"hello").unwrap()).await.unwrap(), + ); + } + Some(DataChannelEvent::Message(mut buf)) => { + assert_eq!( + "world", + &String::from_utf8_lossy(&buf.to_vec().unwrap()), + ); - // -- setup event handler for data channel 1 -- // + state = state.msg1(); + } + oth => panic!("unexpected: {oth:?}"), + }, + evt = drcv2.recv() => match evt { + Some(DataChannelEvent::BufferedAmountLow) => (), + Some(DataChannelEvent::Open) => { + assert_eq!( + "data", + &String::from_utf8_lossy( + &data2.label().unwrap().to_vec().unwrap()), + ); + println!( + "data2 pre-send buffered amount: {}", + data2.set_buffered_amount_low_threshold(5).unwrap(), + ); + println!( + "data2 post-send buffered amount: {}", + data2.send(GoBuf::from_slice(b"world").unwrap()).await.unwrap(), + ); + } + Some(DataChannelEvent::Message(mut buf)) => { + assert_eq!( + "hello", + &String::from_utf8_lossy(&buf.to_vec().unwrap()), + ); - let s_open1 = s_open.clone(); - let s_data1 = s_data.clone(); - let mut chan1 = chan1.unwrap().handle(move |evt| { - println!("chan1: {:?}", evt); - if let DataChannelEvent::Open = evt { - s_open1.send(()).unwrap(); + state = state.msg2(); + } + oth => panic!("unexpected: {oth:?}"), + }, } - if let DataChannelEvent::Message(mut msg) = evt { - msg.access(|data| { - assert_eq!(b"world", data.unwrap()); - Ok(()) - }) - .unwrap(); - s_data1.send(()).unwrap(); + if state.is_done() { + println!( + "peer1: {}", + String::from_utf8_lossy( + &peer1.stats().await.unwrap().to_vec().unwrap() + ), + ); + println!( + "peer2: {}", + String::from_utf8_lossy( + &peer1.stats().await.unwrap().to_vec().unwrap() + ), + ); + break; } - }); + } - // -- setup event handler for data channel 2 -- // + let data1 = Arc::new(data1); + + let mut all = Vec::new(); + + const COUNT: usize = 10; + + let bar = Arc::new(tokio::sync::Barrier::new(COUNT)); + + for i in 0..COUNT { + let hnd = tokio::runtime::Handle::current(); + let bar = bar.clone(); + let data1 = data1.clone(); + all.push(std::thread::spawn(move || { + hnd.block_on(async move { + println!("send {i}"); + bar.wait().await; + data1 + .send(GoBuf::from_slice(b"hello").unwrap()) + .await + .unwrap(); + println!("sent {i}"); + }); + })); + } - let mut chan2 = chan2.unwrap().handle(move |evt| { - println!("chan2: {:?}", evt); - if let DataChannelEvent::Open = evt { - s_open.send(()).unwrap(); - } - if let DataChannelEvent::Message(mut msg) = evt { - msg.access(|data| { - assert_eq!(b"hello", data.unwrap()); - Ok(()) - }) - .unwrap(); - s_data.send(()).unwrap(); + let mut r_count = 0; + while let Some(evt) = drcv2.recv().await { + println!("{evt:?}"); + if matches!(evt, DataChannelEvent::Message(_)) { + r_count += 1; + println!("got {r_count}"); + if r_count >= COUNT { + break; + } } - }); - - // -- make sure the channels are ready / open -- // - - let chan1ready = chan1.ready_state().unwrap(); - println!("chan1 ready_state: {}", chan1ready); - let chan2ready = chan2.ready_state().unwrap(); - println!("chan2 ready_state: {}", chan2ready); - - let mut need_open_cnt = 0; - if chan1ready < 2 { - need_open_cnt += 1; - } - if chan2ready < 2 { - need_open_cnt += 1; } - for _ in 0..need_open_cnt { - r_open.recv().unwrap(); + for (i, t) in all.into_iter().enumerate() { + println!("await thread {i}"); + t.join().unwrap(); + println!("thread {i} complete"); } - // -- check the channel labels -- // - - let lbl1 = - String::from_utf8_lossy(&chan1.label().unwrap().to_vec().unwrap()) - .to_string(); - let lbl2 = - String::from_utf8_lossy(&chan2.label().unwrap().to_vec().unwrap()) - .to_string(); - tracing::info!(%lbl1, %lbl2); - assert_eq!("data", &lbl1); - assert_eq!("data", &lbl2); - - // -- set the buffered amount low thresholds -- // - - let b = chan1.set_buffered_amount_low_threshold(5).unwrap(); - println!("chan1 pre-send buffered amount: {b}"); - let b = chan2.set_buffered_amount_low_threshold(5).unwrap(); - println!("chan2 pre-send buffered amount: {b}"); - - // -- send data on the data channels -- // - - let mut buf = GoBuf::new().unwrap(); - buf.extend(b"hello").unwrap(); - let b = chan1.send(buf).await.unwrap(); - println!("chan1 post-send buffered amount: {b}"); + println!("close data 1"); + data1.close(Error::id("").into()); + println!("close data 2"); + data2.close(Error::id("").into()); + println!("close peer 1"); + peer1.close(Error::id("").into()); + println!("close peer 2"); + peer2.close(Error::id("").into()); - let mut buf = GoBuf::new().unwrap(); - buf.extend(b"world").unwrap(); - let b = chan2.send(buf).await.unwrap(); - println!("chan2 post-send buffered amount: {b}"); - - // -- await receiving data on the data channels -- // - - for _ in 0..2 { - r_data.recv().unwrap(); - } - - // -- get stats -- // - - let (s, r) = tokio::sync::oneshot::channel(); - cmd_send_1.send(Cmd::Stats(s)).unwrap(); - println!( - "peer_con_1: {}", - String::from_utf8_lossy(&r.await.unwrap().to_vec().unwrap()) - ); - let (s, r) = tokio::sync::oneshot::channel(); - cmd_send_2.send(Cmd::Stats(s)).unwrap(); - println!( - "peer_con_2: {}", - String::from_utf8_lossy(&r.await.unwrap().to_vec().unwrap()) - ); - - // -- cleanup -- // - - drop(chan1); - drop(chan2); - cmd_send_1.send(Cmd::Shutdown).unwrap(); - cmd_send_2.send(Cmd::Shutdown).unwrap(); - hnd1.await.unwrap(); - hnd2.await.unwrap(); + println!("close turn"); turn.stop().await.unwrap(); + + println!("all done."); } } diff --git a/crates/tx5-go-pion/src/peer_con.rs b/crates/tx5-go-pion/src/peer_con.rs index 7c4114a8..b9015d1d 100644 --- a/crates/tx5-go-pion/src/peer_con.rs +++ b/crates/tx5-go-pion/src/peer_con.rs @@ -1,5 +1,5 @@ use crate::*; -use std::sync::Arc; +use std::sync::{Arc, Mutex, Weak}; use tx5_go_pion_sys::API; /// ICE server configuration. @@ -93,126 +93,236 @@ impl From<&AnswerConfig> for GoBufRef<'static> { } } -/// A go pion webrtc PeerConnection. -#[derive(Debug)] -pub struct PeerConnection(usize); +pub(crate) struct PeerConCore { + peer_con_id: usize, + evt_send: tokio::sync::mpsc::UnboundedSender, + drop_err: Error, +} -impl Drop for PeerConnection { +impl Drop for PeerConCore { fn drop(&mut self) { + let _ = self + .evt_send + .send(PeerConnectionEvent::Error(self.drop_err.clone())); + unregister_peer_con(self.peer_con_id); unsafe { - unregister_peer_con_evt_cb(self.0); - API.peer_con_free(self.0); + API.peer_con_free(self.peer_con_id); + } + } +} + +impl PeerConCore { + pub fn new( + peer_con_id: usize, + evt_send: tokio::sync::mpsc::UnboundedSender, + ) -> Self { + Self { + peer_con_id, + evt_send, + drop_err: Error::id("PeerConnectionDropped").into(), + } + } + + pub fn close(&mut self, err: Error) { + // self.evt_send.send_err() is called in Drop impl + self.drop_err = err; + } +} + +#[derive(Clone)] +pub(crate) struct WeakPeerCon( + pub(crate) Weak>>, +); + +macro_rules! peer_con_strong_core { + ($inner:expr, $ident:ident, $block:block) => { + match &mut *$inner.lock().unwrap() { + Ok($ident) => $block, + Err(err) => Result::Err(err.clone().into()), + } + }; +} + +macro_rules! peer_con_weak_core { + ($inner:expr, $ident:ident, $block:block) => { + match $inner.upgrade() { + Some(strong) => peer_con_strong_core!(strong, $ident, $block), + None => Result::Err(Error::id("PeerConnectionClosed")), } + }; +} + +impl WeakPeerCon { + pub fn send_evt(&self, evt: PeerConnectionEvent) -> Result<()> { + peer_con_weak_core!(self.0, core, { + core.evt_send + .send(evt) + .map_err(|_| Error::id("PeerConnectionClosed")) + }) } } +/// A go pion webrtc PeerConnection. +pub struct PeerConnection(Arc>>); + impl PeerConnection { /// Construct a new PeerConnection. - pub async fn new<'a, B, Cb>(config: B, cb: Cb) -> Result + /// Warning: This returns an unbounded channel, + /// you should process this as quickly and synchronously as possible + /// to avoid a backlog filling up memory. + pub async fn new<'a, B>( + config: B, + ) -> Result<( + Self, + tokio::sync::mpsc::UnboundedReceiver, + )> where B: Into>, - Cb: Fn(PeerConnectionEvent) + 'static + Send + Sync, { tx5_init().await.map_err(Error::err)?; init_evt_manager(); r2id!(config); - let cb: PeerConEvtCb = Arc::new(cb); tokio::task::spawn_blocking(move || unsafe { let peer_con_id = API.peer_con_alloc(config)?; - register_peer_con_evt_cb(peer_con_id, cb); - Ok(Self(peer_con_id)) + let (evt_send, evt_recv) = tokio::sync::mpsc::unbounded_channel(); + + let strong = Arc::new(Mutex::new(Ok(PeerConCore::new( + peer_con_id, + evt_send, + )))); + + let weak = WeakPeerCon(Arc::downgrade(&strong)); + + register_peer_con(peer_con_id, weak); + + Ok((Self(strong), evt_recv)) }) .await? } + /// Close this connection. + pub fn close(&self, err: Error) { + let mut tmp = Err(err.clone()); + + { + let mut lock = self.0.lock().unwrap(); + let mut do_swap = false; + if let Ok(core) = &mut *lock { + core.close(err.clone()); + do_swap = true; + } + if do_swap { + std::mem::swap(&mut *lock, &mut tmp); + } + } + + // make sure the above lock is released before this is dropped + drop(tmp); + } + + fn get_peer_con_id(&self) -> Result { + peer_con_strong_core!(self.0, core, { Ok(core.peer_con_id) }) + } + /// Get stats. - pub async fn stats(&mut self) -> Result { - let peer_con = self.0; + pub async fn stats(&self) -> Result { + let peer_con_id = self.get_peer_con_id()?; + tokio::task::spawn_blocking(move || unsafe { - API.peer_con_stats(peer_con).map(GoBuf) + API.peer_con_stats(peer_con_id).map(GoBuf) }) .await? } /// Create offer. - pub async fn create_offer<'a, B>(&mut self, config: B) -> Result + pub async fn create_offer<'a, B>(&self, config: B) -> Result where B: Into>, { - let peer_con = self.0; + let peer_con_id = self.get_peer_con_id()?; + r2id!(config); tokio::task::spawn_blocking(move || unsafe { - API.peer_con_create_offer(peer_con, config).map(GoBuf) + API.peer_con_create_offer(peer_con_id, config).map(GoBuf) }) .await? } /// Create answer. - pub async fn create_answer<'a, B>(&mut self, config: B) -> Result + pub async fn create_answer<'a, B>(&self, config: B) -> Result where B: Into>, { - let peer_con = self.0; + let peer_con_id = self.get_peer_con_id()?; + r2id!(config); tokio::task::spawn_blocking(move || unsafe { - API.peer_con_create_answer(peer_con, config).map(GoBuf) + API.peer_con_create_answer(peer_con_id, config).map(GoBuf) }) .await? } /// Set local description. - pub async fn set_local_description<'a, B>(&mut self, desc: B) -> Result<()> + pub async fn set_local_description<'a, B>(&self, desc: B) -> Result<()> where B: Into>, { - let peer_con = self.0; + let peer_con_id = self.get_peer_con_id()?; + r2id!(desc); tokio::task::spawn_blocking(move || unsafe { - API.peer_con_set_local_desc(peer_con, desc) + API.peer_con_set_local_desc(peer_con_id, desc) }) .await? } /// Set remote description. - pub async fn set_remote_description<'a, B>(&mut self, desc: B) -> Result<()> + pub async fn set_remote_description<'a, B>(&self, desc: B) -> Result<()> where B: Into>, { - let peer_con = self.0; + let peer_con_id = self.get_peer_con_id()?; + r2id!(desc); tokio::task::spawn_blocking(move || unsafe { - API.peer_con_set_rem_desc(peer_con, desc) + API.peer_con_set_rem_desc(peer_con_id, desc) }) .await? } /// Add ice candidate. - pub async fn add_ice_candidate<'a, B>(&mut self, ice: B) -> Result<()> + pub async fn add_ice_candidate<'a, B>(&self, ice: B) -> Result<()> where B: Into>, { - let peer_con = self.0; + let peer_con_id = self.get_peer_con_id()?; + r2id!(ice); tokio::task::spawn_blocking(move || unsafe { - API.peer_con_add_ice_candidate(peer_con, ice) + API.peer_con_add_ice_candidate(peer_con_id, ice) }) .await? } /// Create data channel. pub async fn create_data_channel<'a, B>( - &mut self, + &self, config: B, - ) -> Result + ) -> Result<( + DataChannel, + tokio::sync::mpsc::UnboundedReceiver, + )> where B: Into>, { - let peer_con = self.0; + let peer_con_id = + peer_con_strong_core!(self.0, core, { Ok(core.peer_con_id) })?; + r2id!(config); tokio::task::spawn_blocking(move || unsafe { let data_chan_id = - API.peer_con_create_data_chan(peer_con, config)?; - Ok(DataChannelSeed::new(data_chan_id)) + API.peer_con_create_data_chan(peer_con_id, config)?; + Ok(DataChannel::new(data_chan_id)) }) .await? } diff --git a/crates/tx5-go-pion/tests/limit-ports.rs b/crates/tx5-go-pion/tests/limit-ports.rs index 2c849151..89d86e55 100644 --- a/crates/tx5-go-pion/tests/limit-ports.rs +++ b/crates/tx5-go-pion/tests/limit-ports.rs @@ -7,13 +7,8 @@ async fn limit_ports() { .set_as_global_default() .unwrap(); - let (s, mut r) = tokio::sync::mpsc::unbounded_channel(); - - let mut con = tx5_go_pion::PeerConnection::new( + let (con, mut r) = tx5_go_pion::PeerConnection::new( tx5_go_pion::PeerConnectionConfig::default(), - move |evt| { - let _ = s.send(evt); - }, ) .await .unwrap(); diff --git a/crates/tx5-online/Cargo.toml b/crates/tx5-online/Cargo.toml index 2285a32e..7ac96fb9 100644 --- a/crates/tx5-online/Cargo.toml +++ b/crates/tx5-online/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tx5-online" -version = "0.0.6-alpha" +version = "0.0.7-alpha" edition = "2021" description = "Holochain WebRTC P2P Communication Ecosystem Online Connectivity Events" license = "MIT OR Apache-2.0" diff --git a/crates/tx5-signal-srv/Cargo.toml b/crates/tx5-signal-srv/Cargo.toml index c03ba9be..015b7a36 100644 --- a/crates/tx5-signal-srv/Cargo.toml +++ b/crates/tx5-signal-srv/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tx5-signal-srv" -version = "0.0.6-alpha" +version = "0.0.7-alpha" description = "holochain webrtc signal server" license = "MIT OR Apache-2.0" homepage = "https://github.com/holochain/tx5" diff --git a/crates/tx5-signal-srv/src/bin/tx5-signal-srv.rs b/crates/tx5-signal-srv/src/bin/tx5-signal-srv.rs index af1c4123..9c8c530e 100644 --- a/crates/tx5-signal-srv/src/bin/tx5-signal-srv.rs +++ b/crates/tx5-signal-srv/src/bin/tx5-signal-srv.rs @@ -46,7 +46,7 @@ async fn main_err() -> Result<()> { ConfigPerOpt::ConfigLoaded(config) => config, }; - let (driver, addr_list, err_list) = exec_tx5_signal_srv(config)?; + let (_hnd, addr_list, err_list) = exec_tx5_signal_srv(config).await?; for err in err_list { println!("# tx5-signal-srv ERR {err:?}"); @@ -58,7 +58,5 @@ async fn main_err() -> Result<()> { println!("# tx5-signal-srv START"); - driver.await; - - Ok(()) + futures::future::pending().await } diff --git a/crates/tx5-signal-srv/src/server.rs b/crates/tx5-signal-srv/src/server.rs index 4163b057..33a06b21 100644 --- a/crates/tx5-signal-srv/src/server.rs +++ b/crates/tx5-signal-srv/src/server.rs @@ -4,9 +4,8 @@ use futures::sink::SinkExt; use futures::stream::StreamExt; use sodoken::crypto_box::curve25519xsalsa20poly1305 as crypto_box; use std::collections::HashMap; -use std::future::Future; use std::net::SocketAddr; -use std::sync::Arc; +use std::sync::{Arc, Weak}; use tx5_core::wire; struct IntGaugeGuard(prometheus::IntGauge); @@ -24,9 +23,20 @@ impl IntGaugeGuard { } } -/// [exec_tx5_signal_srv] will return this driver future. -pub type ServerDriver = - std::pin::Pin + 'static + Send>>; +/// A handle to the server instance. This has no functionality +/// except that when it is dropped, the server will shut down. +pub struct SrvHnd { + task_list: Vec>, + _srv_cmd: Arc, +} + +impl Drop for SrvHnd { + fn drop(&mut self) { + for t in self.task_list.iter() { + t.abort(); + } + } +} /// Protocol Version #[derive(Debug)] @@ -36,9 +46,9 @@ enum ProtoVer { /// The main entrypoint tx5-signal-server logic task. #[allow(deprecated)] -pub fn exec_tx5_signal_srv( +pub async fn exec_tx5_signal_srv( config: Config, -) -> Result<(ServerDriver, Vec, Vec)> { +) -> Result<(SrvHnd, Vec, Vec)> { // make sure our metrics are initialized let _ = &*METRICS_REQ_COUNT; let _ = &*METRICS_REQ_TIME_S; @@ -50,11 +60,12 @@ pub fn exec_tx5_signal_srv( let _ = &*REQ_DEMO_CNT; let ice_servers = Arc::new(config.ice_servers); - let srv_hnd = Srv::spawn(); + let srv_cmd = Srv::spawn(); use warp::Filter; use warp::Reply; + let srv_cmd_weak = Arc::downgrade(&srv_cmd); let tx5_ws = warp::path!("tx5-ws" / String / String) .and(warp::ws()) .map(move |ver: String, client_pub: String, ws: warp::ws::Ws| { @@ -72,12 +83,12 @@ pub fn exec_tx5_signal_srv( Ok(client_pub) => client_pub, }; let ice_servers = ice_servers.clone(); - let srv_hnd = srv_hnd.clone(); + let srv_cmd_weak = srv_cmd_weak.clone(); ws.max_send_queue(tx5_core::ws::MAX_SEND_QUEUE) .max_message_size(tx5_core::ws::MAX_MESSAGE_SIZE) .max_frame_size(tx5_core::ws::MAX_FRAME_SIZE) .on_upgrade(move |ws| async move { - client_task(ws, ver, client_pub, ice_servers, srv_hnd) + client_task(ws, ver, client_pub, ice_servers, srv_cmd_weak) .await; drop(active_ws_g); }) @@ -98,9 +109,10 @@ pub fn exec_tx5_signal_srv( let routes = tx5_ws.or(prometheus).with(warp::trace::request()); - let mut drv_out = Vec::new(); + let mut task_list = Vec::new(); let mut add_out = Vec::new(); let mut err_out = Vec::new(); + let mut pend_list = Vec::new(); for addr in config .interfaces @@ -114,18 +126,20 @@ pub fn exec_tx5_signal_srv( } Ok(addr) => addr, }; + let (pend_send, pend_recv) = tokio::sync::oneshot::channel(); + pend_list.push(pend_recv); match warp::serve(routes.clone()) .try_bind_ephemeral((addr, config.port)) .map_err(Error::err) - .map(|(addr, fut)| { - let fut: ServerDriver = Box::pin(fut); - (addr, fut) - }) { + { Err(err) => { err_out.push(err); } Ok((addr, drv)) => { - drv_out.push(drv); + task_list.push(tokio::task::spawn(async move { + let _ = pend_send.send(()); + drv.await; + })); add_out.append(&mut tx_addr(addr)?); } } @@ -135,11 +149,16 @@ pub fn exec_tx5_signal_srv( return Err(Error::str(format!("{err_out:?}"))); } - let drv_out = Box::pin(async move { - let _ = futures::future::join_all(drv_out).await; - }); + for pend_recv in pend_list { + let _ = pend_recv.await; + } - Ok((drv_out, add_out, err_out)) + let srv_hnd = SrvHnd { + task_list, + _srv_cmd: srv_cmd, + }; + + Ok((srv_hnd, add_out, err_out)) } fn tx_addr(addr: std::net::SocketAddr) -> Result> { @@ -196,7 +215,7 @@ async fn client_task( ver: ProtoVer, client_pub: sodoken::BufReadSized<{ crypto_box::PUBLICKEYBYTES }>, ice_servers: Arc, - srv_hnd: Arc, + srv_cmd: Weak, ) { macro_rules! dbg_err { ($e:expr) => { @@ -217,13 +236,17 @@ async fn client_task( let (mut tx, mut rx) = ws.split(); let (out_send, mut out_recv) = tokio::sync::mpsc::unbounded_channel(); - dbg_err!(srv_hnd.register(client_id, out_send.clone()).await); + if let Some(srv_cmd) = srv_cmd.upgrade() { + dbg_err!(srv_cmd.register(client_id, out_send.clone()).await); + } else { + return; + } CLIENT_AUTH_WS_COUNT.inc(); tracing::info!(?client_id, ?ver, "Accepted Incoming Connection"); - let srv_hnd_read = srv_hnd.clone(); + let srv_cmd_read = srv_cmd.clone(); let client_id_read = client_id; tokio::select! { res = async move { @@ -277,7 +300,11 @@ async fn client_task( }) => { // TODO - pay attention to demo config flag // right now we just always honor demos - srv_hnd_read.broadcast(msg); + if let Some(srv_cmd) = srv_cmd_read.upgrade() { + srv_cmd.broadcast(msg); + } else { + break; + } REQ_DEMO_CNT.inc(); } Ok(wire::Wire::FwdV1 { rem_pub, nonce, cipher }) => { @@ -286,19 +313,23 @@ async fn client_task( nonce, cipher, }.encode()); - match srv_hnd_read.forward(rem_pub, data).await { - Ok(fut) => { - match fut.await { - Ok(Ok(())) => (), - Ok(Err(err)) => { - tracing::trace!(?err); + if let Some(srv_cmd) = srv_cmd_read.upgrade() { + match srv_cmd.forward(rem_pub, data).await { + Ok(fut) => { + match fut.await { + Ok(Ok(())) => (), + Ok(Err(err)) => { + tracing::trace!(?err); + } + Err(_) => (), } - Err(_) => (), + } + Err(err) => { + tracing::trace!(?err); } } - Err(err) => { - tracing::trace!(?err); - } + } else { + break; } REQ_FWD_CNT.inc(); } @@ -312,7 +343,9 @@ async fn client_task( }; tracing::debug!("ConShutdown"); - dbg_err!(srv_hnd.unregister(client_id).await); + if let Some(srv_cmd) = srv_cmd.upgrade() { + dbg_err!(srv_cmd.unregister(client_id).await); + } } async fn authenticate( @@ -382,7 +415,7 @@ type DataSend = tokio::sync::mpsc::UnboundedSender<( OneSend>, )>; -enum SrvCmd { +enum SrvMsg { Shutdown, Register(Id, DataSend, OneSend>), Unregister(Id, OneSend>), @@ -390,11 +423,11 @@ enum SrvCmd { Broadcast(Vec), } -type SrvSend = tokio::sync::mpsc::UnboundedSender; +type SrvSend = tokio::sync::mpsc::UnboundedSender; -struct SrvHnd(SrvSend, tokio::task::JoinHandle<()>); +struct SrvCmd(SrvSend, tokio::task::JoinHandle<()>); -impl Drop for SrvHnd { +impl Drop for SrvCmd { fn drop(&mut self) { self.shutdown(); } @@ -402,15 +435,15 @@ impl Drop for SrvHnd { const E_SERVER_SHUTDOWN: &str = "ServerShutdown"; -impl SrvHnd { +impl SrvCmd { pub fn shutdown(&self) { - let _ = self.0.send(SrvCmd::Shutdown); + let _ = self.0.send(SrvMsg::Shutdown); self.1.abort(); } pub async fn register(&self, id: Id, data_send: DataSend) -> Result<()> { let (s, r) = tokio::sync::oneshot::channel(); - if self.0.send(SrvCmd::Register(id, data_send, s)).is_err() { + if self.0.send(SrvMsg::Register(id, data_send, s)).is_err() { return Err(Error::id(E_SERVER_SHUTDOWN)); } r.await.map_err(|_| Error::id(E_SERVER_SHUTDOWN))? @@ -418,7 +451,7 @@ impl SrvHnd { pub async fn unregister(&self, id: Id) -> Result<()> { let (s, r) = tokio::sync::oneshot::channel(); - if self.0.send(SrvCmd::Unregister(id, s)).is_err() { + if self.0.send(SrvMsg::Unregister(id, s)).is_err() { return Err(Error::id(E_SERVER_SHUTDOWN)); } r.await.map_err(|_| Error::id(E_SERVER_SHUTDOWN))? @@ -430,14 +463,14 @@ impl SrvHnd { data: Vec, ) -> Result>> { let (s, r) = tokio::sync::oneshot::channel(); - if self.0.send(SrvCmd::Forward(id, data, s)).is_err() { + if self.0.send(SrvMsg::Forward(id, data, s)).is_err() { return Err(Error::id(E_SERVER_SHUTDOWN)); } r.await.map_err(|_| Error::id(E_SERVER_SHUTDOWN))? } pub fn broadcast(&self, data: Vec) { - let _ = self.0.send(SrvCmd::Broadcast(data)); + let _ = self.0.send(SrvMsg::Broadcast(data)); } } @@ -452,7 +485,7 @@ impl Srv { } } - fn spawn() -> Arc { + fn spawn() -> Arc { let (hnd_send, mut hnd_recv) = tokio::sync::mpsc::unbounded_channel(); let mut srv = Srv::new(); @@ -464,24 +497,24 @@ impl Srv { } }); - Arc::new(SrvHnd(hnd_send, task)) + Arc::new(SrvCmd(hnd_send, task)) } // we want to make sure this function is *not* async // so that we can chew through the work loop without stalling - fn sync_process(&mut self, cmd: SrvCmd) -> bool { + fn sync_process(&mut self, cmd: SrvMsg) -> bool { match cmd { - SrvCmd::Shutdown => return false, - SrvCmd::Register(id, data_send, resp) => { + SrvMsg::Shutdown => return false, + SrvMsg::Register(id, data_send, resp) => { let _ = resp.send(self.register(id, data_send)); } - SrvCmd::Unregister(id, resp) => { + SrvMsg::Unregister(id, resp) => { let _ = resp.send(self.unregister(id)); } - SrvCmd::Forward(id, data, resp) => { + SrvMsg::Forward(id, data, resp) => { let _ = resp.send(self.forward(id, data)); } - SrvCmd::Broadcast(data) => { + SrvMsg::Broadcast(data) => { self.broadcast(data); } } diff --git a/crates/tx5-signal/Cargo.toml b/crates/tx5-signal/Cargo.toml index 76a73d7d..291154ce 100644 --- a/crates/tx5-signal/Cargo.toml +++ b/crates/tx5-signal/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tx5-signal" -version = "0.0.6-alpha" +version = "0.0.7-alpha" description = "holochain webrtc signal client" license = "MIT OR Apache-2.0" homepage = "https://github.com/holochain/tx5" diff --git a/crates/tx5-signal/src/cli.rs b/crates/tx5-signal/src/cli.rs index 417ea4c1..d73b4b69 100644 --- a/crates/tx5-signal/src/cli.rs +++ b/crates/tx5-signal/src/cli.rs @@ -591,7 +591,7 @@ impl Cli { } })); - let _ = init_recv.await; + init_recv.await.map_err(|_| Error::id("ShuttingDown"))??; Ok(( Self { @@ -621,47 +621,36 @@ async fn con_task( lair_client: LairClient, write_send: WriteSend, write_recv: WriteRecv, - init: tokio::sync::oneshot::Sender<()>, + init: tokio::sync::oneshot::Sender>, ) { - let mut init = Some(init); - let mut write_recv = Some(write_recv); - loop { - if let Some(socket) = con_open_connection( - &use_tls, - &host, - &con_url, - &endpoint, - x25519_pub, - &ice, - &lair_client, - ) - .await - { - // once we've run open_connection once, proceed with init - if let Some(init) = init.take() { - let _ = init.send(()); - } + match con_open_connection( + &use_tls, + &host, + &con_url, + &endpoint, + x25519_pub, + &ice, + &lair_client, + ) + .await + { + Ok(socket) => { + // once we've run open_connection proceed with init + let _ = init.send(Ok(())); - let a_write_recv = con_manage_connection( + con_manage_connection( socket, msg_send.clone(), x25519_pub, &lair_client, write_send.clone(), - write_recv.take().unwrap(), + write_recv, ) .await; - write_recv = Some(a_write_recv); } - - // once we've run open_connection once, proceed with init - if let Some(init) = init.take() { - let _ = init.send(()); + Err(err) => { + let _ = init.send(Err(err)); } - - let s = rand::Rng::gen_range(&mut rand::thread_rng(), 4.0..8.0); - let s = std::time::Duration::from_secs_f64(s); - tokio::time::sleep(s).await; } } @@ -670,24 +659,12 @@ async fn con_stack( host: &str, con_url: &str, addr: std::net::SocketAddr, -) -> Option { +) -> Result { tracing::debug!(?addr, "try connect"); - let socket = match tokio::net::TcpStream::connect(addr).await { - Ok(socket) => socket, - Err(err) => { - tracing::debug!(?err); - return None; - } - }; + let socket = tokio::net::TcpStream::connect(addr).await?; - let socket = match tcp_configure(socket) { - Ok(socket) => socket, - Err(err) => { - tracing::debug!(?err); - return None; - } - }; + let socket = tcp_configure(socket)?; let socket: tokio_tungstenite::MaybeTlsStream = if let Some(tls) = use_tls { @@ -695,37 +672,24 @@ async fn con_stack( .try_into() .unwrap_or_else(|_| "tx5-signal".try_into().unwrap()); - let socket = match tokio_rustls::TlsConnector::from(tls.clone()) + let socket = tokio_rustls::TlsConnector::from(tls.clone()) .connect(name, socket) - .await - { - Ok(socket) => socket, - Err(err) => { - tracing::debug!(?err); - return None; - } - }; + .await?; tokio_tungstenite::MaybeTlsStream::Rustls(socket) } else { tokio_tungstenite::MaybeTlsStream::Plain(socket) }; - let (socket, _rsp) = match tokio_tungstenite::client_async_with_config( + let (socket, _rsp) = tokio_tungstenite::client_async_with_config( con_url, socket, Some(WS_CONFIG), ) .await - { - Ok(r) => r, - Err(err) => { - tracing::debug!(?err); - return None; - } - }; + .map_err(Error::err)?; - Some(socket) + Ok(socket) } async fn con_open_connection( @@ -736,51 +700,38 @@ async fn con_open_connection( x25519_pub: Id, ice: &Mutex>, lair_client: &LairClient, -) -> Option { +) -> Result { let mut result_socket = None; - let addr_list = match tokio::net::lookup_host(&endpoint).await { - Ok(addr_list) => addr_list, - Err(err) => { - tracing::debug!(?err); - return None; - } - }; + let addr_list = tokio::net::lookup_host(&endpoint).await?; + + let mut err_list = Vec::new(); for addr in addr_list { - if let Some(con) = con_stack(use_tls, host, con_url, addr).await { - result_socket = Some(con); - break; + match con_stack(use_tls, host, con_url, addr).await { + Ok(con) => { + result_socket = Some(con); + break; + } + Err(err) => err_list.push(err), } } let mut socket = match result_socket { Some(socket) => socket, None => { - tracing::debug!("failed all sig dns addr connects"); - return None; + err_list.push(Error::str("failed all sig dns addr connects")); + return Err(Error::str(format!("{err_list:?}"))); } }; let auth_req = match socket.next().await { Some(Ok(auth_req)) => auth_req.into_data(), - Some(Err(err)) => { - tracing::debug!(?err); - return None; - } - None => { - tracing::debug!("InvalidServerAuthReq"); - return None; - } + Some(Err(err)) => return Err(Error::err(err)), + None => return Err(Error::id("InvalidServerAuthReq")), }; - let decode = match wire::Wire::decode(&auth_req) { - Ok(decode) => decode, - Err(err) => { - tracing::debug!(?err); - return None; - } - }; + let decode = wire::Wire::decode(&auth_req)?; let (srv_pub, nonce, cipher, got_ice) = match decode { wire::Wire::AuthReqV1 { @@ -790,12 +741,11 @@ async fn con_open_connection( ice, } => (srv_pub, nonce, cipher, ice), _ => { - tracing::debug!("InvalidServerAuthReq"); - return None; + return Err(Error::id("InvalidServerAuthReq")); } }; - let con_key = match lair_client + let con_key = lair_client .crypto_box_xsalsa_open_by_pub_key( srv_pub.0.into(), x25519_pub.0.into(), @@ -803,46 +753,22 @@ async fn con_open_connection( nonce.0, cipher.0.into(), ) - .await - { - Ok(con_key) => con_key, - Err(err) => { - tracing::debug!(?err); - return None; - } - }; + .await?; - if let Err(err) = socket - .send(Message::binary( - match (wire::Wire::AuthResV1 { - con_key: match Id::from_slice(&con_key) { - Ok(con_key) => con_key, - Err(err) => { - tracing::debug!(?err); - return None; - } - }, - req_addr: true, - }) - .encode() - { - Ok(binary) => binary, - Err(err) => { - tracing::debug!(?err); - return None; - } - }, - )) - .await - { - tracing::debug!(?err); - return None; - } + let con_key = Id::from_slice(&con_key)?; + let msg = Message::binary( + wire::Wire::AuthResV1 { + con_key, + req_addr: true, + } + .encode()?, + ); + socket.send(msg).await.map_err(Error::err)?; tracing::info!(%got_ice, "signal connection established"); *ice.lock() = Arc::new(got_ice); - Some(socket) + Ok(socket) } async fn con_manage_connection( diff --git a/crates/tx5-signal/src/tests.rs b/crates/tx5-signal/src/tests.rs index bd5daf28..3decdb8a 100644 --- a/crates/tx5-signal/src/tests.rs +++ b/crates/tx5-signal/src/tests.rs @@ -69,6 +69,97 @@ impl Test { } } +#[tokio::test(flavor = "multi_thread")] +async fn server_stop_restart() { + init_tracing(); + + let mut task = None; + let mut port = None; + + for p in 31181..31191 { + let mut srv_config = tx5_signal_srv::Config::default(); + srv_config.port = p; + srv_config.ice_servers = serde_json::json!([]); + + if let Ok((srv_hnd, _, _)) = + tx5_signal_srv::exec_tx5_signal_srv(srv_config).await + { + task = Some(srv_hnd); + port = Some(p); + break; + } + } + + let mut task = task.unwrap(); + let port = port.unwrap(); + + tokio::time::sleep(std::time::Duration::from_secs(2)).await; + + // test setup + let (cli1, mut rcv1) = Test::new(port).await; + let id1 = *cli1.cli.local_id(); + + cli1.cli + .ice(id1, serde_json::json!({"test": "ice"})) + .await + .unwrap(); + + let msg = rcv1.recv().await; + tracing::info!(?msg); + assert!(matches!(msg, Some(SignalMsg::Ice { .. }))); + + // drop server + drop(task); + + tokio::time::sleep(std::time::Duration::from_secs(2)).await; + + // make sure it now errors + + // first just trigger an update + let _ = cli1.cli.ice(id1, serde_json::json!({"test": "ice"})).await; + + // now our receive ends + let msg = rcv1.recv().await; + tracing::info!(?msg); + assert!(matches!(msg, None)); + + // now we get errors on send + cli1.cli + .ice(id1, serde_json::json!({"test": "ice"})) + .await + .unwrap_err(); + + // new server on same port + + let mut srv_config = tx5_signal_srv::Config::default(); + srv_config.port = port; + srv_config.ice_servers = serde_json::json!([]); + + let (srv_hnd, _, _) = tx5_signal_srv::exec_tx5_signal_srv(srv_config) + .await + .unwrap(); + + task = srv_hnd; + + tokio::time::sleep(std::time::Duration::from_secs(2)).await; + + // test continuation + let (cli1, mut rcv1) = Test::new(port).await; + let id1 = *cli1.cli.local_id(); + + cli1.cli + .ice(id1, serde_json::json!({"test": "ice"})) + .await + .unwrap(); + + let msg = rcv1.recv().await; + tracing::info!(?msg); + assert!(matches!(msg, Some(SignalMsg::Ice { .. }))); + + // cleanup + drop(task); +} + #[tokio::test(flavor = "multi_thread")] async fn wrong_version() { init_tracing(); @@ -78,23 +169,20 @@ async fn wrong_version() { srv_config.ice_servers = serde_json::json!([]); srv_config.demo = true; - let (srv_driver, addr_list, _) = - tx5_signal_srv::exec_tx5_signal_srv(srv_config).unwrap(); + let (_srv_hnd, addr_list, _) = + tx5_signal_srv::exec_tx5_signal_srv(srv_config) + .await + .unwrap(); let srv_port = addr_list.get(0).unwrap().port(); tracing::info!(%srv_port); - tokio::select! { - _ = srv_driver => (), - _ = async move { - // TODO remove - tokio::time::sleep(std::time::Duration::from_millis(10)).await; + // TODO remove + tokio::time::sleep(std::time::Duration::from_millis(10)).await; - tokio_tungstenite::connect_async(format!("ws://127.0.0.1:{srv_port}/tx5-ws/v1/AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")).await.unwrap(); - assert!(tokio_tungstenite::connect_async(format!("ws://127.0.0.1:{srv_port}/tx5-ws/v0/AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")).await.is_err()); - } => (), - } + tokio_tungstenite::connect_async(format!("ws://127.0.0.1:{srv_port}/tx5-ws/v1/AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")).await.unwrap(); + assert!(tokio_tungstenite::connect_async(format!("ws://127.0.0.1:{srv_port}/tx5-ws/v0/AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")).await.is_err()); } #[tokio::test(flavor = "multi_thread")] @@ -106,22 +194,19 @@ async fn sanity() { srv_config.ice_servers = serde_json::json!([]); srv_config.demo = true; - let (srv_driver, addr_list, _) = - tx5_signal_srv::exec_tx5_signal_srv(srv_config).unwrap(); + let (_srv_hnd, addr_list, _) = + tx5_signal_srv::exec_tx5_signal_srv(srv_config) + .await + .unwrap(); let srv_port = addr_list.get(0).unwrap().port(); tracing::info!(%srv_port); - tokio::select! { - _ = srv_driver => (), - _ = async move { - // TODO remove - tokio::time::sleep(std::time::Duration::from_millis(10)).await; + // TODO remove + tokio::time::sleep(std::time::Duration::from_millis(10)).await; - sanity_inner(srv_port).await; - } => (), - } + sanity_inner(srv_port).await; } async fn sanity_inner(srv_port: u16) { diff --git a/crates/tx5/Cargo.toml b/crates/tx5/Cargo.toml index 6befe345..b7da629a 100644 --- a/crates/tx5/Cargo.toml +++ b/crates/tx5/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tx5" -version = "0.0.6-alpha" +version = "0.0.7-alpha" edition = "2021" description = "The main holochain tx5 webrtc networking crate" license = "MIT OR Apache-2.0" @@ -20,6 +20,7 @@ backend-go-pion = [ "tx5-go-pion" ] backend-webrtc-rs = [ "webrtc" ] [dependencies] +bit_field = { workspace = true } bytes = { workspace = true } futures = { workspace = true } influxive-otel-atomic-obs = { workspace = true } diff --git a/crates/tx5/benches/throughput.rs b/crates/tx5/benches/throughput.rs index 79b07983..7d3d448a 100644 --- a/crates/tx5/benches/throughput.rs +++ b/crates/tx5/benches/throughput.rs @@ -1,17 +1,19 @@ use criterion::{criterion_group, criterion_main, Criterion}; use std::sync::Arc; use tokio::sync::Mutex; -use tx5::{actor::ManyRcv, *}; +use tx5::*; +use tx5_core::EventRecv; const DATA: &[u8] = &[0xdb; 4096]; struct Test { - cli_url1: Tx5Url, - ep1: Ep, - ep_rcv1: ManyRcv, - cli_url2: Tx5Url, - ep2: Ep, - ep_rcv2: ManyRcv, + _sig_srv_hnd: tx5_signal_srv::SrvHnd, + cli_url1: PeerUrl, + ep1: Ep3, + ep_rcv1: EventRecv, + cli_url2: PeerUrl, + ep2: Ep3, + ep_rcv2: EventRecv, } impl Test { @@ -20,41 +22,45 @@ impl Test { srv_config.port = 0; srv_config.demo = true; - let (srv_driver, addr_list, _) = - tx5_signal_srv::exec_tx5_signal_srv(srv_config).unwrap(); - tokio::task::spawn(srv_driver); + let (_sig_srv_hnd, addr_list, _) = + tx5_signal_srv::exec_tx5_signal_srv(srv_config) + .await + .unwrap(); let sig_port = addr_list.get(0).unwrap().port(); let sig_url = Tx5Url::new(format!("ws://localhost:{sig_port}")).unwrap(); - let (ep1, mut ep_rcv1) = Ep::new().await.unwrap(); + let config = Arc::new(Config3::default()); + + let (ep1, mut ep_rcv1) = Ep3::new(config.clone()).await; let cli_url1 = ep1.listen(sig_url.clone()).await.unwrap(); - let (ep2, mut ep_rcv2) = Ep::new().await.unwrap(); + let (ep2, mut ep_rcv2) = Ep3::new(config).await; let cli_url2 = ep2.listen(sig_url).await.unwrap(); - ep1.send(cli_url2.clone(), &b"hello"[..]).await.unwrap(); + ep1.send(cli_url2.clone(), b"hello").await.unwrap(); match ep_rcv2.recv().await { - Some(Ok(EpEvt::Connected { .. })) => (), + Some(Ep3Event::Connected { .. }) => (), oth => panic!("unexpected: {oth:?}"), } match ep_rcv2.recv().await { - Some(Ok(EpEvt::Data { .. })) => (), + Some(Ep3Event::Message { .. }) => (), oth => panic!("unexpected: {oth:?}"), } - ep2.send(cli_url1.clone(), &b"world"[..]).await.unwrap(); + ep2.send(cli_url1.clone(), b"world").await.unwrap(); match ep_rcv1.recv().await { - Some(Ok(EpEvt::Connected { .. })) => (), + Some(Ep3Event::Connected { .. }) => (), oth => panic!("unexpected: {oth:?}"), } match ep_rcv1.recv().await { - Some(Ok(EpEvt::Data { .. })) => (), + Some(Ep3Event::Message { .. }) => (), oth => panic!("unexpected: {oth:?}"), } Self { + _sig_srv_hnd, cli_url1, ep1, ep_rcv1, @@ -66,6 +72,7 @@ impl Test { pub async fn test(&mut self) { let Test { + _sig_srv_hnd, cli_url1, ep1, ep_rcv1, @@ -75,15 +82,23 @@ impl Test { } = self; let _ = tokio::try_join!( - ep1.send(cli_url2.clone(), DATA), - ep2.send(cli_url1.clone(), DATA), - async { ep_rcv1.recv().await.unwrap() }, - async { ep_rcv2.recv().await.unwrap() }, + ep1.send(cli_url2.clone(), DATA,), + ep2.send(cli_url1.clone(), DATA,), + async { txerr(ep_rcv1.recv().await) }, + async { txerr(ep_rcv2.recv().await) }, ) .unwrap(); } } +fn txerr(v: Option) -> Result<()> { + match v { + None => Err(Error::id("end")), + Some(Ep3Event::Error(err)) => Err(err.into()), + _ => Ok(()), + } +} + fn criterion_benchmark(c: &mut Criterion) { let rt = tokio::runtime::Builder::new_multi_thread() .enable_all() diff --git a/crates/tx5/examples/sig_idle.rs b/crates/tx5/examples/sig_idle.rs index 3b1834cd..e69b940e 100644 --- a/crates/tx5/examples/sig_idle.rs +++ b/crates/tx5/examples/sig_idle.rs @@ -1,6 +1,7 @@ //! Opens a connection to a running signal server, and sits idle. //! This is useful to test our keepalive logic. +use std::sync::Arc; use tx5::*; fn init_tracing() { @@ -23,7 +24,7 @@ async fn main() { let sig_url = Tx5Url::new(sig_url).unwrap(); println!("{sig_url}"); - let (ep, _ep_rcv) = Ep::new().await.unwrap(); + let (ep, _ep_rcv) = Ep3::new(Arc::new(Config3::default())).await; ep.listen(sig_url).await.unwrap(); diff --git a/crates/tx5/examples/turn_doctor.rs b/crates/tx5/examples/turn_doctor.rs index e4706642..15505ef3 100644 --- a/crates/tx5/examples/turn_doctor.rs +++ b/crates/tx5/examples/turn_doctor.rs @@ -331,13 +331,8 @@ async fn gather_ice( let _ = tokio::time::timeout( std::time::Duration::from_secs(10), async move { - let (s, mut r) = tokio::sync::mpsc::unbounded_channel(); - let mut con = - tx5_go_pion::PeerConnection::new(config, move |evt| { - let _ = s.send(evt); - }) - .await - .unwrap(); + let (con, mut r) = + tx5_go_pion::PeerConnection::new(config).await.unwrap(); let _dc = con .create_data_channel( diff --git a/crates/tx5/src/actor.rs b/crates/tx5/src/actor.rs deleted file mode 100644 index c57c6dcd..00000000 --- a/crates/tx5/src/actor.rs +++ /dev/null @@ -1,136 +0,0 @@ -//! Quick-n-dirty actor system for the tx5 state. - -use crate::*; -use parking_lot::Mutex; -use std::sync::{Arc, Weak}; - -type ManySnd = tokio::sync::mpsc::UnboundedSender>; - -/// Generic receiver type. -pub struct ManyRcv( - pub(crate) tokio::sync::mpsc::UnboundedReceiver>, -); - -impl ManyRcv { - /// Receive data from this receiver type. - #[inline] - pub async fn recv(&mut self) -> Option> { - tokio::sync::mpsc::UnboundedReceiver::recv(&mut self.0).await - } -} - -/// Weak actor handle that does not add to reference count. -pub struct ActorWeak(Weak>>>); - -impl ActorWeak { - /// Attempt to upgrade to a full actor handle. - pub fn upgrade(&self) -> Option> { - match self.0.upgrade() { - None => None, - Some(a) => { - if a.lock().is_some() { - Some(Actor(a)) - } else { - None - } - } - } - } -} - -impl PartialEq for ActorWeak { - fn eq(&self, rhs: &Self) -> bool { - Weak::ptr_eq(&self.0, &rhs.0) - } -} - -impl Eq for ActorWeak {} - -impl PartialEq> for ActorWeak { - fn eq(&self, rhs: &Actor) -> bool { - Weak::ptr_eq(&self.0, &Arc::downgrade(&rhs.0)) - } -} - -impl Clone for ActorWeak { - fn clone(&self) -> Self { - Self(self.0.clone()) - } -} - -/// An actor that executes as a task, providing synchronized messaging. -pub struct Actor(Arc>>>); - -impl PartialEq for Actor { - fn eq(&self, rhs: &Self) -> bool { - Arc::ptr_eq(&self.0, &rhs.0) - } -} - -impl Eq for Actor {} - -impl PartialEq> for Actor { - fn eq(&self, rhs: &ActorWeak) -> bool { - Weak::ptr_eq(&Arc::downgrade(&self.0), &rhs.0) - } -} - -impl Clone for Actor { - fn clone(&self) -> Self { - Self(self.0.clone()) - } -} - -impl Actor { - /// Construct a new actor. - pub fn new(cb: Cb) -> Self - where - Fut: std::future::Future> + 'static + Send, - Cb: FnOnce(ActorWeak, ManyRcv) -> Fut + 'static + Send, - { - let (s, r) = tokio::sync::mpsc::unbounded_channel(); - let out = Self(Arc::new(Mutex::new(Some(s)))); - let weak = out.weak(); - tokio::task::spawn(cb(weak, ManyRcv(r))); - out - } - - /// Get a weak handle to the actor that does not add to reference count. - pub fn weak(&self) -> ActorWeak { - ActorWeak(Arc::downgrade(&self.0)) - } - - /// Check if this handle is pointing to a closed actor. - pub fn is_closed(&self) -> bool { - match &*self.0.lock() { - None => true, - Some(s) => s.is_closed(), - } - } - - /// Close this actor, stopping the task with an error if it is running. - pub fn close(&self, err: std::io::Error) { - let mut l = self.0.lock(); - if let Some(s) = &*l { - let _ = s.send(Err(err)); - } - let _ = l.take(); - } - - /// Send a message to the actor task. - /// If the message sent is an Err variant, the task will be closed. - pub fn send(&self, t: Result) -> Result<()> { - let mut res = Err(Error::id("Closed")); - let close = t.is_err(); - let mut l = self.0.lock(); - if let Some(s) = &*l { - if s.send(t).is_ok() { - res = Ok(()); - } - } - if close { - let _ = l.take(); - } - res - } -} diff --git a/crates/tx5/src/back_buf.rs b/crates/tx5/src/back_buf.rs index 1a36da57..886a0203 100644 --- a/crates/tx5/src/back_buf.rs +++ b/crates/tx5/src/back_buf.rs @@ -8,60 +8,9 @@ pub(crate) mod imp { pub use imp_go_pion::*; } -/// Tx5 buffer creation type via std::io::Write. -pub struct BackBufWriter { - imp: imp::ImpWriter, - _not_sync: std::marker::PhantomData>, -} - -impl BackBufWriter { - /// Create a new Tx5 buffer writer. - #[inline] - pub fn new() -> Result { - Ok(Self { - imp: imp::ImpWriter::new()?, - _not_sync: std::marker::PhantomData, - }) - } - - /// Indicate we are done writing, and extract the internal buffer. - #[inline] - pub fn finish(self) -> BackBuf { - BackBuf { - imp: self.imp.finish(), - _not_sync: std::marker::PhantomData, - } - } -} - -impl std::io::Write for BackBufWriter { - #[inline] - fn write(&mut self, buf: &[u8]) -> std::io::Result { - self.imp.write(buf) - } - - #[inline] - fn write_vectored( - &mut self, - bufs: &[std::io::IoSlice<'_>], - ) -> std::io::Result { - self.imp.write_vectored(bufs) - } - - #[inline] - fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> { - self.imp.write_all(buf) - } - - #[inline] - fn flush(&mut self) -> std::io::Result<()> { - self.imp.flush() - } -} - /// Tx5 buffer type for sending and receiving data. #[allow(clippy::len_without_is_empty)] -pub struct BackBuf { +pub(crate) struct BackBuf { pub(crate) imp: imp::Imp, pub(crate) _not_sync: std::marker::PhantomData>, } @@ -89,12 +38,6 @@ impl BackBuf { }) } - /// Build a tx5 buffer using std::io::Write. - #[inline] - pub fn from_writer() -> Result { - BackBufWriter::new() - } - /// Serialize a type as json into a new BackBuf. #[inline] pub fn from_json(s: S) -> Result { @@ -109,31 +52,6 @@ impl BackBuf { pub fn len(&mut self) -> Result { self.imp.len() } - - /// Attempt to clone this buffer. - #[inline] - pub fn try_clone(&mut self) -> Result { - Ok(Self { - imp: self.imp.try_clone()?, - _not_sync: std::marker::PhantomData, - }) - } - - /// Copy the buffer out into a rust `Vec`. - #[inline] - pub fn to_vec(&mut self) -> Result> { - self.imp.to_vec() - } - - /// Deserialize this buffer as json bytes - /// into a type implementing serde::DeserializeOwned. - #[inline] - pub fn to_json(&mut self) -> Result - where - D: serde::de::DeserializeOwned + Sized, - { - self.imp.to_json() - } } impl std::io::Read for BackBuf { @@ -144,7 +62,7 @@ impl std::io::Read for BackBuf { } /// Conversion type facilitating Into<&mut BackBuf>. -pub enum BackBufRef<'lt> { +pub(crate) enum BackBufRef<'lt> { /// An owned BackBuf. Owned(Result), @@ -152,22 +70,6 @@ pub enum BackBufRef<'lt> { Borrowed(Result<&'lt mut BackBuf>), } -impl<'lt> BackBufRef<'lt> { - /// Get a mutable reference to the buffer. - pub fn as_mut_ref(&'lt mut self) -> Result<&'lt mut BackBuf> { - match self { - BackBufRef::Owned(o) => match o { - Ok(o) => Ok(o), - Err(e) => Err(e.err_clone()), - }, - BackBufRef::Borrowed(b) => match b { - Ok(b) => Ok(b), - Err(e) => Err(e.err_clone()), - }, - } - } -} - impl From for BackBufRef<'static> { fn from(b: BackBuf) -> Self { Self::Owned(Ok(b)) diff --git a/crates/tx5/src/back_buf/imp/imp_go_pion.rs b/crates/tx5/src/back_buf/imp/imp_go_pion.rs index eb00ba4c..71fdfc3f 100644 --- a/crates/tx5/src/back_buf/imp/imp_go_pion.rs +++ b/crates/tx5/src/back_buf/imp/imp_go_pion.rs @@ -1,53 +1,5 @@ use crate::*; -pub struct ImpWriter { - buf: tx5_go_pion::GoBuf, - _not_sync: std::marker::PhantomData>, -} - -impl ImpWriter { - #[inline] - pub fn new() -> Result { - Ok(Self { - buf: tx5_go_pion::GoBuf::new()?, - _not_sync: std::marker::PhantomData, - }) - } - - #[inline] - pub fn finish(self) -> Imp { - Imp { - buf: self.buf, - _not_sync: std::marker::PhantomData, - } - } -} - -impl std::io::Write for ImpWriter { - #[inline] - fn write(&mut self, buf: &[u8]) -> std::io::Result { - self.buf.write(buf) - } - - #[inline] - fn write_vectored( - &mut self, - bufs: &[std::io::IoSlice<'_>], - ) -> std::io::Result { - self.buf.write_vectored(bufs) - } - - #[inline] - fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> { - self.buf.write_all(buf) - } - - #[inline] - fn flush(&mut self) -> std::io::Result<()> { - self.buf.flush() - } -} - pub struct Imp { pub(crate) buf: tx5_go_pion::GoBuf, pub(crate) _not_sync: std::marker::PhantomData>, @@ -85,34 +37,11 @@ impl Imp { }) } - #[inline] - pub fn try_clone(&mut self) -> Result { - Ok(Self { - buf: self.buf.try_clone()?, - _not_sync: std::marker::PhantomData, - }) - } - #[inline] #[allow(clippy::wrong_self_convention)] // ya, well, we need it mut pub fn len(&mut self) -> Result { self.buf.len() } - - #[inline] - #[allow(clippy::wrong_self_convention)] // ya, well, we need it mut - pub fn to_vec(&mut self) -> Result> { - self.buf.to_vec() - } - - #[inline] - #[allow(clippy::wrong_self_convention)] // ya, well, we need it mut - pub fn to_json(&mut self) -> Result - where - D: serde::de::DeserializeOwned + Sized, - { - self.buf.as_json() - } } impl std::io::Read for Imp { diff --git a/crates/tx5/src/config.rs b/crates/tx5/src/config.rs deleted file mode 100644 index ea7ee2f8..00000000 --- a/crates/tx5/src/config.rs +++ /dev/null @@ -1,547 +0,0 @@ -use crate::deps::lair_keystore_api; -use crate::deps::sodoken; -use crate::*; -use lair_keystore_api::prelude::*; -use std::sync::{Arc, Weak}; -use tx5_core::{BoxFut, Tx5Url}; - -/// Tx5 config trait. -pub trait Config: 'static + Send + Sync { - /// Get the max pending send byte count limit. - fn max_send_bytes(&self) -> u32; - - /// The per-data-channel buffer low threshold. - fn per_data_chan_buf_low(&self) -> usize; - - /// Get the max queued recv byte count limit. - fn max_recv_bytes(&self) -> u32; - - /// Get the max concurrent connection limit. - fn max_conn_count(&self) -> u32; - - /// Get the max init (connect) time for a connection. - fn max_conn_init(&self) -> std::time::Duration; - - /// Request the lair client associated with this config. - fn lair_client(&self) -> &LairClient; - - /// Request the lair tag associated with this config. - fn lair_tag(&self) -> &Arc; - - /// A request to open a new signal server connection. - fn on_new_sig(&self, sig_url: Tx5Url, seed: state::SigStateSeed); - - /// A request to open a new peer connection. - fn on_new_conn( - &self, - ice_servers: Arc, - seed: state::ConnStateSeed, - ); - - /// Provide a chance to send preflight handshake data to be received - /// in the `on_conn_validate` hook on the remote side. - /// You may also return any `Err(_)` to cancel the connection even - /// before sending preflight data. - fn on_conn_preflight( - &self, - rem_url: Tx5Url, - ) -> BoxFut<'static, Result>>; - - /// Provide as async chance to validate/accept/reject a connection before - /// any events related to that connection are published. - /// This hook is triggered for both outgoing and incoming connections. - /// Return `Ok(())` to accept the connection, or any `Err(_)` to reject. - fn on_conn_validate( - &self, - rem_url: Tx5Url, - preflight_data: Option, - ) -> BoxFut<'static, Result<()>>; -} - -/// Dynamic config type alias. -pub type DynConfig = Arc; - -impl std::fmt::Debug for dyn Config + 'static + Send + Sync { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("Config") - .field("max_send_bytes", &self.max_send_bytes()) - .field("per_data_chan_buf_low", &self.per_data_chan_buf_low()) - .field("max_recv_bytes", &self.max_recv_bytes()) - .field("max_conn_count", &self.max_conn_count()) - .field("max_conn_init", &self.max_conn_init()) - .finish() - } -} - -/// Indicates a type is capable of being converted into a Config type. -pub trait IntoConfig: 'static + Send + Sync { - /// Convert this type into a concrete config type. - fn into_config(self) -> BoxFut<'static, Result>; -} - -impl IntoConfig for DynConfig { - fn into_config(self) -> BoxFut<'static, Result> { - Box::pin(async move { Ok(self) }) - } -} - -struct DefConfigBuilt { - this: Weak, - max_send_bytes: u32, - per_data_chan_buf_low: usize, - max_recv_bytes: u32, - max_conn_count: u32, - max_conn_init: std::time::Duration, - _lair_keystore: Option, - lair_client: LairClient, - lair_tag: Arc, - on_new_sig_cb: Arc< - dyn Fn(DynConfig, Tx5Url, state::SigStateSeed) + 'static + Send + Sync, - >, - on_new_conn_cb: Arc< - dyn Fn(DynConfig, Arc, state::ConnStateSeed) - + 'static - + Send - + Sync, - >, - #[allow(clippy::type_complexity)] - on_conn_preflight_cb: Arc< - dyn Fn( - DynConfig, - Tx5Url, - ) -> BoxFut<'static, Result>> - + 'static - + Send - + Sync, - >, - #[allow(clippy::type_complexity)] - on_conn_validate_cb: Arc< - dyn Fn( - DynConfig, - Tx5Url, - Option, - ) -> BoxFut<'static, Result<()>> - + 'static - + Send - + Sync, - >, -} - -impl Config for DefConfigBuilt { - fn max_send_bytes(&self) -> u32 { - self.max_send_bytes - } - - fn per_data_chan_buf_low(&self) -> usize { - self.per_data_chan_buf_low - } - - fn max_recv_bytes(&self) -> u32 { - self.max_recv_bytes - } - - fn max_conn_count(&self) -> u32 { - self.max_conn_count - } - - fn max_conn_init(&self) -> std::time::Duration { - self.max_conn_init - } - - fn lair_client(&self) -> &LairClient { - &self.lair_client - } - - fn lair_tag(&self) -> &Arc { - &self.lair_tag - } - - fn on_new_sig(&self, sig_url: Tx5Url, seed: state::SigStateSeed) { - if let Some(this) = self.this.upgrade() { - (self.on_new_sig_cb)(this, sig_url, seed); - } - } - - fn on_new_conn( - &self, - ice_servers: Arc, - seed: state::ConnStateSeed, - ) { - if let Some(this) = self.this.upgrade() { - (self.on_new_conn_cb)(this, ice_servers, seed); - } - } - - fn on_conn_preflight( - &self, - rem_url: Tx5Url, - ) -> BoxFut<'static, Result>> { - if let Some(this) = self.this.upgrade() { - (self.on_conn_preflight_cb)(this, rem_url) - } else { - Box::pin(async move { Ok(None) }) - } - } - - fn on_conn_validate( - &self, - rem_url: Tx5Url, - preflight_data: Option, - ) -> BoxFut<'static, Result<()>> { - if let Some(this) = self.this.upgrade() { - (self.on_conn_validate_cb)(this, rem_url, preflight_data) - } else { - Box::pin(async move { Ok(()) }) - } - } -} - -/// Builder type for constructing a DefConfig for a Tx5 endpoint. -#[derive(Default)] -#[allow(clippy::type_complexity)] -pub struct DefConfig { - max_send_bytes: Option, - per_data_chan_buf_low: Option, - max_recv_bytes: Option, - max_conn_count: Option, - max_conn_init: Option, - lair_client: Option, - lair_tag: Option>, - on_new_sig_cb: Option< - Arc< - dyn Fn(DynConfig, Tx5Url, state::SigStateSeed) - + 'static - + Send - + Sync, - >, - >, - on_new_conn_cb: Option< - Arc< - dyn Fn(DynConfig, Arc, state::ConnStateSeed) - + 'static - + Send - + Sync, - >, - >, - on_conn_preflight_cb: Option< - Arc< - dyn Fn( - DynConfig, - Tx5Url, - ) - -> BoxFut<'static, Result>> - + 'static - + Send - + Sync, - >, - >, - on_conn_validate_cb: Option< - Arc< - dyn Fn( - DynConfig, - Tx5Url, - Option, - ) -> BoxFut<'static, Result<()>> - + 'static - + Send - + Sync, - >, - >, -} - -impl IntoConfig for DefConfig { - fn into_config(self) -> BoxFut<'static, Result> { - Box::pin(async move { - let max_send_bytes = - self.max_send_bytes.unwrap_or(16 * 1024 * 1024); - let per_data_chan_buf_low = - self.per_data_chan_buf_low.unwrap_or(64 * 1024); - let max_recv_bytes = - self.max_recv_bytes.unwrap_or(16 * 1024 * 1024); - let max_conn_count = self.max_conn_count.unwrap_or(255); - let max_conn_init = self - .max_conn_init - .unwrap_or(std::time::Duration::from_secs(60)); - let mut lair_keystore = None; - - let lair_tag = self.lair_tag.unwrap_or_else(|| { - rand_utf8::rand_utf8(&mut rand::thread_rng(), 32).into() - }); - - let lair_client = match self.lair_client { - Some(lair_client) => lair_client, - None => { - let passphrase = sodoken::BufRead::new_no_lock( - rand_utf8::rand_utf8(&mut rand::thread_rng(), 32) - .as_bytes(), - ); - - // this is a memory keystore, - // so weak persistence security is okay, - // since it will not be persisted. - // The private keys will still be mem_locked - // so they shouldn't be swapped to disk. - let keystore_config = PwHashLimits::Minimum - .with_exec(|| { - LairServerConfigInner::new("/", passphrase.clone()) - }) - .await - .unwrap(); - - let keystore = PwHashLimits::Minimum - .with_exec(|| { - lair_keystore_api::in_proc_keystore::InProcKeystore::new( - Arc::new(keystore_config), - lair_keystore_api::mem_store::create_mem_store_factory(), - passphrase, - ) - }) - .await - .unwrap(); - - let lair_client = keystore.new_client().await.unwrap(); - - lair_client - .new_seed(lair_tag.clone(), None, false) - .await - .unwrap(); - - lair_keystore = Some(keystore); - - lair_client - } - }; - - let on_new_sig_cb = self - .on_new_sig_cb - .unwrap_or_else(|| Arc::new(endpoint::on_new_sig)); - - let on_new_conn_cb = self - .on_new_conn_cb - .unwrap_or_else(|| Arc::new(endpoint::on_new_conn)); - - let on_conn_preflight_cb = - self.on_conn_preflight_cb.unwrap_or_else(|| { - Arc::new(|_, _| Box::pin(async move { Ok(None) })) - }); - - let on_conn_validate_cb = - self.on_conn_validate_cb.unwrap_or_else(|| { - Arc::new(|_, _, _| Box::pin(async move { Ok(()) })) - }); - - let out: DynConfig = Arc::new_cyclic(|this| DefConfigBuilt { - this: this.clone(), - max_send_bytes, - per_data_chan_buf_low, - max_recv_bytes, - max_conn_count, - max_conn_init, - _lair_keystore: lair_keystore, - lair_client, - lair_tag, - on_new_sig_cb, - on_new_conn_cb, - on_conn_preflight_cb, - on_conn_validate_cb, - }); - - Ok(out) - }) - } -} - -impl DefConfig { - /// Set the max queued send bytes to hold before applying backpressure. - /// The default is `16 * 1024 * 1024`. - pub fn set_max_send_bytes(&mut self, max_send_bytes: u32) { - self.max_send_bytes = Some(max_send_bytes); - } - - /// See `set_max_send_bytes()`, this is the builder version. - pub fn with_max_send_bytes(mut self, max_send_bytes: u32) -> Self { - self.set_max_send_bytes(max_send_bytes); - self - } - - /// Set the per-data-channel buffer low threshold. - /// The default is `64 * 1024`. - pub fn set_per_data_chan_buf_low(&mut self, per_data_chan_buf_low: usize) { - self.per_data_chan_buf_low = Some(per_data_chan_buf_low); - } - - /// See `set_per_data_chan_buf_low()`, this is the builder version. - pub fn with_per_data_chan_buf_low( - mut self, - per_data_chan_buf_low: usize, - ) -> Self { - self.set_per_data_chan_buf_low(per_data_chan_buf_low); - self - } - - /// Set the max queued recv bytes to hold before dropping connection. - /// The default is `16 * 1024 * 1024`. - pub fn set_max_recv_bytes(&mut self, max_recv_bytes: u32) { - self.max_recv_bytes = Some(max_recv_bytes); - } - - /// See `set_max_recv_bytes()`, this is the builder version. - pub fn with_max_recv_bytes(mut self, max_recv_bytes: u32) -> Self { - self.set_max_recv_bytes(max_recv_bytes); - self - } - - /// Set the max concurrent connection count. - /// The default is `255`. - pub fn set_max_conn_count(&mut self, max_conn_count: u32) { - self.max_conn_count = Some(max_conn_count); - } - - /// See `set_max_conn_count()`, this is the builder version. - pub fn with_max_conn_count(mut self, max_conn_count: u32) -> Self { - self.set_max_conn_count(max_conn_count); - self - } - - /// Set the max connection init (connect) time. - /// The default is `60` seconds. - pub fn set_max_conn_init(&mut self, max_conn_init: std::time::Duration) { - self.max_conn_init = Some(max_conn_init); - } - - /// See `set_max_conn_init()`, this is the builder version. - pub fn with_max_conn_init( - mut self, - max_conn_init: std::time::Duration, - ) -> Self { - self.set_max_conn_init(max_conn_init); - self - } - - /// Set the lair client. - /// The default is a generated in-process, in-memory only keystore. - pub fn set_lair_client(&mut self, lair_client: LairClient) { - self.lair_client = Some(lair_client); - } - - /// See `set_lair_client()`, this is the builder version. - pub fn with_lair_client(mut self, lair_client: LairClient) -> Self { - self.set_lair_client(lair_client); - self - } - - /// Set the lair tag used to identify the signing identity keypair. - /// The default is a random 32 byte utf8 string. - pub fn set_lair_tag(&mut self, lair_tag: Arc) { - self.lair_tag = Some(lair_tag); - } - - /// See `set_lair_tag()`, this is the builder version. - pub fn with_lair_tag(mut self, lair_tag: Arc) -> Self { - self.set_lair_tag(lair_tag); - self - } - - /// Override the default new signal connection request handler. - /// The default uses the default tx5-signal dependency. - pub fn set_new_sig_cb(&mut self, cb: Cb) - where - Cb: Fn(DynConfig, Tx5Url, state::SigStateSeed) + 'static + Send + Sync, - { - self.on_new_sig_cb = Some(Arc::new(cb)); - } - - /// See `set_new_sig_cb()`, this is the builder version. - pub fn with_new_sig_cb(mut self, cb: Cb) -> Self - where - Cb: Fn(DynConfig, Tx5Url, state::SigStateSeed) + 'static + Send + Sync, - { - self.set_new_sig_cb(cb); - self - } - - /// Override the default new peer connection request handler. - /// The default uses either tx5-go-pion, or rust-webrtc depending - /// on the feature flipper chosen at compile time. - pub fn set_new_conn_cb(&mut self, cb: Cb) - where - Cb: Fn(DynConfig, Arc, state::ConnStateSeed) - + 'static - + Send - + Sync, - { - self.on_new_conn_cb = Some(Arc::new(cb)); - } - - /// See `set_new_conn_cb()`, this is the builder version. - pub fn with_new_conn_cb(mut self, cb: Cb) -> Self - where - Cb: Fn(DynConfig, Arc, state::ConnStateSeed) - + 'static - + Send - + Sync, - { - self.set_new_conn_cb(cb); - self - } - - /// Override the default no-op conn preflight hook. - pub fn set_conn_preflight(&mut self, cb: Cb) - where - Cb: Fn( - DynConfig, - Tx5Url, - ) -> BoxFut<'static, Result>> - + 'static - + Send - + Sync, - { - self.on_conn_preflight_cb = Some(Arc::new(cb)); - } - - /// See `set_conn_preflight()`, this is the builder version. - pub fn with_conn_preflight(mut self, cb: Cb) -> Self - where - Cb: Fn( - DynConfig, - Tx5Url, - ) -> BoxFut<'static, Result>> - + 'static - + Send - + Sync, - { - self.set_conn_preflight(cb); - self - } - - /// Override the default no-op conn validate hook. - pub fn set_conn_validate(&mut self, cb: Cb) - where - Cb: Fn( - DynConfig, - Tx5Url, - Option, - ) -> BoxFut<'static, Result<()>> - + 'static - + Send - + Sync, - { - self.on_conn_validate_cb = Some(Arc::new(cb)); - } - - /// See `set_conn_validate()`, this is the builder version. - pub fn with_conn_validate(mut self, cb: Cb) -> Self - where - Cb: Fn( - DynConfig, - Tx5Url, - Option, - ) -> BoxFut<'static, Result<()>> - + 'static - + Send - + Sync, - { - self.set_conn_validate(cb); - self - } -} diff --git a/crates/tx5/src/endpoint.rs b/crates/tx5/src/endpoint.rs deleted file mode 100644 index 3569b265..00000000 --- a/crates/tx5/src/endpoint.rs +++ /dev/null @@ -1,746 +0,0 @@ -//! Tx5 endpoint. - -use crate::*; -use opentelemetry_api::{metrics::Unit, KeyValue}; -use std::collections::HashMap; -use std::sync::Arc; -use tx5_core::Tx5Url; - -/// Event type emitted by a tx5 endpoint. -pub enum EpEvt { - /// Connection established. - Connected { - /// The remote client url connected. - rem_cli_url: Tx5Url, - }, - - /// Connection closed. - Disconnected { - /// The remote client url disconnected. - rem_cli_url: Tx5Url, - }, - - /// Received data from a remote. - Data { - /// The remote client url that sent this message. - rem_cli_url: Tx5Url, - - /// The payload of the message. - data: Box, - - /// Drop this when you've accepted the data to allow additional - /// incoming messages. - permit: Vec, - }, - - /// Received a demo broadcast. - Demo { - /// The remote client url that is available for communication. - rem_cli_url: Tx5Url, - }, -} - -impl std::fmt::Debug for EpEvt { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - EpEvt::Connected { rem_cli_url } => f - .debug_struct("EpEvt::Connected") - .field("rem_cli_url", rem_cli_url) - .finish(), - EpEvt::Disconnected { rem_cli_url } => f - .debug_struct("EpEvt::Disconnected") - .field("rem_cli_url", rem_cli_url) - .finish(), - EpEvt::Data { - rem_cli_url, - data, - permit: _, - } => { - let data_len = data.remaining(); - f.debug_struct("EpEvt::Data") - .field("rem_cli_url", rem_cli_url) - .field("data_len", &data_len) - .finish() - } - EpEvt::Demo { rem_cli_url } => f - .debug_struct("EpEvt::Demo") - .field("rem_cli_url", rem_cli_url) - .finish(), - } - } -} - -/// A tx5 endpoint representing an instance that can send and receive. -#[derive(Clone, PartialEq, Eq)] -pub struct Ep { - state: state::State, -} - -impl std::fmt::Debug for Ep { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("Ep").finish() - } -} - -impl Ep { - /// Construct a new tx5 endpoint. - pub async fn new() -> Result<(Self, actor::ManyRcv)> { - Self::with_config(DefConfig::default()).await - } - - /// Construct a new tx5 endpoint with configuration. - pub async fn with_config( - into_config: I, - ) -> Result<(Self, actor::ManyRcv)> { - let (ep_snd, ep_rcv) = tokio::sync::mpsc::unbounded_channel(); - - let config = into_config.into_config().await?; - let (state, mut state_evt) = state::State::new(config.clone())?; - tokio::task::spawn(async move { - while let Some(evt) = state_evt.recv().await { - match evt { - Ok(state::StateEvt::NewSig(sig_url, seed)) => { - config.on_new_sig(sig_url, seed); - } - Ok(state::StateEvt::Address(_cli_url)) => {} - Ok(state::StateEvt::NewConn(ice_servers, seed)) => { - config.on_new_conn(ice_servers, seed); - } - Ok(state::StateEvt::RcvData(url, buf, permit)) => { - let _ = ep_snd.send(Ok(EpEvt::Data { - rem_cli_url: url, - data: buf, - permit, - })); - } - Ok(state::StateEvt::Demo(cli_url)) => { - let _ = ep_snd.send(Ok(EpEvt::Demo { - rem_cli_url: cli_url, - })); - } - Ok(state::StateEvt::Connected(cli_url)) => { - let _ = ep_snd.send(Ok(EpEvt::Connected { - rem_cli_url: cli_url, - })); - } - Ok(state::StateEvt::Disconnected(cli_url)) => { - let _ = ep_snd.send(Ok(EpEvt::Disconnected { - rem_cli_url: cli_url, - })); - } - Err(err) => { - let _ = ep_snd.send(Err(err)); - break; - } - } - } - }); - let ep = Self { state }; - Ok((ep, actor::ManyRcv(ep_rcv))) - } - - /// Establish a listening connection to a signal server, - /// from which we can accept incoming remote connections. - /// Returns the client url at which this endpoint may now be addressed. - pub fn listen( - &self, - sig_url: Tx5Url, - ) -> impl std::future::Future> + 'static + Send - { - self.state.listener_sig(sig_url) - } - - /// Close down all connections to, fail all outgoing messages to, - /// and drop all incoming messages from, the given remote id, - /// for the specified ban time period. - pub fn ban(&self, rem_id: Id, span: std::time::Duration) { - self.state.ban(rem_id, span); - } - - /// Send data to a remote on this tx5 endpoint. - /// The future returned from this method will resolve when - /// the data is handed off to our networking backend. - pub fn send( - &self, - rem_cli_url: Tx5Url, - data: B, - ) -> impl std::future::Future> + 'static + Send { - self.state.snd_data(rem_cli_url, data) - } - - /// Broadcast data to all connections that happen to be open. - /// If no connections are open, no data will be broadcast. - /// The future returned from this method will resolve when all - /// broadcast messages have been handed off to our networking backend. - /// - /// This method is currently not ideal. It naively gets a list - /// of open connection urls and adds the broadcast to all of their queues. - /// This could result in a connection being re-established just - /// for the broadcast to occur. - pub fn broadcast( - &self, - mut data: B, - ) -> impl std::future::Future>>> + 'static + Send - { - let data = data.copy_to_bytes(data.remaining()); - let state = self.state.clone(); - async move { - let url_list = state.list_connected().await?; - Ok(futures::future::join_all( - url_list - .into_iter() - .map(|url| state.snd_data(url, data.clone())), - ) - .await) - } - } - - /// Send a demo broadcast to every connected signal server. - /// Warning, if demo mode is not enabled on these servers, this - /// could result in a ban. - pub fn demo(&self) -> Result<()> { - self.state.snd_demo() - } - - /// Get stats. - pub fn get_stats( - &self, - ) -> impl std::future::Future> + 'static + Send - { - self.state.stats() - } -} - -pub(crate) fn on_new_sig( - config: DynConfig, - sig_url: Tx5Url, - seed: state::SigStateSeed, -) { - tokio::task::spawn(new_sig_task(config, sig_url, seed)); -} - -async fn new_sig_task( - config: DynConfig, - sig_url: Tx5Url, - seed: state::SigStateSeed, -) { - tracing::debug!(%sig_url, "spawning new signal task"); - - let (sig, mut sig_rcv, cli_url) = match async { - let (sig, sig_rcv) = tx5_signal::Cli::builder() - .with_lair_client(config.lair_client().clone()) - .with_lair_tag(config.lair_tag().clone()) - .with_url(sig_url.to_string().parse().unwrap()) - .build() - .await?; - - let cli_url = Tx5Url::new(sig.local_addr())?; - - Result::Ok((sig, sig_rcv, cli_url)) - } - .await - { - Ok(r) => r, - Err(err) => { - tracing::error!(?err, "error connecting to signal server"); - seed.result_err(err); - return; - } - }; - - tracing::debug!(%cli_url, "signal connection established"); - - let sig = &sig; - - let ice_servers = sig.ice_servers(); - - let (sig_state, mut sig_evt) = match seed.result_ok(cli_url, ice_servers) { - Err(_) => return, - Ok(r) => r, - }; - - loop { - tokio::select! { - msg = sig_rcv.recv() => { - if let Err(err) = async { - match msg { - Some(tx5_signal::SignalMsg::Demo { rem_pub }) => { - sig_state.demo(rem_pub) - } - Some(tx5_signal::SignalMsg::Offer { rem_pub, offer }) => { - let offer = BackBuf::from_json(offer)?; - sig_state.offer(rem_pub, offer) - } - Some(tx5_signal::SignalMsg::Answer { rem_pub, answer }) => { - let answer = BackBuf::from_json(answer)?; - sig_state.answer(rem_pub, answer) - } - Some(tx5_signal::SignalMsg::Ice { rem_pub, ice }) => { - let ice = BackBuf::from_json(ice)?; - sig_state.ice(rem_pub, ice) - } - None => Err(Error::id("SigClosed")), - } - }.await { - sig_state.close(err); - break; - } - } - msg = sig_evt.recv() => { - match msg { - Some(Ok(state::SigStateEvt::SndOffer( - rem_id, - mut offer, - mut resp, - ))) => { - resp.with(move || async move { - sig.offer(rem_id, offer.to_json()?).await - }).await; - } - Some(Ok(state::SigStateEvt::SndAnswer( - rem_id, - mut answer, - mut resp, - ))) => { - resp.with(move || async move { - sig.answer(rem_id, answer.to_json()?).await - }).await; - } - Some(Ok(state::SigStateEvt::SndIce( - rem_id, - mut ice, - mut resp, - ))) => { - resp.with(move || async move { - sig.ice(rem_id, ice.to_json()?).await - }).await; - } - Some(Ok(state::SigStateEvt::SndDemo)) => { - sig.demo() - } - Some(Err(_)) => break, - None => break, - } - } - }; - } - - tracing::warn!("signal connection CLOSED"); -} - -#[derive(Debug, serde::Deserialize)] -#[serde(rename_all = "camelCase")] -struct BackendMetrics { - #[serde(default)] - messages_sent: u64, - #[serde(default)] - messages_received: u64, - #[serde(default)] - bytes_sent: u64, - #[serde(default)] - bytes_received: u64, -} - -#[cfg(feature = "backend-go-pion")] -pub(crate) fn on_new_conn( - config: DynConfig, - ice_servers: Arc, - seed: state::ConnStateSeed, -) { - tokio::task::spawn(new_conn_task(config, ice_servers, seed)); -} - -#[cfg(feature = "backend-go-pion")] -async fn new_conn_task( - config: DynConfig, - ice_servers: Arc, - seed: state::ConnStateSeed, -) { - let config = &config; - - use tx5_go_pion::DataChannelEvent as DataEvt; - use tx5_go_pion::PeerConnectionEvent as PeerEvt; - use tx5_go_pion::PeerConnectionState as PeerState; - - enum MultiEvt { - OneSec, - Stats( - tokio::sync::oneshot::Sender< - Option>, - >, - ), - Peer(PeerEvt), - Data(DataEvt), - } - - let (peer_snd, mut peer_rcv) = tokio::sync::mpsc::unbounded_channel(); - - let peer_snd2 = peer_snd.clone(); - let mut peer = match async { - let peer_config = BackBuf::from_json(ice_servers)?; - - let peer = - tx5_go_pion::PeerConnection::new(peer_config.imp.buf, move |evt| { - let _ = peer_snd2.send(MultiEvt::Peer(evt)); - }) - .await?; - - Result::Ok(peer) - } - .await - { - Ok(r) => r, - Err(err) => { - seed.result_err(err); - return; - } - }; - - let (conn_state, mut conn_evt) = match seed.result_ok() { - Err(_) => return, - Ok(r) => r, - }; - - let state_uniq = conn_state.meta().state_uniq.clone(); - let conn_uniq = conn_state.meta().conn_uniq.clone(); - let rem_id = conn_state.meta().cli_url.id().unwrap(); - - struct Unregister( - Option>, - ); - impl Drop for Unregister { - fn drop(&mut self) { - if let Some(mut unregister) = self.0.take() { - let _ = unregister.unregister(); - } - } - } - - let peer_snd_task = peer_snd.clone(); - tokio::task::spawn(async move { - loop { - tokio::time::sleep(std::time::Duration::from_secs(1)).await; - if peer_snd_task.send(MultiEvt::OneSec).is_err() { - break; - } - } - }); - - let slot: Arc>>> = - Arc::new(std::sync::Mutex::new(None)); - let weak_slot = Arc::downgrade(&slot); - let peer_snd_task = peer_snd.clone(); - tokio::task::spawn(async move { - loop { - tokio::time::sleep(std::time::Duration::from_secs(5)).await; - - if let Some(slot) = weak_slot.upgrade() { - let (s, r) = tokio::sync::oneshot::channel(); - if peer_snd_task.send(MultiEvt::Stats(s)).is_err() { - break; - } - if let Ok(stats) = r.await { - *slot.lock().unwrap() = stats; - } - } else { - break; - } - } - }); - - let weak_slot = Arc::downgrade(&slot); - let _unregister = { - use opentelemetry_api::metrics::MeterProvider; - - let meter = opentelemetry_api::global::meter_provider() - .versioned_meter( - "tx5", - None::<&'static str>, - None::<&'static str>, - Some(vec![ - KeyValue::new("state_uniq", state_uniq.to_string()), - KeyValue::new("conn_uniq", conn_uniq.to_string()), - KeyValue::new("remote_id", rem_id.to_string()), - ]), - ); - let ice_snd = meter - .u64_observable_counter("tx5.conn.ice.send") - .with_description("Bytes sent on ice channel") - .with_unit(Unit::new("By")) - .init(); - let ice_rcv = meter - .u64_observable_counter("tx5.conn.ice.recv") - .with_description("Bytes received on ice channel") - .with_unit(Unit::new("By")) - .init(); - let data_snd = meter - .u64_observable_counter("tx5.conn.data.send") - .with_description("Bytes sent on data channel") - .with_unit(Unit::new("By")) - .init(); - let data_rcv = meter - .u64_observable_counter("tx5.conn.data.recv") - .with_description("Bytes received on data channel") - .with_unit(Unit::new("By")) - .init(); - let data_snd_msg = meter - .u64_observable_counter("tx5.conn.data.send.message.count") - .with_description("Message count sent on data channel") - .init(); - let data_rcv_msg = meter - .u64_observable_counter("tx5.conn.data.recv.message.count") - .with_description("Message count received on data channel") - .init(); - let unregister = match meter.register_callback( - &[data_snd.as_any(), data_rcv.as_any()], - move |obs| { - if let Some(slot) = weak_slot.upgrade() { - let guard = slot.lock().unwrap(); - if let Some(slot) = &*guard { - for (k, v) in slot.iter() { - if k.starts_with("DataChannel") { - obs.observe_u64(&data_snd, v.bytes_sent, &[]); - obs.observe_u64( - &data_rcv, - v.bytes_received, - &[], - ); - obs.observe_u64( - &data_snd_msg, - v.messages_sent, - &[], - ); - obs.observe_u64( - &data_rcv_msg, - v.messages_received, - &[], - ); - } else if k.starts_with("iceTransport") { - obs.observe_u64(&ice_snd, v.bytes_sent, &[]); - obs.observe_u64( - &ice_rcv, - v.bytes_received, - &[], - ); - } - } - } - } - }, - ) { - Ok(unregister) => Some(unregister), - Err(err) => { - tracing::warn!(?err, "unable to register connection metrics"); - None - } - }; - Unregister(unregister) - }; - - let mut data_chan: Option = None; - let mut data_chan_ready = false; - - let mut check_data_chan_ready = - |data_chan: &mut Option| { - if data_chan_ready { - return Ok(()); - } - if let Some(data_chan) = data_chan.as_mut() { - let state = data_chan.ready_state()?; - if state == 2 - /* open */ - { - data_chan_ready = true; - data_chan.set_buffered_amount_low_threshold( - config.per_data_chan_buf_low(), - )?; - conn_state.ready()?; - } - } - Result::Ok(()) - }; - - tracing::debug!(?conn_uniq, "PEER CON OPEN"); - - loop { - tokio::select! { - msg = peer_rcv.recv() => { - match msg { - None => { - conn_state.close(Error::id("PeerConClosed")); - break; - } - Some(MultiEvt::OneSec) => { - if let Some(data_chan) = data_chan.as_mut() { - if let Ok(buf) = data_chan.buffered_amount() { - if buf <= config.per_data_chan_buf_low() { - conn_state.check_send_waiting(Some(state::BufState::Low)).await; - } - } - } - } - Some(MultiEvt::Stats(resp)) => { - if let Ok(mut buf) = peer.stats().await.map(BackBuf::from_raw) { - if let Ok(val) = buf.to_json() { - let _ = resp.send(Some(val)); - } else { - let _ = resp.send(None); - } - } else { - let _ = resp.send(None); - } - } - Some(MultiEvt::Peer(PeerEvt::Error(err))) => { - conn_state.close(err); - break; - } - Some(MultiEvt::Peer(PeerEvt::State(peer_state))) => { - match peer_state { - PeerState::New - | PeerState::Connecting - | PeerState::Connected => { - tracing::debug!(?peer_state); - } - PeerState::Disconnected - | PeerState::Failed - | PeerState::Closed => { - conn_state.close(Error::err(format!("BackendState:{peer_state:?}"))); - break; - } - } - } - Some(MultiEvt::Peer(PeerEvt::ICECandidate(buf))) => { - let buf = BackBuf::from_raw(buf); - if conn_state.ice(buf).is_err() { - break; - } - } - Some(MultiEvt::Peer(PeerEvt::DataChannel(chan))) => { - let peer_snd = peer_snd.clone(); - data_chan = Some(chan.handle(move |evt| { - let _ = peer_snd.send(MultiEvt::Data(evt)); - })); - if check_data_chan_ready(&mut data_chan).is_err() { - break; - } - } - Some(MultiEvt::Data(DataEvt::Open)) => { - if check_data_chan_ready(&mut data_chan).is_err() { - break; - } - } - Some(MultiEvt::Data(DataEvt::Close)) => { - conn_state.close(Error::id("DataChanClosed")); - break; - } - Some(MultiEvt::Data(DataEvt::Message(buf))) => { - if conn_state.rcv_data(BackBuf::from_raw(buf)).is_err() { - break; - } - } - Some(MultiEvt::Data(DataEvt::BufferedAmountLow)) => { - tracing::debug!(?conn_uniq, "BufferedAmountLow"); - conn_state.check_send_waiting(Some(state::BufState::Low)).await; - } - } - } - msg = conn_evt.recv() => { - match msg { - Some(Ok(state::ConnStateEvt::CreateOffer(mut resp))) => { - let peer = &mut peer; - let data_chan_w = &mut data_chan; - let peer_snd = peer_snd.clone(); - resp.with(move || async move { - let chan = peer.create_data_channel( - tx5_go_pion::DataChannelConfig { - label: Some("data".into()), - } - ).await?; - - *data_chan_w = Some(chan.handle(move |evt| { - let _ = peer_snd.send(MultiEvt::Data(evt)); - })); - - let mut buf = peer.create_offer( - tx5_go_pion::OfferConfig::default(), - ).await?; - - if let Ok(bytes) = buf.to_vec() { - tracing::debug!( - offer=%String::from_utf8_lossy(&bytes), - "create_offer", - ); - } - - Ok(BackBuf::from_raw(buf)) - }).await; - - if check_data_chan_ready(&mut data_chan).is_err() { - break; - } - } - Some(Ok(state::ConnStateEvt::CreateAnswer(mut resp))) => { - let peer = &mut peer; - resp.with(move || async move { - let mut buf = peer.create_answer( - tx5_go_pion::AnswerConfig::default(), - ).await?; - if let Ok(bytes) = buf.to_vec() { - tracing::debug!( - offer=%String::from_utf8_lossy(&bytes), - "create_answer", - ); - } - Ok(BackBuf::from_raw(buf)) - }).await; - } - Some(Ok(state::ConnStateEvt::SetLoc(buf, mut resp))) => { - let peer = &mut peer; - resp.with(move || async move { - peer.set_local_description(buf.imp.buf).await - }).await; - } - Some(Ok(state::ConnStateEvt::SetRem(buf, mut resp))) => { - let peer = &mut peer; - resp.with(move || async move { - peer.set_remote_description(buf.imp.buf).await - }).await; - } - Some(Ok(state::ConnStateEvt::SetIce(buf, mut resp))) => { - let peer = &mut peer; - resp.with(move || async move { - peer.add_ice_candidate(buf.imp.buf).await - }).await; - } - Some(Ok(state::ConnStateEvt::SndData(buf, mut resp))) => { - let data_chan = &mut data_chan; - resp.with(move || async move { - match data_chan { - None => Err(Error::id("NoDataChannel")), - Some(chan) => { - let buf = chan.send(buf.imp.buf).await?; - if buf > config.per_data_chan_buf_low() { - Ok(state::BufState::High) - } else { - Ok(state::BufState::Low) - } - } - } - }).await; - } - Some(Ok(state::ConnStateEvt::Stats(mut resp))) => { - let peer = &mut peer; - resp.with(move || async move { - peer.stats().await - .map(BackBuf::from_raw) - }).await; - } - Some(Err(_)) => break, - None => break, - } - } - }; - } - - tracing::debug!("PEER CON CLOSE"); -} diff --git a/crates/tx5/src/ep3.rs b/crates/tx5/src/ep3.rs new file mode 100644 index 00000000..9ec6367c --- /dev/null +++ b/crates/tx5/src/ep3.rs @@ -0,0 +1,633 @@ +//! Module containing tx5 endpoint version 3 types. + +use crate::deps::lair_keystore_api; +use crate::deps::sodoken; +use crate::AbortableTimedSharedFuture; +use crate::BackBuf; +use futures::future::BoxFuture; +use lair_keystore_api::prelude::*; +use std::collections::HashMap; +use std::sync::{Arc, Mutex, Weak}; +use tx5_core::{Error, EventRecv, EventSend, Id, Result, Tx5Url}; + +fn next_uniq() -> u64 { + static UNIQ: std::sync::atomic::AtomicU64 = + std::sync::atomic::AtomicU64::new(1); + UNIQ.fetch_add(1, std::sync::atomic::Ordering::Relaxed) +} + +type CRes = std::result::Result; + +/// Events generated by a tx5 endpoint version 3. +pub enum Ep3Event { + /// A fatal error indicating the endpoint is no longer viable. + Error(Error), + + /// Connection established. + Connected { + /// Url of the remote peer. + peer_url: PeerUrl, + }, + + /// Connection closed. + Disconnected { + /// Url of the remote peer. + peer_url: PeerUrl, + }, + + /// Receiving an incoming message from a remote peer. + Message { + /// Url of the remote peer. + peer_url: PeerUrl, + + /// Message sent by the remote peer. + message: Vec, + }, +} + +impl From for Ep3Event { + fn from(err: Error) -> Self { + Self::Error(err) + } +} + +impl std::fmt::Debug for Ep3Event { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Error(err) => { + f.debug_struct("Error").field("err", err).finish() + } + Self::Connected { peer_url } => { + let url = format!("{peer_url}"); + f.debug_struct("Connected").field("peer_url", &url).finish() + } + Self::Disconnected { peer_url } => { + let url = format!("{peer_url}"); + f.debug_struct("Disconnected") + .field("peer_url", &url) + .finish() + } + Self::Message { peer_url, .. } => { + let url = format!("{peer_url}"); + f.debug_struct("Message").field("peer_url", &url).finish() + } + } + } +} + +/// A signal server url. +pub type SigUrl = Tx5Url; + +/// A peer connection url. +pub type PeerUrl = Tx5Url; + +type SigMap = HashMap>)>; + +/// Callback in charge of sending preflight data if any. +pub type PreflightSendCb = Arc< + dyn Fn(&PeerUrl) -> BoxFuture<'static, Result>> + + 'static + + Send + + Sync, +>; + +/// Callback in charge of validating preflight data if any. +pub type PreflightCheckCb = Arc< + dyn Fn(&PeerUrl, Vec) -> BoxFuture<'static, Result<()>> + + 'static + + Send + + Sync, +>; + +/// Tx5 endpoint version 3 configuration. +pub struct Config3 { + /// Maximum count of open connections. Default 4096. + pub connection_count_max: u32, + + /// Max backend send buffer bytes (per connection). Default 64 KiB. + pub send_buffer_bytes_max: u32, + + /// Max backend recv buffer bytes (per connection). Default 64 KiB. + pub recv_buffer_bytes_max: u32, + + /// Maximum receive message reconstruction bytes in memory + /// (accross entire endpoint). Default 512 MiB. + pub incoming_message_bytes_max: u32, + + /// Maximum size of an individual message. Default 16 MiB. + pub message_size_max: u32, + + /// Internal event channel size. Default is 1024. + pub internal_event_channel_size: u32, + + /// Default timeout for network operations. Default 60 seconds. + pub timeout: std::time::Duration, + + /// Starting backoff duration for retries. Default 5 seconds. + pub backoff_start: std::time::Duration, + + /// Max backoff duration for retries. Default 60 seconds. + pub backoff_max: std::time::Duration, + + /// If the protocol should manage a preflight message, + /// set the callbacks here, otherwise no preflight will + /// be sent nor validated. Default: None. + pub preflight: Option<(PreflightSendCb, PreflightCheckCb)>, +} + +impl std::fmt::Debug for Config3 { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Config3") + .field("connection_count_max", &self.connection_count_max) + .field("send_buffer_bytes_max", &self.send_buffer_bytes_max) + .field("recv_buffer_bytes_max", &self.recv_buffer_bytes_max) + .field( + "incoming_message_bytes_max", + &self.incoming_message_bytes_max, + ) + .field("message_size_max", &self.message_size_max) + .field( + "internal_event_channel_size", + &self.internal_event_channel_size, + ) + .field("timeout", &self.timeout) + .field("backoff_start", &self.backoff_start) + .field("backoff_max", &self.backoff_max) + .finish() + } +} + +impl Default for Config3 { + fn default() -> Self { + Self { + connection_count_max: 4096, + send_buffer_bytes_max: 64 * 1024, + recv_buffer_bytes_max: 64 * 1024, + incoming_message_bytes_max: 512 * 1024 * 1024, + message_size_max: 16 * 1024 * 1024, + internal_event_channel_size: 1024, + timeout: std::time::Duration::from_secs(60), + backoff_start: std::time::Duration::from_secs(5), + backoff_max: std::time::Duration::from_secs(60), + preflight: None, + } + } +} + +#[derive(Default)] +struct BanMap(HashMap); + +impl BanMap { + fn set_ban(&mut self, id: Id, until: tokio::time::Instant) { + self.0.insert(id, until); + } + + fn is_banned(&mut self, id: Id) -> bool { + let now = tokio::time::Instant::now(); + if let Some(until) = self.0.get(&id).cloned() { + if now < until { + true + } else { + self.0.remove(&id); + false + } + } else { + false + } + } +} + +pub(crate) struct EpShared { + config: Arc, + this_id: Id, + ep_uniq: u64, + lair_tag: Arc, + lair_client: LairClient, + sig_limit: Arc, + peer_limit: Arc, + recv_recon_limit: Arc, + weak_sig_map: Weak>, + evt_send: EventSend, + ban_map: Mutex, + metric_conn_count: + influxive_otel_atomic_obs::AtomicObservableUpDownCounterI64, +} + +/// Tx5 endpoint version 3. +pub struct Ep3 { + ep: Arc, + _lair_keystore: lair_keystore_api::in_proc_keystore::InProcKeystore, + _sig_map: Arc>, + listen_sigs: Arc>>>, +} + +impl std::fmt::Debug for Ep3 { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Ep3") + .field("this_id", &self.ep.this_id) + .field("ep_uniq", &self.ep.ep_uniq) + .finish() + } +} + +impl Drop for Ep3 { + fn drop(&mut self) { + let handles = std::mem::take(&mut *self.listen_sigs.lock().unwrap()); + for handle in handles { + handle.abort(); + } + } +} + +impl Ep3 { + /// Construct a new tx5 endpoint version 3. + pub async fn new(config: Arc) -> (Self, EventRecv) { + use influxive_otel_atomic_obs::MeterExt; + use opentelemetry_api::metrics::MeterProvider; + + let sig_limit = Arc::new(tokio::sync::Semaphore::new( + config.connection_count_max as usize, + )); + + let peer_limit = Arc::new(tokio::sync::Semaphore::new( + config.connection_count_max as usize, + )); + + let recv_recon_limit = Arc::new(tokio::sync::Semaphore::new( + config.incoming_message_bytes_max as usize, + )); + + let lair_tag: Arc = + rand_utf8::rand_utf8(&mut rand::thread_rng(), 32).into(); + + let passphrase = sodoken::BufRead::new_no_lock( + rand_utf8::rand_utf8(&mut rand::thread_rng(), 32).as_bytes(), + ); + + // this is a memory keystore, + // so weak persistence security is okay, + // since it will not be persisted. + // The private keys will still be mem_locked + // so they shouldn't be swapped to disk. + let keystore_config = PwHashLimits::Minimum + .with_exec(|| LairServerConfigInner::new("/", passphrase.clone())) + .await + .unwrap(); + + let _lair_keystore = PwHashLimits::Minimum + .with_exec(|| { + lair_keystore_api::in_proc_keystore::InProcKeystore::new( + Arc::new(keystore_config), + lair_keystore_api::mem_store::create_mem_store_factory(), + passphrase, + ) + }) + .await + .unwrap(); + + let lair_client = _lair_keystore.new_client().await.unwrap(); + + let seed = lair_client + .new_seed(lair_tag.clone(), None, false) + .await + .unwrap(); + + let this_id = Id(*seed.x25519_pub_key.0); + + let (evt_send, evt_recv) = + EventSend::new(config.internal_event_channel_size); + + let sig_map = Arc::new(Mutex::new(HashMap::new())); + let weak_sig_map = Arc::downgrade(&sig_map); + + let ep_uniq = next_uniq(); + + let meter = opentelemetry_api::global::meter_provider() + .versioned_meter( + "tx5", + None::<&'static str>, + None::<&'static str>, + Some(vec![opentelemetry_api::KeyValue::new( + "ep_uniq", + ep_uniq.to_string(), + )]), + ); + + let metric_conn_count = meter + .i64_observable_up_down_counter_atomic("tx5.endpoint.conn.count", 0) + .with_description("Count of open connections managed by endpoint") + .init() + .0; + + let this = Self { + ep: Arc::new(EpShared { + config, + this_id, + ep_uniq, + lair_tag, + lair_client, + sig_limit, + peer_limit, + recv_recon_limit, + weak_sig_map, + evt_send, + ban_map: Mutex::new(BanMap::default()), + metric_conn_count, + }), + _lair_keystore, + _sig_map: sig_map, + listen_sigs: Arc::new(Mutex::new(Vec::new())), + }; + + (this, evt_recv) + } + + /// Establish a listening connection to a signal server, + /// from which we can accept incoming remote connections. + /// Returns the client url at which this endpoint may now be addressed. + pub async fn listen(&self, sig_url: SigUrl) -> Result { + if !sig_url.is_server() { + return Err(Error::str("Expected SigUrl, got PeerUrl")); + } + + let ep = self.ep.clone(); + let peer_url = sig_url.to_client(ep.this_id); + + let (wait_send, wait_recv) = tokio::sync::oneshot::channel(); + let mut wait_send = Some(wait_send); + + self.listen_sigs + .lock() + .unwrap() + .push(tokio::task::spawn(async move { + let mut backoff = ep.config.backoff_start; + loop { + /* + tracing::error!( + %ep.ep_uniq, + %sig_url, + "TRY ASSERT SIG", + ); + */ + match assert_sig(&ep, &sig_url).await { + Ok(_) => { + //tracing::error!(%ep.ep_uniq, "SIG CONNECTED!"); + // if the conn is still open it's essentially + // a no-op to assert it again, so it's + // okay to do that quickly. + backoff = ep.config.backoff_start; + } + Err(_err) => { + //tracing::error!(%ep.ep_uniq, ?err, "SIG ERROR!"); + backoff *= 2; + if backoff > ep.config.backoff_max { + backoff = ep.config.backoff_max; + } + } + } + + if let Some(wait_send) = wait_send.take() { + let _ = wait_send.send(()); + } + + tokio::time::sleep(backoff).await; + } + })); + + // await at least one loop of connect attempt before returning + let _ = wait_recv.await; + + Ok(peer_url) + } + + /// Close down all connections to, fail all outgoing messages to, + /// and drop all incoming messages from, the given remote id, + /// for the specified ban time period. + pub fn ban(&self, rem_id: Id, span: std::time::Duration) { + self.ep + .ban_map + .lock() + .unwrap() + .set_ban(rem_id, tokio::time::Instant::now() + span); + + let fut_list = self + ._sig_map + .lock() + .unwrap() + .values() + .map(|v| v.1.clone()) + .collect::>(); + for fut in fut_list { + let ep = self.ep.clone(); + // fire and forget + tokio::task::spawn(async move { + if let Ok(sig) = fut.await { + // see if we are still banning this id. + if ep.ban_map.lock().unwrap().is_banned(rem_id) { + sig.ban(rem_id); + } + } + }); + } + } + + /// Send data to a remote on this tx5 endpoint. + /// The future returned from this method will resolve when + /// the data is handed off to our networking backend. + pub async fn send(&self, peer_url: PeerUrl, data: &[u8]) -> Result<()> { + if !peer_url.is_client() { + return Err(Error::str("Expected PeerUrl, got SigUrl")); + } + + let sig_url = peer_url.to_server(); + let peer_id = peer_url.id().unwrap(); + + if self.ep.ban_map.lock().unwrap().is_banned(peer_id) { + return Err(Error::str("Peer is currently banned")); + } + + let sig = assert_sig(&self.ep, &sig_url).await?; + + let peer = sig + .assert_peer(peer_url, peer_id, PeerDir::ActiveOrOutgoing) + .await?; + + peer.send(data).await + } + + /// Broadcast data to all connections that happen to be open. + /// If no connections are open, no data will be broadcast. + /// The future returned from this method will resolve when all + /// broadcast messages have been handed off to our networking backend + /// (or have timed out). + pub async fn broadcast(&self, data: &[u8]) { + let mut task_list = Vec::new(); + + let fut_list = self + ._sig_map + .lock() + .unwrap() + .values() + .map(|v| v.1.clone()) + .collect::>(); + + for fut in fut_list { + task_list.push(async move { + // timeouts are built into this future as well + // as the sig.broadcast function + if let Ok(sig) = fut.await { + sig.broadcast(data).await; + } + }); + } + + futures::future::join_all(task_list).await; + } + + /// Get stats. + pub async fn get_stats(&self) -> serde_json::Value { + let mut task_list = Vec::new(); + + let mut ban_map = serde_json::Map::new(); + + let now = tokio::time::Instant::now(); + for (id, until) in self.ep.ban_map.lock().unwrap().0.iter() { + ban_map.insert(id.to_string(), (*until - now).as_secs_f64().into()); + } + + let fut_list = self + ._sig_map + .lock() + .unwrap() + .values() + .map(|v| v.1.clone()) + .collect::>(); + + for fut in fut_list { + task_list.push(async move { + if let Ok(sig) = fut.await { + Some(sig.get_stats().await) + } else { + None + } + }); + } + + let res: Vec<(Id, serde_json::Value)> = + futures::future::join_all(task_list) + .await + .into_iter() + .flatten() + .flatten() + .collect(); + + let mut map = serde_json::Map::default(); + + #[cfg(feature = "backend-go-pion")] + const BACKEND: &str = "go-pion"; + #[cfg(feature = "backend-webrtc-rs")] + const BACKEND: &str = "webrtc-rs"; + + map.insert("backend".into(), BACKEND.into()); + map.insert("thisId".into(), self.ep.this_id.to_string().into()); + map.insert("banned".into(), ban_map.into()); + + for (id, v) in res { + map.insert(id.to_string(), v); + } + + serde_json::Value::Object(map) + } +} + +async fn assert_sig(ep: &Arc, sig_url: &SigUrl) -> CRes> { + let sig_map = match ep.weak_sig_map.upgrade() { + Some(sig_map) => sig_map, + None => { + return Err(Error::str( + "Signal connection failed due to closed endpoint", + ) + .into()) + } + }; + + let (sig_uniq, fut) = sig_map + .lock() + .unwrap() + .entry(sig_url.clone()) + .or_insert_with(|| { + let sig_uniq = next_uniq(); + let sig_url = sig_url.clone(); + let ep = ep.clone(); + let _sig_drop = SigDrop { + ep_uniq: ep.ep_uniq, + sig_uniq, + sig_url: sig_url.clone(), + weak_sig_map: ep.weak_sig_map.clone(), + }; + ( + sig_uniq, + AbortableTimedSharedFuture::new( + ep.config.timeout, + Error::str("Timeout awaiting signal server connection") + .into(), + Sig::new(_sig_drop, ep, sig_uniq, sig_url), + ), + ) + }) + .clone(); + + match fut.await { + Err(err) => { + // if a new sig got added in the mean time, return that instead + let r = sig_map.lock().unwrap().get(sig_url).cloned(); + + if let Some((new_sig_uniq, new_sig_fut)) = r { + if new_sig_uniq != sig_uniq { + return new_sig_fut.await; + } + } + + Err(err) + } + Ok(r) => Ok(r), + } +} + +fn close_sig( + weak_sig_map: &Weak>, + sig_url: &SigUrl, + close_sig_uniq: u64, +) { + let mut tmp = None; + + if let Some(sig_map) = weak_sig_map.upgrade() { + let mut lock = sig_map.lock().unwrap(); + if let Some((sig_uniq, sig)) = lock.remove(sig_url) { + if close_sig_uniq != sig_uniq { + // most of the time we'll be closing the real one, + // so optimize for that case, and cause a hash probe + // in the less likely case some race caused us to + // try to remove the wrong one. + tmp = lock.insert(sig_url.clone(), (sig_uniq, sig)); + } else { + tmp = Some((sig_uniq, sig)); + } + } + } + + // make sure nothing is dropped while we're holding the mutex lock + if let Some((_sig_uniq, sig_fut)) = tmp { + sig_fut.abort(Error::id("Close").into()); + drop(sig_fut); + } +} + +pub(crate) mod sig; +pub(crate) use sig::*; + +pub(crate) mod peer; +pub(crate) use peer::*; + +#[cfg(test)] +mod test; diff --git a/crates/tx5/src/ep3/peer.rs b/crates/tx5/src/ep3/peer.rs new file mode 100644 index 00000000..a08b3aac --- /dev/null +++ b/crates/tx5/src/ep3/peer.rs @@ -0,0 +1,676 @@ +use super::*; +use crate::proto::*; + +pub(crate) enum PeerCmd { + Error(Error), + SigRecvIce(serde_json::Value), +} + +impl From for PeerCmd { + fn from(err: Error) -> Self { + Self::Error(err) + } +} + +pub(crate) enum PeerDir { + ActiveOrOutgoing, + Incoming { offer: serde_json::Value }, +} + +impl PeerDir { + pub fn is_incoming(&self) -> bool { + matches!(self, PeerDir::Incoming { .. }) + } +} + +pub(crate) enum NewPeerDir { + Outgoing { + answer_recv: tokio::sync::oneshot::Receiver, + }, + Incoming { + offer: serde_json::Value, + }, +} + +pub(crate) struct PeerDrop { + pub ep_uniq: u64, + pub sig_uniq: u64, + pub peer_uniq: u64, + pub peer_id: Id, + pub weak_peer_map: Weak>, +} + +impl Drop for PeerDrop { + fn drop(&mut self) { + tracing::info!(%self.ep_uniq, %self.sig_uniq, %self.peer_uniq, ?self.peer_id, "Peer Connection Close"); + + close_peer(&self.weak_peer_map, self.peer_id, self.peer_uniq); + } +} + +pub(crate) struct Peer { + _peer_drop: PeerDrop, + created_at: tokio::time::Instant, + sig: Arc, + peer_id: Id, + peer_url: PeerUrl, + _permit: tokio::sync::OwnedSemaphorePermit, + cmd_task: tokio::task::JoinHandle<()>, + recv_task: tokio::task::JoinHandle<()>, + data_task: tokio::task::JoinHandle<()>, + #[allow(dead_code)] + peer: Arc, + data_chan: Arc, + send_limit: Arc, + metric_bytes_send: influxive_otel_atomic_obs::AtomicObservableCounterU64, + metric_unreg: + Option>, + dec: Arc>, +} + +impl Drop for Peer { + fn drop(&mut self) { + let evt_send = self.sig.evt_send.clone(); + let msg = Ep3Event::Disconnected { + peer_url: self.peer_url.clone(), + }; + tokio::task::spawn(async move { + let _ = evt_send.send(msg).await; + }); + self.cmd_task.abort(); + self.recv_task.abort(); + self.data_task.abort(); + if let Some(mut metric_unreg) = self.metric_unreg.take() { + let _ = metric_unreg.unregister(); + } + } +} + +impl Peer { + #[allow(clippy::too_many_arguments)] + pub async fn new( + _peer_drop: PeerDrop, + sig: Arc, + peer_url: PeerUrl, + peer_id: Id, + peer_uniq: u64, + ice_servers: Arc, + new_peer_dir: NewPeerDir, + mut peer_cmd_recv: EventRecv, + ) -> CRes> { + use influxive_otel_atomic_obs::MeterExt; + use opentelemetry_api::metrics::MeterProvider; + + tracing::info!(%sig.ep_uniq, %sig.sig_uniq, %peer_uniq, ?peer_id, "Peer Connection Connecting"); + + let meter = opentelemetry_api::global::meter_provider() + .versioned_meter( + "tx5", + None::<&'static str>, + None::<&'static str>, + Some(vec![ + opentelemetry_api::KeyValue::new( + "ep_uniq", + sig.ep_uniq.to_string(), + ), + opentelemetry_api::KeyValue::new( + "sig_uniq", + sig.sig_uniq.to_string(), + ), + opentelemetry_api::KeyValue::new( + "peer_uniq", + peer_uniq.to_string(), + ), + ]), + ); + + let metric_bytes_send = meter + .u64_observable_counter_atomic("tx5.endpoint.conn.send", 0) + .with_description("Outgoing bytes sent on this connection") + .with_unit(opentelemetry_api::metrics::Unit::new("By")) + .init() + .0; + + let metric_bytes_recv = meter + .u64_observable_counter_atomic("tx5.endpoint.conn.recv", 0) + .with_description("Incoming bytes received on this connection") + .with_unit(opentelemetry_api::metrics::Unit::new("By")) + .init() + .0; + + let _permit = + sig.peer_limit.clone().acquire_owned().await.map_err(|_| { + Error::str( + "Endpoint closed while acquiring peer connection permit", + ) + })?; + + let sig_hnd = match sig.weak_sig.upgrade() { + None => { + return Err(Error::str( + "Sig shutdown while opening peer connection", + ) + .into()) + } + Some(sig_hnd) => sig_hnd, + }; + + let peer_config = BackBuf::from_json(ice_servers)?; + + let (peer, mut peer_recv) = + tx5_go_pion::PeerConnection::new(peer_config.imp.buf).await?; + + let peer = Arc::new(peer); + + #[derive(Debug, serde::Deserialize)] + #[serde(rename_all = "camelCase")] + struct BackendMetrics { + #[serde(default)] + messages_sent: u64, + #[serde(default)] + messages_received: u64, + #[serde(default)] + bytes_sent: u64, + #[serde(default)] + bytes_received: u64, + } + + let data_snd = meter + .u64_observable_counter("tx5.conn.data.send") + .with_description("Bytes sent on data channel") + .with_unit(opentelemetry_api::metrics::Unit::new("By")) + .init(); + let data_rcv = meter + .u64_observable_counter("tx5.conn.data.recv") + .with_description("Bytes received on data channel") + .with_unit(opentelemetry_api::metrics::Unit::new("By")) + .init(); + let data_snd_msg = meter + .u64_observable_counter("tx5.conn.data.send.message.count") + .with_description("Message count sent on data channel") + .init(); + let data_rcv_msg = meter + .u64_observable_counter("tx5.conn.data.recv.message.count") + .with_description("Message count received on data channel") + .init(); + let ice_snd = meter + .u64_observable_counter("tx5.conn.ice.send") + .with_description("Bytes sent on ice channel") + .with_unit(opentelemetry_api::metrics::Unit::new("By")) + .init(); + let ice_rcv = meter + .u64_observable_counter("tx5.conn.ice.recv") + .with_description("Bytes received on ice channel") + .with_unit(opentelemetry_api::metrics::Unit::new("By")) + .init(); + + let metric_unreg = { + let peer = peer.clone(); + let data: Arc>>> = + Arc::new(Mutex::new(None)); + meter + .register_callback( + &[ + data_snd.as_any(), + data_rcv.as_any(), + data_snd_msg.as_any(), + data_rcv_msg.as_any(), + ice_snd.as_any(), + ice_rcv.as_any(), + ], + move |obs| { + let data2 = data.clone(); + let peer2 = peer.clone(); + tokio::task::spawn(async move { + if let Ok(mut stats) = peer2.stats().await { + if let Ok(stats) = stats.as_json() { + *data2.lock().unwrap() = Some(stats); + } + } + }); + if let Some(stats) = data.lock().unwrap().take() { + for (k, v) in stats.iter() { + if k.starts_with("DataChannel") { + obs.observe_u64( + &data_snd, + v.bytes_sent, + &[], + ); + obs.observe_u64( + &data_rcv, + v.bytes_received, + &[], + ); + obs.observe_u64( + &data_snd_msg, + v.messages_sent, + &[], + ); + obs.observe_u64( + &data_rcv_msg, + v.messages_received, + &[], + ); + } else if k.starts_with("iceTransport") { + obs.observe_u64( + &ice_snd, + v.bytes_sent, + &[], + ); + obs.observe_u64( + &ice_rcv, + v.bytes_received, + &[], + ); + } + } + } + }, + ) + .map_err(Error::err)? + }; + + let (chan_send, chan_recv) = tokio::sync::oneshot::channel(); + let mut chan_send = Some(chan_send); + + match new_peer_dir { + NewPeerDir::Outgoing { answer_recv } => { + let chan = peer + .create_data_channel(tx5_go_pion::DataChannelConfig { + label: Some("data".into()), + }) + .await?; + + if let Some(chan_send) = chan_send.take() { + let _ = chan_send.send(chan); + } + + let mut offer = peer + .create_offer(tx5_go_pion::OfferConfig::default()) + .await?; + + let offer_json = offer.as_json()?; + + tracing::debug!(?offer_json, "create_offer"); + + sig_hnd.offer(peer_id, offer_json).await?; + + peer.set_local_description(offer).await?; + + let answer = answer_recv.await.map_err(|_| { + Error::str("Failed to receive answer on peer connect") + })?; + let answer = BackBuf::from_json(answer)?; + + peer.set_remote_description(answer.imp.buf).await?; + } + NewPeerDir::Incoming { offer } => { + let offer = BackBuf::from_json(offer)?; + + peer.set_remote_description(offer.imp.buf).await?; + + let mut answer = peer + .create_answer(tx5_go_pion::AnswerConfig::default()) + .await?; + + let answer_json = answer.as_json()?; + + tracing::debug!(?answer_json, "create_answer"); + + sig_hnd.answer(peer_id, answer_json).await?; + + peer.set_local_description(answer).await?; + } + } + + let cmd_task = { + let weak_peer = Arc::downgrade(&peer); + let sig = sig.clone(); + tokio::task::spawn(async move { + while let Some(cmd) = peer_cmd_recv.recv().await { + match cmd { + PeerCmd::Error(err) => { + tracing::warn!(?err); + break; + } + PeerCmd::SigRecvIce(ice) => { + if let Some(peer) = weak_peer.upgrade() { + if let Ok(ice) = BackBuf::from_json(ice) { + if let Err(err) = peer + .add_ice_candidate(ice.imp.buf) + .await + { + tracing::trace!(?err); + } + } + } else { + break; + } + } + } + } + + close_peer(&sig.weak_peer_map, peer_id, peer_uniq); + }) + }; + + let recv_task = { + let sig = sig.clone(); + tokio::task::spawn(async move { + while let Some(evt) = peer_recv.recv().await { + use tx5_go_pion::PeerConnectionEvent as Evt; + match evt { + Evt::Error(err) => { + tracing::warn!(?err); + break; + } + Evt::State(_state) => (), + Evt::ICECandidate(mut ice) => { + let ice = match ice.as_json() { + Err(err) => { + tracing::warn!(?err, "invalid ice"); + break; + } + Ok(ice) => ice, + }; + if let Some(sig_hnd) = sig.weak_sig.upgrade() { + if sig_hnd.ice(peer_id, ice).await.is_err() { + break; + } + } else { + break; + } + } + Evt::DataChannel(c, r) => { + if let Some(chan_send) = chan_send.take() { + let _ = chan_send.send((c, r)); + } else { + tracing::warn!("Invalid incoming data channel"); + break; + } + } + } + } + + close_peer(&sig.weak_peer_map, peer_id, peer_uniq); + }) + }; + + let (data_chan, mut data_recv) = chan_recv.await.map_err(|_| { + Error::str("Failed to establish peer connection data channel") + })?; + + let data_chan = Arc::new(data_chan); + + let mut ready_state = data_chan.ready_state()?; + let mut backoff = std::time::Duration::from_millis(10); + + loop { + if ready_state >= 2 { + break; + } + + tokio::time::sleep(backoff).await; + backoff *= 2; + ready_state = data_chan.ready_state()?; + } + + if ready_state > 2 { + return Err(Error::str( + "Data channel closed while connecting peer", + ) + .into()); + } + + let (a, b, c, d) = PROTO_VER_1.encode()?; + data_chan + .send(BackBuf::from_slice([a, b, c, d])?.imp.buf) + .await?; + + sig.evt_send + .send(Ep3Event::Connected { + peer_url: peer_url.clone(), + }) + .await?; + + let dec = Arc::new(Mutex::new(ProtoDecoder::default())); + + let data_task = { + let dec = dec.clone(); + let weak_data_chan = Arc::downgrade(&data_chan); + let peer_url = peer_url.clone(); + let sig = sig.clone(); + tokio::task::spawn(async move { + let mut did_preflight = false; + + while let Some(evt) = data_recv.recv().await { + use tx5_go_pion::DataChannelEvent::*; + match evt { + Error(err) => { + tracing::warn!(?err); + break; + } + Open => (), + Close => break, + Message(message) => { + let mut message = BackBuf::from_raw(message); + + let len = message.len().unwrap(); + + metric_bytes_recv.add(len as u64); + + let dec_res = dec.lock().unwrap().decode(message); + + use ProtoDecodeResult::*; + match match dec_res { + Ok(r) => r, + Err(err) => { + tracing::debug!(?err, "DecodeError"); + break; + } + } { + Idle => (), + Message(message) => { + if !did_preflight { + did_preflight = true; + + if let Some((_pf_send, pf_check)) = + sig.config.preflight.as_ref() + { + if pf_check(&peer_url, message) + .await + .is_err() + { + break; + } + continue; + } + } + if sig + .evt_send + .send(Ep3Event::Message { + peer_url: peer_url.clone(), + message, + }) + .await + .is_err() + { + break; + } + } + RemotePermitRequest(permit_len) => { + if permit_len > sig.config.message_size_max + { + tracing::debug!(%permit_len, "InvalidPermitSizeRequest"); + break; + } + if let Some(data_chan) = + weak_data_chan.upgrade() + { + let dec = dec.clone(); + let recv_recon_limit = + sig.recv_recon_limit.clone(); + // fire and forget + tokio::task::spawn(async move { + let permit = recv_recon_limit + .acquire_many_owned(permit_len) + .await + .map_err( + tx5_core::Error::err, + )?; + let (a, b, c, d) = + ProtoHeader::PermitGrant( + permit_len, + ) + .encode()?; + data_chan + .send( + BackBuf::from_slice([ + a, b, c, d, + ]) + .unwrap() + .imp + .buf, + ) + .await?; + dec.lock() + .unwrap() + .sent_remote_permit_grant( + permit, + )?; + Result::Ok(()) + }); + } else { + break; + } + } + RemotePermitGrant(_permit_len) => (), + } + } + BufferedAmountLow => (), + } + } + + close_peer(&sig.weak_peer_map, peer_id, peer_uniq); + }) + }; + + tracing::info!(%sig.ep_uniq, %sig.sig_uniq, %peer_uniq, ?peer_id, "Peer Connection Open"); + + let this = Arc::new(Self { + _peer_drop, + created_at: tokio::time::Instant::now(), + sig, + peer_id, + peer_url: peer_url.clone(), + _permit, + cmd_task, + recv_task, + data_task, + peer, + data_chan, + // size 1 semaphore makes sure blocks of messages are contiguous + send_limit: Arc::new(tokio::sync::Semaphore::new(1)), + metric_bytes_send, + metric_unreg: Some(metric_unreg), + dec, + }); + + if let Some((pf_send, _pf_check)) = this.sig.config.preflight.as_ref() { + this.send(&pf_send(&peer_url).await?).await?; + } + + Ok(this) + } + + async fn sub_send(&self, mut buf: BackBuf) -> Result<()> { + tokio::time::timeout(self.sig.config.timeout, async { + let mut backoff = std::time::Duration::from_millis(1); + loop { + if self.data_chan.buffered_amount()? + <= self.sig.config.send_buffer_bytes_max as usize + { + break; + } + tokio::time::sleep(backoff).await; + backoff *= 2; + if backoff.as_millis() > 200 { + backoff = std::time::Duration::from_millis(200); + } + } + self.metric_bytes_send.add(buf.len()? as u64); + + if self.sig.ban_map.lock().unwrap().is_banned(self.peer_id) { + return Err(Error::str("Peer is currently banned")); + } + + self.data_chan.send(buf.imp.buf).await + }) + .await + .map_err(|_| { + Error::str("Timeout sending data to backend data channel") + })??; + Ok(()) + } + + pub async fn send(&self, data: &[u8]) -> Result<()> { + if data.len() > self.sig.config.message_size_max as usize { + return Err(Error::str("Message is too large")); + } + + if self.sig.ban_map.lock().unwrap().is_banned(self.peer_id) { + return Err(Error::str("Peer is currently banned")); + } + + // size 1 semaphore makes sure blocks of messages are contiguous + let _permit = tokio::time::timeout( + self.sig.config.timeout, + self.send_limit.acquire(), + ) + .await + .map_err(|_| Error::str("Timeout acquiring send permit"))? + .map_err(|_| Error::str("Failed to acquire send permit"))?; + + match proto_encode(data)? { + ProtoEncodeResult::NeedPermit { + permit_req, + msg_payload, + } => { + let (s, r) = tokio::sync::oneshot::channel(); + + self.dec + .lock() + .unwrap() + .sent_remote_permit_request(Some(s))?; + + self.sub_send(permit_req).await?; + + r.await.map_err(|_| Error::id("ConnectionClosed"))?; + + for message in msg_payload { + self.sub_send(message).await?; + } + } + ProtoEncodeResult::OneMessage(buf) => { + self.sub_send(buf).await?; + } + } + + Ok(()) + } + + pub async fn stats(&self) -> Result { + self.peer.stats().await.map(|mut s| { + let mut out: serde_json::Value = s.as_json()?; + + if let Some(map) = out.as_object_mut() { + map.insert( + "ageSeconds".into(), + self.created_at.elapsed().as_secs_f64().into(), + ); + } + + Ok(out) + })? + } +} diff --git a/crates/tx5/src/ep3/sig.rs b/crates/tx5/src/ep3/sig.rs new file mode 100644 index 00000000..9978a3d6 --- /dev/null +++ b/crates/tx5/src/ep3/sig.rs @@ -0,0 +1,404 @@ +use super::*; + +type PeerCmdSend = EventSend; +type AnswerSend = + Arc>>>; +pub(crate) type PeerMap = HashMap< + Id, + ( + u64, + PeerCmdSend, + AnswerSend, + AbortableTimedSharedFuture>, + ), +>; + +pub(crate) struct SigShared { + pub ep: Arc, + pub weak_sig: Weak, + pub sig_uniq: u64, + pub weak_peer_map: Weak>, +} + +impl std::ops::Deref for SigShared { + type Target = Arc; + + fn deref(&self) -> &Self::Target { + &self.ep + } +} + +pub(crate) struct SigDrop { + pub ep_uniq: u64, + pub sig_uniq: u64, + pub sig_url: SigUrl, + pub weak_sig_map: Weak>, +} + +impl Drop for SigDrop { + fn drop(&mut self) { + tracing::info!(%self.ep_uniq, %self.sig_uniq, %self.sig_url, "Signal Connection Close"); + + close_sig(&self.weak_sig_map, &self.sig_url, self.sig_uniq); + } +} + +pub(crate) struct Sig { + _sig_drop: SigDrop, + sig: Arc, + _permit: tokio::sync::OwnedSemaphorePermit, + recv_task: tokio::task::JoinHandle<()>, + ice_servers: Arc, + peer_map: Arc>, + sig_cli: tx5_signal::Cli, +} + +impl Drop for Sig { + fn drop(&mut self) { + self.sig.metric_conn_count.add(-1); + self.recv_task.abort(); + } +} + +impl std::ops::Deref for Sig { + type Target = tx5_signal::Cli; + + fn deref(&self) -> &Self::Target { + &self.sig_cli + } +} + +impl Sig { + pub async fn new( + _sig_drop: SigDrop, + ep: Arc, + sig_uniq: u64, + sig_url: SigUrl, + ) -> CRes> { + ep.metric_conn_count.add(1); + + tracing::info!(%ep.ep_uniq, %sig_uniq, %sig_url, "Signal Connection Connecting"); + + let _permit = + ep.sig_limit.clone().acquire_owned().await.map_err(|_| { + Error::str( + "Endpoint closed while acquiring signal connection permit", + ) + })?; + + let (sig_cli, mut sig_recv) = tx5_signal::Cli::builder() + .with_lair_tag(ep.lair_tag.clone()) + .with_lair_client(ep.lair_client.clone()) + .with_url(sig_url.to_string().parse().unwrap()) + .build() + .await?; + + let peer_url = Tx5Url::new(sig_cli.local_addr())?; + if peer_url.id().unwrap() != ep.this_id { + return Err(Error::str("Invalid signal server peer Id").into()); + } + + let ice_servers = sig_cli.ice_servers(); + + let peer_map: Arc> = + Arc::new(Mutex::new(HashMap::new())); + let weak_peer_map = Arc::downgrade(&peer_map); + + Ok(Arc::new_cyclic(move |weak_sig: &Weak| { + let recv_task = { + let ep = ep.clone(); + let weak_sig = weak_sig.clone(); + let sig_url = sig_url.clone(); + tokio::task::spawn(async move { + while let Some(msg) = sig_recv.recv().await { + use tx5_signal::SignalMsg::*; + match msg { + Demo { .. } => (), + Offer { rem_pub, offer } => { + tracing::trace!(%ep.ep_uniq, %sig_uniq, ?rem_pub, ?offer, "Sig Recv Offer"); + if let Some(sig) = weak_sig.upgrade() { + let peer_url = sig_url.to_client(rem_pub); + // fire and forget this + tokio::task::spawn(async move { + let _ = sig + .assert_peer( + peer_url, + rem_pub, + PeerDir::Incoming { offer }, + ) + .await; + }); + } else { + break; + } + } + Answer { rem_pub, answer } => { + tracing::trace!(%ep.ep_uniq, %sig_uniq, ?rem_pub, ?answer, "Sig Recv Answer"); + if let Some(peer_map) = weak_peer_map.upgrade() + { + let r = peer_map + .lock() + .unwrap() + .get(&rem_pub) + .cloned(); + if let Some((_, _, answer_send, _)) = r { + let r = + answer_send.lock().unwrap().take(); + if let Some(answer_send) = r { + let _ = answer_send.send(answer); + } + } + } else { + break; + } + } + Ice { rem_pub, ice } => { + tracing::trace!(%ep.ep_uniq, %sig_uniq, ?rem_pub, ?ice, "Sig Recv ICE"); + if let Some(peer_map) = weak_peer_map.upgrade() + { + let r = peer_map + .lock() + .unwrap() + .get(&rem_pub) + .cloned(); + if let Some((_, peer_cmd_send, _, _)) = r { + if let Some(permit) = + peer_cmd_send.try_permit() + { + if peer_cmd_send + .send_permit( + PeerCmd::SigRecvIce(ice), + permit, + ) + .is_err() + { + break; + } + } else { + break; + } + } + } else { + break; + } + } + } + } + + close_sig(&ep.weak_sig_map, &sig_url, sig_uniq); + }) + }; + + let weak_peer_map = Arc::downgrade(&peer_map); + + tracing::info!(%ep.ep_uniq, %sig_uniq, %sig_url, "Signal Connection Open"); + + Self { + _sig_drop, + sig: Arc::new(SigShared { + ep, + weak_sig: weak_sig.clone(), + sig_uniq, + weak_peer_map, + }), + _permit, + recv_task, + ice_servers, + peer_map, + sig_cli, + } + })) + } + + pub async fn assert_peer( + &self, + peer_url: PeerUrl, + peer_id: Id, + peer_dir: PeerDir, + ) -> CRes> { + if peer_id == self.sig.this_id { + return Err(Error::str("Cannot establish connection with remote peer id matching this id").into()); + } + + if self.sig.ban_map.lock().unwrap().is_banned(peer_id) { + return Err(Error::str("Peer is currently banned").into()); + } + + let mut tmp = None; + + let (peer_uniq, _peer_cmd_send, _answer_send, fut) = { + let mut lock = self.peer_map.lock().unwrap(); + + if peer_dir.is_incoming() && lock.contains_key(&peer_id) { + // we need to check negotiation + if peer_id > self.sig.this_id { + // we are the polite node, drop our existing connection + tmp = lock.remove(&peer_id); + } + // otherwise continue on to return the currently + // registered connection because we're the impolite node. + } + + lock.entry(peer_id) + .or_insert_with(|| { + let mut answer_send = None; + let new_peer_dir = match peer_dir { + PeerDir::ActiveOrOutgoing => { + let (s, r) = tokio::sync::oneshot::channel(); + answer_send = Some(s); + NewPeerDir::Outgoing { answer_recv: r } + } + PeerDir::Incoming { offer } => { + NewPeerDir::Incoming { offer } + } + }; + let sig = self.sig.clone(); + let peer_uniq = next_uniq(); + let ice_servers = self.ice_servers.clone(); + let (peer_cmd_send, peer_cmd_recv) = + EventSend::new(sig.config.internal_event_channel_size); + let _peer_drop = PeerDrop { + ep_uniq: sig.ep_uniq, + sig_uniq: sig.sig_uniq, + peer_uniq, + peer_id, + weak_peer_map: sig.weak_peer_map.clone(), + }; + ( + peer_uniq, + peer_cmd_send, + Arc::new(Mutex::new(answer_send)), + AbortableTimedSharedFuture::new( + sig.config.timeout, + Error::str("Timeout awaiting peer connection") + .into(), + Peer::new( + _peer_drop, + sig, + peer_url, + peer_id, + peer_uniq, + ice_servers, + new_peer_dir, + peer_cmd_recv, + ), + ), + ) + }) + .clone() + }; + + // make sure to drop this after releasing our mutex lock + if let Some((_peer_uniq, _cmd, _ans, peer_fut)) = tmp { + peer_fut.abort(Error::str("Dropping connection because we are the polite node and received an offer from the remote").into()); + drop(peer_fut); + } + + match fut.await { + Err(err) => { + // if a new peer got added in the mean time, return that instead + let r = self.peer_map.lock().unwrap().get(&peer_id).cloned(); + + if let Some((new_peer_uniq, _cmd, _ans, new_peer_fut)) = r { + if new_peer_uniq != peer_uniq { + return new_peer_fut.await; + } + } + + Err(err) + } + Ok(r) => Ok(r), + } + } + + pub fn ban(&self, id: Id) { + let r = self.peer_map.lock().unwrap().get(&id).cloned(); + if let Some((uniq, _, _, _)) = r { + close_peer(&self.sig.weak_peer_map, id, uniq); + } + } + + pub async fn broadcast(&self, data: &[u8]) { + let mut task_list = Vec::new(); + + let fut_list = self + .peer_map + .lock() + .unwrap() + .values() + .map(|v| v.3.clone()) + .collect::>(); + + for fut in fut_list { + task_list.push(async move { + // timeouts are built into this future as well + // as the peer.send function + if let Ok(peer) = fut.await { + let _ = peer.send(data).await; + } + }); + } + + futures::future::join_all(task_list).await; + } + + pub async fn get_stats(&self) -> Vec<(Id, serde_json::Value)> { + let mut task_list = Vec::new(); + + let fut_list = self + .peer_map + .lock() + .unwrap() + .iter() + .map(|(k, v)| (*k, v.3.clone())) + .collect::>(); + + for (peer_id, fut) in fut_list { + task_list.push(async move { + if let Ok(peer) = fut.await { + match peer.stats().await { + Ok(s) => Some((peer_id, s)), + _ => None, + } + } else { + None + } + }) + } + + futures::future::join_all(task_list) + .await + .into_iter() + .flatten() + .collect() + } +} + +pub(crate) fn close_peer( + weak_peer_map: &Weak>, + peer_id: Id, + close_peer_uniq: u64, +) { + let mut tmp = None; + + if let Some(peer_map) = weak_peer_map.upgrade() { + let mut lock = peer_map.lock().unwrap(); + if let Some((peer_uniq, cmd, ans, peer)) = lock.remove(&peer_id) { + if close_peer_uniq != peer_uniq { + // most of the time we'll be closing the real one, + // so optimize for that case, and cause a hash probe + // in the less likely case some race caused us to + // try to remove the wrong one. + tmp = lock.insert(peer_id, (peer_uniq, cmd, ans, peer)); + } else { + tmp = Some((peer_uniq, cmd, ans, peer)); + } + } + } + + // make sure nothing is dropped while we're holding the mutex lock + if let Some((_peer_uniq, _cmd, _ans, peer_fut)) = tmp { + peer_fut.abort(Error::id("Close").into()); + drop(peer_fut); + } +} diff --git a/crates/tx5/src/ep3/test.rs b/crates/tx5/src/ep3/test.rs new file mode 100644 index 00000000..b8736d3a --- /dev/null +++ b/crates/tx5/src/ep3/test.rs @@ -0,0 +1,621 @@ +use super::*; + +struct Test { + sig_srv_hnd: Option, + sig_port: Option, + sig_url: Option, +} + +impl Test { + pub async fn new() -> Self { + let subscriber = tracing_subscriber::FmtSubscriber::builder() + .with_env_filter( + tracing_subscriber::filter::EnvFilter::from_default_env(), + ) + .with_file(true) + .with_line_number(true) + .finish(); + + let _ = tracing::subscriber::set_global_default(subscriber); + + let mut this = Test { + sig_srv_hnd: None, + sig_port: None, + sig_url: None, + }; + + this.restart_sig().await; + + this + } + + pub async fn ep( + &self, + config: Arc, + ) -> (PeerUrl, Ep3, EventRecv) { + let sig_url = self.sig_url.clone().unwrap(); + + let (ep, recv) = Ep3::new(config).await; + let url = ep.listen(sig_url).await.unwrap(); + + (url, ep, recv) + } + + pub fn drop_sig(&mut self) { + drop(self.sig_srv_hnd.take()); + } + + pub async fn restart_sig(&mut self) { + self.drop_sig(); + + tokio::time::sleep(std::time::Duration::from_millis(20)).await; + + let mut srv_config = tx5_signal_srv::Config::default(); + srv_config.port = self.sig_port.unwrap_or(0); + + let (sig_srv_hnd, addr_list, _) = + tx5_signal_srv::exec_tx5_signal_srv(srv_config) + .await + .unwrap(); + self.sig_srv_hnd = Some(sig_srv_hnd); + + let sig_port = addr_list.get(0).unwrap().port(); + self.sig_port = Some(sig_port); + + let sig_url = + SigUrl::new(format!("ws://localhost:{}", sig_port)).unwrap(); + if let Some(old_sig_url) = &self.sig_url { + if old_sig_url != &sig_url { + panic!("mismatching new sig url"); + } + } + tracing::info!(%sig_url); + self.sig_url = Some(sig_url); + } +} + +#[tokio::test(flavor = "multi_thread")] +async fn ep3_sanity() { + let config = Arc::new(Config3::default()); + let test = Test::new().await; + + let (_cli_url1, ep1, _ep1_recv) = test.ep(config.clone()).await; + let (cli_url2, _ep2, mut ep2_recv) = test.ep(config).await; + + ep1.send(cli_url2, b"hello").await.unwrap(); + + let res = ep2_recv.recv().await.unwrap(); + match res { + Ep3Event::Connected { .. } => (), + _ => panic!(), + } + + let res = ep2_recv.recv().await.unwrap(); + match res { + Ep3Event::Message { message, .. } => { + assert_eq!(&b"hello"[..], &message); + } + oth => panic!("{oth:?}"), + } + + let stats = ep1.get_stats().await; + + println!("STATS: {}", serde_json::to_string_pretty(&stats).unwrap()); +} + +#[tokio::test(flavor = "multi_thread")] +async fn ep3_sig_down() { + eprintln!("-- STARTUP --"); + + const TIMEOUT: std::time::Duration = std::time::Duration::from_secs(5); + + let mut config = Config3::default(); + config.timeout = TIMEOUT * 2; + config.backoff_start = std::time::Duration::from_millis(200); + config.backoff_max = std::time::Duration::from_millis(200); + let config = Arc::new(config); + let mut test = Test::new().await; + + let (_cli_url1, ep1, _ep1_recv) = test.ep(config.clone()).await; + let (cli_url2, ep2, mut ep2_recv) = test.ep(config.clone()).await; + + eprintln!("-- Establish Connection --"); + + ep1.send(cli_url2.clone(), b"hello").await.unwrap(); + + let res = ep2_recv.recv().await.unwrap(); + match res { + Ep3Event::Connected { .. } => (), + _ => panic!(), + } + + let res = ep2_recv.recv().await.unwrap(); + match res { + Ep3Event::Message { message, .. } => { + assert_eq!(&b"hello"[..], &message); + } + _ => panic!(), + } + + eprintln!("-- Drop Sig --"); + + test.drop_sig(); + + tokio::time::sleep(TIMEOUT).await; + + // need to trigger another signal message so we know the connection is down + let (cli_url3, _ep3, _ep3_recv) = test.ep(config).await; + + let (a, b) = tokio::join!( + ep1.send(cli_url3.clone(), b"hello",), + ep2.send(cli_url3, b"hello"), + ); + + a.unwrap_err(); + b.unwrap_err(); + + tokio::time::sleep(TIMEOUT).await; + + // now a send to cli_url2 should *also* fail + eprintln!("-- Send Should Fail --"); + + ep1.send(cli_url2.clone(), b"hello").await.unwrap_err(); + + eprintln!("-- Restart Sig --"); + + test.restart_sig().await; + + tokio::time::sleep(TIMEOUT).await; + + eprintln!("-- Send Should Succeed --"); + + ep1.send(cli_url2.clone(), b"hello").await.unwrap(); + + let res = ep2_recv.recv().await.unwrap(); + match res { + Ep3Event::Disconnected { .. } => (), + oth => panic!("{oth:?}"), + } + + let res = ep2_recv.recv().await.unwrap(); + match res { + Ep3Event::Connected { .. } => (), + oth => panic!("{oth:?}"), + } + + let res = ep2_recv.recv().await.unwrap(); + match res { + Ep3Event::Message { message, .. } => { + assert_eq!(&b"hello"[..], &message); + } + oth => panic!("{oth:?}"), + } + + eprintln!("-- Done --"); +} + +#[tokio::test(flavor = "multi_thread")] +async fn ep3_drop() { + let config = Arc::new(Config3::default()); + let test = Test::new().await; + + let (_cli_url1, ep1, _ep1_recv) = test.ep(config.clone()).await; + let (cli_url2, ep2, mut ep2_recv) = test.ep(config.clone()).await; + + ep1.send(cli_url2, b"hello").await.unwrap(); + + let res = ep2_recv.recv().await.unwrap(); + match res { + Ep3Event::Connected { .. } => (), + _ => panic!(), + } + + let res = ep2_recv.recv().await.unwrap(); + match res { + Ep3Event::Message { message, .. } => { + assert_eq!(&b"hello"[..], &message); + } + _ => panic!(), + } + + drop(ep2); + drop(ep2_recv); + + let (cli_url3, _ep3, mut ep3_recv) = test.ep(config).await; + + ep1.send(cli_url3, b"world").await.unwrap(); + + let res = ep3_recv.recv().await.unwrap(); + match res { + Ep3Event::Connected { .. } => (), + _ => panic!(), + } + + let res = ep3_recv.recv().await.unwrap(); + match res { + Ep3Event::Message { message, .. } => { + assert_eq!(&b"world"[..], &message); + } + _ => panic!(), + } +} + +/// Test negotiation (polite / impolite node logic) by setting up a lot +/// of nodes and having them all try to make connections to each other +/// at the same time and see if we get all the messages. +#[tokio::test(flavor = "multi_thread")] +async fn ep3_negotiation() { + const NODE_COUNT: usize = 9; + + let mut url_list = Vec::new(); + let mut ep_list = Vec::new(); + let mut recv_list = Vec::new(); + + let config = Arc::new(Config3::default()); + let test = Test::new().await; + + let mut fut_list = Vec::new(); + for _ in 0..NODE_COUNT { + fut_list.push(test.ep(config.clone())); + } + + for (url, ep, recv) in futures::future::join_all(fut_list).await { + url_list.push(url); + ep_list.push(ep); + recv_list.push(recv); + } + + let first_url = url_list.get(0).unwrap().clone(); + + // first, make sure all the connections are active + // by connecting to the first node + let mut fut_list = Vec::new(); + for (i, ep) in ep_list.iter_mut().enumerate() { + if i != 0 { + fut_list.push(ep.send(first_url.clone(), b"hello")); + } + } + + for r in futures::future::join_all(fut_list).await { + r.unwrap(); + } + + // now send messages between all the nodes + let mut fut_list = Vec::new(); + for (i, ep) in ep_list.iter_mut().enumerate() { + for (j, url) in url_list.iter().enumerate() { + if i != j { + fut_list.push(ep.send(url.clone(), b"world")); + } + } + } + + for r in futures::future::join_all(fut_list).await { + r.unwrap(); + } +} + +#[tokio::test(flavor = "multi_thread")] +async fn ep3_messages_contiguous() { + let config = Arc::new(Config3::default()); + let test = Test::new().await; + + let (dest_url, _dest_ep, mut dest_recv) = test.ep(config.clone()).await; + + const NODE_COUNT: usize = 3; // 3 nodes + const SEND_COUNT: usize = 10; // sending 10 messages + const CHUNK_COUNT: usize = 10; // broken into 10 chunks + + let mut all_tasks = Vec::new(); + + let start = Arc::new(tokio::sync::Barrier::new(NODE_COUNT)); + let stop = Arc::new(tokio::sync::Barrier::new(NODE_COUNT + 1)); + + for node_id in 0..NODE_COUNT { + let dest_url = dest_url.clone(); + let start = start.clone(); + let stop = stop.clone(); + let (_url, ep, _recv) = test.ep(config.clone()).await; + all_tasks.push(tokio::task::spawn(async move { + let _recv = _recv; + + let mut messages = Vec::new(); + + for msg_id in 0..SEND_COUNT { + let mut chunks = vec![b'-'; ((16 * 1024) - 4) * CHUNK_COUNT]; + + for chunk_id in 0..CHUNK_COUNT { + let data = format!("{node_id}:{msg_id}:{chunk_id}"); + let data = data.as_bytes(); + let s = ((16 * 1024) - 4) * chunk_id; + chunks[s..s + data.len()].copy_from_slice(data); + } + + messages.push(chunks); + } + + start.wait().await; + + for message in messages { + ep.send(dest_url.clone(), &message).await.unwrap(); + } + + stop.wait().await; + })); + } + + let mut sort: HashMap> = HashMap::new(); + + let mut count = 0; + + loop { + let res = dest_recv.recv().await.unwrap(); + match res { + Ep3Event::Message { + peer_url, message, .. + } => { + assert_eq!(((16 * 1024) - 4) * CHUNK_COUNT, message.len()); + for chunk_id in 0..CHUNK_COUNT { + let s = ((16 * 1024) - 4) * chunk_id; + let s = String::from_utf8_lossy(&message[s..s + 32]); + let mut s = s.split("-"); + let s = s.next().unwrap(); + let mut parts = s.split(':'); + let node = parts.next().unwrap().parse().unwrap(); + let msg = parts.next().unwrap().parse().unwrap(); + let chunk = parts.next().unwrap().parse().unwrap(); + sort.entry(peer_url.to_string()) + .or_default() + .push((node, msg, chunk)); + } + + count += 1; + + if count >= NODE_COUNT * SEND_COUNT { + break; + } + } + _ => (), + } + } + + println!("{sort:?}"); + + // make sure the there is no cross-messaging + for (_, list) in sort.iter() { + let (check_node_id, _, _) = list.get(0).unwrap(); + + for (node_id, _, _) in list.iter() { + assert_eq!(check_node_id, node_id); + } + } + + // make sure msg/chunk strictly ascend + for (_, list) in sort.iter() { + let mut expect_msg = 0; + let mut expect_chunk = 0; + + for (_, msg, chunk) in list.iter() { + //println!("msg: {expect_msg}=={msg}, chunk: {expect_chunk}=={chunk}"); + + assert_eq!(expect_msg, *msg); + assert_eq!(expect_chunk, *chunk); + + expect_chunk += 1; + if expect_chunk >= CHUNK_COUNT { + expect_msg += 1; + expect_chunk = 0; + } + } + } + + stop.wait().await; + + for task in all_tasks { + task.await.unwrap(); + } +} + +#[tokio::test(flavor = "multi_thread")] +async fn ep3_preflight_happy() { + use rand::Rng; + + let did_send = Arc::new(std::sync::atomic::AtomicBool::new(false)); + let did_valid = Arc::new(std::sync::atomic::AtomicBool::new(false)); + let mut config = Config3::default(); + + let mut preflight = vec![0; 17 * 1024]; + rand::thread_rng().fill(&mut preflight[..]); + + let pf_send: PreflightSendCb = { + let did_send = did_send.clone(); + let preflight = preflight.clone(); + Arc::new(move |_| { + did_send.store(true, std::sync::atomic::Ordering::SeqCst); + let preflight = preflight.clone(); + Box::pin(async move { Ok(preflight) }) + }) + }; + + let pf_check: PreflightCheckCb = { + let did_valid = did_valid.clone(); + Arc::new(move |_, bytes| { + did_valid.store(true, std::sync::atomic::Ordering::SeqCst); + assert_eq!(preflight, bytes); + Box::pin(async move { Ok(()) }) + }) + }; + + config.preflight = Some((pf_send, pf_check)); + + let config = Arc::new(config); + let test = Test::new().await; + + let (_cli_url1, ep1, _ep1_recv) = test.ep(config.clone()).await; + let (cli_url2, _ep2, mut ep2_recv) = test.ep(config).await; + + ep1.send(cli_url2, b"hello").await.unwrap(); + + let res = ep2_recv.recv().await.unwrap(); + match res { + Ep3Event::Connected { .. } => (), + _ => panic!(), + } + + let res = ep2_recv.recv().await.unwrap(); + match res { + Ep3Event::Message { message, .. } => { + assert_eq!(&b"hello"[..], &message); + } + _ => panic!(), + } + + assert_eq!(true, did_send.load(std::sync::atomic::Ordering::SeqCst)); + assert_eq!(true, did_valid.load(std::sync::atomic::Ordering::SeqCst)); +} + +#[tokio::test(flavor = "multi_thread")] +async fn ep3_ban_after_connected_outgoing_side() { + let config = Arc::new(Config3::default()); + let test = Test::new().await; + + let (_cli_url1, ep1, _ep1_recv) = test.ep(config.clone()).await; + let (cli_url2, _ep2, mut ep2_recv) = test.ep(config).await; + + ep1.send(cli_url2.clone(), b"hello").await.unwrap(); + + let res = ep2_recv.recv().await.unwrap(); + match res { + Ep3Event::Connected { .. } => (), + _ => panic!(), + } + + let res = ep2_recv.recv().await.unwrap(); + match res { + Ep3Event::Message { message, .. } => { + assert_eq!(&b"hello"[..], &message); + } + _ => panic!(), + } + + ep1.ban(cli_url2.id().unwrap(), std::time::Duration::from_secs(10)); + + assert!(ep1.send(cli_url2, b"hello").await.is_err()); + + let stats = ep1.get_stats().await; + + println!("STATS: {}", serde_json::to_string_pretty(&stats).unwrap()); +} + +#[tokio::test(flavor = "multi_thread")] +async fn ep3_recon_after_ban() { + let config = Arc::new(Config3::default()); + let test = Test::new().await; + + let (_cli_url1, ep1, _ep1_recv) = test.ep(config.clone()).await; + let (cli_url2, _ep2, mut ep2_recv) = test.ep(config).await; + + ep1.send(cli_url2.clone(), b"hello").await.unwrap(); + + let res = ep2_recv.recv().await.unwrap(); + match res { + Ep3Event::Connected { .. } => (), + _ => panic!(), + } + + let res = ep2_recv.recv().await.unwrap(); + match res { + Ep3Event::Message { message, .. } => { + assert_eq!(&b"hello"[..], &message); + } + oth => panic!("{oth:?}"), + } + + ep1.ban(cli_url2.id().unwrap(), std::time::Duration::from_millis(10)); + + assert!(ep1.send(cli_url2.clone(), b"hello").await.is_err()); + + let res = ep2_recv.recv().await.unwrap(); + match res { + Ep3Event::Disconnected { .. } => (), + oth => panic!("{oth:?}"), + } + + tokio::time::sleep(std::time::Duration::from_millis(15)).await; + + ep1.send(cli_url2.clone(), b"world").await.unwrap(); + + let res = ep2_recv.recv().await.unwrap(); + match res { + Ep3Event::Connected { .. } => (), + oth => panic!("{oth:?}"), + } + + let res = ep2_recv.recv().await.unwrap(); + match res { + Ep3Event::Message { message, .. } => { + assert_eq!(&b"world"[..], &message); + } + _ => panic!(), + } +} + +#[tokio::test(flavor = "multi_thread")] +async fn ep3_broadcast_happy() { + let config = Arc::new(Config3::default()); + let test = Test::new().await; + + let (_cli_url1, ep1, _ep1_recv) = test.ep(config.clone()).await; + let (cli_url2, _ep2, mut ep2_recv) = test.ep(config.clone()).await; + let (cli_url3, _ep3, mut ep3_recv) = test.ep(config).await; + + ep1.send(cli_url2.clone(), b"hello").await.unwrap(); + + ep1.send(cli_url3.clone(), b"hello").await.unwrap(); + + let res = ep2_recv.recv().await.unwrap(); + match res { + Ep3Event::Connected { .. } => (), + _ => panic!(), + } + + let res = ep2_recv.recv().await.unwrap(); + match res { + Ep3Event::Message { message, .. } => { + assert_eq!(&b"hello"[..], &message); + } + _ => panic!(), + } + + let res = ep3_recv.recv().await.unwrap(); + match res { + Ep3Event::Connected { .. } => (), + _ => panic!(), + } + + let res = ep3_recv.recv().await.unwrap(); + match res { + Ep3Event::Message { message, .. } => { + assert_eq!(&b"hello"[..], &message); + } + _ => panic!(), + } + + ep1.broadcast(b"world").await; + + let res = ep2_recv.recv().await.unwrap(); + match res { + Ep3Event::Message { message, .. } => { + assert_eq!(&b"world"[..], &message); + } + _ => panic!(), + } + + let res = ep3_recv.recv().await.unwrap(); + match res { + Ep3Event::Message { message, .. } => { + assert_eq!(&b"world"[..], &message); + } + _ => panic!(), + } +} diff --git a/crates/tx5/src/lib.rs b/crates/tx5/src/lib.rs index 58f8e265..8db61272 100644 --- a/crates/tx5/src/lib.rs +++ b/crates/tx5/src/lib.rs @@ -36,261 +36,156 @@ pub mod deps { pub use tx5_signal::deps::*; } -use deps::{serde, serde_json}; - -use tx5_core::Uniq; pub use tx5_core::{Error, ErrorExt, Id, Result, Tx5InitConfig, Tx5Url}; -pub mod actor; - -mod back_buf; -pub use back_buf::*; - -/// Helper extension trait for `Box`. -pub trait BytesBufExt { - /// Convert into a `Vec`. - fn to_vec(self) -> Result>; -} - -impl BytesBufExt for Box { - fn to_vec(self) -> Result> { - use bytes::Buf; - use std::io::Read; - let mut out = Vec::with_capacity(self.remaining()); - self.reader().read_to_end(&mut out)?; - Ok(out) - } -} - -const FINISH: u64 = 1 << 63; - -trait FinishExt: Sized { - fn set_finish(&self) -> Self; - fn unset_finish(&self) -> Self; - fn is_finish(&self) -> bool; -} - -impl FinishExt for u64 { - fn set_finish(&self) -> Self { - *self | FINISH - } - - fn unset_finish(&self) -> Self { - *self & !FINISH - } - - fn is_finish(&self) -> bool { - *self & FINISH > 0 - } +mod ep3; +pub use ep3::*; + +pub(crate) mod back_buf; +pub(crate) use back_buf::*; + +pub(crate) mod proto; + +/// Make a shared (clonable) future abortable and set a timeout. +/// The timeout is managed by tokio::time::timeout. +/// The clone-ability is managed by futures::future::shared. +/// The abortability is NOT managed by futures::future::abortable, +/// because we need to be able to pass in a specific error when aborting, +/// so it is managed via a tokio::sync::oneshot channel and tokio::select!. +#[derive(Clone)] +struct AbortableTimedSharedFuture { + f: futures::future::Shared< + futures::future::BoxFuture<'static, std::result::Result>, + >, + a: std::sync::Arc< + std::sync::Mutex>>, + >, } -/// A set of distinct chunks of bytes that can be treated as a single unit -//#[derive(Clone, Default)] -#[derive(Default)] -struct BytesList(pub std::collections::VecDeque); - -impl BytesList { - /// Construct a new BytesList. - pub fn new() -> Self { - Self::default() - } - - /* - /// Construct a new BytesList with given capacity. - pub fn with_capacity(capacity: usize) -> Self { - Self(std::collections::VecDeque::with_capacity(capacity)) - } - */ - - /// Push a new bytes::Bytes into this BytesList. - pub fn push(&mut self, data: bytes::Bytes) { - if bytes::Buf::has_remaining(&data) { - self.0.push_back(data); +impl AbortableTimedSharedFuture { + /// Construct a new AbortableTimedSharedFuture that will timeout + /// after the given duration. + pub fn new( + timeout: std::time::Duration, + timeout_err: Error, + f: F, + ) -> Self + where + F: std::future::Future> + + 'static + + Send, + { + let (a, ar) = tokio::sync::oneshot::channel(); + let a = std::sync::Arc::new(std::sync::Mutex::new(Some(a))); + Self { + f: futures::future::FutureExt::shared( + futures::future::FutureExt::boxed(async move { + tokio::time::timeout( + timeout, + async move { + tokio::select! { + r = async { + Err(ar.await.map_err(|_| Error::id("AbortHandleDropped"))?) + } => r, + r = f => r, + } + }, + ) + .await + .map_err(|_| timeout_err)? + }), + ), + a, } } - /// Convert into a trait object. - pub fn into_dyn(self) -> Box { - Box::new(self) - } - - /* - /// Copy data into a Vec. You should avoid this if possible. - pub fn to_vec(&self) -> Vec { - use std::io::Read; - let mut out = Vec::with_capacity(self.remaining()); - // data is local, it can't error, safe to unwrap - self.clone().reader().read_to_end(&mut out).unwrap(); - out + /// Abort this future with the given error. + pub fn abort(&self, err: Error) { + let a = self.a.lock().unwrap().take(); + if let Some(a) = a { + let _ = a.send(err); } - - /// Extract the contents of this BytesList into a new one - pub fn extract(&mut self) -> Self { - Self(std::mem::take(&mut self.0)) - } - - /// Remove specified byte cnt from the front of this list as a new list. - /// Panics if self doesn't contain enough bytes. - #[allow(clippy::comparison_chain)] // clearer written explicitly - pub fn take_front(&mut self, mut cnt: usize) -> Self { - let mut out = BytesList::new(); - loop { - let mut item = match self.0.pop_front() { - Some(item) => item, - None => panic!("UnexpectedEof"), - }; - - let rem = item.remaining(); - if rem == cnt { - out.push(item); - return out; - } else if rem < cnt { - out.push(item); - cnt -= rem; - } else if rem > cnt { - out.push(item.split_to(cnt)); - self.0.push_front(item); - return out; - } - } - } - */ -} - -/* -impl From> for BytesList { - #[inline(always)] - fn from(v: std::collections::VecDeque) -> Self { - Self(v) - } -} - -impl From for BytesList { - #[inline(always)] - fn from(b: bytes::Bytes) -> Self { - let mut out = std::collections::VecDeque::with_capacity(8); - out.push_back(b); - out.into() - } -} - -impl From> for BytesList { - #[inline(always)] - fn from(v: Vec) -> Self { - bytes::Bytes::from(v).into() } } -impl From<&[u8]> for BytesList { - #[inline(always)] - fn from(b: &[u8]) -> Self { - bytes::Bytes::copy_from_slice(b).into() - } -} +impl std::future::Future for AbortableTimedSharedFuture { + type Output = std::result::Result; -impl From<&[u8; N]> for BytesList { - #[inline(always)] - fn from(b: &[u8; N]) -> Self { - bytes::Bytes::copy_from_slice(&b[..]).into() + fn poll( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll { + std::pin::Pin::new(&mut self.f).poll(cx) } } -*/ -impl bytes::Buf for BytesList { - fn remaining(&self) -> usize { - self.0.iter().map(|b| b.remaining()).sum() - } +#[cfg(test)] +mod test_behavior; - fn chunk(&self) -> &[u8] { - match self.0.front() { - Some(b) => b.chunk(), - None => &[], +#[cfg(test)] +mod test { + use super::*; + + #[tokio::test(flavor = "multi_thread")] + async fn atsf_traits() { + fn check(_f: F) + where + F: Send + Sync + Unpin, + { } - } - - #[allow(clippy::comparison_chain)] // clearer written explicitly - fn advance(&mut self, mut cnt: usize) { - loop { - let mut item = match self.0.pop_front() { - Some(item) => item, - None => return, - }; - let rem = item.remaining(); - if rem == cnt { - return; - } else if rem < cnt { - cnt -= rem; - } else if rem > cnt { - item.advance(cnt); - self.0.push_front(item); - return; - } + let a = AbortableTimedSharedFuture::new( + std::time::Duration::from_millis(10), + Error::str("my timeout err").into(), + async move { Ok(()) }, + ); + + check(a); + } + + #[tokio::test(flavor = "multi_thread")] + async fn atsf_happy() { + AbortableTimedSharedFuture::new( + std::time::Duration::from_secs(1), + Error::id("to").into(), + async move { Ok(()) }, + ) + .await + .unwrap(); + } + + #[tokio::test(flavor = "multi_thread")] + async fn atsf_timeout() { + let r = AbortableTimedSharedFuture::new( + std::time::Duration::from_millis(1), + Error::id("to").into(), + async move { + tokio::time::sleep(std::time::Duration::from_secs(1)).await; + Ok(()) + }, + ) + .await; + assert_eq!("to", r.unwrap_err().to_string()); + } + + #[tokio::test(flavor = "multi_thread")] + async fn atsf_abort() { + let a = AbortableTimedSharedFuture::new( + std::time::Duration::from_secs(1), + Error::id("to").into(), + async move { + tokio::time::sleep(std::time::Duration::from_secs(1)).await; + Ok(()) + }, + ); + { + let a = a.clone(); + tokio::task::spawn(async move { + tokio::time::sleep(std::time::Duration::from_millis(10)).await; + a.abort(Error::id("abort").into()); + }); } + let r = a.await; + assert_eq!("abort", r.unwrap_err().to_string()); } } - -pub mod state; - -mod config; -pub use config::*; - -mod endpoint; -pub use endpoint::*; - -fn divide_send( - config: &dyn Config, - snd_ident: &std::sync::atomic::AtomicU64, - mut data: B, -) -> Result> { - use std::io::Write; - - let max_send_bytes = config.max_send_bytes(); - - if bytes::Buf::remaining(&data) > max_send_bytes as usize { - Err(Error::id("DataTooLarge")) - } else { - (|| { - let ident = - snd_ident.fetch_add(1, std::sync::atomic::Ordering::Relaxed); - - let mut buf_list = Vec::new(); - - const MAX_MSG: usize = (16 * 1024) - 8; - while data.has_remaining() { - let loc_len = std::cmp::min(data.remaining(), MAX_MSG); - let ident = if data.remaining() <= loc_len { - ident.set_finish() - } else { - ident.unset_finish() - }; - - tracing::trace!(ident=%ident.unset_finish(), is_finish=%ident.is_finish(), %loc_len, "prepare send"); - - let mut tmp = - bytes::Buf::reader(bytes::Buf::take(data, loc_len)); - - // TODO - reserve the bytes before writing - let mut buf = BackBuf::from_writer()?; - buf.write_all(&ident.to_le_bytes())?; - std::io::copy(&mut tmp, &mut buf)?; - - buf_list.push(buf.finish()); - - data = tmp.into_inner().into_inner(); - } - - if buf_list.is_empty() { - let ident = ident.set_finish(); - let mut buf = BackBuf::from_writer()?; - buf.write_all(&ident.to_le_bytes())?; - buf_list.push(buf.finish()); - } - - Ok(buf_list) - })() - } -} - -#[cfg(test)] -mod test; diff --git a/crates/tx5/src/proto.rs b/crates/tx5/src/proto.rs new file mode 100644 index 00000000..a5d9c449 --- /dev/null +++ b/crates/tx5/src/proto.rs @@ -0,0 +1,589 @@ +//! Types associated with the Tx5 protocol. +//! +//! The tx5 protocol solves 2 problems at once: +//! - First, webrtc only supports messages up to 16K, so this lets us send +//! bigger messages. +//! - Second, if we're sending bigger messages, we have to worry about the size +//! in memory taken up by the receiving side. This protocol lets us request +//! permits to send larger messages, which gives the receiving side a tool to +//! be able to manage more open connections at the same time without worrying +//! about the worst case memory usage of each connection all at the same time. + +use crate::BackBuf; +use tx5_core::{Error, Result}; + +/// Tx5 protocol message payload max size. +pub(crate) const MAX_PAYLOAD: u32 = 0b00011111111111111111111111111111; + +/// Protocol version 1 header. +pub(crate) const PROTO_VER_1: ProtoHeader = + ProtoHeader::Version(1, [b't', b'x', b'5']); + +/// 4 byte Tx5 protocol header. +/// +/// The initial 3 bits representing values 0, 1, and 7 are reserved. +/// Decoders should error on receiving these values. +#[derive(Debug, PartialEq)] +pub(crate) enum ProtoHeader { + /// This is a protocol version handshake. + /// `3 bits == 2` + /// `5 bits == version number (currently 1)` + /// `24 bits == as bytes b"tx5"` + Version(u8, [u8; 3]), + + /// The remainder of this message represents the entirety of a message. + /// `3 bits == 3` + /// `29 bits == the message size == full chunk size - 4` + CompleteMessage(u32), + + /// The remainder of this message is a chunk of a multi-part message. + /// `3 bits == 4` + /// `29 bits == byte count included in this message chunk` + MultipartMessage(u32), + + /// This is a request for a permit to send a message of size larger than + /// a single chunk. + /// `3 bits == 5` + /// `29 bits == the size of the message` + PermitRequest(u32), + + /// This is an authorization to proceed with sending payload chunks. + /// `3 bits == 6` + /// `29 bits == the size of the message` + PermitGrant(u32), +} + +impl ProtoHeader { + /// Decode 4 bytes into a Tx5 protocol header. + pub fn decode(a: u8, b: u8, c: u8, d: u8) -> Result { + use bit_field::BitField; + let r = u32::from_be_bytes([a, b, c, d]); + match r.get_bits(29..) { + 2 => Ok(Self::Version( + r.get_bits(24..29) as u8, + [ + r.get_bits(16..24) as u8, + r.get_bits(8..16) as u8, + r.get_bits(0..8) as u8, + ], + )), + 3 => Ok(Self::CompleteMessage(r.get_bits(..29))), + 4 => Ok(Self::MultipartMessage(r.get_bits(..29))), + 5 => Ok(Self::PermitRequest(r.get_bits(..29))), + 6 => Ok(Self::PermitGrant(r.get_bits(..29))), + _ => Err(Error::id("ReservedHeaderBits")), + } + } + + /// Encode a Tx5 protocol header into canonical 4 bytes. + pub fn encode(&self) -> Result<(u8, u8, u8, u8)> { + use bit_field::BitField; + let mut out: u32 = 0; + match self { + Self::Version(v, [a, b, c]) => { + if *v > 31 { + return Err(Error::id("VersionOverflow")); + } + out.set_bits(29.., 2); + out.set_bits(24..29, *v as u32); + out.set_bits(16..24, *a as u32); + out.set_bits(8..16, *b as u32); + out.set_bits(0..8, *c as u32); + } + Self::CompleteMessage(s) => { + if *s > MAX_PAYLOAD { + return Err(Error::id("SizeOverflow")); + } + out.set_bits(29.., 3); + out.set_bits(..29, *s); + } + Self::MultipartMessage(s) => { + if *s > MAX_PAYLOAD { + return Err(Error::id("SizeOverflow")); + } + out.set_bits(29.., 4); + out.set_bits(..29, *s); + } + Self::PermitRequest(s) => { + if *s > MAX_PAYLOAD { + return Err(Error::id("SizeOverflow")); + } + out.set_bits(29.., 5); + out.set_bits(..29, *s); + } + Self::PermitGrant(s) => { + if *s > MAX_PAYLOAD { + return Err(Error::id("SizeOverflow")); + } + out.set_bits(29.., 6); + out.set_bits(..29, *s); + } + } + let out = out.to_be_bytes(); + Ok((out[0], out[1], out[2], out[3])) + } +} + +/// Result of encoding a message into the Tx5 protocol. +pub(crate) enum ProtoEncodeResult { + /// We need to request a permit. Send the permit request first, + /// once we receive the authorization, forward the rest of + /// the message payload. + NeedPermit { + /// First, request a permit to send the payload. + permit_req: BackBuf, + + /// Second, send the actual payload chunks. + msg_payload: Vec, + }, + + /// This message fit in a single payload chunk. We do not need + /// to request a permit ahead of time, so just send the chunk. + OneMessage(BackBuf), +} + +/// Encode some data into the Tx5 protocol. +pub(crate) fn proto_encode(data: &[u8]) -> Result { + const MAX: usize = (16 * 1024) - 4; + let len = data.len(); + + if len > MAX_PAYLOAD as usize { + return Err(Error::id("PayloadSizeOverflow")); + } + + if len <= MAX { + let (a, b, c, d) = ProtoHeader::CompleteMessage(len as u32).encode()?; + + let mut go_buf = tx5_go_pion::GoBuf::new()?; + go_buf.reserve(len + 4)?; + go_buf.extend(&[a, b, c, d])?; + go_buf.extend(data)?; + + Ok(ProtoEncodeResult::OneMessage(BackBuf::from_raw(go_buf))) + } else { + let (a, b, c, d) = ProtoHeader::PermitRequest(len as u32).encode()?; + let permit_req = BackBuf::from_slice([a, b, c, d])?; + + let mut msg_payload = Vec::new(); + let mut cur = 0; + + while len - cur > 0 { + let amt = std::cmp::min((16 * 1024) - 4, len - cur); + + let (a, b, c, d) = + ProtoHeader::MultipartMessage(amt as u32).encode()?; + + let mut go_buf = tx5_go_pion::GoBuf::new()?; + go_buf.reserve(amt + 4)?; + go_buf.extend(&[a, b, c, d])?; + go_buf.extend(&data[cur..cur + amt])?; + + msg_payload.push(BackBuf::from_raw(go_buf)); + + cur += amt; + } + + Ok(ProtoEncodeResult::NeedPermit { + permit_req, + msg_payload, + }) + } +} + +/// Result of decoding an incoming message chunk. +#[derive(Debug, PartialEq)] +pub(crate) enum ProtoDecodeResult { + /// Nothing needs to happen at the moment... continue receiving chunks. + Idle, + + /// Received incoming message. + Message(Vec), + + /// The remote node is requesting a permit to send us chunks of data. + RemotePermitRequest(u32), + + /// The remote node has granted us a permit to send them chunks of data. + RemotePermitGrant(u32), +} + +#[derive(Clone, Copy, PartialEq)] +enum DecodeState { + NeedVersion, + Ready, + /// The REMOTE requested a permit from US. + /// Totally different from when we make a permit request of the remote : ) + RemoteAwaitingPermit(u32), + ReceiveChunked, +} + +/// Tx5 protocol decoder. +pub(crate) struct ProtoDecoder { + state: DecodeState, + want_size: usize, + incoming: Vec, + want_remote_grant: bool, + did_error: bool, + grant_permit: Option, + grant_notify: Option>, +} + +impl Default for ProtoDecoder { + fn default() -> Self { + Self { + state: DecodeState::NeedVersion, + want_size: 0, + incoming: Vec::new(), + want_remote_grant: false, + did_error: false, + grant_permit: None, + grant_notify: None, + } + } +} + +impl ProtoDecoder { + /// Notify the decoder that we sent the previously requested permit + /// to the remote. + pub fn sent_remote_permit_grant( + &mut self, + grant_permit: tokio::sync::OwnedSemaphorePermit, + ) -> Result<()> { + self.check_err()?; + if let DecodeState::RemoteAwaitingPermit(permit_len) = self.state { + self.state = DecodeState::ReceiveChunked; + self.want_size = permit_len as usize; + self.incoming.reserve(self.want_size); + self.grant_permit = Some(grant_permit); + Ok(()) + } else { + self.did_error = true; + Err(Error::id("InvalidStateToSendPermit")) + } + } + + /// Notify the decoder that we have requested a permit from the remote, + /// so we should expect to receive a grant. + pub fn sent_remote_permit_request( + &mut self, + grant_notify: Option>, + ) -> Result<()> { + self.check_err()?; + if self.want_remote_grant { + self.did_error = true; + Err(Error::id("InvalidDuplicatePermitRequest")) + } else { + self.want_remote_grant = true; + self.grant_notify = grant_notify; + Ok(()) + } + } + + /// Process the next incoming chunk from the remote. + pub fn decode(&mut self, chunk: BackBuf) -> Result { + self.check_err()?; + match self.priv_decode(chunk) { + Ok(r) => Ok(r), + Err(err) => { + self.did_error = true; + Err(err) + } + } + } + + fn check_err(&self) -> Result<()> { + if self.did_error { + Err(Error::id("FnCallOnErroredDecoder")) + } else { + Ok(()) + } + } + + fn priv_decode(&mut self, mut chunk: BackBuf) -> Result { + chunk.imp.buf.access(|buf| { + let buf = buf?; + let len = buf.len(); + if len < 4 { + return Err(Error::id("InvalidHeaderLen")); + } + + match ProtoHeader::decode(buf[0], buf[1], buf[2], buf[3])? { + ProtoHeader::Version(v, [a, b, c]) => { + if v != 1 || a != b't' || b != b'x' || c != b'5' { + return Err(Error::err(format!( + "invalid version v = {v}, tag = {}", + String::from_utf8_lossy(&[a, b, c][..]), + ))); + } + + if self.state == DecodeState::NeedVersion { + self.state = DecodeState::Ready; + Ok(ProtoDecodeResult::Idle) + } else { + Err(Error::id("RecvUnexpectedVersionMessage")) + } + } + ProtoHeader::CompleteMessage(msg_len) => { + if self.state == DecodeState::Ready { + if msg_len as usize != len - 4 { + return Err(Error::id("InvalidCompleteMessageLen")); + } + + Ok(ProtoDecodeResult::Message(buf[4..].to_vec())) + } else { + Err(Error::id("RecvUnexpectedCompleteMessage")) + } + } + ProtoHeader::MultipartMessage(msg_len) => { + if self.state == DecodeState::ReceiveChunked { + if msg_len as usize != len - 4 || msg_len == 0 { + return Err(Error::id( + "InvalidMultipartMessageLen", + )); + } + + if msg_len as usize + self.incoming.len() + > self.want_size + { + return Err(Error::id("ChunkTooLarge")); + } + + self.incoming.extend_from_slice(&buf[4..]); + + if self.incoming.len() == self.want_size { + drop(self.grant_permit.take()); + self.state = DecodeState::Ready; + Ok(ProtoDecodeResult::Message(std::mem::take( + &mut self.incoming, + ))) + } else { + Ok(ProtoDecodeResult::Idle) + } + } else { + Err(Error::id("RecvUnexpectedMultipartMessage")) + } + } + ProtoHeader::PermitRequest(permit_len) => { + if self.state == DecodeState::Ready { + self.state = + DecodeState::RemoteAwaitingPermit(permit_len); + Ok(ProtoDecodeResult::RemotePermitRequest(permit_len)) + } else { + Err(Error::id("RecvUnexpectedPermitRequest")) + } + } + ProtoHeader::PermitGrant(permit_len) => { + if self.want_remote_grant { + self.want_remote_grant = false; + if let Some(grant_notify) = self.grant_notify.take() { + let _ = grant_notify.send(()); + } + Ok(ProtoDecodeResult::RemotePermitGrant(permit_len)) + } else { + Err(Error::id("RecvUnexpectedPermitGrant")) + } + } + } + }) + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn proto_header_encode_decode() { + fn check(hdr: ProtoHeader) { + let (a, b, c, d) = hdr.encode().unwrap(); + let res = ProtoHeader::decode(a, b, c, d).unwrap(); + assert_eq!(hdr, res); + } + + for v in 0..32 { + check(ProtoHeader::Version(v, [b't', b'x', b'5'])); + } + + for v in &[0, 42, 0b00011111111111111111111111111111] { + check(ProtoHeader::CompleteMessage(*v)); + check(ProtoHeader::MultipartMessage(*v)); + check(ProtoHeader::PermitRequest(*v)); + check(ProtoHeader::PermitGrant(*v)); + } + } + + #[test] + fn proto_header_overflow() { + assert!(ProtoHeader::Version(0b00100000, [b't', b'x', b'5']) + .encode() + .is_err()); + assert!(ProtoHeader::CompleteMessage(u32::MAX).encode().is_err()); + assert!(ProtoHeader::MultipartMessage(u32::MAX).encode().is_err()); + assert!(ProtoHeader::PermitRequest(u32::MAX).encode().is_err()); + assert!(ProtoHeader::PermitGrant(u32::MAX).encode().is_err()); + } + + #[test] + fn proto_header_version_1() { + const PROTO_VERSION_1: &[u8; 4] = &[ + 0b01000001, // 010 for 3 bits == 2, 00001 for version #1 + b't', b'x', b'5', + ]; + + let (a, b, c, d) = PROTO_VER_1.encode().unwrap(); + + assert_eq!(PROTO_VERSION_1[0], a); + assert_eq!(PROTO_VERSION_1[1], b); + assert_eq!(PROTO_VERSION_1[2], c); + assert_eq!(PROTO_VERSION_1[3], d); + } + + #[test] + fn proto_decode_complete_msg() { + let mut dec = ProtoDecoder::default(); + let (a, b, c, d) = PROTO_VER_1.encode().unwrap(); + assert_eq!( + ProtoDecodeResult::Idle, + dec.decode(BackBuf::from_slice(&[a, b, c, d]).unwrap()) + .unwrap(), + ); + match proto_encode(b"hello").unwrap() { + ProtoEncodeResult::OneMessage(buf) => { + match dec.decode(buf).unwrap() { + ProtoDecodeResult::Message(msg) => { + assert_eq!(b"hello", msg.as_slice()); + } + _ => panic!(), + } + } + _ => panic!(), + } + } + + #[test] + fn proto_decode_chunked_msg() { + use rand::Rng; + let mut dec = ProtoDecoder::default(); + let (a, b, c, d) = PROTO_VER_1.encode().unwrap(); + assert_eq!( + ProtoDecodeResult::Idle, + dec.decode(BackBuf::from_slice(&[a, b, c, d]).unwrap()) + .unwrap(), + ); + let mut msg = vec![0; 15 * 1024 * 1024]; + rand::thread_rng().fill(&mut msg[..]); + match proto_encode(&msg).unwrap() { + ProtoEncodeResult::NeedPermit { + permit_req, + mut msg_payload, + } => { + match dec.decode(permit_req).unwrap() { + ProtoDecodeResult::RemotePermitRequest(permit_len) => { + assert_eq!(15 * 1024 * 1024, permit_len); + } + _ => panic!(), + } + + dec.sent_remote_permit_grant( + std::sync::Arc::new(tokio::sync::Semaphore::new(1)) + .try_acquire_owned() + .unwrap(), + ) + .unwrap(); + + while msg_payload.len() > 1 { + assert_eq!( + ProtoDecodeResult::Idle, + dec.decode(msg_payload.remove(0)).unwrap(), + ) + } + + match dec.decode(msg_payload.remove(0)).unwrap() { + ProtoDecodeResult::Message(msg_res) => { + assert_eq!(msg, msg_res); + } + _ => panic!(), + } + } + _ => panic!(), + } + } + + #[test] + fn proto_decode_bad_version() { + let mut dec = ProtoDecoder::default(); + assert!(dec.decode(BackBuf::from_slice(b"hello").unwrap()).is_err()); + } + + #[test] + fn proto_decode_no_duplicate_permit_requests() { + let mut dec = ProtoDecoder::default(); + let (a, b, c, d) = PROTO_VER_1.encode().unwrap(); + assert_eq!( + ProtoDecodeResult::Idle, + dec.decode(BackBuf::from_slice(&[a, b, c, d]).unwrap()) + .unwrap(), + ); + dec.sent_remote_permit_request(None).unwrap(); + assert!(dec.sent_remote_permit_request(None).is_err()); + } + + #[test] + fn proto_decode_grant_during_multipart() { + use rand::Rng; + let mut dec = ProtoDecoder::default(); + let (a, b, c, d) = PROTO_VER_1.encode().unwrap(); + assert_eq!( + ProtoDecodeResult::Idle, + dec.decode(BackBuf::from_slice(&[a, b, c, d]).unwrap()) + .unwrap(), + ); + + dec.sent_remote_permit_request(None).unwrap(); + + let mut msg = vec![0; 17 * 1024]; + rand::thread_rng().fill(&mut msg[..]); + match proto_encode(&msg).unwrap() { + ProtoEncodeResult::NeedPermit { + permit_req, + mut msg_payload, + } => { + match dec.decode(permit_req).unwrap() { + ProtoDecodeResult::RemotePermitRequest(permit_len) => { + assert_eq!(17 * 1024, permit_len); + } + _ => panic!(), + } + + dec.sent_remote_permit_grant( + std::sync::Arc::new(tokio::sync::Semaphore::new(1)) + .try_acquire_owned() + .unwrap(), + ) + .unwrap(); + + assert_eq!(2, msg_payload.len()); + + assert_eq!( + ProtoDecodeResult::Idle, + dec.decode(msg_payload.remove(0)).unwrap(), + ); + + let (a, b, c, d) = + ProtoHeader::PermitGrant(18 * 1024).encode().unwrap(); + assert_eq!( + ProtoDecodeResult::RemotePermitGrant(18 * 1024), + dec.decode(BackBuf::from_slice(&[a, b, c, d]).unwrap()) + .unwrap(), + ); + + match dec.decode(msg_payload.remove(0)).unwrap() { + ProtoDecodeResult::Message(msg_res) => { + assert_eq!(msg, msg_res); + } + _ => panic!(), + } + } + _ => panic!(), + } + } +} diff --git a/crates/tx5/src/state.rs b/crates/tx5/src/state.rs deleted file mode 100644 index 1c8396e5..00000000 --- a/crates/tx5/src/state.rs +++ /dev/null @@ -1,1235 +0,0 @@ -//! Tx5 high-level conn mgmt state. - -use crate::actor::*; -use crate::*; - -use std::collections::VecDeque; -use std::collections::{hash_map, HashMap}; -use std::future::Future; -use std::sync::Arc; - -use influxive_otel_atomic_obs::*; -use opentelemetry_api::metrics::MeterProvider; - -use tx5_core::{Id, Tx5Url}; - -mod sig; -pub use sig::*; - -mod conn; -pub use conn::*; - -mod drop_consider; - -#[cfg(test)] -mod test; - -/// The max connection open time. Would be nice for this to be a negotiation, -/// so that it could be configured... but right now we just need both sides -/// to agree, so it is hard-coded. -const MAX_CON_TIME: std::time::Duration = - std::time::Duration::from_secs(60 * 5); - -/// The connection send grace period. Connections will not send new messages -/// when within this duration from the MAX_CON_TIME close. -/// Similar to MAX_CON_TIME, this has to be hard-coded for now. -const CON_CLOSE_SEND_GRACE: std::time::Duration = - std::time::Duration::from_secs(30); - -/// Respond type. -#[must_use] -pub struct OneSnd( - Option) + 'static + Send>>, -); - -impl Drop for OneSnd { - fn drop(&mut self) { - self.send(Err(Error::id("Dropped"))) - } -} - -impl OneSnd { - pub(crate) fn new(cb: Cb) -> Self - where - Cb: FnOnce(Result) + 'static + Send, - { - Self(Some(Box::new(cb))) - } - - /// Send data on this single sender respond type. - pub fn send(&mut self, t: Result) { - if let Some(sender) = self.0.take() { - sender(t); - } - } - - /// Wrap such that a closure's result is sent. - pub async fn with(&mut self, cb: Cb) - where - Fut: Future>, - Cb: FnOnce() -> Fut, - { - self.send(cb().await); - } -} - -/// Drop this when you consider the data "received". -pub struct Permit(Vec); - -impl std::fmt::Debug for Permit { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("Permit").finish() - } -} - -/// State wishes to invoke an action. -pub enum StateEvt { - /// Request to create a new signal client connection. - NewSig(Tx5Url, SigStateSeed), - - /// Indicates the current node is addressable at the given url. - Address(Tx5Url), - - /// Request to create a new webrtc peer connection. - NewConn(Arc, ConnStateSeed), - - /// Incoming data received on a peer connection. - RcvData(Tx5Url, Box, Vec), - - /// Received a demo broadcast. - Demo(Tx5Url), - - /// This is an informational notification indicating a connection - /// has been successfully established. Unlike 'NewConn' above, - /// no action is required, other than to let your users know. - Connected(Tx5Url), - - /// This is an informational notification indicating a connection has - /// been dropped. No action is required, other than to let your users know. - /// Note, you may get disconnected events for connections that were never - /// successfully established. - Disconnected(Tx5Url), -} - -impl std::fmt::Debug for StateEvt { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - StateEvt::NewSig(url, _seed) => f - .debug_struct("StateEvt::NewSig") - .field("url", url) - .finish(), - StateEvt::Address(url) => f - .debug_struct("StateEvt::Address") - .field("url", url) - .finish(), - StateEvt::NewConn(info, _seed) => f - .debug_struct("StateEvt::NewConn") - .field("info", info) - .finish(), - StateEvt::RcvData(url, data, _permit) => { - let data_len = data.remaining(); - f.debug_struct("StateEvt::RcvData") - .field("url", url) - .field("data_len", &data_len) - .finish() - } - StateEvt::Demo(url) => { - f.debug_struct("StateEvt::Demo").field("url", url).finish() - } - StateEvt::Connected(url) => f - .debug_struct("StateEvt::Connected") - .field("url", url) - .finish(), - StateEvt::Disconnected(url) => f - .debug_struct("StateEvt::Disconnected") - .field("url", url) - .finish(), - } - } -} - -#[derive(Clone)] -struct StateEvtSnd(tokio::sync::mpsc::UnboundedSender>); - -impl StateEvtSnd { - pub fn err(&self, err: std::io::Error) { - let _ = self.0.send(Err(err)); - } - - pub fn publish(&self, evt: StateEvt) -> Result<()> { - self.0.send(Ok(evt)).map_err(|_| Error::id("Closed")) - } -} - -pub(crate) struct SendData { - msg_uniq: Uniq, - data: BackBuf, - timestamp: std::time::Instant, - resp: Option>>, - send_permit: tokio::sync::OwnedSemaphorePermit, -} - -struct IceData { - timestamp: std::time::Instant, - ice: BackBuf, -} - -struct RmConn(StateEvtSnd, Tx5Url); - -impl Drop for RmConn { - fn drop(&mut self) { - let _ = self.0.publish(StateEvt::Disconnected(self.1.clone())); - } -} - -struct StateData { - state_uniq: Uniq, - this_id: Option, - this: StateWeak, - meta: StateMeta, - evt: StateEvtSnd, - signal_map: HashMap, - ban_map: HashMap, - conn_map: HashMap, - send_map: HashMap>, - ice_cache: HashMap>, - recv_limit: Arc, -} - -impl Drop for StateData { - fn drop(&mut self) { - self.shutdown(Error::id("Dropped")); - } -} - -impl StateData { - fn shutdown(&mut self, err: std::io::Error) { - tracing::trace!(state_uniq = %self.state_uniq, this_id = ?self.this_id, "StateShutdown"); - for (_, sig) in self.signal_map.drain() { - if let Some(sig) = sig.upgrade() { - sig.close(err.err_clone()); - } - } - for (_, (conn, _)) in self.conn_map.drain() { - if let Some(conn) = conn.upgrade() { - conn.close(err.err_clone()); - } - } - self.send_map.clear(); - self.evt.err(err); - } - - async fn exec(&mut self, cmd: StateCmd) -> Result<()> { - // TODO - any errors returned by these fn calls will shut down - // the entire endpoint... probably not what we want. - // Instead, maybe shutting down the whole endpoint should - // require a special call, and otherwise errors can be - // logged / ignored OR just not allow returning errors. - match cmd { - StateCmd::Tick1s => self.tick_1s().await, - StateCmd::TrackSig { rem_id, ty, bytes } => { - self.track_sig(rem_id, ty, bytes).await - } - StateCmd::SndDemo => self.snd_demo().await, - StateCmd::ListConnected(resp) => self.list_connected(resp).await, - StateCmd::AssertListenerSig { sig_url, resp } => { - self.assert_listener_sig(sig_url, resp).await - } - StateCmd::SendData { - msg_uniq, - rem_id, - data, - timestamp, - send_permit, - resp, - cli_url, - } => { - self.send_data( - msg_uniq, - rem_id, - data, - timestamp, - send_permit, - resp, - cli_url, - ) - .await - } - StateCmd::Ban { rem_id, span } => self.ban(rem_id, span).await, - StateCmd::Stats(resp) => self.stats(resp).await, - StateCmd::Publish { evt } => self.publish(evt).await, - StateCmd::SigConnected { cli_url } => { - self.sig_connected(cli_url).await - } - StateCmd::FetchForSend { conn, rem_id } => { - self.fetch_for_send(conn, rem_id).await - } - StateCmd::InOffer { - sig_url, - rem_id, - data, - } => self.in_offer(sig_url, rem_id, data).await, - StateCmd::InDemo { sig_url, rem_id } => { - self.in_demo(sig_url, rem_id).await - } - StateCmd::CacheIce { rem_id, ice } => { - self.cache_ice(rem_id, ice).await - } - StateCmd::GetCachedIce { rem_id } => { - self.get_cached_ice(rem_id).await - } - StateCmd::CloseSig { sig_url, sig, err } => { - self.close_sig(sig_url, sig, err).await - } - StateCmd::ConnReady { cli_url } => self.conn_ready(cli_url).await, - StateCmd::CloseConn { rem_id, conn, err } => { - self.close_conn(rem_id, conn, err).await - } - } - } - - async fn tick_1s(&mut self) -> Result<()> { - let timeout = self.meta.config.max_conn_init(); - - self.ice_cache.retain(|_, list| { - list.retain_mut(|data| data.timestamp.elapsed() < timeout); - !list.is_empty() - }); - - self.send_map.retain(|_, list| { - list.retain_mut(|info| { - if info.timestamp.elapsed() < timeout { - true - } else { - tracing::trace!(msg_uniq = %info.msg_uniq, "dropping msg due to timeout"); - if let Some(resp) = info.resp.take() { - let _ = resp.send(Err(Error::id("Timeout"))); - } - false - } - }); - !list.is_empty() - }); - - let tot_conn_cnt = self.meta.metric_conn_count.get(); - - let mut tot_snd_bytes = 0; - let mut tot_rcv_bytes = 0; - let mut tot_age = 0.0; - let mut age_cnt = 0.0; - - for (_, (conn, _)) in self.conn_map.iter() { - let meta = conn.meta(); - tot_snd_bytes += meta.metric_bytes_snd.get(); - tot_rcv_bytes += meta.metric_bytes_rcv.get(); - tot_age += meta.created_at.elapsed().as_secs_f64(); - age_cnt += 1.0; - } - - let tot_avg_age_s = if age_cnt > 0.0 { - tot_age / age_cnt - } else { - 0.0 - }; - - self.conn_map.retain(|_, (conn, _)| { - let meta = conn.meta(); - - let args = drop_consider::DropConsiderArgs { - conn_uniq: meta.conn_uniq.clone(), - cfg_conn_max_cnt: meta.config.max_conn_count() as i64, - cfg_conn_max_init: meta.config.max_conn_init().as_secs_f64(), - tot_conn_cnt, - tot_snd_bytes, - tot_rcv_bytes, - tot_avg_age_s, - this_connected: meta - .connected - .load(std::sync::atomic::Ordering::SeqCst), - this_snd_bytes: meta.metric_bytes_snd.get(), - this_rcv_bytes: meta.metric_bytes_rcv.get(), - this_age_s: meta.created_at.elapsed().as_secs_f64(), - this_last_active_s: meta.last_active_at.elapsed().as_secs_f64(), - }; - - if let drop_consider::DropConsiderResult::MustDrop = - drop_consider::drop_consider(&args) - { - if let Some(conn) = conn.upgrade() { - conn.close(Error::id("DropContention")); - } - return false; - } - - true - }); - - Ok(()) - } - - async fn track_sig( - &mut self, - rem_id: Id, - ty: &'static str, - bytes: usize, - ) -> Result<()> { - if let Some((conn, _)) = self.conn_map.get(&rem_id) { - if let Some(conn) = conn.upgrade() { - conn.track_sig(ty, bytes); - } - } - Ok(()) - } - - async fn snd_demo(&mut self) -> Result<()> { - for (_, sig) in self.signal_map.iter() { - if let Some(sig) = sig.upgrade() { - sig.snd_demo(); - } - } - - Ok(()) - } - - async fn list_connected( - &mut self, - resp: tokio::sync::oneshot::Sender>>, - ) -> Result<()> { - let mut urls = Vec::new(); - for (_, (con, _)) in self.conn_map.iter() { - if let Some(con) = con.upgrade() { - urls.push(con.rem_url()); - } - } - resp.send(Ok(urls)).map_err(|_| Error::id("Closed")) - } - - async fn assert_listener_sig( - &mut self, - sig_url: Tx5Url, - resp: tokio::sync::oneshot::Sender>, - ) -> Result<()> { - tracing::debug!(state_uniq = %self.state_uniq, %sig_url, "begin register with signal server"); - let new_sig = |resp| -> SigState { - let (sig, sig_evt) = SigState::new( - self.this.clone(), - sig_url.clone(), - resp, - self.meta.config.max_conn_init(), - ); - let seed = SigStateSeed::new(sig.clone(), sig_evt); - let _ = self.evt.publish(StateEvt::NewSig(sig_url.clone(), seed)); - sig - }; - match self.signal_map.entry(sig_url.clone()) { - hash_map::Entry::Occupied(mut e) => match e.get().upgrade() { - Some(sig) => sig.push_assert_respond(resp).await, - None => { - let sig = new_sig(resp); - e.insert(sig.weak()); - } - }, - hash_map::Entry::Vacant(e) => { - let sig = new_sig(resp); - e.insert(sig.weak()); - } - } - Ok(()) - } - - fn is_banned(&mut self, rem_id: Id) -> bool { - let now = std::time::Instant::now(); - self.ban_map.retain(|_id, expires_at| *expires_at > now); - self.ban_map.contains_key(&rem_id) - } - - async fn create_new_conn( - &mut self, - sig_url: Tx5Url, - rem_id: Id, - maybe_offer: Option, - maybe_msg_uniq: Option, - ) -> Result<()> { - if self.is_banned(rem_id) { - tracing::warn!( - ?rem_id, - "Ignoring request to create con to banned remote" - ); - return Ok(()); - //return Err(Error::id("Ban")); - } - - let (s, r) = tokio::sync::oneshot::channel(); - if let Err(err) = self.assert_listener_sig(sig_url.clone(), s).await { - tracing::warn!(?err, "failed to assert signal listener"); - return Ok(()); - } - - let sig = self.signal_map.get(&sig_url).unwrap().clone(); - - let conn_uniq = self.state_uniq.sub(); - - tracing::trace!(?maybe_msg_uniq, %conn_uniq, "create_new_conn"); - - let cli_url = sig_url.to_client(rem_id); - let conn = match ConnState::new_and_publish( - self.meta.config.clone(), - self.meta.conn_limit.clone(), - self.meta.metric_conn_count.clone(), - self.this.clone(), - sig, - self.state_uniq.clone(), - conn_uniq, - self.this_id.unwrap(), - cli_url.clone(), - rem_id, - self.recv_limit.clone(), - r, - maybe_offer, - self.meta.snd_ident.clone(), - ) { - Err(err) => { - tracing::warn!(?err, "failed to create conn state"); - return Ok(()); - } - Ok(conn) => conn, - }; - - self.conn_map - .insert(rem_id, (conn, RmConn(self.evt.clone(), cli_url))); - - Ok(()) - } - - #[allow(clippy::too_many_arguments)] - async fn send_data( - &mut self, - msg_uniq: Uniq, - rem_id: Id, - data: BackBuf, - timestamp: std::time::Instant, - send_permit: tokio::sync::OwnedSemaphorePermit, - data_sent: tokio::sync::oneshot::Sender>, - cli_url: Tx5Url, - ) -> Result<()> { - if self.is_banned(rem_id) { - tracing::warn!( - ?rem_id, - "Ignoring request to send data to banned remote" - ); - let _ = data_sent.send(Err(Error::id("Ban"))); - return Ok(()); - } - - self.send_map - .entry(rem_id) - .or_default() - .push_back(SendData { - msg_uniq: msg_uniq.clone(), - data, - timestamp, - resp: Some(data_sent), - send_permit, - }); - - let rem_id = cli_url.id().unwrap(); - - if let Some((e, _)) = self.conn_map.get(&rem_id) { - if let Some(conn) = e.upgrade() { - conn.check_send_waiting(None).await; - return Ok(()); - } else { - self.conn_map.remove(&rem_id); - } - } - - let sig_url = cli_url.to_server(); - self.create_new_conn(sig_url, rem_id, None, Some(msg_uniq)) - .await - } - - async fn ban( - &mut self, - rem_id: Id, - span: std::time::Duration, - ) -> Result<()> { - let expires_at = std::time::Instant::now() + span; - self.ban_map.insert(rem_id, expires_at); - self.send_map.remove(&rem_id); - self.ice_cache.remove(&rem_id); - if let Some((conn, _)) = self.conn_map.remove(&rem_id) { - if let Some(conn) = conn.upgrade() { - conn.close(Error::id("Ban")); - } - } - Ok(()) - } - - fn stats( - &mut self, - resp: tokio::sync::oneshot::Sender>, - ) -> impl std::future::Future> + 'static + Send { - let this_id = self - .this_id - .map(|id| id.to_string()) - .unwrap_or_else(|| "".into()); - let conn_list = self - .conn_map - .iter() - .map(|(id, (c, _))| (*id, c.clone())) - .collect::>(); - let now = std::time::Instant::now(); - let mut ban_map = serde_json::Map::new(); - for (id, until) in self.ban_map.iter() { - ban_map.insert(id.to_string(), (*until - now).as_secs_f64().into()); - } - async move { - let mut map = serde_json::Map::new(); - - #[cfg(feature = "backend-go-pion")] - const BACKEND: &str = "go-pion"; - #[cfg(feature = "backend-webrtc-rs")] - const BACKEND: &str = "webrtc-rs"; - - map.insert("backend".into(), BACKEND.into()); - map.insert("thisId".into(), this_id.into()); - map.insert("banned".into(), ban_map.into()); - - for (id, conn) in conn_list { - if let Some(conn) = conn.upgrade() { - if let Ok(stats) = conn.stats().await { - map.insert(id.to_string(), stats); - } - } - } - - let _ = resp.send(Ok(map.into())); - - Ok(()) - } - } - - async fn publish(&mut self, evt: StateEvt) -> Result<()> { - let _ = self.evt.publish(evt); - Ok(()) - } - - async fn sig_connected(&mut self, cli_url: Tx5Url) -> Result<()> { - let loc_id = cli_url.id().unwrap(); - if let Some(this_id) = &self.this_id { - if this_id != &loc_id { - return Err(Error::err("MISMATCH LOCAL ID, please use the same lair instance for every sig connection")); - } - } else { - self.this_id = Some(loc_id); - } - let _ = self.evt.publish(StateEvt::Address(cli_url)); - Ok(()) - } - - async fn fetch_for_send( - &mut self, - want_conn: ConnStateWeak, - rem_id: Id, - ) -> Result<()> { - let conn = match self.conn_map.get(&rem_id) { - None => return Ok(()), - Some((cur_conn, _)) => { - if cur_conn != &want_conn { - return Ok(()); - } - match cur_conn.upgrade() { - None => return Ok(()), - Some(conn) => conn, - } - } - }; - let to_send = match self.send_map.get_mut(&rem_id) { - None => return Ok(()), - Some(to_send) => match to_send.pop_front() { - None => return Ok(()), - Some(to_send) => to_send, - }, - }; - conn.send(to_send); - Ok(()) - } - - async fn in_offer( - &mut self, - sig_url: Tx5Url, - rem_id: Id, - offer: BackBuf, - ) -> Result<()> { - if let Some((e, _)) = self.conn_map.get(&rem_id) { - if let Some(conn) = e.upgrade() { - // we seem to have a valid conn here... but - // we're receiving an incoming offer: - // activate PERFECT NEGOTIATION - // (https://developer.mozilla.org/en-US/docs/Web/API/WebRTC_API/Perfect_negotiation) - - if self.this_id.is_none() { - return Err(Error::err("Somehow we ended up receiving a webrtc offer before establishing a signal connection... this should be impossible")); - } - - match self.this_id.as_ref().unwrap().cmp(&rem_id) { - std::cmp::Ordering::Less => { - //println!("OFFER_CONFLICT:BEING_POLITE"); - - // we are the POLITE node, delete our connection - // and set up a new one with the incoming offer. - self.conn_map.remove(&rem_id); - conn.close(Error::id( - "PoliteShutdownToAcceptIncomingOffer", - )); - } - std::cmp::Ordering::Greater => { - //println!("OFFER_CONFLICT:BEING_IMPOLITE"); - - // we are the IMPOLITE node, we'll ignore this - // offer and continue with our existing connection. - drop(offer); - return Ok(()); - } - std::cmp::Ordering::Equal => { - tracing::warn!("Invalid incoming webrtc offer with id matching our local id. Please don't share lair connections"); - self.conn_map.remove(&rem_id); - return Ok(()); - //return Err(Error::err("Invalid incoming webrtc offer with id matching our local id. Please don't share lair connections")); - } - } - } else { - self.conn_map.remove(&rem_id); - } - } - - self.create_new_conn(sig_url, rem_id, Some(offer), None) - .await - } - - async fn in_demo(&mut self, sig_url: Tx5Url, rem_id: Id) -> Result<()> { - let cli_url = sig_url.to_client(rem_id); - self.evt.publish(StateEvt::Demo(cli_url)) - } - - async fn cache_ice(&mut self, rem_id: Id, ice: BackBuf) -> Result<()> { - let list = self.ice_cache.entry(rem_id).or_default(); - list.push_back(IceData { - timestamp: std::time::Instant::now(), - ice, - }); - Ok(()) - } - - async fn get_cached_ice(&mut self, rem_id: Id) -> Result<()> { - let StateData { - conn_map, - ice_cache, - .. - } = self; - if let Some((conn, _)) = conn_map.get(&rem_id) { - if let Some(conn) = conn.upgrade() { - if let Some(list) = ice_cache.get_mut(&rem_id) { - for ice_data in list.iter_mut() { - conn.in_ice(ice_data.ice.try_clone()?, false); - } - } - } - } - Ok(()) - } - - async fn close_sig( - &mut self, - sig_url: Tx5Url, - sig: SigStateWeak, - err: std::io::Error, - ) -> Result<()> { - if let Some(cur_sig) = self.signal_map.remove(&sig_url) { - if cur_sig == sig { - if let Some(sig) = sig.upgrade() { - sig.close(err); - } - } else { - // Whoops! - self.signal_map.insert(sig_url, cur_sig); - } - } - Ok(()) - } - - async fn conn_ready(&mut self, cli_url: Tx5Url) -> Result<()> { - self.evt.publish(StateEvt::Connected(cli_url)) - } - - async fn close_conn( - &mut self, - rem_id: Id, - conn: ConnStateWeak, - err: std::io::Error, - ) -> Result<()> { - if let Some((cur_conn, rm)) = self.conn_map.remove(&rem_id) { - if cur_conn == conn { - if let Some(conn) = conn.upgrade() { - conn.close(err); - } - } else { - // Whoops! - self.conn_map.insert(rem_id, (cur_conn, rm)); - } - } - Ok(()) - } -} - -enum StateCmd { - Tick1s, - TrackSig { - rem_id: Id, - ty: &'static str, - bytes: usize, - }, - SndDemo, - ListConnected(tokio::sync::oneshot::Sender>>), - AssertListenerSig { - sig_url: Tx5Url, - resp: tokio::sync::oneshot::Sender>, - }, - SendData { - msg_uniq: Uniq, - rem_id: Id, - data: BackBuf, - timestamp: std::time::Instant, - send_permit: tokio::sync::OwnedSemaphorePermit, - resp: tokio::sync::oneshot::Sender>, - cli_url: Tx5Url, - }, - Ban { - rem_id: Id, - span: std::time::Duration, - }, - Stats(tokio::sync::oneshot::Sender>), - Publish { - evt: StateEvt, - }, - SigConnected { - cli_url: Tx5Url, - }, - FetchForSend { - conn: ConnStateWeak, - rem_id: Id, - }, - InOffer { - sig_url: Tx5Url, - rem_id: Id, - data: BackBuf, - }, - InDemo { - sig_url: Tx5Url, - rem_id: Id, - }, - CacheIce { - rem_id: Id, - ice: BackBuf, - }, - GetCachedIce { - rem_id: Id, - }, - CloseSig { - sig_url: Tx5Url, - sig: SigStateWeak, - err: std::io::Error, - }, - ConnReady { - cli_url: Tx5Url, - }, - CloseConn { - rem_id: Id, - conn: ConnStateWeak, - err: std::io::Error, - }, -} - -#[allow(clippy::too_many_arguments)] -async fn state_task( - mut rcv: ManyRcv, - state_uniq: Uniq, - this: StateWeak, - meta: StateMeta, - evt: StateEvtSnd, - recv_limit: Arc, -) -> Result<()> { - let mut data = StateData { - state_uniq, - this_id: None, - this, - meta, - evt, - signal_map: HashMap::new(), - ban_map: HashMap::new(), - conn_map: HashMap::new(), - send_map: HashMap::new(), - ice_cache: HashMap::new(), - recv_limit, - }; - let err = match async { - while let Some(cmd) = rcv.recv().await { - data.exec(cmd?).await?; - } - Ok(()) - } - .await - { - Err(err) => err, - Ok(_) => Error::id("Dropped"), - }; - data.shutdown(err.err_clone()); - Err(err) -} - -#[derive(Clone)] -pub(crate) struct StateMeta { - pub(crate) state_uniq: Uniq, - pub(crate) config: DynConfig, - pub(crate) conn_limit: Arc, - pub(crate) snd_limit: Arc, - pub(crate) metric_conn_count: AtomicObservableUpDownCounterI64, - pub(crate) snd_ident: Arc, -} - -/// Weak version of State. -#[derive(Clone)] -pub struct StateWeak(ActorWeak, StateMeta); - -impl StateWeak { - /// Upgrade to a full State instance. - pub fn upgrade(&self) -> Option { - self.0.upgrade().map(|s| State(s, self.1.clone())) - } -} - -/// Handle to a state tracking instance. -#[derive(Clone)] -pub struct State(Actor, StateMeta); - -impl PartialEq for State { - fn eq(&self, other: &Self) -> bool { - self.0 == other.0 - } -} - -impl Eq for State {} - -impl State { - /// Construct a new state instance. - pub fn new(config: DynConfig) -> Result<(Self, ManyRcv)> { - let conn_limit = Arc::new(tokio::sync::Semaphore::new( - config.max_conn_count() as usize, - )); - - let snd_limit = Arc::new(tokio::sync::Semaphore::new( - config.max_send_bytes() as usize, - )); - let rcv_limit = Arc::new(tokio::sync::Semaphore::new( - config.max_recv_bytes() as usize, - )); - - let state_uniq = Uniq::default(); - - let metric_conn_count = opentelemetry_api::global::meter_provider() - .versioned_meter( - "tx5", - None::<&'static str>, - None::<&'static str>, - Some(vec![opentelemetry_api::KeyValue::new( - "state_uniq", - state_uniq.to_string(), - )]), - ) - .i64_observable_up_down_counter_atomic("tx5.endpoint.conn.count", 0) - .with_description("Count of open connections managed by endpoint") - .init() - .0; - - let meta = StateMeta { - state_uniq: state_uniq.clone(), - config, - conn_limit, - snd_limit, - metric_conn_count, - snd_ident: Arc::new(std::sync::atomic::AtomicU64::new(1)), - }; - - let (state_snd, state_rcv) = tokio::sync::mpsc::unbounded_channel(); - let actor = { - let meta = meta.clone(); - Actor::new(move |this, rcv| { - state_task( - rcv, - state_uniq, - StateWeak(this, meta.clone()), - meta, - StateEvtSnd(state_snd), - rcv_limit, - ) - }) - }; - - let weak = StateWeak(actor.weak(), meta.clone()); - tokio::task::spawn(async move { - loop { - tokio::time::sleep(std::time::Duration::from_secs(1)).await; - match weak.upgrade() { - None => break, - Some(actor) => { - if actor.tick_1s().is_err() { - break; - } - } - } - } - }); - - Ok((Self(actor, meta), ManyRcv(state_rcv))) - } - - /// Get a weak version of this State instance. - pub fn weak(&self) -> StateWeak { - StateWeak(self.0.weak(), self.1.clone()) - } - - /// Returns `true` if this State is closed. - pub fn is_closed(&self) -> bool { - self.0.is_closed() - } - - /// Shutdown the state management with an error. - pub fn close(&self, err: std::io::Error) { - self.0.close(err); - } - - /// List the ids of current open connections. - pub fn list_connected( - &self, - ) -> impl Future>> + 'static + Send { - let this = self.clone(); - async move { - let (s, r) = tokio::sync::oneshot::channel(); - this.0.send(Ok(StateCmd::ListConnected(s)))?; - r.await.map_err(|_| Error::id("Closed"))? - } - } - - /// Establish a new listening connection through given signal server. - pub fn listener_sig( - &self, - sig_url: Tx5Url, - ) -> impl Future> + 'static + Send { - let this = self.clone(); - async move { - if !sig_url.is_server() { - return Err(Error::err( - "Invalid tx5 client url, expected signal server url", - )); - } - let (s, r) = tokio::sync::oneshot::channel(); - this.0 - .send(Ok(StateCmd::AssertListenerSig { sig_url, resp: s }))?; - r.await.map_err(|_| Error::id("Closed"))? - } - } - - /// Close down all connections to, fail all outgoing messages to, - /// and drop all incoming messages from, the given remote id, - /// for the specified ban time period. - pub fn ban(&self, rem_id: Id, span: std::time::Duration) { - let _ = self.0.send(Ok(StateCmd::Ban { rem_id, span })); - } - - /// Schedule data to be sent out over a channel managed by the state system. - /// The future will resolve immediately if there is still space - /// in the outgoing buffer, or once there is again space in the buffer. - pub fn snd_data( - &self, - cli_url: Tx5Url, - data: B, - ) -> impl Future> + 'static + Send { - let timestamp = std::time::Instant::now(); - - let buf_list = if !cli_url.is_client() { - Err(Error::err( - "Invalid tx5 signal server url, expect client url", - )) - } else { - divide_send(&*self.1.config, &self.1.snd_ident, data) - }; - - let meta = self.1.clone(); - - let this = self.clone(); - async move { - tokio::time::timeout(meta.config.max_conn_init(), async move { - let cli_url = &cli_url; - let msg_uniq = meta.state_uniq.sub(); - let msg_uniq = &msg_uniq; - - let buf_list = buf_list?; - - let mut resp_list = Vec::with_capacity(buf_list.len()); - - for (idx, mut buf) in buf_list.into_iter().enumerate() { - let len = buf.len()?; - - tracing::trace!(%msg_uniq, %len, "snd_data"); - - if meta.snd_limit.available_permits() < len { - tracing::warn!(%msg_uniq, %len, "send queue full, waiting for permits"); - } - - let send_permit = meta - .snd_limit - .clone() - .acquire_many_owned(len as u32) - .await - .map_err(Error::err)?; - - tracing::trace!(%msg_uniq, %idx, %len, "snd_data:got permit"); - - let rem_id = cli_url.id().unwrap(); - - let (s_sent, r_sent) = tokio::sync::oneshot::channel(); - - if let Err(err) = this.0.send(Ok(StateCmd::SendData { - msg_uniq: msg_uniq.clone(), - rem_id, - data: buf, - timestamp, - send_permit, - resp: s_sent, - cli_url: cli_url.clone(), - })) { - tracing::trace!(%msg_uniq, %idx, ?err, "snd_data:complete err"); - return Err(err); - } - - resp_list.push(async move { - match r_sent.await.map_err(|_| Error::id("Closed")) { - Ok(r) => match r { - Ok(_) => { - tracing::trace!(%msg_uniq, %idx, "snd_data:complete ok"); - } - Err(err) => { - tracing::trace!(%msg_uniq, %idx, ?err, "snd_data:complete err"); - return Err(err); - } - }, - Err(err) => { - tracing::trace!(%msg_uniq, %idx, ?err, "snd_data:complete err"); - return Err(err); - } - } - Ok(()) - }); - } - - for resp in resp_list { - resp.await?; - } - - Ok(()) - }).await.map_err(|_| Error::id("Timeout"))? - } - } - - /// Send a demo broadcast to every connected signal server. - /// Warning, if demo mode is not enabled on these servers, this - /// could result in a ban. - pub fn snd_demo(&self) -> Result<()> { - self.0.send(Ok(StateCmd::SndDemo)) - } - - /// Get stats. - pub fn stats( - &self, - ) -> impl Future> + 'static + Send { - let this = self.clone(); - async move { - let (s, r) = tokio::sync::oneshot::channel(); - this.0.send(Ok(StateCmd::Stats(s)))?; - r.await.map_err(|_| Error::id("Shutdown"))? - } - } - - // -- // - - fn tick_1s(&self) -> Result<()> { - self.0.send(Ok(StateCmd::Tick1s)) - } - - pub(crate) fn track_sig(&self, rem_id: Id, ty: &'static str, bytes: usize) { - let _ = self.0.send(Ok(StateCmd::TrackSig { rem_id, ty, bytes })); - } - - pub(crate) fn publish(&self, evt: StateEvt) { - let _ = self.0.send(Ok(StateCmd::Publish { evt })); - } - - pub(crate) fn sig_connected(&self, cli_url: Tx5Url) { - let _ = self.0.send(Ok(StateCmd::SigConnected { cli_url })); - } - - pub(crate) fn fetch_for_send( - &self, - conn: ConnStateWeak, - rem_id: Id, - ) -> Result<()> { - self.0.send(Ok(StateCmd::FetchForSend { conn, rem_id })) - } - - pub(crate) fn in_offer( - &self, - sig_url: Tx5Url, - rem_id: Id, - data: BackBuf, - ) -> Result<()> { - self.0.send(Ok(StateCmd::InOffer { - sig_url, - rem_id, - data, - })) - } - - pub(crate) fn in_demo(&self, sig_url: Tx5Url, rem_id: Id) -> Result<()> { - self.0.send(Ok(StateCmd::InDemo { sig_url, rem_id })) - } - - pub(crate) fn cache_ice(&self, rem_id: Id, ice: BackBuf) -> Result<()> { - self.0.send(Ok(StateCmd::CacheIce { rem_id, ice })) - } - - pub(crate) fn get_cached_ice(&self, rem_id: Id) -> Result<()> { - self.0.send(Ok(StateCmd::GetCachedIce { rem_id })) - } - - pub(crate) fn close_sig( - &self, - sig_url: Tx5Url, - sig: SigStateWeak, - err: std::io::Error, - ) { - let _ = self.0.send(Ok(StateCmd::CloseSig { sig_url, sig, err })); - } - - pub(crate) fn conn_ready(&self, cli_url: Tx5Url) { - let _ = self.0.send(Ok(StateCmd::ConnReady { cli_url })); - } - - pub(crate) fn close_conn( - &self, - rem_id: Id, - conn: ConnStateWeak, - err: std::io::Error, - ) { - let _ = self.0.send(Ok(StateCmd::CloseConn { rem_id, conn, err })); - } -} diff --git a/crates/tx5/src/state/conn.rs b/crates/tx5/src/state/conn.rs deleted file mode 100644 index 575aca6b..00000000 --- a/crates/tx5/src/state/conn.rs +++ /dev/null @@ -1,1212 +0,0 @@ -use super::*; -use std::sync::atomic; - -/// Temporary indicating we want a new conn instance. -pub struct ConnStateSeed { - done: bool, - output: Option<(ConnState, ManyRcv)>, -} - -impl std::fmt::Debug for ConnStateSeed { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("ConnStateSeed").finish() - } -} - -impl Drop for ConnStateSeed { - fn drop(&mut self) { - self.result_err_inner(Error::id("Dropped")); - } -} - -impl ConnStateSeed { - /// Finalize this conn_state seed by indicating a successful connection. - pub fn result_ok(mut self) -> Result<(ConnState, ManyRcv)> { - self.done = true; - let (conn, conn_evt) = self.output.take().unwrap(); - conn.notify_constructed()?; - Ok((conn, conn_evt)) - } - - /// Finalize this conn_state seed by indicating an error connecting. - pub fn result_err(mut self, err: std::io::Error) { - self.result_err_inner(err); - } - - // -- // - - pub(crate) fn new( - conn: ConnState, - conn_evt: ManyRcv, - ) -> Self { - Self { - done: false, - output: Some((conn, conn_evt)), - } - } - - fn result_err_inner(&mut self, err: std::io::Error) { - if !self.done { - self.done = true; - if let Some((conn, _)) = self.output.take() { - conn.close(err); - } - } - } -} - -/// Indication of the current buffer state. -#[derive(Debug, PartialEq)] -pub enum BufState { - /// BackBuffer is low, we can buffer more data. - Low, - - /// BackBuffer is high, we should wait / apply backpressure. - High, -} - -/// State wishes to invoke an action on a connection instance. -pub enum ConnStateEvt { - /// Request to create an offer. - CreateOffer(OneSnd), - - /// Request to create an answer. - CreateAnswer(OneSnd), - - /// Request to set a local description. - SetLoc(BackBuf, OneSnd<()>), - - /// Request to set a remote description. - SetRem(BackBuf, OneSnd<()>), - - /// Request to append a trickle ICE candidate. - SetIce(BackBuf, OneSnd<()>), - - /// Request to send a message on the data channel. - SndData(BackBuf, OneSnd), - - /// Request a stats dump of this peer connection. - Stats(OneSnd), -} - -impl std::fmt::Debug for ConnStateEvt { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - ConnStateEvt::CreateOffer(_) => f.write_str("CreateOffer"), - ConnStateEvt::CreateAnswer(_) => f.write_str("CreateAnswer"), - ConnStateEvt::SetLoc(_, _) => f.write_str("SetLoc"), - ConnStateEvt::SetRem(_, _) => f.write_str("SetRem"), - ConnStateEvt::SetIce(_, _) => f.write_str("SetIce"), - ConnStateEvt::SndData(_, _) => f.write_str("SndData"), - ConnStateEvt::Stats(_) => f.write_str("Stats"), - } - } -} - -#[derive(Clone)] -struct ConnStateEvtSnd( - tokio::sync::mpsc::UnboundedSender>, -); - -impl ConnStateEvtSnd { - pub fn err(&self, err: std::io::Error) { - let _ = self.0.send(Err(err)); - } - - pub fn create_offer(&self, conn: ConnStateWeak) { - let s = OneSnd::new(move |result| { - if let Some(conn) = conn.upgrade() { - conn.self_offer(result); - } - }); - let _ = self.0.send(Ok(ConnStateEvt::CreateOffer(s))); - } - - pub fn create_answer(&self, conn: ConnStateWeak) { - let s = OneSnd::new(move |result| { - if let Some(conn) = conn.upgrade() { - conn.self_answer(result); - } - }); - let _ = self.0.send(Ok(ConnStateEvt::CreateAnswer(s))); - } - - pub fn set_loc(&self, conn: ConnStateWeak, data: BackBuf) { - let s = OneSnd::new(move |result| { - if let Err(err) = result { - if let Some(conn) = conn.upgrade() { - conn.close(err); - } - } - }); - let _ = self.0.send(Ok(ConnStateEvt::SetLoc(data, s))); - } - - pub fn set_rem( - &self, - conn: ConnStateWeak, - data: BackBuf, - should_answer: bool, - ) { - let s = if should_answer { - OneSnd::new(move |result| match result { - Ok(_) => { - if let Some(conn) = conn.upgrade() { - conn.req_self_answer(); - } - } - Err(err) => { - if let Some(conn) = conn.upgrade() { - conn.close(err); - } - } - }) - } else { - OneSnd::new(move |result| { - if let Err(err) = result { - if let Some(conn) = conn.upgrade() { - conn.close(err); - } - } - }) - }; - let _ = self.0.send(Ok(ConnStateEvt::SetRem(data, s))); - } - - pub fn set_ice(&self, _conn: ConnStateWeak, data: BackBuf) { - let s = OneSnd::new(move |result| { - if let Err(err) = result { - tracing::debug!(?err, "ICEError"); - // treat ice errors loosely... sometimes things - // get out of order... especially with perfect negotiation - /* - if let Some(conn) = conn.upgrade() { - conn.close(err); - } - */ - } - }); - let _ = self.0.send(Ok(ConnStateEvt::SetIce(data, s))); - } - - pub fn snd_data( - &self, - conn: ConnStateWeak, - data: BackBuf, - resp: Option>>, - send_permit: Option, - ) { - let s = OneSnd::new(move |result| { - let _send_permit = send_permit; - match result { - Err(err) => { - if let Some(conn) = conn.upgrade() { - conn.close(err.err_clone()); - } - if let Some(resp) = resp { - let _ = resp.send(Err(err)); - } - } - Ok(buffer_state) => { - if let Some(conn) = conn.upgrade() { - conn.notify_send_complete(buffer_state); - } - if let Some(resp) = resp { - let _ = resp.send(Ok(())); - } - } - } - }); - let _ = self.0.send(Ok(ConnStateEvt::SndData(data, s))); - } - - pub fn stats( - &self, - additions: Vec<(String, serde_json::Value)>, - rsp: tokio::sync::oneshot::Sender>, - ) { - let _ = self.0.send(Ok(ConnStateEvt::Stats(OneSnd::new( - move |buf: Result| { - let _ = rsp.send((move || { - let mut stats: serde_json::Value = buf?.to_json()?; - for (key, value) in additions { - stats.as_object_mut().unwrap().insert(key, value); - } - Ok(stats) - })()); - }, - )))); - } -} - -struct ConnStateData { - conn_uniq: Uniq, - this: ConnStateWeak, - metric_conn_count: AtomicObservableUpDownCounterI64, - meta: ConnStateMeta, - state: StateWeak, - this_id: Id, - rem_id: Id, - conn_evt: ConnStateEvtSnd, - sig_state: SigStateWeak, - rcv_offer: bool, - rcv_pending: - HashMap)>, - wait_preflight: bool, - offer: (u64, u64, u64, u64), - answer: (u64, u64, u64, u64), - ice: (u64, u64, u64, u64), - buf_state: BufState, - send_wait: bool, -} - -impl Drop for ConnStateData { - fn drop(&mut self) { - self.metric_conn_count.add(-1); - self.shutdown(Error::id("Dropped")); - } -} - -impl ConnStateData { - fn connected(&self) -> bool { - self.meta.connected.load(atomic::Ordering::SeqCst) - } - - fn shutdown(&mut self, err: std::io::Error) { - tracing::debug!( - ?err, - conn_uniq = %self.conn_uniq, - this_id = ?self.this_id, - rem_id = ?self.rem_id, - "ConnShutdown", - ); - if let Some(state) = self.state.upgrade() { - state.close_conn(self.rem_id, self.this.clone(), err.err_clone()); - } - if let Some(sig) = self.sig_state.upgrade() { - sig.unregister_conn(self.rem_id, self.this.clone()); - } - self.conn_evt.err(err); - } - - fn get_sig(&mut self) -> Result { - match self.sig_state.upgrade() { - Some(sig) => Ok(sig), - None => Err(Error::id("SigClosed")), - } - } - - async fn exec(&mut self, cmd: ConnCmd) -> Result<()> { - match cmd { - ConnCmd::Tick1s => self.tick_1s().await, - ConnCmd::Stats(rsp) => self.stats(rsp).await, - ConnCmd::TrackSig { ty, bytes } => self.track_sig(ty, bytes).await, - ConnCmd::NotifyConstructed => self.notify_constructed().await, - ConnCmd::CheckConnectedTimeout => { - self.check_connected_timeout().await - } - ConnCmd::Ice { data } => self.ice(data).await, - ConnCmd::SelfOffer { offer } => self.self_offer(offer).await, - ConnCmd::ReqSelfAnswer => self.req_self_answer().await, - ConnCmd::SelfAnswer { answer } => self.self_answer(answer).await, - ConnCmd::InOffer { offer } => self.in_offer(offer).await, - ConnCmd::InAnswer { answer } => self.in_answer(answer).await, - ConnCmd::InIce { ice, cache } => self.in_ice(ice, cache).await, - ConnCmd::Ready => self.ready().await, - ConnCmd::MaybeFetchForSend { - send_complete, - buf_state, - } => self.maybe_fetch_for_send(send_complete, buf_state).await, - ConnCmd::Send { to_send } => self.send(to_send).await, - ConnCmd::Recv { - ident, - data, - permit, - } => self.recv(ident, data, permit).await, - } - } - - async fn tick_1s(&mut self) -> Result<()> { - if self.meta.last_active_at.elapsed() > self.meta.config.max_conn_init() - && !self.connected() - { - self.shutdown(Error::id("InactivityTimeout")); - } - - Ok(()) - } - - async fn stats( - &mut self, - rsp: tokio::sync::oneshot::Sender>, - ) -> Result<()> { - let mut additions = Vec::new(); - additions.push(( - "ageSeconds".into(), - self.meta.created_at.elapsed().as_secs_f64().into(), - )); - - let sig_stats = serde_json::json!({ - "offersSent": self.offer.0, - "offerBytesSent": self.offer.1, - "offersReceived": self.offer.2, - "offerBytesReceived": self.offer.3, - "answersSent": self.answer.0, - "answerBytesSent": self.answer.1, - "answersReceived": self.answer.2, - "answerBytesReceived": self.answer.3, - "iceMessagesSent": self.ice.0, - "iceBytesSent": self.ice.1, - "iceMessagesReceived": self.ice.2, - "iceBytesReceived": self.ice.3, - }); - additions.push(("signalingTransport".into(), sig_stats)); - - self.conn_evt.stats(additions, rsp); - - Ok(()) - } - - async fn track_sig( - &mut self, - ty: &'static str, - bytes: usize, - ) -> Result<()> { - match ty { - "offer_out" => { - self.offer.0 += 1; - self.offer.1 += bytes as u64; - } - "offer_in" => { - self.offer.2 += 1; - self.offer.3 += bytes as u64; - } - "answer_out" => { - self.answer.0 += 1; - self.answer.1 += bytes as u64; - } - "answer_in" => { - self.answer.2 += 1; - self.answer.3 += bytes as u64; - } - "ice_out" => { - self.ice.0 += 1; - self.ice.1 += bytes as u64; - } - "ice_in" => { - self.ice.2 += 1; - self.ice.3 += bytes as u64; - } - _ => (), - } - - Ok(()) - } - - async fn notify_constructed(&mut self) -> Result<()> { - if !self.rcv_offer { - // Kick off connection initialization by requesting - // an outgoing offer be created by this connection. - // This will result in a `self_offer` call. - self.conn_evt.create_offer(self.this.clone()); - } - Ok(()) - } - - async fn check_connected_timeout(&mut self) -> Result<()> { - if !self.connected() { - Err(Error::id("Timeout")) - } else { - Ok(()) - } - } - - async fn ice(&mut self, data: BackBuf) -> Result<()> { - let sig = self.get_sig()?; - sig.snd_ice(self.rem_id, data) - } - - async fn self_offer(&mut self, offer: Result) -> Result<()> { - let sig = self.get_sig()?; - let mut offer = offer?; - self.conn_evt.set_loc(self.this.clone(), offer.try_clone()?); - sig.snd_offer(self.rem_id, offer) - } - - async fn req_self_answer(&mut self) -> Result<()> { - self.conn_evt.create_answer(self.this.clone()); - Ok(()) - } - - async fn self_answer(&mut self, answer: Result) -> Result<()> { - let sig = self.get_sig()?; - let mut answer = answer?; - self.conn_evt - .set_loc(self.this.clone(), answer.try_clone()?); - sig.snd_answer(self.rem_id, answer) - } - - async fn in_offer(&mut self, mut offer: BackBuf) -> Result<()> { - tracing::trace!( - conn_uniq = %self.conn_uniq, - this_id = ?self.this_id, - rem_id = ?self.rem_id, - offer = %String::from_utf8_lossy(&offer.to_vec()?), - "OfferRecv", - ); - self.rcv_offer = true; - self.conn_evt.set_rem(self.this.clone(), offer, true); - self.state - .upgrade() - .ok_or_else(|| Error::id("Closed"))? - .get_cached_ice(self.rem_id)?; - Ok(()) - } - - async fn in_answer(&mut self, mut answer: BackBuf) -> Result<()> { - tracing::trace!( - conn_uniq = %self.conn_uniq, - this_id = ?self.this_id, - rem_id = ?self.rem_id, - answer = %String::from_utf8_lossy(&answer.to_vec()?), - "AnswerRecv", - ); - self.conn_evt.set_rem(self.this.clone(), answer, false); - self.state - .upgrade() - .ok_or_else(|| Error::id("Closed"))? - .get_cached_ice(self.rem_id)?; - Ok(()) - } - - async fn in_ice(&mut self, mut ice: BackBuf, cache: bool) -> Result<()> { - tracing::trace!( - conn_uniq = %self.conn_uniq, - this_id = ?self.this_id, - - rem_id = ?self.rem_id, - ice = %String::from_utf8_lossy(&ice.to_vec()?), - "ICERecv", - ); - if cache { - self.state - .upgrade() - .ok_or_else(|| Error::id("Closed"))? - .cache_ice(self.rem_id, ice.try_clone()?)?; - } - self.conn_evt.set_ice(self.this.clone(), ice); - Ok(()) - } - - async fn ready(&mut self) -> Result<()> { - // first, check / send the preflight - let data = self - .meta - .config - .on_conn_preflight(self.meta.cli_url.clone()) - .await? - .unwrap_or_else(bytes::Bytes::new); - - for buf in divide_send(&*self.meta.config, &self.meta.snd_ident, data)? - { - self.conn_evt.snd_data(self.this.clone(), buf, None, None); - } - - self.meta.connected.store(true, atomic::Ordering::SeqCst); - self.maybe_fetch_for_send(false, None).await - } - - async fn maybe_fetch_for_send( - &mut self, - send_complete: bool, - buf_state: Option, - ) -> Result<()> { - if send_complete { - self.send_wait = false; - } - - if let Some(buf_state) = buf_state { - if self.buf_state != buf_state { - tracing::debug!( - conn_uniq = %self.meta.conn_uniq, - old_buf_state = ?self.buf_state, - new_buf_state = ?buf_state, - "Updating BufState", - ); - self.buf_state = buf_state; - } - } - - if !self.connected() { - return Ok(()); - } - - // if we are within the close time send grace period - // do not send any new messages so we can try to shut - // down gracefully - if self.meta.created_at.elapsed() - > (MAX_CON_TIME - CON_CLOSE_SEND_GRACE) - { - return Ok(()); - } - - if let BufState::High = self.buf_state { - // wait for our buffer state to be low before fetching - // more data to send. - return Ok(()); - } - - if self.send_wait { - // we already have an outgoing send, don't request another - return Ok(()); - } - - if let Some(state) = self.state.upgrade() { - state.fetch_for_send(self.this.clone(), self.rem_id)?; - Ok(()) - } else { - Err(Error::id("StateClosed")) - } - } - - async fn send(&mut self, to_send: SendData) -> Result<()> { - let SendData { - msg_uniq, - mut data, - resp, - send_permit, - .. - } = to_send; - - tracing::trace!(conn_uniq = %self.conn_uniq, %msg_uniq, "conn send"); - - self.meta.last_active_at = std::time::Instant::now(); - self.meta.metric_bytes_snd.add(data.len()? as u64); - - self.send_wait = true; - self.conn_evt.snd_data( - self.this.clone(), - data, - resp, - Some(send_permit), - ); - - Ok(()) - } - - async fn handle_recv_data( - &mut self, - mut bl: BytesList, - permit: Vec, - ) -> Result<()> { - use bytes::Buf; - - if let Some(state) = self.state.upgrade() { - if self.wait_preflight { - let bytes = if bl.has_remaining() { - Some(bl.copy_to_bytes(bl.remaining())) - } else { - None - }; - self.meta - .config - .on_conn_validate(self.meta.cli_url.clone(), bytes) - .await?; - self.wait_preflight = false; - - state.conn_ready(self.meta.cli_url.clone()); - } else { - state.publish(StateEvt::RcvData( - self.meta.cli_url.clone(), - bl.into_dyn(), - vec![Permit(permit)], - )); - } - } - Ok(()) - } - - async fn recv( - &mut self, - ident: u64, - data: bytes::Bytes, - permit: tokio::sync::OwnedSemaphorePermit, - ) -> Result<()> { - let len = data.len(); - self.meta.last_active_at = std::time::Instant::now(); - - self.meta.metric_bytes_rcv.add(len as u64); - - let is_finish = ident.is_finish(); - let ident = ident.unset_finish(); - - match self.rcv_pending.entry(ident) { - std::collections::hash_map::Entry::Vacant(e) => { - if data.is_empty() || is_finish { - tracing::trace!(%is_finish, %ident, "rcv already finished"); - // special case for oneshot message - let mut bl = BytesList::new(); - if !data.is_empty() { - bl.push(data); - } - self.handle_recv_data(bl, vec![permit]).await?; - } else { - tracing::trace!(%is_finish, %ident, byte_count=%len, "rcv new"); - let mut bl = BytesList::new(); - bl.push(data); - e.insert((bl, vec![permit])); - } - } - std::collections::hash_map::Entry::Occupied(mut e) => { - if data.is_empty() || is_finish { - tracing::trace!(%is_finish, %ident, "rcv complete"); - // we've gotten to the end - let (mut bl, permit) = e.remove(); - if !data.is_empty() { - bl.push(data); - } - self.handle_recv_data(bl, permit).await?; - } else { - tracing::trace!(%is_finish, %ident, byte_count=%len, "rcv next"); - e.get_mut().0.push(data); - e.get_mut().1.push(permit); - } - } - } - - Ok(()) - } -} - -enum ConnCmd { - Tick1s, - Stats(tokio::sync::oneshot::Sender>), - TrackSig { - ty: &'static str, - bytes: usize, - }, - NotifyConstructed, - CheckConnectedTimeout, - Ice { - data: BackBuf, - }, - SelfOffer { - offer: Result, - }, - ReqSelfAnswer, - SelfAnswer { - answer: Result, - }, - InOffer { - offer: BackBuf, - }, - InAnswer { - answer: BackBuf, - }, - InIce { - ice: BackBuf, - cache: bool, - }, - Ready, - MaybeFetchForSend { - send_complete: bool, - buf_state: Option, - }, - Send { - to_send: SendData, - }, - Recv { - ident: u64, - data: bytes::Bytes, - permit: tokio::sync::OwnedSemaphorePermit, - }, -} - -#[allow(clippy::too_many_arguments)] -async fn conn_state_task( - conn_limit: Arc, - metric_conn_count: AtomicObservableUpDownCounterI64, - meta: ConnStateMeta, - strong: ConnState, - conn_rcv: ManyRcv, - mut rcv: ManyRcv, - this: ConnStateWeak, - state: StateWeak, - conn_uniq: Uniq, - this_id: Id, - rem_id: Id, - conn_evt: ConnStateEvtSnd, - sig_state: SigStateWeak, - sig_ready: tokio::sync::oneshot::Receiver>, -) -> Result<()> { - metric_conn_count.add(1); - - let mut data = ConnStateData { - conn_uniq, - this, - metric_conn_count, - meta, - state, - this_id, - rem_id, - conn_evt, - sig_state, - rcv_offer: false, - rcv_pending: HashMap::new(), - wait_preflight: true, - offer: (0, 0, 0, 0), - answer: (0, 0, 0, 0), - ice: (0, 0, 0, 0), - buf_state: BufState::Low, - send_wait: false, - }; - - let mut permit = None; - - let err = match async { - if conn_limit.available_permits() < 1 { - tracing::warn!(conn_uniq = %data.conn_uniq, "max connections reached, waiting for permit"); - } - - permit = Some( - conn_limit - .acquire_owned() - .await - .map_err(|_| Error::id("Closed"))?, - ); - - sig_ready.await.map_err(|_| Error::id("SigClosed"))??; - - let sig = data.get_sig()?; - let ice_servers = - sig.register_conn(data.rem_id, data.this.clone()).await?; - - match data.state.upgrade() { - None => return Err(Error::id("Closed")), - Some(state) => { - tracing::debug!(conn_uniq = %data.conn_uniq, id = ?data.rem_id, "NewConn"); - let seed = ConnStateSeed::new(strong, conn_rcv); - state.publish(StateEvt::NewConn(ice_servers, seed)); - } - } - - while let Some(cmd) = rcv.recv().await { - data.exec(cmd?).await?; - } - Ok(()) - } - .await - { - Err(err) => err, - Ok(_) => Error::id("Dropped"), - }; - - data.shutdown(err.err_clone()); - - drop(permit); - - Err(err) -} - -#[derive(Clone)] -pub(crate) struct ConnStateMeta { - pub(crate) created_at: std::time::Instant, - pub(crate) last_active_at: std::time::Instant, - pub(crate) cli_url: Tx5Url, - pub(crate) state_uniq: Uniq, - pub(crate) conn_uniq: Uniq, - pub(crate) config: DynConfig, - pub(crate) connected: Arc, - _conn_snd: ConnStateEvtSnd, - pub(crate) rcv_limit: Arc, - pub(crate) metric_bytes_snd: AtomicObservableCounterU64, - pub(crate) metric_bytes_rcv: AtomicObservableCounterU64, - snd_ident: Arc, -} - -/// Weak version on ConnState. -#[derive(Clone)] -pub struct ConnStateWeak(ActorWeak, ConnStateMeta); - -impl PartialEq for ConnStateWeak { - fn eq(&self, rhs: &ConnStateWeak) -> bool { - self.0 == rhs.0 - } -} - -impl PartialEq for ConnStateWeak { - fn eq(&self, rhs: &ConnState) -> bool { - self.0 == rhs.0 - } -} - -impl Eq for ConnStateWeak {} - -impl ConnStateWeak { - /// Access the meta struct - pub(crate) fn meta(&self) -> &ConnStateMeta { - &self.1 - } - - /// Upgrade to a full ConnState instance. - pub fn upgrade(&self) -> Option { - self.0.upgrade().map(|i| ConnState(i, self.1.clone())) - } -} - -/// A handle for notifying the state system of connection events. -#[derive(Clone)] -pub struct ConnState(Actor, ConnStateMeta); - -impl PartialEq for ConnState { - fn eq(&self, rhs: &ConnState) -> bool { - self.0 == rhs.0 - } -} - -impl PartialEq for ConnState { - fn eq(&self, rhs: &ConnStateWeak) -> bool { - self.0 == rhs.0 - } -} - -impl Eq for ConnState {} - -impl ConnState { - /* - /// Access the meta struct - pub(crate) fn meta(&self) -> &ConnStateMeta { - &self.1 - } - */ - - pub(crate) fn meta(&self) -> &ConnStateMeta { - &self.1 - } - - /// Get a weak version of this ConnState instance. - pub fn weak(&self) -> ConnStateWeak { - ConnStateWeak(self.0.weak(), self.1.clone()) - } - - /// Returns `true` if this ConnState is closed. - pub fn is_closed(&self) -> bool { - self.0.is_closed() - } - - /// Shutdown the connection with an error. - pub fn close(&self, err: std::io::Error) { - self.0.close(err); - } - - /// Get the remote url of this connection. - pub fn rem_url(&self) -> Tx5Url { - self.1.cli_url.clone() - } - - /// The connection generated an ice candidate for the remote. - pub fn ice(&self, data: BackBuf) -> Result<()> { - self.0.send(Ok(ConnCmd::Ice { data })) - } - - /// The connection is ready to send and receive data. - pub fn ready(&self) -> Result<()> { - self.0.send(Ok(ConnCmd::Ready)) - } - - /// The connection received data on the data channel. - /// This synchronous function must not block for now... - /// (we'll need to test some blocking strategies - /// for the goroutine in tx5-go-pion)... but we also can't just - /// fill up memory if the application is processing slowly. - /// So it will error / trigger connection shutdown if we get - /// too much of a backlog. - pub fn rcv_data(&self, mut data: BackBuf) -> Result<()> { - // polling try_acquire doesn't fairly reserve a place in line, - // so we need to timeout an actual acquire future.. - - // we've got 15 ms of time to acquire the recv permit - // this is a little more forgiving than just blanket deny - // if the app is a little slow, but it's also not so much time - // that we'll end up stalling other tasks that need to run. - - let fut = tokio::time::timeout( - std::time::Duration::from_millis(15), - async move { - use std::io::Read; - - let mut len = data.len()?; - if len > 16 * 1024 { - return Err(Error::id("MsgChunkTooLarge")); - } - if len < 8 { - return Err(Error::id("MsgChunkInvalid")); - } - - let mut ident = [0; 8]; - data.read_exact(&mut ident[..])?; - let ident = u64::from_le_bytes(ident); - - len -= 8; - - let buf = bytes::BytesMut::with_capacity(len); - let mut buf = bytes::BufMut::writer(buf); - std::io::copy(&mut data, &mut buf)?; - let buf = buf.into_inner().freeze(); - - if self.1.rcv_limit.available_permits() < len { - tracing::warn!(%len, "recv queue full, waiting for permits"); - } - - let permit = self - .1 - .rcv_limit - .clone() - .acquire_many_owned(len as u32) - .await - .map_err(|_| Error::id("Closed"))?; - - self.0.send(Ok(ConnCmd::Recv { - ident, - data: buf, - permit, - })) - }, - ); - - // need an external (futures) polling executor, because tokio - // won't let us use blocking_recv, since we're still in the runtime. - - futures::executor::block_on(fut).map_err(|_| { - tracing::error!("SLOW_APP: failed to receive in timely manner"); - Error::id("RecvQueueFull") - })? - } - - /// The send buffer *was* high, but has now transitioned to low. - pub fn buf_amt_low(&self) -> Result<()> { - todo!() - } - - /// Get stats. - pub async fn stats(&self) -> Result { - let (s, r) = tokio::sync::oneshot::channel(); - if self.0.send(Ok(ConnCmd::Stats(s))).is_err() { - return Err(Error::id("Closed")); - } - r.await.map_err(|_| Error::id("Closed"))? - } - - // -- // - - #[allow(clippy::too_many_arguments)] - pub(crate) fn new_and_publish( - config: DynConfig, - conn_limit: Arc, - metric_conn_count: AtomicObservableUpDownCounterI64, - state: StateWeak, - sig_state: SigStateWeak, - state_uniq: Uniq, - conn_uniq: Uniq, - this_id: Id, - cli_url: Tx5Url, - rem_id: Id, - rcv_limit: Arc, - sig_ready: tokio::sync::oneshot::Receiver>, - maybe_offer: Option, - snd_ident: Arc, - ) -> Result { - let (conn_snd, conn_rcv) = tokio::sync::mpsc::unbounded_channel(); - let conn_snd = ConnStateEvtSnd(conn_snd); - - let metric_bytes_snd = opentelemetry_api::global::meter_provider() - .versioned_meter( - "tx5", - None::<&'static str>, - None::<&'static str>, - Some(vec![ - opentelemetry_api::KeyValue::new( - "state_uniq", - state_uniq.to_string(), - ), - opentelemetry_api::KeyValue::new( - "conn_uniq", - conn_uniq.to_string(), - ), - opentelemetry_api::KeyValue::new( - "remote_id", - rem_id.to_string(), - ), - ]), - ) - .u64_observable_counter_atomic("tx5.endpoint.conn.send", 0) - .with_description("Outgoing bytes sent on this connection") - .with_unit(opentelemetry_api::metrics::Unit::new("By")) - .init() - .0; - - let metric_bytes_rcv = opentelemetry_api::global::meter_provider() - .versioned_meter( - "tx5", - None::<&'static str>, - None::<&'static str>, - Some(vec![ - opentelemetry_api::KeyValue::new( - "state_uniq", - state_uniq.to_string(), - ), - opentelemetry_api::KeyValue::new( - "conn_uniq", - conn_uniq.to_string(), - ), - opentelemetry_api::KeyValue::new( - "remote_id", - rem_id.to_string(), - ), - ]), - ) - .u64_observable_counter_atomic("tx5.endpoint.conn.recv", 0) - .with_description("Incoming bytes received on this connection") - .with_unit(opentelemetry_api::metrics::Unit::new("By")) - .init() - .0; - - let meta = ConnStateMeta { - created_at: std::time::Instant::now(), - last_active_at: std::time::Instant::now(), - cli_url, - state_uniq, - conn_uniq: conn_uniq.clone(), - config: config.clone(), - connected: Arc::new(atomic::AtomicBool::new(false)), - _conn_snd: conn_snd.clone(), - rcv_limit, - metric_bytes_snd, - metric_bytes_rcv, - snd_ident, - }; - - let actor = { - let meta = meta.clone(); - Actor::new(move |this, rcv| { - // woo, this is wonkey... - // we actually publish the "strong" version of this - // inside the task after waiting for the sig to be ready - // so upgrade here while the outer strong still exists, - // then the outer strong will be downgraded to return - // from new_and_publish... - let strong = ConnState(this.upgrade().unwrap(), meta.clone()); - conn_state_task( - conn_limit, - metric_conn_count, - meta.clone(), - strong, - ManyRcv(conn_rcv), - rcv, - ConnStateWeak(this, meta), - state, - conn_uniq, - this_id, - rem_id, - conn_snd, - sig_state, - sig_ready, - ) - }) - }; - - let actor = ConnState(actor, meta); - - if let Some(offer) = maybe_offer { - actor.in_offer(offer); - } - - let weak = actor.weak(); - tokio::task::spawn(async move { - tokio::time::sleep(config.max_conn_init()).await; - if let Some(actor) = weak.upgrade() { - actor.check_connected_timeout().await; - } - }); - - let weak = actor.weak(); - tokio::task::spawn(async move { - loop { - tokio::time::sleep(std::time::Duration::from_secs(1)).await; - match weak.upgrade() { - None => break, - Some(actor) => { - if actor.tick_1s().is_err() { - break; - } - } - } - } - }); - - Ok(actor.weak()) - } - - fn tick_1s(&self) -> Result<()> { - self.0.send(Ok(ConnCmd::Tick1s)) - } - - pub(crate) fn track_sig(&self, ty: &'static str, bytes: usize) { - let _ = self.0.send(Ok(ConnCmd::TrackSig { ty, bytes })); - } - - async fn check_connected_timeout(&self) { - let _ = self.0.send(Ok(ConnCmd::CheckConnectedTimeout)); - } - - fn notify_constructed(&self) -> Result<()> { - self.0.send(Ok(ConnCmd::NotifyConstructed)) - } - - fn self_offer(&self, offer: Result) { - let _ = self.0.send(Ok(ConnCmd::SelfOffer { offer })); - } - - fn req_self_answer(&self) { - let _ = self.0.send(Ok(ConnCmd::ReqSelfAnswer)); - } - - fn self_answer(&self, answer: Result) { - let _ = self.0.send(Ok(ConnCmd::SelfAnswer { answer })); - } - - pub(crate) fn in_offer(&self, offer: BackBuf) { - let _ = self.0.send(Ok(ConnCmd::InOffer { offer })); - } - - pub(crate) fn in_answer(&self, answer: BackBuf) { - let _ = self.0.send(Ok(ConnCmd::InAnswer { answer })); - } - - pub(crate) fn in_ice(&self, mut ice: BackBuf, cache: bool) { - let bytes = ice.len().unwrap(); - let _ = self.0.send(Ok(ConnCmd::TrackSig { - ty: "ice_in", - bytes, - })); - let _ = self.0.send(Ok(ConnCmd::InIce { ice, cache })); - } - - pub(crate) async fn check_send_waiting(&self, buf_state: Option) { - let _ = self.0.send(Ok(ConnCmd::MaybeFetchForSend { - send_complete: false, - buf_state, - })); - } - - pub(crate) fn send(&self, to_send: SendData) { - let _ = self.0.send(Ok(ConnCmd::Send { to_send })); - } - - pub(crate) fn notify_send_complete(&self, buf_state: BufState) { - let _ = self.0.send(Ok(ConnCmd::MaybeFetchForSend { - send_complete: true, - buf_state: Some(buf_state), - })); - } -} diff --git a/crates/tx5/src/state/drop_consider.rs b/crates/tx5/src/state/drop_consider.rs deleted file mode 100644 index 26b331df..00000000 --- a/crates/tx5/src/state/drop_consider.rs +++ /dev/null @@ -1,282 +0,0 @@ -use tx5_core::Uniq; - -#[derive(Debug)] -pub(crate) enum DropConsiderResult { - /// Force dropping the connection (supercedes ShouldKeep) - MustDrop, - - /// Recommend keeping the connection open - ShouldKeep, -} - -#[derive(Debug)] -pub(crate) struct DropConsiderArgs { - pub(crate) conn_uniq: Uniq, - pub(crate) cfg_conn_max_cnt: i64, - pub(crate) cfg_conn_max_init: f64, - pub(crate) tot_conn_cnt: i64, - pub(crate) tot_snd_bytes: u64, - pub(crate) tot_rcv_bytes: u64, - #[allow(dead_code)] - pub(crate) tot_avg_age_s: f64, - pub(crate) this_connected: bool, - pub(crate) this_snd_bytes: u64, - pub(crate) this_rcv_bytes: u64, - pub(crate) this_age_s: f64, - pub(crate) this_last_active_s: f64, -} - -pub(crate) fn drop_consider(args: &DropConsiderArgs) -> DropConsiderResult { - // sneak in a force keep new connections open long enough - // to try to connect. - if args.this_age_s < args.cfg_conn_max_init { - return DropConsiderResult::ShouldKeep; - } - - for consider in [ - consider_max, - consider_long_inactive, - consider_long_unconnected, - consider_connected_contention, - consider_low_throughput, - ] { - if let DropConsiderResult::MustDrop = consider(args) { - return DropConsiderResult::MustDrop; - } - } - - DropConsiderResult::ShouldKeep -} - -fn consider_max(args: &DropConsiderArgs) -> DropConsiderResult { - if std::time::Duration::from_secs_f64(args.this_age_s) > super::MAX_CON_TIME - { - tracing::trace!(conn_uniq = %args.conn_uniq, "MustDrop::consider_max"); - return DropConsiderResult::MustDrop; - } - DropConsiderResult::ShouldKeep -} - -fn consider_long_inactive(args: &DropConsiderArgs) -> DropConsiderResult { - if args.this_last_active_s >= args.cfg_conn_max_init { - tracing::trace!(conn_uniq = %args.conn_uniq, "MustDrop::consider_long_inactive"); - return DropConsiderResult::MustDrop; - } - DropConsiderResult::ShouldKeep -} - -fn consider_long_unconnected(args: &DropConsiderArgs) -> DropConsiderResult { - if !args.this_connected && args.this_age_s >= args.cfg_conn_max_init { - tracing::trace!(conn_uniq = %args.conn_uniq, "MustDrop::consider_long_unconnected"); - return DropConsiderResult::MustDrop; - } - DropConsiderResult::ShouldKeep -} - -fn consider_connected_contention( - args: &DropConsiderArgs, -) -> DropConsiderResult { - if args.this_connected - && args.tot_conn_cnt >= args.cfg_conn_max_cnt - && args.this_last_active_s >= args.cfg_conn_max_init / 2.0 - { - tracing::trace!(conn_uniq = %args.conn_uniq, "MustDrop::consider_connected_contention"); - return DropConsiderResult::MustDrop; - } - DropConsiderResult::ShouldKeep -} - -fn consider_low_throughput(args: &DropConsiderArgs) -> DropConsiderResult { - // if there is no contention, keep - if args.tot_conn_cnt < args.cfg_conn_max_cnt { - return DropConsiderResult::ShouldKeep; - } - - // if the total xfer is miniscule, keep - let tot_xfer = (args.tot_snd_bytes + args.tot_rcv_bytes) as f64; - if tot_xfer < 4096.0 { - return DropConsiderResult::ShouldKeep; - } - - // if the con hasn't existed for at least double the connect time, keep - if args.this_age_s < args.cfg_conn_max_init * 2.0 { - return DropConsiderResult::ShouldKeep; - } - - // if we have received/sent data recently, keep - if args.this_last_active_s < args.cfg_conn_max_init / 4.0 { - return DropConsiderResult::ShouldKeep; - } - - // if we are within 30% of the average, keep - let this_xfer = (args.this_snd_bytes + args.this_rcv_bytes) as f64; - let avg_xfer = tot_xfer / args.tot_conn_cnt as f64; - let fact = this_xfer / avg_xfer; - if fact >= 0.3 { - return DropConsiderResult::ShouldKeep; - } - - tracing::trace!(conn_uniq = %args.conn_uniq, "MustDrop::consider_low_throughput"); - - // finally, if we haven't kept, recommend dropping - DropConsiderResult::MustDrop -} - -// -- testing -- // - -#[cfg(test)] -impl Default for DropConsiderArgs { - fn default() -> Self { - Self { - conn_uniq: Uniq::default(), - cfg_conn_max_cnt: 20, - cfg_conn_max_init: 20.0, - tot_conn_cnt: 30, - tot_snd_bytes: 30720, - tot_rcv_bytes: 30720, - tot_avg_age_s: 30.0, - this_connected: true, - this_snd_bytes: 1024, - this_rcv_bytes: 1024, - this_age_s: 30.0, - this_last_active_s: 2.0, - } - } -} - -#[cfg(test)] -mod test { - use super::*; - - #[test] - fn drop_long_inactive() { - let mut args = DropConsiderArgs::default(); - args.tot_conn_cnt = 2; - - // first, the negative test - args.this_last_active_s = 10.0; - let res = drop_consider(&args); - assert!( - matches!(res, DropConsiderResult::ShouldKeep), - "\nexpected: ShouldKeep\ngot: {:?}\nargs: {:#?}", - res, - args, - ); - - // then, the positive - args.this_last_active_s = 30.0; - let res = drop_consider(&args); - assert!( - matches!(res, DropConsiderResult::MustDrop), - "\nexpected: MustDrop\ngot: {:?}\nargs: {:#?}", - res, - args, - ); - } - - #[test] - fn drop_long_unconnected() { - let mut args = DropConsiderArgs::default(); - - // first, the negative test - args.this_connected = true; - let res = drop_consider(&args); - assert!( - matches!(res, DropConsiderResult::ShouldKeep), - "\nexpected: ShouldKeep\ngot: {:?}\nargs: {:#?}", - res, - args, - ); - - // then, the positive - args.this_connected = false; - let res = drop_consider(&args); - assert!( - matches!(res, DropConsiderResult::MustDrop), - "\nexpected: MustDrop\ngot: {:?}\nargs: {:#?}", - res, - args, - ); - } - - #[test] - fn drop_connected_contention() { - let mut args = DropConsiderArgs::default(); - - // first, the negative test - args.this_last_active_s = 2.0; - let res = drop_consider(&args); - assert!( - matches!(res, DropConsiderResult::ShouldKeep), - "\nexpected: ShouldKeep\ngot: {:?}\nargs: {:#?}", - res, - args, - ); - - // then, the positive - args.this_last_active_s = 10.0; - let res = drop_consider(&args); - assert!( - matches!(res, DropConsiderResult::MustDrop), - "\nexpected: MustDrop\ngot: {:?}\nargs: {:#?}", - res, - args, - ); - } - - #[test] - fn keep_new_conns() { - let mut args = DropConsiderArgs::default(); - args.this_last_active_s = 30.0; - args.this_connected = false; - - // first, the negative test - args.this_age_s = 30.0; - let res = drop_consider(&args); - assert!( - matches!(res, DropConsiderResult::MustDrop), - "\nexpected: MustDrop\ngot: {:?}\nargs: {:#?}", - res, - args, - ); - - // then, the positive - args.this_age_s = 10.0; - let res = drop_consider(&args); - assert!( - matches!(res, DropConsiderResult::ShouldKeep), - "\nexpected: ShouldKeep\ngot: {:?}\nargs: {:#?}", - res, - args, - ); - } - - #[test] - fn low_throughput() { - let mut args = DropConsiderArgs::default(); - args.this_age_s = 50.0; - args.this_last_active_s = 7.0; - - // first, the negative test - args.this_snd_bytes = 1024; - args.this_rcv_bytes = 1024; - let res = drop_consider(&args); - assert!( - matches!(res, DropConsiderResult::ShouldKeep), - "\nexpected: ShouldKeep\ngot: {:?}\nargs: {:#?}", - res, - args, - ); - - // then, the positive - args.this_snd_bytes = 102; - args.this_rcv_bytes = 102; - let res = drop_consider(&args); - assert!( - matches!(res, DropConsiderResult::MustDrop), - "\nexpected: MustDrop\ngot: {:?}\nargs: {:#?}", - res, - args, - ); - } -} diff --git a/crates/tx5/src/state/sig.rs b/crates/tx5/src/state/sig.rs deleted file mode 100644 index 969daaaa..00000000 --- a/crates/tx5/src/state/sig.rs +++ /dev/null @@ -1,599 +0,0 @@ -use super::*; - -/// Temporary indicating we want a new signal instance. -pub struct SigStateSeed { - done: bool, - output: Option<(SigState, ManyRcv)>, -} - -impl std::fmt::Debug for SigStateSeed { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("SigStateSeed").finish() - } -} - -impl Drop for SigStateSeed { - fn drop(&mut self) { - self.result_err_inner(Error::id("Dropped")); - } -} - -impl SigStateSeed { - /// Finalize this sig_state seed by indicating a successful sig connection. - pub fn result_ok( - mut self, - cli_url: Tx5Url, - ice_servers: Arc, - ) -> Result<(SigState, ManyRcv)> { - self.done = true; - let (sig, sig_evt) = self.output.take().unwrap(); - sig.notify_connected(cli_url, ice_servers)?; - Ok((sig, sig_evt)) - } - - /// Finalize this sig_state seed by indicating an error connecting. - pub fn result_err(mut self, err: std::io::Error) { - self.result_err_inner(err); - } - - // -- // - - pub(crate) fn new(sig: SigState, sig_evt: ManyRcv) -> Self { - Self { - done: false, - output: Some((sig, sig_evt)), - } - } - - fn result_err_inner(&mut self, err: std::io::Error) { - if !self.done { - self.done = true; - if let Some((sig, _)) = self.output.take() { - sig.close(err); - } - } - } -} - -/// State wishes to invoke an action on a signal instance. -pub enum SigStateEvt { - /// Forward an offer to a remote. - SndOffer(Id, BackBuf, OneSnd<()>), - - /// Forward an answer to a remote. - SndAnswer(Id, BackBuf, OneSnd<()>), - - /// Forward an ICE candidate to a remote. - SndIce(Id, BackBuf, OneSnd<()>), - - /// Trigger a demo broadcast. - SndDemo, -} - -impl std::fmt::Debug for SigStateEvt { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - SigStateEvt::SndOffer(_, _, _) => f.write_str("SndOffer"), - SigStateEvt::SndAnswer(_, _, _) => f.write_str("SndAnswer"), - SigStateEvt::SndIce(_, _, _) => f.write_str("SndIce"), - SigStateEvt::SndDemo => f.write_str("SndDemo"), - } - } -} - -#[derive(Clone)] -struct SigStateEvtSnd(tokio::sync::mpsc::UnboundedSender>); - -impl SigStateEvtSnd { - pub fn err(&self, err: std::io::Error) { - let _ = self.0.send(Err(err)); - } - - pub fn snd_offer( - &self, - state: StateWeak, - sig: SigStateWeak, - rem_id: Id, - mut offer: BackBuf, - ) { - if let Some(state) = state.upgrade() { - state.track_sig(rem_id, "offer_out", offer.len().unwrap()); - } - let s = OneSnd::new(move |result| { - if let Err(err) = result { - if let Some(sig) = sig.upgrade() { - sig.close(err); - } - } - }); - let _ = self.0.send(Ok(SigStateEvt::SndOffer(rem_id, offer, s))); - } - - pub fn snd_answer( - &self, - state: StateWeak, - sig: SigStateWeak, - rem_id: Id, - mut answer: BackBuf, - ) { - if let Some(state) = state.upgrade() { - state.track_sig(rem_id, "answer_out", answer.len().unwrap()); - } - let s = OneSnd::new(move |result| { - if let Err(err) = result { - if let Some(sig) = sig.upgrade() { - sig.close(err); - } - } - }); - let _ = self.0.send(Ok(SigStateEvt::SndAnswer(rem_id, answer, s))); - } - - pub fn snd_ice( - &self, - state: StateWeak, - sig: SigStateWeak, - rem_id: Id, - mut ice: BackBuf, - ) { - if let Some(state) = state.upgrade() { - state.track_sig(rem_id, "ice_out", ice.len().unwrap()); - } - let s = OneSnd::new(move |result| { - if let Err(err) = result { - if let Some(sig) = sig.upgrade() { - sig.close(err); - } - } - }); - let _ = self.0.send(Ok(SigStateEvt::SndIce(rem_id, ice, s))); - } - - pub fn snd_demo(&self) { - let _ = self.0.send(Ok(SigStateEvt::SndDemo)); - } -} - -struct SigStateData { - this: SigStateWeak, - state: StateWeak, - sig_url: Tx5Url, - sig_evt: SigStateEvtSnd, - connected: bool, - cli_url: Option, - ice_servers: Option>, - pending_sig_resp: Vec>>, - registered_conn_map: HashMap, -} - -impl Drop for SigStateData { - fn drop(&mut self) { - self.shutdown(Error::id("Dropped")); - } -} - -impl SigStateData { - fn shutdown(&mut self, err: std::io::Error) { - if let Some(state) = self.state.upgrade() { - state.close_sig( - self.sig_url.clone(), - self.this.clone(), - err.err_clone(), - ); - } - for (_, conn) in self.registered_conn_map.drain() { - if let Some(conn) = conn.upgrade() { - conn.close(err.err_clone()); - } - drop(conn); - } - self.sig_evt.err(err); - } - - async fn exec(&mut self, cmd: SigCmd) -> Result<()> { - match cmd { - SigCmd::CheckConnectedTimeout => { - self.check_connected_timeout().await - } - SigCmd::PushAssertRespond { resp } => { - self.push_assert_respond(resp).await - } - SigCmd::NotifyConnected { - cli_url, - ice_servers, - } => self.notify_connected(cli_url, ice_servers).await, - SigCmd::RegisterConn { rem_id, conn, resp } => { - self.register_conn(rem_id, conn, resp).await - } - SigCmd::UnregisterConn { rem_id, conn } => { - self.unregister_conn(rem_id, conn).await - } - SigCmd::Offer { rem_id, data } => self.offer(rem_id, data).await, - SigCmd::Answer { rem_id, data } => self.answer(rem_id, data).await, - SigCmd::Ice { rem_id, data } => self.ice(rem_id, data).await, - SigCmd::Demo { rem_id } => self.demo(rem_id).await, - SigCmd::SndOffer { rem_id, data } => { - self.snd_offer(rem_id, data).await - } - SigCmd::SndAnswer { rem_id, data } => { - self.snd_answer(rem_id, data).await - } - SigCmd::SndIce { rem_id, data } => self.snd_ice(rem_id, data).await, - SigCmd::SndDemo => self.snd_demo().await, - } - } - - async fn check_connected_timeout(&mut self) -> Result<()> { - if !self.connected { - Err(Error::id("Timeout")) - } else { - Ok(()) - } - } - - async fn push_assert_respond( - &mut self, - resp: tokio::sync::oneshot::Sender>, - ) -> Result<()> { - if self.connected { - let _ = resp.send(Ok(self.cli_url.clone().unwrap())); - } else { - self.pending_sig_resp.push(resp); - } - Ok(()) - } - - async fn notify_connected( - &mut self, - cli_url: Tx5Url, - ice_servers: Arc, - ) -> Result<()> { - self.connected = true; - self.cli_url = Some(cli_url.clone()); - self.ice_servers = Some(ice_servers); - for resp in self.pending_sig_resp.drain(..) { - let _ = resp.send(Ok(cli_url.clone())); - } - if let Some(state) = self.state.upgrade() { - state.sig_connected(cli_url); - } - Ok(()) - } - - async fn register_conn( - &mut self, - rem_id: Id, - conn: ConnStateWeak, - resp: tokio::sync::oneshot::Sender>>, - ) -> Result<()> { - self.registered_conn_map.insert(rem_id, conn); - let _ = resp.send( - self.ice_servers - .clone() - .ok_or_else(|| Error::id("NoIceServers")), - ); - Ok(()) - } - - async fn unregister_conn( - &mut self, - rem_id: Id, - conn: ConnStateWeak, - ) -> Result<()> { - if let Some(cur_conn) = self.registered_conn_map.remove(&rem_id) { - if cur_conn != conn { - // Whoops! - self.registered_conn_map.insert(rem_id, cur_conn); - } - } - Ok(()) - } - - async fn offer(&mut self, rem_id: Id, mut data: BackBuf) -> Result<()> { - let len = data.len().unwrap(); - if let Some(state) = self.state.upgrade() { - state.in_offer(self.sig_url.clone(), rem_id, data)?; - state.track_sig(rem_id, "offer_in", len); - } - Ok(()) - } - - async fn answer(&mut self, rem_id: Id, mut data: BackBuf) -> Result<()> { - let len = data.len().unwrap(); - if let Some(conn) = self.registered_conn_map.get(&rem_id) { - if let Some(conn) = conn.upgrade() { - conn.in_answer(data); - } - } - if let Some(state) = self.state.upgrade() { - state.track_sig(rem_id, "answer_in", len); - } - Ok(()) - } - - async fn ice(&mut self, rem_id: Id, data: BackBuf) -> Result<()> { - if let Some(conn) = self.registered_conn_map.get(&rem_id) { - if let Some(conn) = conn.upgrade() { - conn.in_ice(data, true); - return Ok(()); - } - } - if let Some(state) = self.state.upgrade() { - let _ = state.cache_ice(rem_id, data); - } - Ok(()) - } - - async fn demo(&mut self, rem_id: Id) -> Result<()> { - if let Some(state) = self.state.upgrade() { - state.in_demo(self.sig_url.clone(), rem_id)?; - } - Ok(()) - } - - async fn snd_offer(&mut self, rem_id: Id, data: BackBuf) -> Result<()> { - self.sig_evt.snd_offer( - self.state.clone(), - self.this.clone(), - rem_id, - data, - ); - Ok(()) - } - - async fn snd_answer(&mut self, rem_id: Id, data: BackBuf) -> Result<()> { - self.sig_evt.snd_answer( - self.state.clone(), - self.this.clone(), - rem_id, - data, - ); - Ok(()) - } - - async fn snd_ice(&mut self, rem_id: Id, data: BackBuf) -> Result<()> { - self.sig_evt.snd_ice( - self.state.clone(), - self.this.clone(), - rem_id, - data, - ); - Ok(()) - } - - async fn snd_demo(&mut self) -> Result<()> { - self.sig_evt.snd_demo(); - Ok(()) - } -} - -enum SigCmd { - CheckConnectedTimeout, - PushAssertRespond { - resp: tokio::sync::oneshot::Sender>, - }, - NotifyConnected { - cli_url: Tx5Url, - ice_servers: Arc, - }, - RegisterConn { - rem_id: Id, - conn: ConnStateWeak, - resp: tokio::sync::oneshot::Sender>>, - }, - UnregisterConn { - rem_id: Id, - conn: ConnStateWeak, - }, - Offer { - rem_id: Id, - data: BackBuf, - }, - Answer { - rem_id: Id, - data: BackBuf, - }, - Ice { - rem_id: Id, - data: BackBuf, - }, - Demo { - rem_id: Id, - }, - SndOffer { - rem_id: Id, - data: BackBuf, - }, - SndAnswer { - rem_id: Id, - data: BackBuf, - }, - SndIce { - rem_id: Id, - data: BackBuf, - }, - SndDemo, -} - -async fn sig_state_task( - mut rcv: ManyRcv, - this: SigStateWeak, - state: StateWeak, - sig_url: Tx5Url, - sig_evt: SigStateEvtSnd, - pending_sig_resp: Vec>>, -) -> Result<()> { - let mut data = SigStateData { - this, - state, - sig_url, - sig_evt, - connected: false, - cli_url: None, - ice_servers: None, - pending_sig_resp, - registered_conn_map: HashMap::new(), - }; - let err = match async { - while let Some(cmd) = rcv.recv().await { - data.exec(cmd?).await?; - } - Ok(()) - } - .await - { - Err(err) => err, - Ok(_) => Error::id("Dropped"), - }; - data.shutdown(err.err_clone()); - Err(err) -} - -/// Weak version of SigState. -#[derive(Clone, PartialEq, Eq)] -pub struct SigStateWeak(ActorWeak); - -impl PartialEq for SigStateWeak { - fn eq(&self, rhs: &SigState) -> bool { - self.0 == rhs.0 - } -} - -impl SigStateWeak { - /// Upgrade to a full SigState instance. - pub fn upgrade(&self) -> Option { - self.0.upgrade().map(SigState) - } -} - -/// A handle for notifying the state system of signal events. -#[derive(Clone, PartialEq, Eq)] -pub struct SigState(Actor); - -impl PartialEq for SigState { - fn eq(&self, rhs: &SigStateWeak) -> bool { - self.0 == rhs.0 - } -} - -impl SigState { - /// Get a weak version of this SigState instance. - pub fn weak(&self) -> SigStateWeak { - SigStateWeak(self.0.weak()) - } - - /// Returns `true` if this SigState is closed. - pub fn is_closed(&self) -> bool { - self.0.is_closed() - } - - /// Shutdown the signal client with an error. - pub fn close(&self, err: std::io::Error) { - self.0.close(err); - } - - /// Receive an incoming offer from a remote. - pub fn offer(&self, rem_id: Id, data: BackBuf) -> Result<()> { - self.0.send(Ok(SigCmd::Offer { rem_id, data })) - } - - /// Receive an incoming answer from a remote. - pub fn answer(&self, rem_id: Id, data: BackBuf) -> Result<()> { - self.0.send(Ok(SigCmd::Answer { rem_id, data })) - } - - /// Receive an incoming ice candidate from a remote. - pub fn ice(&self, rem_id: Id, data: BackBuf) -> Result<()> { - self.0.send(Ok(SigCmd::Ice { rem_id, data })) - } - - /// Receive a demo broadcast from a remote. - pub fn demo(&self, rem_id: Id) -> Result<()> { - self.0.send(Ok(SigCmd::Demo { rem_id })) - } - - // -- // - - pub(crate) fn new( - state: StateWeak, - sig_url: Tx5Url, - resp: tokio::sync::oneshot::Sender>, - timeout: std::time::Duration, - ) -> (Self, ManyRcv) { - let (sig_snd, sig_rcv) = tokio::sync::mpsc::unbounded_channel(); - let actor = Actor::new(move |this, rcv| { - sig_state_task( - rcv, - SigStateWeak(this), - state, - sig_url, - SigStateEvtSnd(sig_snd), - vec![resp], - ) - }); - let weak = SigStateWeak(actor.weak()); - tokio::task::spawn(async move { - tokio::time::sleep(timeout).await; - if let Some(actor) = weak.upgrade() { - actor.check_connected_timeout().await; - } - }); - (Self(actor), ManyRcv(sig_rcv)) - } - - pub(crate) fn snd_offer(&self, rem_id: Id, data: BackBuf) -> Result<()> { - self.0.send(Ok(SigCmd::SndOffer { rem_id, data })) - } - - pub(crate) fn snd_answer(&self, rem_id: Id, data: BackBuf) -> Result<()> { - self.0.send(Ok(SigCmd::SndAnswer { rem_id, data })) - } - - pub(crate) fn snd_ice(&self, rem_id: Id, data: BackBuf) -> Result<()> { - self.0.send(Ok(SigCmd::SndIce { rem_id, data })) - } - - pub(crate) fn snd_demo(&self) { - let _ = self.0.send(Ok(SigCmd::SndDemo)); - } - - async fn check_connected_timeout(&self) { - let _ = self.0.send(Ok(SigCmd::CheckConnectedTimeout)); - } - - pub(crate) async fn register_conn( - &self, - rem_id: Id, - conn: ConnStateWeak, - ) -> Result> { - let (s, r) = tokio::sync::oneshot::channel(); - self.0.send(Ok(SigCmd::RegisterConn { - rem_id, - conn, - resp: s, - }))?; - r.await.map_err(|_| Error::id("Closed"))? - } - - pub(crate) fn unregister_conn(&self, rem_id: Id, conn: ConnStateWeak) { - let _ = self.0.send(Ok(SigCmd::UnregisterConn { rem_id, conn })); - } - - pub(crate) async fn push_assert_respond( - &self, - resp: tokio::sync::oneshot::Sender>, - ) { - let _ = self.0.send(Ok(SigCmd::PushAssertRespond { resp })); - } - - fn notify_connected( - &self, - cli_url: Tx5Url, - ice_servers: Arc, - ) -> Result<()> { - self.0.send(Ok(SigCmd::NotifyConnected { - cli_url, - ice_servers, - })) - } -} diff --git a/crates/tx5/src/state/test.rs b/crates/tx5/src/state/test.rs deleted file mode 100644 index f5d78e5f..00000000 --- a/crates/tx5/src/state/test.rs +++ /dev/null @@ -1,650 +0,0 @@ -use super::*; - -fn init_tracing() { - let subscriber = tracing_subscriber::FmtSubscriber::builder() - .with_env_filter( - tracing_subscriber::filter::EnvFilter::from_default_env(), - ) - .with_file(true) - .with_line_number(true) - .finish(); - let _ = tracing::subscriber::set_global_default(subscriber); -} - -#[allow(dead_code)] -struct Test { - shutdown: bool, - cli_a: Tx5Url, - id_a: Id, - cli_b: Tx5Url, - id_b: Id, - state: State, - state_evt: ManyRcv, - sig_state: SigState, - sig_evt: ManyRcv, -} - -impl Drop for Test { - fn drop(&mut self) { - if !self.shutdown { - // print and abort, since panic within drop breaks things - eprintln!("Expected Test::shutdown() to be called, aborting. Backtrace: {:#?}", std::backtrace::Backtrace::capture()); - std::process::abort(); - } - } -} - -impl Test { - pub async fn new(as_a: bool) -> Self { - init_tracing(); - - let sig: Tx5Url = Tx5Url::new("wss://s").unwrap(); - let cli_a: Tx5Url = Tx5Url::new( - "wss://s/tx5-ws/AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA", - ) - .unwrap(); - let id_a = cli_a.id().unwrap(); - let cli_b: Tx5Url = Tx5Url::new( - "wss://s/tx5-ws/BAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA", - ) - .unwrap(); - let id_b = cli_b.id().unwrap(); - - let config = DefConfig::default().into_config().await.unwrap(); - let (state, mut state_evt) = State::new(config).unwrap(); - - // -- register with a signal server -- // - - let task = { - let state = state.clone(); - let sig = sig.clone(); - - // can't do this inline, since it won't resolve until result_ok - // call on the seed below - tokio::task::spawn( - async move { state.listener_sig(sig).await.unwrap() }, - ) - }; - - let sig_seed = match state_evt.recv().await { - Some(Ok(StateEvt::NewSig(_url, seed))) => seed, - oth => panic!("unexpected: {:?}", oth), - }; - - let cli = if as_a { cli_a.clone() } else { cli_b.clone() }; - let (sig_state, sig_evt) = sig_seed - .result_ok(cli, Arc::new(serde_json::json!([]))) - .unwrap(); - - task.await.unwrap(); - - if as_a { - assert!(matches!( - state_evt.recv().await, - Some(Ok(StateEvt::Address(tmp))) if tmp == cli_a, - )); - } else { - assert!(matches!( - state_evt.recv().await, - Some(Ok(StateEvt::Address(tmp))) if tmp == cli_b, - )); - } - - println!("got addr"); - - Self { - shutdown: false, - cli_a, - id_a, - cli_b, - id_b, - state, - state_evt, - sig_state, - sig_evt, - } - } - - pub async fn shutdown(mut self) { - self.shutdown = true; - - self.state.close(Error::id("TestShutdown")); - - let res = self.state_evt.recv().await; - assert!( - matches!(res, Some(Ok(StateEvt::Disconnected { .. })),), - "expected Disconnected, got: {:?}", - res - ); - - let res = self.state_evt.recv().await; - assert!( - matches!( - res, - Some(Err(ref err)) if &err.to_string() == "TestShutdown", - ), - "expected Err(\"TestShutdown\"), got: {:?}", - res - ); - - // erm... is this what we want?? - assert!(matches!( - self.state_evt.recv().await, - Some(Err(err)) if &err.to_string() == "Dropped", - )); - - assert!(matches!(self.state_evt.recv().await, None)); - } -} - -#[tokio::test(flavor = "multi_thread")] -async fn extended_outgoing() { - better_panic::install(); - - let mut test = Test::new(true).await; - - // -- send data to a "peer" (causes connecting to that peer) -- // - - let task = { - let state = test.state.clone(); - let cli_b = test.cli_b.clone(); - - tokio::task::spawn(async move { - state.snd_data(cli_b.clone(), &b"hello"[..]).await.unwrap() - }) - }; - - // -- new peer connection -- // - - let conn_seed = match test.state_evt.recv().await { - Some(Ok(StateEvt::NewConn(_ice_servers, seed))) => seed, - oth => panic!("unexpected: {:?}", oth), - }; - - println!("got new conn"); - - let (conn_state, mut conn_evt) = conn_seed.result_ok().unwrap(); - - // -- generate an offer -- // - - let mut resp = match conn_evt.recv().await { - Some(Ok(ConnStateEvt::CreateOffer(resp))) => resp, - oth => panic!("unexpected: {:?}", oth), - }; - - resp.send(BackBuf::from_slice(b"offer")); - - println!("got create_offer"); - - match test.sig_evt.recv().await { - Some(Ok(SigStateEvt::SndOffer(id, mut buf, mut resp))) => { - assert_eq!(id, test.id_b); - assert_eq!(&buf.to_vec().unwrap(), b"offer"); - resp.send(Ok(())); - } - oth => panic!("unexpected: {:?}", oth), - } - - println!("sent offer"); - - match conn_evt.recv().await { - Some(Ok(ConnStateEvt::SetLoc(mut offer, mut resp))) => { - assert_eq!(&offer.to_vec().unwrap(), b"offer"); - resp.send(Ok(())); - } - oth => panic!("unexpected: {:?}", oth), - } - - println!("set loc"); - - test.sig_state - .answer(test.id_b, BackBuf::from_slice(b"answer").unwrap()) - .unwrap(); - - match conn_evt.recv().await { - Some(Ok(ConnStateEvt::SetRem(mut answer, mut resp))) => { - assert_eq!(&answer.to_vec().unwrap(), b"answer"); - resp.send(Ok(())); - } - oth => panic!("unexpected: {:?}", oth), - }; - - println!("set rem"); - - conn_state - .ice(BackBuf::from_slice(b"ice").unwrap()) - .unwrap(); - - match test.sig_evt.recv().await { - Some(Ok(SigStateEvt::SndIce(id, mut buf, mut resp))) => { - assert_eq!(id, test.id_b); - assert_eq!(&buf.to_vec().unwrap(), b"ice"); - resp.send(Ok(())); - } - oth => panic!("unexpected: {:?}", oth), - } - - test.sig_state - .ice(test.id_b, BackBuf::from_slice(b"rem_ice").unwrap()) - .unwrap(); - - println!("sent ice"); - - match conn_evt.recv().await { - Some(Ok(ConnStateEvt::SetIce(mut ice, mut resp))) => { - assert_eq!(&ice.to_vec().unwrap(), b"rem_ice"); - resp.send(Ok(())); - } - oth => panic!("unexpected: {:?}", oth), - }; - - println!("set rem ice"); - - conn_state.ready().unwrap(); - - println!("ready"); - - match conn_evt.recv().await { - Some(Ok(ConnStateEvt::SndData(mut data, mut resp))) => { - // blank message for preflight data - assert_eq!(8, data.to_vec().unwrap().len()); - resp.send(Ok(BufState::Low)); - } - oth => panic!("unexpected: {:?}", oth), - }; - - // receive the empty preflight - conn_state - .rcv_data(BackBuf::from_slice(b"\0\0\0\0\0\0\0\x80").unwrap()) - .unwrap(); - - match test.state_evt.recv().await { - Some(Ok(StateEvt::Connected { .. })) => (), - oth => panic!("unexpected: {:?}", oth), - } - - match conn_evt.recv().await { - Some(Ok(ConnStateEvt::SndData(mut data, mut resp))) => { - assert_eq!(&data.to_vec().unwrap()[8..], b"hello"); - resp.send(Ok(BufState::Low)); - } - oth => panic!("unexpected: {:?}", oth), - }; - - println!("snd data"); - - task.await.unwrap(); - - // -- recv data from the remote -- // - - println!("about to rcv"); - - // now, receive the actual message - conn_state - .rcv_data(BackBuf::from_slice(b"\x2a\0\0\0\0\0\0\x80world").unwrap()) - .unwrap(); - - match test.state_evt.recv().await { - Some(Ok(StateEvt::RcvData(url, data, _permit))) => { - assert_eq!(url, test.cli_b); - assert_eq!(&data.to_vec().unwrap(), b"world"); - } - oth => panic!("unexpected: {:?}", oth), - }; - - println!("rcv data"); - - test.shutdown().await; -} - -#[tokio::test(flavor = "multi_thread")] -async fn short_incoming() { - better_panic::install(); - - let mut test = Test::new(true).await; - - // -- receive an incoming offer -- // - - test.sig_state - .offer(test.id_b, BackBuf::from_slice(b"offer").unwrap()) - .unwrap(); - - // -- new peer connection -- // - - let conn_seed = match test.state_evt.recv().await { - Some(Ok(StateEvt::NewConn(_ice_servers, seed))) => seed, - oth => panic!("unexpected: {:?}", oth), - }; - - println!("got new conn"); - - let (_conn_state, mut conn_evt) = conn_seed.result_ok().unwrap(); - - match conn_evt.recv().await { - Some(Ok(ConnStateEvt::SetRem(mut offer, mut resp))) => { - assert_eq!(&offer.to_vec().unwrap(), b"offer"); - resp.send(Ok(())); - } - oth => panic!("unexpected: {:?}", oth), - }; - - println!("set rem"); - - let mut resp = match conn_evt.recv().await { - Some(Ok(ConnStateEvt::CreateAnswer(resp))) => resp, - oth => panic!("unexpected {:?}", oth), - }; - - resp.send(BackBuf::from_slice(b"answer")); - - println!("got create_answer"); - - match conn_evt.recv().await { - Some(Ok(ConnStateEvt::SetLoc(mut answer, mut resp))) => { - assert_eq!(&answer.to_vec().unwrap(), b"answer"); - resp.send(Ok(())); - } - oth => panic!("unexpected: {:?}", oth), - }; - - println!("set loc"); - - match test.sig_evt.recv().await { - Some(Ok(SigStateEvt::SndAnswer(id, mut buf, mut resp))) => { - assert_eq!(id, test.id_b); - assert_eq!(&buf.to_vec().unwrap(), b"answer"); - resp.send(Ok(())); - } - oth => panic!("unexpected: {:?}", oth), - } - - println!("sent answer"); - - test.shutdown().await; -} - -#[tokio::test(flavor = "multi_thread")] -async fn polite_in_offer() { - better_panic::install(); - - let mut test = Test::new(true).await; - - // -- send data to a "peer" (causes connecting to that peer) -- // - - let task = { - let state = test.state.clone(); - let cli_b = test.cli_b.clone(); - - tokio::task::spawn(async move { - state.snd_data(cli_b.clone(), &b"hello"[..]).await.unwrap() - }) - }; - - let conn_seed = match test.state_evt.recv().await { - Some(Ok(StateEvt::NewConn(_ice_servers, seed))) => seed, - oth => panic!("unexpected: {:?}", oth), - }; - - println!("got new conn"); - - let (_conn_state, mut conn_evt) = conn_seed.result_ok().unwrap(); - - // -- generate an offer -- // - - let mut resp = match conn_evt.recv().await { - Some(Ok(ConnStateEvt::CreateOffer(resp))) => resp, - oth => panic!("unexpected: {:?}", oth), - }; - - resp.send(BackBuf::from_slice(b"offer")); - - println!("got create_offer"); - - match test.sig_evt.recv().await { - Some(Ok(SigStateEvt::SndOffer(id, mut buf, mut resp))) => { - assert_eq!(id, test.id_b); - assert_eq!(&buf.to_vec().unwrap(), b"offer"); - resp.send(Ok(())); - } - oth => panic!("unexpected: {:?}", oth), - } - - println!("sent offer"); - - match conn_evt.recv().await { - Some(Ok(ConnStateEvt::SetLoc(mut offer, mut resp))) => { - assert_eq!(&offer.to_vec().unwrap(), b"offer"); - resp.send(Ok(())); - } - oth => panic!("unexpected: {:?}", oth), - } - - println!("set loc"); - - // - BUT, instead we get an new incoming OFFER - // maybe because the other node started a racy try to connect to us too? - - test.sig_state - .offer(test.id_b, BackBuf::from_slice(b"in_offer").unwrap()) - .unwrap(); - - match conn_evt.recv().await { - Some(Err(err)) => { - assert_eq!("PoliteShutdownToAcceptIncomingOffer", &err.to_string()) - } - oth => panic!("unexpected: {:?}", oth), - } - - match test.state_evt.recv().await { - Some(Ok(StateEvt::Disconnected { .. })) => (), - oth => panic!("unexpected: {:?}", oth), - } - - let conn_seed = match test.state_evt.recv().await { - Some(Ok(StateEvt::NewConn(_ice_servers, seed))) => seed, - oth => panic!("unexpected: {:?}", oth), - }; - - println!("got new conn"); - - let (conn_state, mut conn_evt) = conn_seed.result_ok().unwrap(); - - match conn_evt.recv().await { - Some(Ok(ConnStateEvt::SetRem(mut offer, mut resp))) => { - assert_eq!(&offer.to_vec().unwrap(), b"in_offer"); - resp.send(Ok(())); - } - oth => panic!("unexpected: {:?}", oth), - }; - - println!("set rem"); - - let mut resp = match conn_evt.recv().await { - Some(Ok(ConnStateEvt::CreateAnswer(resp))) => resp, - oth => panic!("unexpected {:?}", oth), - }; - - resp.send(BackBuf::from_slice(b"answer")); - - println!("got create_answer"); - - match conn_evt.recv().await { - Some(Ok(ConnStateEvt::SetLoc(mut answer, mut resp))) => { - assert_eq!(&answer.to_vec().unwrap(), b"answer"); - resp.send(Ok(())); - } - oth => panic!("unexpected: {:?}", oth), - }; - - println!("set loc"); - - match test.sig_evt.recv().await { - Some(Ok(SigStateEvt::SndAnswer(id, mut buf, mut resp))) => { - assert_eq!(id, test.id_b); - assert_eq!(&buf.to_vec().unwrap(), b"answer"); - resp.send(Ok(())); - } - oth => panic!("unexpected: {:?}", oth), - } - - println!("sent answer"); - - conn_state.ready().unwrap(); - - println!("ready"); - - match conn_evt.recv().await { - Some(Ok(ConnStateEvt::SndData(mut data, mut resp))) => { - // blank message for preflight data - assert_eq!(8, data.to_vec().unwrap().len()); - resp.send(Ok(BufState::Low)); - } - oth => panic!("unexpected: {:?}", oth), - }; - - // receive the empty preflight - conn_state - .rcv_data(BackBuf::from_slice(b"\0\0\0\0\0\0\0\x80").unwrap()) - .unwrap(); - - match test.state_evt.recv().await { - Some(Ok(StateEvt::Connected { .. })) => (), - oth => panic!("unexpected: {:?}", oth), - } - - match conn_evt.recv().await { - Some(Ok(ConnStateEvt::SndData(mut data, mut resp))) => { - assert_eq!(&data.to_vec().unwrap()[8..], b"hello"); - resp.send(Ok(BufState::Low)); - } - oth => panic!("unexpected: {:?}", oth), - }; - - println!("snd data"); - - // finally the data is sent - task.await.unwrap(); - - test.shutdown().await; -} - -#[tokio::test(flavor = "multi_thread")] -async fn impolite_in_offer() { - better_panic::install(); - - let mut test = Test::new(false).await; - - // -- send data to a "peer" (causes connecting to that peer) -- // - - let task = { - let state = test.state.clone(); - let cli_a = test.cli_a.clone(); - - tokio::task::spawn(async move { - state.snd_data(cli_a.clone(), &b"hello"[..]).await.unwrap() - }) - }; - - let conn_seed = match test.state_evt.recv().await { - Some(Ok(StateEvt::NewConn(_ice_servers, seed))) => seed, - oth => panic!("unexpected: {:?}", oth), - }; - - println!("got new conn"); - - let (conn_state, mut conn_evt) = conn_seed.result_ok().unwrap(); - - // -- generate an offer -- // - - let mut resp = match conn_evt.recv().await { - Some(Ok(ConnStateEvt::CreateOffer(resp))) => resp, - oth => panic!("unexpected: {:?}", oth), - }; - - resp.send(BackBuf::from_slice(b"offer")); - - println!("got create_offer"); - - match test.sig_evt.recv().await { - Some(Ok(SigStateEvt::SndOffer(id, mut buf, mut resp))) => { - assert_eq!(id, test.id_a); - assert_eq!(&buf.to_vec().unwrap(), b"offer"); - resp.send(Ok(())); - } - oth => panic!("unexpected: {:?}", oth), - } - - println!("sent offer"); - - match conn_evt.recv().await { - Some(Ok(ConnStateEvt::SetLoc(mut offer, mut resp))) => { - assert_eq!(&offer.to_vec().unwrap(), b"offer"); - resp.send(Ok(())); - } - oth => panic!("unexpected: {:?}", oth), - } - - println!("set loc"); - - // - BUT, instead we get an new incoming OFFER - // maybe because the other node started a racy try to connect to us too? - - test.sig_state - .offer(test.id_a, BackBuf::from_slice(b"in_offer").unwrap()) - .unwrap(); - - // since we're the IMPOLITE node, we just ignore this offer - // and continue with the negotiation of the original connection. - - test.sig_state - .answer(test.id_a, BackBuf::from_slice(b"answer").unwrap()) - .unwrap(); - - match conn_evt.recv().await { - Some(Ok(ConnStateEvt::SetRem(mut answer, mut resp))) => { - assert_eq!(&answer.to_vec().unwrap(), b"answer"); - resp.send(Ok(())); - } - oth => panic!("unexpected: {:?}", oth), - }; - - println!("set rem"); - - conn_state.ready().unwrap(); - - println!("ready"); - - match conn_evt.recv().await { - Some(Ok(ConnStateEvt::SndData(mut data, mut resp))) => { - // blank message for preflight data - assert_eq!(8, data.to_vec().unwrap().len()); - resp.send(Ok(BufState::Low)); - } - oth => panic!("unexpected: {:?}", oth), - }; - - // receive the empty preflight - conn_state - .rcv_data(BackBuf::from_slice(b"\0\0\0\0\0\0\0\x80").unwrap()) - .unwrap(); - - match test.state_evt.recv().await { - Some(Ok(StateEvt::Connected { .. })) => (), - oth => panic!("unexpected: {:?}", oth), - } - - match conn_evt.recv().await { - Some(Ok(ConnStateEvt::SndData(mut data, mut resp))) => { - assert_eq!(&data.to_vec().unwrap()[8..], b"hello"); - resp.send(Ok(BufState::Low)); - } - oth => panic!("unexpected: {:?}", oth), - }; - - println!("snd data"); - - // finally the data is sent - task.await.unwrap(); - - test.shutdown().await; -} diff --git a/crates/tx5/src/test.rs b/crates/tx5/src/test.rs deleted file mode 100644 index 2f421ab7..00000000 --- a/crates/tx5/src/test.rs +++ /dev/null @@ -1,723 +0,0 @@ -use super::*; -use std::sync::Arc; - -fn init_tracing() { - let subscriber = tracing_subscriber::FmtSubscriber::builder() - .with_env_filter( - tracing_subscriber::filter::EnvFilter::from_default_env(), - ) - .with_file(true) - .with_line_number(true) - .finish(); - let _ = tracing::subscriber::set_global_default(subscriber); -} - -#[tokio::test(flavor = "multi_thread")] -async fn check_send_backpressure() { - init_tracing(); - - let this_id = Id::from([1; 32]); - let this_id = &this_id; - let other_id = Id::from([2; 32]); - let other_id = &other_id; - - let sig_url = Tx5Url::new("ws://fake:1").unwrap(); - - let this_url = sig_url.to_client(this_id.clone()); - let this_url = &this_url; - let other_url = sig_url.to_client(other_id.clone()); - let other_url = &other_url; - - let send_count = Arc::new(std::sync::atomic::AtomicUsize::new(0)); - - let notify = Arc::new(tokio::sync::Notify::new()); - - let this_url_sig = this_url.clone(); - let other_id_sig = other_id.clone(); - let send_count_conn = send_count.clone(); - let notify_conn = notify.clone(); - let config = DefConfig::default() - .with_max_send_bytes(20) - .with_per_data_chan_buf_low(10) - .with_new_sig_cb(move |_, _, seed| { - let (sig_state, mut sig_evt) = seed - .result_ok( - this_url_sig.clone(), - Arc::new(serde_json::json!([])), - ) - .unwrap(); - let other_id = other_id_sig.clone(); - tokio::task::spawn(async move { - while let Some(Ok(evt)) = sig_evt.recv().await { - match evt { - state::SigStateEvt::SndOffer(_, _, mut r) => { - r.send(Ok(())); - sig_state - .answer( - other_id.clone(), - BackBuf::from_slice(&[]).unwrap(), - ) - .unwrap(); - } - _ => println!("unhandled SIG_EVT: {evt:?}"), - } - } - }); - }) - .with_new_conn_cb(move |_, _, seed| { - let (conn_state, mut conn_evt) = seed.result_ok().unwrap(); - let send_count_conn = send_count_conn.clone(); - let notify_conn = notify_conn.clone(); - tokio::task::spawn(async move { - while let Some(Ok(evt)) = conn_evt.recv().await { - match evt { - state::ConnStateEvt::CreateOffer(mut r) => { - r.send(BackBuf::from_slice(&[])); - } - state::ConnStateEvt::SetLoc(_, mut r) => { - r.send(Ok(())); - } - state::ConnStateEvt::SetRem(_, mut r) => { - r.send(Ok(())); - } - state::ConnStateEvt::SndData(_, mut r) => { - let notify_conn = notify_conn.clone(); - tokio::task::spawn(async move { - notify_conn.notified().await; - r.send(Ok(state::BufState::Low)); - }); - send_count_conn.fetch_add( - 1, - std::sync::atomic::Ordering::SeqCst, - ); - } - _ => println!("unhandled CONN_EVT: {evt:?}"), - } - } - }); - conn_state.ready().unwrap(); - tokio::task::spawn(async move { - let _conn_state = conn_state; - - tokio::time::sleep(std::time::Duration::from_secs(5)).await; - }); - }); - - let (ep1, _ep_rcv1) = Ep::with_config(config).await.unwrap(); - - ep1.listen(sig_url).await.unwrap(); - - let fut1 = ep1.send(other_url.clone(), &b"1234567890"[..]); - let send_task_1 = tokio::task::spawn(async move { - fut1.await.unwrap(); - }); - - let fut2 = ep1.send(other_url.clone(), &b"1234567890"[..]); - let send_task_2 = tokio::task::spawn(async move { - fut2.await.unwrap(); - }); - - tokio::time::sleep(std::time::Duration::from_millis(100)).await; - - // make sure only the preflight and our first message have been "sent" - assert_eq!(2, send_count.load(std::sync::atomic::Ordering::SeqCst)); - - notify.notify_waiters(); - - tokio::time::sleep(std::time::Duration::from_millis(100)).await; - - // after sending the preflight and first messages through, - // now the second message is queued up for send on our mock backend. - assert_eq!(3, send_count.load(std::sync::atomic::Ordering::SeqCst)); - - // now let the second message through - notify.notify_waiters(); - - // make sure our send tasks resolve - send_task_1.await.unwrap(); - send_task_2.await.unwrap(); -} - -#[tokio::test(flavor = "multi_thread")] -async fn endpoint_sanity() { - init_tracing(); - - let mut srv_config = tx5_signal_srv::Config::default(); - srv_config.port = 0; - srv_config.demo = true; - - let (srv_driver, addr_list, _) = - tx5_signal_srv::exec_tx5_signal_srv(srv_config).unwrap(); - tokio::task::spawn(srv_driver); - - let sig_port = addr_list.get(0).unwrap().port(); - - let sig_url = Tx5Url::new(format!("ws://localhost:{}", sig_port)).unwrap(); - println!("sig_url: {}", sig_url); - - let (ep1, _ep_rcv1) = Ep::new().await.unwrap(); - - let cli_url1 = ep1.listen(sig_url.clone()).await.unwrap(); - - println!("cli_url1: {}", cli_url1); - - let (ep2, mut ep_rcv2) = Ep::new().await.unwrap(); - - let cli_url2 = ep2.listen(sig_url).await.unwrap(); - - println!("cli_url2: {}", cli_url2); - - ep1.send(cli_url2, &b"hello"[..]).await.unwrap(); - - match ep_rcv2.recv().await { - Some(Ok(EpEvt::Connected { .. })) => (), - oth => panic!("unexpected: {:?}", oth), - } - - let recv = ep_rcv2.recv().await; - - match recv { - Some(Ok(EpEvt::Data { - rem_cli_url, data, .. - })) => { - assert_eq!(cli_url1, rem_cli_url); - assert_eq!(b"hello", data.to_vec().unwrap().as_slice()); - } - oth => panic!("unexpected {:?}", oth), - } - - ep1.ban([42; 32].into(), std::time::Duration::from_secs(42)); - ep2.ban([43; 32].into(), std::time::Duration::from_secs(43)); - - println!( - "{}", - serde_json::to_string_pretty(&ep1.get_stats().await.unwrap()).unwrap() - ); - println!( - "{}", - serde_json::to_string_pretty(&ep2.get_stats().await.unwrap()).unwrap() - ); -} - -#[tokio::test(flavor = "multi_thread")] -async fn disconnect() { - init_tracing(); - - let mut srv_config = tx5_signal_srv::Config::default(); - srv_config.port = 0; - srv_config.demo = true; - - let (srv_driver, addr_list, _) = - tx5_signal_srv::exec_tx5_signal_srv(srv_config).unwrap(); - tokio::task::spawn(srv_driver); - - let sig_port = addr_list.get(0).unwrap().port(); - - let sig_url = Tx5Url::new(format!("ws://localhost:{}", sig_port)).unwrap(); - println!("sig_url: {}", sig_url); - - let conf = DefConfig::default() - .with_max_conn_init(std::time::Duration::from_secs(8)); - - let (ep1, mut ep_rcv1) = Ep::with_config(conf).await.unwrap(); - - let cli_url1 = ep1.listen(sig_url.clone()).await.unwrap(); - - println!("cli_url1: {}", cli_url1); - - let (ep2, mut ep_rcv2) = Ep::new().await.unwrap(); - - let cli_url2 = ep2.listen(sig_url).await.unwrap(); - - println!("cli_url2: {}", cli_url2); - - ep1.send(cli_url2.clone(), &b"hello"[..]).await.unwrap(); - - match ep_rcv1.recv().await { - Some(Ok(EpEvt::Connected { .. })) => (), - oth => panic!("unexpected: {:?}", oth), - } - - match ep_rcv2.recv().await { - Some(Ok(EpEvt::Connected { .. })) => (), - oth => panic!("unexpected: {:?}", oth), - } - - let recv = ep_rcv2.recv().await; - - match recv { - Some(Ok(EpEvt::Data { - rem_cli_url, data, .. - })) => { - assert_eq!(cli_url1, rem_cli_url); - assert_eq!(b"hello", data.to_vec().unwrap().as_slice()); - } - oth => panic!("unexpected {:?}", oth), - } - - ep2.ban(cli_url1.id().unwrap(), std::time::Duration::from_secs(43)); - - match ep_rcv1.recv().await { - Some(Ok(EpEvt::Disconnected { .. })) => (), - oth => panic!("unexpected: {:?}", oth), - } - - assert!(ep1.send(cli_url2, &b"hello"[..]).await.is_err()); -} - -#[tokio::test(flavor = "multi_thread")] -async fn connect_timeout() { - init_tracing(); - - let mut srv_config = tx5_signal_srv::Config::default(); - srv_config.port = 0; - srv_config.demo = true; - - let (srv_driver, addr_list, _) = - tx5_signal_srv::exec_tx5_signal_srv(srv_config).unwrap(); - tokio::task::spawn(srv_driver); - - let sig_port = addr_list.get(0).unwrap().port(); - - let sig_url = Tx5Url::new(format!("ws://localhost:{}", sig_port)).unwrap(); - println!("sig_url: {}", sig_url); - - let conf = DefConfig::default() - .with_max_conn_init(std::time::Duration::from_secs(8)); - - let (ep1, _ep_rcv1) = Ep::with_config(conf).await.unwrap(); - - ep1.listen(sig_url.clone()).await.unwrap(); - - let cli_url_fake = sig_url.to_client([0xdb; 32].into()); - - let start = std::time::Instant::now(); - - assert!(ep1.send(cli_url_fake, &b"hello"[..]).await.is_err()); - - assert!( - start.elapsed().as_secs_f64() < 10.0, - "expected timeout in 8 seconds, timed out in {} seconds", - start.elapsed().as_secs_f64() - ); -} - -#[tokio::test(flavor = "multi_thread")] -async fn preflight_small() { - init_tracing(); - - let mut srv_config = tx5_signal_srv::Config::default(); - srv_config.port = 0; - srv_config.demo = true; - - let (srv_driver, addr_list, _) = - tx5_signal_srv::exec_tx5_signal_srv(srv_config).unwrap(); - tokio::task::spawn(srv_driver); - - let sig_port = addr_list.get(0).unwrap().port(); - - let sig_url = Tx5Url::new(format!("ws://localhost:{}", sig_port)).unwrap(); - println!("sig_url: {}", sig_url); - - let valid_count = Arc::new(std::sync::atomic::AtomicUsize::new(0)); - const SMALL_DATA: &[u8] = &[42; 16]; - - fn make_config( - valid_count: Arc, - ) -> DefConfig { - DefConfig::default() - .with_conn_preflight(|_, _| { - println!("PREFLIGHT"); - Box::pin(async move { - Ok(Some(bytes::Bytes::from_static(SMALL_DATA))) - }) - }) - .with_conn_validate(move |_, _, data| { - println!("VALIDATE"); - assert_eq!(SMALL_DATA, data.unwrap()); - valid_count.fetch_add(1, std::sync::atomic::Ordering::SeqCst); - Box::pin(async move { Ok(()) }) - }) - } - - let (ep1, mut ep_rcv1) = Ep::with_config(make_config(valid_count.clone())) - .await - .unwrap(); - let cli_url1 = ep1.listen(sig_url.clone()).await.unwrap(); - println!("cli_url1: {}", cli_url1); - - let (ep2, mut ep_rcv2) = Ep::with_config(make_config(valid_count.clone())) - .await - .unwrap(); - let cli_url2 = ep2.listen(sig_url).await.unwrap(); - println!("cli_url2: {}", cli_url2); - - ep1.send(cli_url2, &b"hello"[..]).await.unwrap(); - - match ep_rcv1.recv().await { - Some(Ok(EpEvt::Connected { .. })) => (), - Some(Ok(EpEvt::Disconnected { .. })) => (), - oth => panic!("unexpected: {:?}", oth), - } - - match ep_rcv2.recv().await { - Some(Ok(EpEvt::Connected { .. })) => (), - Some(Ok(EpEvt::Disconnected { .. })) => (), - oth => panic!("unexpected: {:?}", oth), - } - - assert_eq!(2, valid_count.load(std::sync::atomic::Ordering::SeqCst)); -} - -#[tokio::test(flavor = "multi_thread")] -async fn preflight_huge() { - init_tracing(); - - const HUGE_DATA: bytes::Bytes = - bytes::Bytes::from_static(&[42; 16 * 1024 * 512]); - - let mut srv_config = tx5_signal_srv::Config::default(); - srv_config.port = 0; - srv_config.demo = true; - - let (srv_driver, addr_list, _) = - tx5_signal_srv::exec_tx5_signal_srv(srv_config).unwrap(); - tokio::task::spawn(srv_driver); - - let sig_port = addr_list.get(0).unwrap().port(); - - let sig_url = Tx5Url::new(format!("ws://localhost:{}", sig_port)).unwrap(); - println!("sig_url: {}", sig_url); - - let valid_count = Arc::new(std::sync::atomic::AtomicUsize::new(0)); - - fn make_config( - ep_num: usize, - valid_count: Arc, - ) -> DefConfig { - DefConfig::default() - .with_conn_preflight(move |_, _| { - println!("PREFLIGHT:{ep_num}"); - Box::pin(async move { Ok(Some(HUGE_DATA.clone())) }) - }) - .with_conn_validate(move |_, _, data| { - println!("VALIDATE:{ep_num}"); - assert_eq!(HUGE_DATA, data.unwrap()); - valid_count.fetch_add(1, std::sync::atomic::Ordering::SeqCst); - Box::pin(async move { Ok(()) }) - }) - } - - let (ep1, mut ep_rcv1) = - Ep::with_config(make_config(1, valid_count.clone())) - .await - .unwrap(); - let cli_url1 = ep1.listen(sig_url.clone()).await.unwrap(); - println!("cli_url1: {}", cli_url1); - - let (ep2, mut ep_rcv2) = - Ep::with_config(make_config(2, valid_count.clone())) - .await - .unwrap(); - let cli_url2 = ep2.listen(sig_url).await.unwrap(); - println!("cli_url2: {}", cli_url2); - - let (s1, r1) = tokio::sync::oneshot::channel(); - let mut s1 = Some(s1); - tokio::task::spawn(async move { - loop { - match ep_rcv1.recv().await { - Some(Ok(EpEvt::Connected { .. })) => { - if let Some(s1) = s1.take() { - println!("ONE connected"); - let _ = s1.send(()); - } - } - Some(Ok(EpEvt::Disconnected { .. })) => (), - oth => panic!("unexpected: {:?}", oth), - } - } - }); - - let (s2, r2) = tokio::sync::oneshot::channel(); - let mut s2 = Some(s2); - tokio::task::spawn(async move { - let mut done_count = 0; - let done_count = &mut done_count; - let mut check = move || { - *done_count += 1; - println!("TWO recv: check_count: {done_count}"); - if *done_count == 2 { - if let Some(s2) = s2.take() { - let _ = s2.send(()); - } - } - }; - - loop { - let _ = ep_rcv2.recv().await; - check(); - } - }); - - ep1.send(cli_url2, &b"hello"[..]).await.unwrap(); - - r1.await.unwrap(); - r2.await.unwrap(); - - assert_eq!(2, valid_count.load(std::sync::atomic::Ordering::SeqCst)); -} - -#[tokio::test(flavor = "multi_thread")] -async fn ban() { - init_tracing(); - - let mut srv_config = tx5_signal_srv::Config::default(); - srv_config.port = 0; - srv_config.demo = true; - - let (srv_driver, addr_list, _) = - tx5_signal_srv::exec_tx5_signal_srv(srv_config).unwrap(); - tokio::task::spawn(srv_driver); - - let sig_port = addr_list.get(0).unwrap().port(); - - let sig_url = Tx5Url::new(format!("ws://localhost:{}", sig_port)).unwrap(); - println!("sig_url: {}", sig_url); - - let (ep1, _ep_rcv1) = Ep::new().await.unwrap(); - - let cli_url1 = ep1.listen(sig_url.clone()).await.unwrap(); - - println!("cli_url1: {}", cli_url1); - - let (ep2, _ep_rcv2) = Ep::new().await.unwrap(); - - let cli_url2 = ep2.listen(sig_url).await.unwrap(); - - println!("cli_url2: {}", cli_url2); - - let msg_sent = Arc::new(std::sync::atomic::AtomicBool::new(false)); - let msg_sent2 = msg_sent.clone(); - - // *Send* the message, but it shouldn't be received - ep2.ban(cli_url1.id().unwrap(), std::time::Duration::from_secs(500)); - let fut = ep1.send(cli_url2.clone(), &b"hello"[..]); - let task = tokio::task::spawn(async move { - if fut.await.is_err() { - // it's okay if this errors, that was the point - return; - } - // the future resolved successfully, that's bad, the ban didn't work. - msg_sent2.store(true, std::sync::atomic::Ordering::SeqCst); - }); - - // allow some time for it to be sent - tokio::time::sleep(std::time::Duration::from_secs(5)).await; - - // Now try banning the *sending* side. Should get an error sending. - ep1.ban(cli_url2.id().unwrap(), std::time::Duration::from_secs(500)); - assert!(ep1.send(cli_url2, &b"hello"[..]).await.is_err()); - - // Allow some additional time for the first send to connect / etc - tokio::time::sleep(std::time::Duration::from_secs(5)).await; - - // if the message sent, our ban didn't work - if msg_sent.load(std::sync::atomic::Ordering::SeqCst) { - panic!("message wast sent! ban failed"); - } - - task.abort(); -} - -#[tokio::test(flavor = "multi_thread")] -async fn large_messages() { - init_tracing(); - - use rand::Rng; - let mut rng = rand::thread_rng(); - let mut msg_1 = vec![0; 1024 * 1024 * 16]; - rng.fill(&mut msg_1[..]); - let mut msg_2 = vec![0; 1024 * 58]; - rng.fill(&mut msg_2[..]); - let msg_1_r = msg_1.clone(); - let msg_2_r = msg_2.clone(); - - let mut srv_config = tx5_signal_srv::Config::default(); - srv_config.port = 0; - srv_config.demo = true; - - let (srv_driver, addr_list, _) = - tx5_signal_srv::exec_tx5_signal_srv(srv_config).unwrap(); - tokio::task::spawn(srv_driver); - - let sig_port = addr_list.get(0).unwrap().port(); - - let sig_url = Tx5Url::new(format!("ws://localhost:{}", sig_port)).unwrap(); - println!("sig_url: {}", sig_url); - - let (ep1, _ep_rcv1) = Ep::new().await.unwrap(); - - let cli_url1 = ep1.listen(sig_url.clone()).await.unwrap(); - - println!("cli_url1: {}", cli_url1); - - let (ep2, mut ep_rcv2) = Ep::new().await.unwrap(); - - let cli_url2 = ep2.listen(sig_url).await.unwrap(); - - println!("cli_url2: {}", cli_url2); - - let recv_task = { - tokio::task::spawn(async move { - match ep_rcv2.recv().await { - Some(Ok(EpEvt::Connected { .. })) => (), - oth => panic!("unexpected: {:?}", oth), - } - - for _ in 0..2 { - let recv = ep_rcv2.recv().await; - match recv { - Some(Ok(EpEvt::Data { - rem_cli_url, data, .. - })) => { - assert_eq!(cli_url1, rem_cli_url); - let msg = data.to_vec().unwrap(); - if msg.len() == msg_1_r.len() { - assert_eq!(msg, msg_1_r); - } else if msg.len() == msg_2_r.len() { - assert_eq!(msg, msg_2_r); - } else { - panic!("unexpected"); - } - } - oth => panic!("unexpected {:?}", oth), - } - } - }) - }; - - let f1 = ep1.send(cli_url2.clone(), msg_1.as_slice()); - let f2 = ep1.send(cli_url2, msg_2.as_slice()); - - tokio::try_join!(f1, f2).unwrap(); - - recv_task.await.unwrap(); -} - -#[tokio::test(flavor = "multi_thread")] -async fn broadcast() { - init_tracing(); - - let mut srv_config = tx5_signal_srv::Config::default(); - srv_config.port = 0; - srv_config.demo = true; - - let (srv_driver, addr_list, _) = - tx5_signal_srv::exec_tx5_signal_srv(srv_config).unwrap(); - tokio::task::spawn(srv_driver); - - let sig_port = addr_list.get(0).unwrap().port(); - - let sig_url = Tx5Url::new(format!("ws://localhost:{}", sig_port)).unwrap(); - println!("sig_url: {}", sig_url); - - let (ep1, mut ep_rcv1) = Ep::new().await.unwrap(); - let cli_url1 = ep1.listen(sig_url.clone()).await.unwrap(); - println!("cli_url1: {}", cli_url1); - - let (ep2, mut ep_rcv2) = Ep::new().await.unwrap(); - let cli_url2 = ep2.listen(sig_url.clone()).await.unwrap(); - println!("cli_url2: {}", cli_url2); - - let (ep3, mut ep_rcv3) = Ep::new().await.unwrap(); - let cli_url3 = ep3.listen(sig_url).await.unwrap(); - println!("cli_url3: {}", cli_url3); - - ep1.send(cli_url2.clone(), &b"hello"[..]).await.unwrap(); - ep1.send(cli_url3, &b"hello"[..]).await.unwrap(); - - match ep_rcv1.recv().await { - Some(Ok(EpEvt::Connected { .. })) => (), - oth => panic!("unexpected: {:?}", oth), - } - - match ep_rcv1.recv().await { - Some(Ok(EpEvt::Connected { .. })) => (), - oth => panic!("unexpected: {:?}", oth), - } - - match ep_rcv2.recv().await { - Some(Ok(EpEvt::Connected { .. })) => (), - oth => panic!("unexpected: {:?}", oth), - } - - let recv = ep_rcv2.recv().await; - - match recv { - Some(Ok(EpEvt::Data { - rem_cli_url, data, .. - })) => { - assert_eq!(cli_url1, rem_cli_url); - assert_eq!(b"hello", data.to_vec().unwrap().as_slice()); - } - oth => panic!("unexpected {:?}", oth), - } - - match ep_rcv3.recv().await { - Some(Ok(EpEvt::Connected { .. })) => (), - oth => panic!("unexpected: {:?}", oth), - } - - let recv = ep_rcv3.recv().await; - - match recv { - Some(Ok(EpEvt::Data { - rem_cli_url, data, .. - })) => { - assert_eq!(cli_url1, rem_cli_url); - assert_eq!(b"hello", data.to_vec().unwrap().as_slice()); - } - oth => panic!("unexpected {:?}", oth), - } - - ep1.broadcast(&b"bcast"[..]).await.unwrap(); - - let recv = ep_rcv2.recv().await; - - match recv { - Some(Ok(EpEvt::Data { - rem_cli_url, data, .. - })) => { - assert_eq!(cli_url1, rem_cli_url); - assert_eq!(b"bcast", data.to_vec().unwrap().as_slice()); - } - oth => panic!("unexpected {:?}", oth), - } - - let recv = ep_rcv3.recv().await; - - match recv { - Some(Ok(EpEvt::Data { - rem_cli_url, data, .. - })) => { - assert_eq!(cli_url1, rem_cli_url); - assert_eq!(b"bcast", data.to_vec().unwrap().as_slice()); - } - oth => panic!("unexpected {:?}", oth), - } - - ep2.broadcast(&b"bcast2"[..]).await.unwrap(); - - let recv = ep_rcv1.recv().await; - - match recv { - Some(Ok(EpEvt::Data { - rem_cli_url, data, .. - })) => { - assert_eq!(cli_url2, rem_cli_url); - assert_eq!(b"bcast2", data.to_vec().unwrap().as_slice()); - } - oth => panic!("unexpected {:?}", oth), - } -} diff --git a/crates/tx5/src/test_behavior.rs b/crates/tx5/src/test_behavior.rs new file mode 100644 index 00000000..048e2995 --- /dev/null +++ b/crates/tx5/src/test_behavior.rs @@ -0,0 +1,389 @@ +use crate::*; +use std::sync::{Arc, Mutex}; + +#[tokio::test(flavor = "multi_thread")] +async fn behavior_20_sec() { + run_behavior(std::time::Duration::from_secs(20)).await; +} + +#[tokio::test(flavor = "multi_thread")] +#[ignore = "this is a long-running test, `cargo test behavior -- --ignored`"] +async fn behavior_5_min() { + run_behavior(std::time::Duration::from_secs(60 * 5)).await; +} + +struct Share { + pub sig_url: SigUrl, + pub config: Arc, + pub tasks: Arc>>>, + pub errors: Arc>>, +} + +impl Drop for Share { + fn drop(&mut self) { + eprintln!("BEG SHARE DROP"); + for task in self.tasks.lock().unwrap().drain(..) { + task.abort(); + } + eprintln!("END SHARE DROP"); + } +} + +async fn run_behavior(dur: std::time::Duration) { + let subscriber = tracing_subscriber::FmtSubscriber::builder() + .with_env_filter( + tracing_subscriber::filter::EnvFilter::from_default_env(), + ) + .with_file(true) + .with_line_number(true) + .finish(); + + let _ = tracing::subscriber::set_global_default(subscriber); + + let mut srv_config = tx5_signal_srv::Config::default(); + srv_config.port = 0; + + let (_sig_srv_hnd, addr_list, _) = + tx5_signal_srv::exec_tx5_signal_srv(srv_config) + .await + .unwrap(); + + let sig_port = addr_list.get(0).unwrap().port(); + + let sig_url = SigUrl::new(format!("ws://localhost:{}", sig_port)).unwrap(); + + tracing::info!(%sig_url); + + let mut config = Config3::default(); + config.timeout = std::time::Duration::from_secs(10); + + let share = Arc::new(Share { + sig_url, + config: Arc::new(config), + tasks: Arc::new(Mutex::new(Vec::new())), + errors: Arc::new(Mutex::new(Vec::new())), + }); + + let peer_echo = run_echo(&share).await; + + run_small_msg(share.clone(), peer_echo.clone()); + run_large_msg(share.clone(), peer_echo.clone()); + run_large_msg(share.clone(), peer_echo.clone()); + run_ban(share.clone(), peer_echo.clone()); + run_self_ban(share.clone(), peer_echo.clone()); + run_dropout(share.clone(), peer_echo); + + tokio::time::sleep(dur).await; + + { + let lock = share.errors.lock().unwrap(); + if lock.len() > 0 { + panic!("{:#?}", *lock); + } + } + + drop(share); + eprintln!("TEST DROP"); +} + +macro_rules! track_err { + ($e:ident, $t:literal, $b:block) => { + if let Err(err) = $b { + let err = + Error::str(format!("{}:{}:{}: {err:?}", file!(), line!(), $t,)) + .into(); + eprintln!("{err:?}"); + $e.lock().unwrap().push(err); + } + }; +} + +async fn run_echo(share: &Share) -> PeerUrl { + let (ep, mut recv) = Ep3::new(share.config.clone()).await; + let ep = Arc::new(ep); + let url = ep.listen(share.sig_url.clone()).await.unwrap(); + let errors = share.errors.clone(); + + share + .tasks + .lock() + .unwrap() + .push(tokio::task::spawn(async move { + while let Some(evt) = recv.recv().await { + if let Ep3Event::Message { + peer_url, message, .. + } = evt + { + if message == b"ban" { + ep.ban( + peer_url.id().unwrap(), + std::time::Duration::from_secs(10), + ); + } else { + track_err!(errors, "echo response", { + ep.send(peer_url.clone(), &message).await + }); + } + } + } + })); + + url +} + +fn run_small_msg(share: Arc, peer_echo: PeerUrl) { + share + .clone() + .tasks + .lock() + .unwrap() + .push(tokio::task::spawn(async move { + let (ep, mut recv) = Ep3::new(share.config.clone()).await; + ep.listen(share.sig_url.clone()).await.unwrap(); + let errors = share.errors.clone(); + drop(share); + + let mut msg_id: usize = 0; + loop { + tokio::time::sleep(std::time::Duration::from_secs(1)).await; + + msg_id += 1; + let msg = format!("msg:{msg_id}"); + + track_err!(errors, "small msg send", { + ep.send(peer_echo.clone(), msg.as_bytes()).await + }); + + loop { + match recv.recv().await.unwrap() { + Ep3Event::Connected { .. } + | Ep3Event::Disconnected { .. } => (), + Ep3Event::Message { + peer_url, message, .. + } => { + if peer_url == peer_echo { + let m = String::from_utf8_lossy(&message); + let mut p = m.split("msg:"); + p.next().unwrap(); + assert_eq!( + msg_id, + p.next().unwrap().parse::().unwrap() + ); + eprintln!("small_msg success"); + break; + } + } + oth => errors.lock().unwrap().push( + Error::str(format!("unexpected: {oth:?}",)).into(), + ), + } + } + } + })); +} + +fn run_large_msg(share: Arc, peer_echo: PeerUrl) { + share + .clone() + .tasks + .lock() + .unwrap() + .push(tokio::task::spawn(async move { + let (ep, mut recv) = Ep3::new(share.config.clone()).await; + ep.listen(share.sig_url.clone()).await.unwrap(); + let errors = share.errors.clone(); + drop(share); + + let mut full = vec![0; 15 * 1024 * 1024]; + + loop { + tokio::time::sleep(std::time::Duration::from_secs(5)).await; + + use rand::Rng; + rand::thread_rng().fill(&mut full[..]); + + track_err!(errors, "large msg send", { + ep.send(peer_echo.clone(), &full).await + }); + + loop { + match recv.recv().await.unwrap() { + Ep3Event::Connected { .. } + | Ep3Event::Disconnected { .. } => (), + Ep3Event::Message { + peer_url, message, .. + } => { + if peer_url == peer_echo { + assert_eq!(full, message); + eprintln!("large_msg success"); + break; + } + } + oth => errors.lock().unwrap().push( + Error::str(format!("unexpected: {oth:?}",)).into(), + ), + } + } + } + })); +} + +fn run_ban(share: Arc, peer_echo: PeerUrl) { + share + .clone() + .tasks + .lock() + .unwrap() + .push(tokio::task::spawn(async move { + let (ep, _recv) = Ep3::new(share.config.clone()).await; + ep.listen(share.sig_url.clone()).await.unwrap(); + let errors = share.errors.clone(); + drop(share); + + loop { + tokio::time::sleep(std::time::Duration::from_secs(1)).await; + + eprintln!("ban req"); + + if ep.send(peer_echo.clone(), b"ban").await.is_err() { + // we might get an error when the connection realizes + // it was shut down... just try again + tokio::time::sleep(std::time::Duration::from_secs(1)).await; + + track_err!(errors, "ban req send", { + ep.send(peer_echo.clone(), b"ban").await + }); + } + + let start = std::time::Instant::now(); + + tokio::time::sleep(std::time::Duration::from_secs(5)).await; + + eprintln!("ban check"); + if ep.send(peer_echo.clone(), b"hello").await.is_ok() { + let e = Error::id("WasNotBanned").into(); + eprintln!("{e:?}"); + errors.lock().unwrap().push(e); + } else { + eprintln!("ban success"); + } + + let sleep_for = std::time::Duration::from_secs(11) + .saturating_sub(start.elapsed()); + + tokio::time::sleep(sleep_for).await; + } + })); +} + +fn run_self_ban(share: Arc, peer_echo: PeerUrl) { + share + .clone() + .tasks + .lock() + .unwrap() + .push(tokio::task::spawn(async move { + let (ep, mut recv) = Ep3::new(share.config.clone()).await; + ep.listen(share.sig_url.clone()).await.unwrap(); + let errors = share.errors.clone(); + drop(share); + + loop { + tokio::time::sleep(std::time::Duration::from_secs(1)).await; + + track_err!(errors, "self_ban check send", { + ep.send(peer_echo.clone(), b"hello").await + }); + + loop { + match recv.recv().await.unwrap() { + Ep3Event::Connected { .. } + | Ep3Event::Disconnected { .. } => (), + Ep3Event::Message { + peer_url, message, .. + } => { + if peer_url == peer_echo { + assert_eq!(b"hello", message.as_slice()); + break; + } + } + oth => errors.lock().unwrap().push( + Error::str(format!("unexpected: {oth:?}",)).into(), + ), + } + } + + ep.ban( + peer_echo.id().unwrap(), + std::time::Duration::from_secs(11), + ); + + tokio::time::sleep(std::time::Duration::from_secs(1)).await; + + let start = std::time::Instant::now(); + + eprintln!("self_ban check"); + if ep.send(peer_echo.clone(), b"hello").await.is_ok() { + let e = Error::id("WasNotBanned").into(); + eprintln!("{e:?}"); + errors.lock().unwrap().push(e); + } else { + eprintln!("self_ban success"); + } + + let sleep_for = std::time::Duration::from_secs(10) + .saturating_sub(start.elapsed()); + + tokio::time::sleep(sleep_for).await; + } + })); +} + +fn run_dropout(share: Arc, peer_echo: PeerUrl) { + share + .clone() + .tasks + .lock() + .unwrap() + .push(tokio::task::spawn(async move { + let sig_url = share.sig_url.clone(); + let config = share.config.clone(); + let errors = share.errors.clone(); + drop(share); + + loop { + let (ep, mut recv) = Ep3::new(config.clone()).await; + ep.listen(sig_url.clone()).await.unwrap(); + + tokio::time::sleep(std::time::Duration::from_secs(1)).await; + + track_err!(errors, "dropout check send", { + ep.send(peer_echo.clone(), b"hello").await + }); + + loop { + match recv.recv().await.unwrap() { + Ep3Event::Connected { .. } + | Ep3Event::Disconnected { .. } => (), + Ep3Event::Message { + peer_url, message, .. + } => { + if peer_url == peer_echo { + assert_eq!(b"hello", message.as_slice()); + eprintln!("dropout success"); + break; + } + } + oth => errors.lock().unwrap().push( + Error::str(format!("unexpected: {oth:?}",)).into(), + ), + } + } + + drop(ep); + drop(recv); + + tokio::time::sleep(std::time::Duration::from_secs(4)).await; + } + })); +}