diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 44efe0993..6d54643b0 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -74,7 +74,7 @@ jobs: strategy: fail-fast: false matrix: - msrv: ["1.70.0"] + msrv: ["1.74.0"] os: ["ubuntu", "macOS", "windows"] runs-on: ${{ matrix.os }}-latest steps: diff --git a/CHANGELOG.md b/CHANGELOG.md index f71fea8c2..b03662230 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,7 +11,7 @@ All user visible changes to this project will be documented in this file. This p ### Initially re-implemented -- Performed major refactoring with non-server code removing. ([#1]) +- Performed major refactoring with non-server code removing. ([#1], [#2]) - Added TCP transport. ([#1]) ### [Upstream changes](https://github.com/webrtc-rs/webrtc/blob/89285ceba23dc57fc99386cb978d2d23fe909437/turn/CHANGELOG.md#unreleased) @@ -21,6 +21,7 @@ All user visible changes to this project will be documented in this file. This p [@clia]: https://github.com/clia [#1]: /../../pull/1 +[#2]: /../../pull/2 [webrtc-rs/webrtc#330]: https://github.com/webrtc-rs/webrtc/pull/330 [webrtc-rs/webrtc#421]: https://github.com/webrtc-rs/webrtc/pull/421 diff --git a/Cargo.toml b/Cargo.toml index 0e1f97298..c64918896 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,8 +3,8 @@ name = "medea-turn" version = "0.7.0-dev" authors = ["Instrumentisto Team "] edition = "2021" -rust-version = "1.70" -description = "TURN implementation used by Medea media server." +rust-version = "1.74" +description = "STUN/TURN server implementation used by Medea media server." license = "MIT OR Apache-2.0" homepage = "https://github.com/instrumentisto/medea-turn-rs" repository = "https://github.com/instrumentisto/medea-turn-rs" @@ -14,11 +14,11 @@ publish = false async-trait = "0.1" bytecodec = "0.4.15" bytes = "1.6" +derive_more = { version = "1.0.0-beta.6", features = ["debug", "display", "error", "from"] } futures = "0.3" log = "0.4" rand = "0.8" stun_codec = "0.3" -thiserror = "1.0" tokio = { version = "1.32", default-features = false, features = ["io-util", "macros", "net", "rt-multi-thread", "time"] } tokio-util = { version = "0.7", features = ["codec"] } diff --git a/README.md b/README.md index 0921b79dc..75a51dce0 100644 --- a/README.md +++ b/README.md @@ -2,12 +2,13 @@ ============ [![CI](https://github.com/instrumentisto/medea-turn-rs/workflows/CI/badge.svg?branch=main "CI")](https://github.com/instrumentisto/medea-turn-rs/actions?query=workflow%3ACI+branch%3Amain) -[![Rust 1.70+](https://img.shields.io/badge/rustc-1.70+-lightgray.svg "Rust 1.70+")](https://blog.rust-lang.org/2023/06/01/Rust-1.70.0.html) - +[![Rust 1.74+](https://img.shields.io/badge/rustc-1.74+-lightgray.svg "Rust 1.74+")](https://blog.rust-lang.org/2023/11/16/Rust-1.74.0.html) [Changelog](https://github.com/instrumentisto/medea-turn-rs/blob/master/CHANGELOG.md) -TURN implementation used by [Medea media server](https://github.com/instrumentisto/medea). Majorly refactored fork of the [`webrtc-rs/turn` crate](https://github.com/webrtc-rs/webrtc/tree/89285ceba23dc57fc99386cb978d2d23fe909437/turn). +[STUN]/[TURN] implementation used by [Medea media server](https://github.com/instrumentisto/medea). Majorly refactored fork of the [`webrtc-rs/turn` crate](https://github.com/webrtc-rs/webrtc/tree/89285ceba23dc57fc99386cb978d2d23fe909437/turn). + +Hard fork of [`webrtc-rs/turn` crate](https://docs.rs/turn). @@ -25,3 +26,5 @@ Unless you explicitly state otherwise, any contribution intentionally submitted [APACHE]: https://github.com/instrumentisto/medea-turn-rs/blob/main/LICENSE-APACHE [MIT]: https://github.com/instrumentisto/medea-turn-rs/blob/main/LICENSE-MIT +[STUN]: https://en.wikipedia.org/wiki/STUN +[TURN]: https://en.wikipedia.org/wiki/TURN diff --git a/src/allocation/allocation_manager.rs b/src/allocation/allocation_manager.rs deleted file mode 100644 index 987e347b5..000000000 --- a/src/allocation/allocation_manager.rs +++ /dev/null @@ -1,625 +0,0 @@ -//! [Allocation]s storage. -//! -//! [Allocation]: https://datatracker.ietf.org/doc/html/rfc5766#section-5 - -use std::{ - collections::HashMap, - mem, - sync::{atomic::Ordering, Arc, Mutex as SyncMutex}, - time::Duration, -}; - -use futures::future; -use tokio::{ - sync::{mpsc, Mutex}, - time::sleep, -}; - -use crate::{ - allocation::{Allocation, AllocationMap}, - attr::Username, - con::Conn, - relay::RelayAllocator, - AllocInfo, Error, FiveTuple, -}; - -/// `ManagerConfig` a bag of config params for [`Manager`]. -pub(crate) struct ManagerConfig { - /// Relay connections allocator. - pub(crate) relay_addr_generator: RelayAllocator, - - /// Injected into allocations to notify when allocation is closed. - pub(crate) alloc_close_notify: Option>, -} - -/// [`Manager`] is used to hold active allocations. -pub(crate) struct Manager { - /// [`Allocation`]s storage. - allocations: AllocationMap, - - /// [Reservation][1]s storage. - /// - /// [1]: https://datatracker.ietf.org/doc/html/rfc5766#section-14.9 - reservations: Arc>>, - - /// Relay connections allocator. - relay_allocator: RelayAllocator, - - /// Injected into allocations to notify when allocation is closed. - alloc_close_notify: Option>, -} - -impl Manager { - /// Creates a new [`Manager`]. - pub(crate) fn new(config: ManagerConfig) -> Self { - Self { - allocations: Arc::new(SyncMutex::new(HashMap::new())), - reservations: Arc::new(Mutex::new(HashMap::new())), - relay_allocator: config.relay_addr_generator, - alloc_close_notify: config.alloc_close_notify, - } - } - - /// Returns the information about the all [`Allocation`]s associated with - /// the specified [`FiveTuple`]s. - pub(crate) fn get_allocations_info( - &self, - five_tuples: &Option>, - ) -> HashMap { - let mut infos = HashMap::new(); - - #[allow( - clippy::unwrap_used, - clippy::iter_over_hash_type, - clippy::significant_drop_in_scrutinee - )] - for (five_tuple, alloc) in self.allocations.lock().unwrap().iter() { - #[allow(clippy::unwrap_used)] - if five_tuples.is_none() - || five_tuples.as_ref().unwrap().contains(five_tuple) - { - drop(infos.insert( - *five_tuple, - AllocInfo::new( - *five_tuple, - alloc.username.name().to_owned(), - alloc.relayed_bytes.load(Ordering::Acquire), - ), - )); - } - } - - infos - } - - /// Fetches the [`Allocation`] matching the passed [`FiveTuple`]. - pub(crate) fn has_alloc(&self, five_tuple: &FiveTuple) -> bool { - #[allow(clippy::unwrap_used)] - self.allocations.lock().unwrap().get(five_tuple).is_some() - } - - /// Fetches the [`Allocation`] matching the passed [`FiveTuple`]. - #[allow(clippy::unwrap_in_result)] - pub(crate) fn get_alloc( - &self, - five_tuple: &FiveTuple, - ) -> Option> { - #[allow(clippy::unwrap_used)] - self.allocations.lock().unwrap().get(five_tuple).cloned() - } - - /// Creates a new [`Allocation`] and starts relaying. - #[allow(clippy::too_many_arguments)] - pub(crate) async fn create_allocation( - &self, - five_tuple: FiveTuple, - turn_socket: Arc, - requested_port: u16, - lifetime: Duration, - username: Username, - use_ipv4: bool, - ) -> Result, Error> { - if lifetime == Duration::from_secs(0) { - return Err(Error::LifetimeZero); - } - - if self.has_alloc(&five_tuple) { - return Err(Error::DupeFiveTuple); - } - - let (relay_socket, relay_addr) = self - .relay_allocator - .allocate_conn(use_ipv4, requested_port) - .await?; - let mut a = Allocation::new( - turn_socket, - relay_socket, - relay_addr, - five_tuple, - username, - self.alloc_close_notify.clone(), - ); - a.allocations = Some(Arc::clone(&self.allocations)); - - log::trace!("listening on relay addr: {:?}", a.relay_addr); - a.start(lifetime); - a.packet_handler(); - - let a = Arc::new(a); - #[allow(clippy::unwrap_used)] - drop( - self.allocations.lock().unwrap().insert(five_tuple, Arc::clone(&a)), - ); - - Ok(a) - } - - /// Removes an [`Allocation`]. - pub(crate) async fn delete_allocation(&self, five_tuple: &FiveTuple) { - #[allow(clippy::unwrap_used)] - let allocation = self.allocations.lock().unwrap().remove(five_tuple); - - if let Some(a) = allocation { - if let Err(err) = a.close().await { - log::error!("Failed to close allocation: {}", err); - } - } - } - - /// Deletes the [`Allocation`]s according to the specified username `name`. - pub(crate) async fn delete_allocations_by_username(&self, name: &str) { - let to_delete = { - #[allow(clippy::unwrap_used)] - let mut allocations = self.allocations.lock().unwrap(); - - let mut to_delete = Vec::new(); - - // TODO(logist322): Use `.drain_filter()` once stabilized. - allocations.retain(|_, allocation| { - let match_name = allocation.username.name() == name; - - if match_name { - to_delete.push(Arc::clone(allocation)); - } - - !match_name - }); - - to_delete - }; - - drop( - future::join_all(to_delete.iter().map(|a| async move { - if let Err(err) = a.close().await { - log::error!("Failed to close allocation: {}", err); - } - })) - .await, - ); - } - - /// Stores the reservation for the token+port. - pub(crate) async fn create_reservation(&self, token: u64, port: u16) { - let reservations = Arc::clone(&self.reservations); - - drop(tokio::spawn(async move { - let liftime = sleep(Duration::from_secs(30)); - tokio::pin!(liftime); - - tokio::select! { - () = &mut liftime => { - _ = reservations.lock().await.remove(&token); - }, - } - })); - - _ = self.reservations.lock().await.insert(token, port); - } - - /// Returns a random un-allocated udp4 port. - pub(crate) async fn get_random_even_port(&self) -> Result { - let (_, addr) = self.relay_allocator.allocate_conn(true, 0).await?; - Ok(addr.port()) - } - - /// Closes this [`Manager`] and closes all [`Allocation`]s it manages. - pub(crate) async fn close(&self) -> Result<(), Error> { - #[allow(clippy::unwrap_used)] - let allocations = mem::take(&mut *self.allocations.lock().unwrap()); - - #[allow(clippy::iter_over_hash_type)] - for a in allocations.values() { - a.close().await?; - } - - Ok(()) - } -} - -#[cfg(test)] -mod allocation_manager_test { - use bytecodec::DecodeExt; - use rand::random; - use std::{ - net::{IpAddr, Ipv4Addr, SocketAddr}, - str::FromStr, - }; - use stun_codec::MessageDecoder; - use tokio::net::UdpSocket; - - use crate::{ - attr::{Attribute, ChannelNumber, Data}, - chandata::ChannelData, - server::DEFAULT_LIFETIME, - }; - - use super::*; - - fn new_test_manager() -> Manager { - let config = ManagerConfig { - relay_addr_generator: RelayAllocator { - relay_address: IpAddr::from([127, 0, 0, 1]), - min_port: 49152, - max_port: 65535, - max_retries: 10, - address: String::from("127.0.0.1"), - }, - alloc_close_notify: None, - }; - Manager::new(config) - } - - fn random_five_tuple() -> FiveTuple { - FiveTuple { - src_addr: SocketAddr::new( - Ipv4Addr::new(0, 0, 0, 0).into(), - random(), - ), - dst_addr: SocketAddr::new( - Ipv4Addr::new(0, 0, 0, 0).into(), - random(), - ), - ..Default::default() - } - } - - #[tokio::test] - async fn test_packet_handler() { - // turn server initialization - let turn_socket = UdpSocket::bind("127.0.0.1:0").await.unwrap(); - - // client listener initialization - let client_listener = UdpSocket::bind("127.0.0.1:0").await.unwrap(); - let src_addr = client_listener.local_addr().unwrap(); - let (data_ch_tx, mut data_ch_rx) = mpsc::channel(1); - // client listener read data - tokio::spawn(async move { - let mut buffer = vec![0u8; 1500]; - loop { - let n = match client_listener.recv_from(&mut buffer).await { - Ok((n, _)) => n, - Err(_) => break, - }; - - let _ = data_ch_tx.send(buffer[..n].to_vec()).await; - } - }); - - let m = new_test_manager(); - let a = m - .create_allocation( - FiveTuple { - src_addr, - dst_addr: turn_socket.local_addr().unwrap(), - ..Default::default() - }, - Arc::new(turn_socket), - 0, - DEFAULT_LIFETIME, - Username::new(String::from("user")).unwrap(), - true, - ) - .await - .unwrap(); - - let peer_listener1 = UdpSocket::bind("127.0.0.1:0").await.unwrap(); - let peer_listener2 = UdpSocket::bind("127.0.0.1:0").await.unwrap(); - - let port = { - // add permission with peer1 address - a.add_permission(peer_listener1.local_addr().unwrap().ip()).await; - // add channel with min channel number and peer2 address - a.add_channel_bind( - ChannelNumber::MIN, - peer_listener2.local_addr().unwrap(), - DEFAULT_LIFETIME, - ) - .await - .unwrap(); - - a.relay_socket.local_addr().unwrap().port() - }; - - let relay_addr_with_host_str = format!("127.0.0.1:{port}"); - let relay_addr_with_host = - SocketAddr::from_str(&relay_addr_with_host_str).unwrap(); - - // test for permission and data message - let target_text = "permission"; - let _ = peer_listener1 - .send_to(target_text.as_bytes(), relay_addr_with_host) - .await - .unwrap(); - let data = data_ch_rx.recv().await.unwrap(); - - let msg = MessageDecoder::::new() - .decode_from_bytes(&data) - .unwrap() - .unwrap(); - - let msg_data = msg.get_attribute::().unwrap().data().to_vec(); - assert_eq!( - target_text.as_bytes(), - &msg_data, - "get message doesn't equal the target text" - ); - - // test for channel bind and channel data - let target_text2 = "channel bind"; - let _ = peer_listener2 - .send_to(target_text2.as_bytes(), relay_addr_with_host) - .await - .unwrap(); - let data = data_ch_rx.recv().await.unwrap(); - - // resolve channel data - assert!(ChannelData::is_channel_data(&data), "should be channel data"); - - let channel_data = ChannelData::decode(data).unwrap(); - assert_eq!( - ChannelNumber::MIN, - channel_data.num(), - "get channel data's number is invalid" - ); - assert_eq!( - target_text2.as_bytes(), - &channel_data.data(), - "get data doesn't equal the target text." - ); - - // listeners close - m.close().await.unwrap(); - } - - #[tokio::test] - async fn test_create_allocation_duplicate_five_tuple() { - // turn server initialization - let turn_socket: Arc = - Arc::new(UdpSocket::bind("0.0.0.0:0").await.unwrap()); - - let m = new_test_manager(); - - let five_tuple = random_five_tuple(); - - let _ = m - .create_allocation( - five_tuple, - Arc::clone(&turn_socket), - 0, - DEFAULT_LIFETIME, - Username::new(String::from("user")).unwrap(), - true, - ) - .await - .unwrap(); - - let result = m - .create_allocation( - five_tuple, - Arc::clone(&turn_socket), - 0, - DEFAULT_LIFETIME, - Username::new(String::from("user")).unwrap(), - true, - ) - .await; - assert!(result.is_err(), "expected error, but got ok"); - } - - #[tokio::test] - async fn test_delete_allocation() { - // turn server initialization - let turn_socket: Arc = - Arc::new(UdpSocket::bind("0.0.0.0:0").await.unwrap()); - - let m = new_test_manager(); - - let five_tuple = random_five_tuple(); - - let _ = m - .create_allocation( - five_tuple, - Arc::clone(&turn_socket), - 0, - DEFAULT_LIFETIME, - Username::new(String::from("user")).unwrap(), - true, - ) - .await - .unwrap(); - - assert!( - m.has_alloc(&five_tuple), - "Failed to get allocation right after creation" - ); - - m.delete_allocation(&five_tuple).await; - - assert!( - !m.has_alloc(&five_tuple), - "Get allocation with {five_tuple} should be nil after delete" - ); - } - - #[tokio::test] - async fn test_allocation_timeout() { - // turn server initialization - let turn_socket: Arc = - Arc::new(UdpSocket::bind("0.0.0.0:0").await.unwrap()); - - let m = new_test_manager(); - - let mut allocations = vec![]; - let lifetime = Duration::from_millis(100); - - for _ in 0..5 { - let five_tuple = random_five_tuple(); - - let a = m - .create_allocation( - five_tuple, - Arc::clone(&turn_socket), - 0, - lifetime, - Username::new(String::from("user")).unwrap(), - true, - ) - .await - .unwrap(); - - allocations.push(a); - } - - let mut count = 0; - - 'outer: loop { - count += 1; - - if count >= 10 { - panic!("Allocations didn't timeout"); - } - - sleep(lifetime + Duration::from_millis(100)).await; - - let any_outstanding = false; - - for a in &allocations { - if a.close().await.is_ok() { - continue 'outer; - } - } - - if !any_outstanding { - return; - } - } - } - - #[tokio::test] - async fn test_manager_close() { - // turn server initialization - let turn_socket: Arc = - Arc::new(UdpSocket::bind("0.0.0.0:0").await.unwrap()); - - let m = new_test_manager(); - - let mut allocations = vec![]; - - let a1 = m - .create_allocation( - random_five_tuple(), - Arc::clone(&turn_socket), - 0, - Duration::from_millis(100), - Username::new(String::from("user")).unwrap(), - true, - ) - .await - .unwrap(); - allocations.push(a1); - - let a2 = m - .create_allocation( - random_five_tuple(), - Arc::clone(&turn_socket), - 0, - Duration::from_millis(200), - Username::new(String::from("user")).unwrap(), - true, - ) - .await - .unwrap(); - allocations.push(a2); - - sleep(Duration::from_millis(150)).await; - - log::trace!("Mgr is going to be closed..."); - - m.close().await.unwrap(); - - for a in allocations { - assert!( - a.close().await.is_err(), - "Allocation should be closed if lifetime timeout" - ); - } - } - - #[tokio::test] - async fn test_delete_allocation_by_username() { - let turn_socket: Arc = - Arc::new(UdpSocket::bind("0.0.0.0:0").await.unwrap()); - - let m = new_test_manager(); - - let five_tuple1 = random_five_tuple(); - let five_tuple2 = random_five_tuple(); - let five_tuple3 = random_five_tuple(); - - let _ = m - .create_allocation( - five_tuple1, - Arc::clone(&turn_socket), - 0, - DEFAULT_LIFETIME, - Username::new(String::from("user")).unwrap(), - true, - ) - .await - .unwrap(); - let _ = m - .create_allocation( - five_tuple2, - Arc::clone(&turn_socket), - 0, - DEFAULT_LIFETIME, - Username::new(String::from("user")).unwrap(), - true, - ) - .await - .unwrap(); - let _ = m - .create_allocation( - five_tuple3, - Arc::clone(&turn_socket), - 0, - DEFAULT_LIFETIME, - Username::new(String::from("user2")).unwrap(), - true, - ) - .await - .unwrap(); - - assert_eq!(m.allocations.lock().unwrap().len(), 3); - - m.delete_allocations_by_username("user").await; - - assert_eq!(m.allocations.lock().unwrap().len(), 1); - - assert!( - m.get_alloc(&five_tuple1).is_none() - && m.get_alloc(&five_tuple2).is_none() - && m.get_alloc(&five_tuple3).is_some() - ); - } -} diff --git a/src/allocation/channel_bind.rs b/src/allocation/channel_bind.rs index f19a82bde..b9e0c462c 100644 --- a/src/allocation/channel_bind.rs +++ b/src/allocation/channel_bind.rs @@ -1,47 +1,42 @@ -//! TURN [`Channel`]. +//! [Channel] definitions. //! -//! [`Channel`]: https://tools.ietf.org/html/rfc5766#section-2.5 +//! [Channel]: https://tools.ietf.org/html/rfc5766#section-2.5 -use std::{collections::HashMap, net::SocketAddr, sync::Arc}; +use std::{collections::HashMap, net::SocketAddr, sync::Arc, time::Duration}; use tokio::{ sync::{mpsc, Mutex}, - time::{sleep, Duration, Instant}, + time::{sleep, Instant}, }; -/// TURN [`Channel`]. +/// Representation of a [channel]. /// -/// [`Channel`]: https://tools.ietf.org/html/rfc5766#section-2.5 -#[derive(Clone)] +/// [channel]: https://tools.ietf.org/html/rfc5766#section-2.5 +#[derive(Clone, Debug)] pub(crate) struct ChannelBind { - /// Transport address of the peer. + /// Transport address of the peer behind this [`ChannelBind`]. peer: SocketAddr, - /// Channel number. + /// Number of this [`ChannelBind`]. number: u16, - /// Channel to the internal loop used to update lifetime or drop channel - /// binding. - reset_tx: Option>, + /// [`mpsc::Sender`] to the internal loop of this [`ChannelBind`], used to + /// update its lifetime or stop it. + reset_tx: mpsc::Sender, } impl ChannelBind { - /// Creates a new [`ChannelBind`] - pub(crate) const fn new(number: u16, peer: SocketAddr) -> Self { - Self { number, peer, reset_tx: None } - } - - /// Starts [`ChannelBind`]'s internal lifetime watching loop. - pub(crate) fn start( - &mut self, + /// Creates a new [`ChannelBind`] and [`spawn`]s a loop watching its + /// lifetime. + /// + /// [`spawn`]: tokio::spawn() + pub(crate) fn new( + number: u16, + peer: SocketAddr, bindings: Arc>>, lifetime: Duration, - ) { + ) -> Self { let (reset_tx, mut reset_rx) = mpsc::channel(1); - self.reset_tx = Some(reset_tx); - - let number = self.number; - drop(tokio::spawn(async move { let timer = sleep(lifetime); tokio::pin!(timer); @@ -51,7 +46,8 @@ impl ChannelBind { () = &mut timer => { if bindings.lock().await.remove(&number).is_none() { log::error!( - "Failed to remove ChannelBind for {number}" + "Failed to remove \ + `ChannelBind(number: {number})`", ); } break; @@ -66,44 +62,50 @@ impl ChannelBind { } } })); - } - /// Returns transport address of the peer. + Self { peer, number, reset_tx } + } + /// Returns the [`SocketAddr`] of the peer behind this [`ChannelBind`]. pub(crate) const fn peer(&self) -> SocketAddr { self.peer } - /// Returns channel number. + /// Returns the number of this [`ChannelBind`]. pub(crate) const fn num(&self) -> u16 { self.number } - /// Updates [`ChannelBind`]'s lifetime. + /// Updates the `lifetime` of this [`ChannelBind`]. pub(crate) async fn refresh(&self, lifetime: Duration) { - if let Some(tx) = &self.reset_tx { - _ = tx.send(lifetime).await; - } + _ = self.reset_tx.send(lifetime).await; } } #[cfg(test)] -mod channel_bind_test { - use std::net::Ipv4Addr; +mod allocation_spec { + use std::{ + net::{Ipv4Addr, SocketAddr}, + sync::Arc, + time::Duration, + }; use tokio::net::UdpSocket; use crate::{ - allocation::Allocation, attr::{ChannelNumber, Username}, - con, Error, FiveTuple, + server::DEFAULT_LIFETIME, + Allocation, Error, FiveTuple, }; - use super::*; + #[cfg(doc)] + use super::ChannelBind; - async fn create_channel_bind( + /// Creates an [`Allocation`] with a bound [`ChannelBind`] for testing + /// purposes. + async fn create_channel_bind_allocation( lifetime: Duration, ) -> Result { - let turn_socket = Arc::new(UdpSocket::bind("0.0.0.0:0").await?); + let turn_socket = Arc::new(UdpSocket::bind("0.0.0.0:0").await.unwrap()); let relay_socket = Arc::clone(&turn_socket); let relay_addr = relay_socket.local_addr().unwrap(); let a = Allocation::new( @@ -111,6 +113,7 @@ mod channel_bind_test { relay_socket, relay_addr, FiveTuple::default(), + DEFAULT_LIFETIME, Username::new(String::from("user")).unwrap(), None, ); @@ -123,12 +126,14 @@ mod channel_bind_test { } #[tokio::test] - async fn test_channel_bind() { - let a = create_channel_bind(Duration::from_millis(20)).await.unwrap(); + async fn channel_bind_is_present() { + let a = create_channel_bind_allocation(Duration::from_millis(20)) + .await + .unwrap(); let result = a.get_channel_addr(&ChannelNumber::MIN).await; if let Some(addr) = result { - assert_eq!(addr.ip().to_string(), "0.0.0.0"); + assert_eq!(addr.ip().to_string(), "0.0.0.0", "wrong IP address"); } else { panic!("expected some, but got none"); } diff --git a/src/allocation/manager.rs b/src/allocation/manager.rs new file mode 100644 index 000000000..3de87d0f1 --- /dev/null +++ b/src/allocation/manager.rs @@ -0,0 +1,491 @@ +//! Storage of [allocation]s. +//! +//! [allocation]: https://tools.ietf.org/html/rfc5766#section-5 + +use std::{ + collections::HashMap, net::SocketAddr, sync::atomic::Ordering, + time::Duration, +}; + +use tokio::sync::mpsc; + +use crate::{attr::Username, relay, Error}; + +use super::{Allocation, DynTransport, FiveTuple, Info}; + +/// Configuration parameters of a [`Manager`]. +#[derive(Debug)] +pub(crate) struct Config { + /// [`relay::Allocator`] of connections. + pub(crate) relay_addr_generator: relay::Allocator, + + /// [`mpsc::Sender`] for notifying when an [`Allocation`] is closed. + pub(crate) alloc_close_notify: Option>, +} + +/// [`Manager`] holding active [`Allocation`]s. +#[derive(Debug)] +pub(crate) struct Manager { + /// Stored [`Allocation`]s. + allocations: HashMap, + + /// [`relay::Allocator`] of connections. + relay_allocator: relay::Allocator, + + /// [`mpsc::Sender`] for notifying when an [`Allocation`] is closed. + alloc_close_notify: Option>, +} + +impl Manager { + /// Creates a new [`Manager`] out of the provided [`Config`]. + pub(crate) fn new(config: Config) -> Self { + Self { + allocations: HashMap::default(), + relay_allocator: config.relay_addr_generator, + alloc_close_notify: config.alloc_close_notify, + } + } + + /// Returns information about all the [`Allocation`]s associated with the + /// provided [`FiveTuple`]s. + pub(crate) fn get_allocations_info( + &self, + five_tuples: &Option>, + ) -> HashMap { + let mut infos = HashMap::new(); + + #[allow(clippy::iter_over_hash_type)] // order doesn't matter here + for (five_tuple, alloc) in &self.allocations { + if five_tuples.as_ref().map_or(true, |f| f.contains(five_tuple)) { + drop(infos.insert( + *five_tuple, + Info::new( + *five_tuple, + alloc.username.clone(), + alloc.relayed_bytes.load(Ordering::Acquire), + ), + )); + } + } + + infos + } + + /// Creates a new [`Allocation`] with provided parameters and starts + /// relaying it. + #[allow(clippy::too_many_arguments)] + pub(crate) async fn create_allocation( + &mut self, + five_tuple: FiveTuple, + turn_socket: DynTransport, + requested_port: u16, + lifetime: Duration, + username: Username, + use_ipv4: bool, + ) -> Result { + if lifetime == Duration::from_secs(0) { + return Err(Error::LifetimeZero); + } + + self.allocations.retain(|_, v| v.is_alive()); + + if self.get_alloc(&five_tuple).is_some() { + return Err(Error::DupeFiveTuple); + } + + let (relay_socket, relay_addr) = self + .relay_allocator + .allocate_conn(use_ipv4, requested_port) + .await?; + let alloc = Allocation::new( + turn_socket, + relay_socket, + relay_addr, + five_tuple, + lifetime, + username, + self.alloc_close_notify.clone(), + ); + + drop(self.allocations.insert(five_tuple, alloc)); + + Ok(relay_addr) + } + + /// Returns the [`Allocation`] matching the provided [`FiveTuple`], if any. + pub(crate) fn get_alloc( + &mut self, + five_tuple: &FiveTuple, + ) -> Option<&Allocation> { + self.allocations.get(five_tuple).and_then(|a| a.is_alive().then_some(a)) + } + + /// Removes the [`Allocation`] matching the provided [`FiveTuple`], if any. + pub(crate) fn delete_allocation(&mut self, five_tuple: &FiveTuple) { + drop(self.allocations.remove(five_tuple)); + } + + /// Removes all the [`Allocation`]s with the provided `username`, if any. + pub(crate) fn delete_allocations_by_username( + &mut self, + username: impl AsRef, + ) { + let username = username.as_ref(); + self.allocations + .retain(|_, allocation| allocation.username.name() != username); + } + + /// Returns a random non-allocated UDP port. + /// + /// # Errors + /// + /// If new port fails to be allocated. See the [`Error`] for details + pub(crate) async fn get_random_even_port(&self) -> Result { + self.relay_allocator + .allocate_conn(true, 0) + .await + .map(|(_, addr)| addr.port()) + } +} + +#[cfg(test)] +mod spec { + use std::{ + net::{IpAddr, Ipv4Addr, SocketAddr}, + str::FromStr, + sync::Arc, + time::Duration, + }; + + use bytecodec::DecodeExt; + use rand::random; + use stun_codec::MessageDecoder; + use tokio::{net::UdpSocket, sync::mpsc, time::sleep}; + + use crate::{ + attr::{Attribute, ChannelNumber, Data, Username}, + chandata::ChannelData, + relay, + server::DEFAULT_LIFETIME, + Error, FiveTuple, + }; + + use super::{Config, DynTransport, Manager}; + + /// Creates a new [`Manager`] for testing purposes. + fn create_manager() -> Manager { + let config = Config { + relay_addr_generator: relay::Allocator { + relay_address: IpAddr::from([127, 0, 0, 1]), + min_port: 49152, + max_port: 65535, + max_retries: 10, + address: String::from("127.0.0.1"), + }, + alloc_close_notify: None, + }; + Manager::new(config) + } + + /// Generates a new random [`FiveTuple`] for testing purposes. + fn random_five_tuple() -> FiveTuple { + FiveTuple { + src_addr: SocketAddr::new( + Ipv4Addr::new(0, 0, 0, 0).into(), + random(), + ), + dst_addr: SocketAddr::new( + Ipv4Addr::new(0, 0, 0, 0).into(), + random(), + ), + ..Default::default() + } + } + + #[tokio::test] + async fn packet_handler_works() { + let turn_socket = UdpSocket::bind("127.0.0.1:0").await.unwrap(); + + let client_listener = UdpSocket::bind("127.0.0.1:0").await.unwrap(); + let src_addr = client_listener.local_addr().unwrap(); + let (data_ch_tx, mut data_ch_rx) = mpsc::channel(1); + // `client_listener` read data + drop(tokio::spawn(async move { + let mut buffer = vec![0u8; 1500]; + loop { + let n = match client_listener.recv_from(&mut buffer).await { + Ok((n, _)) => n, + Err(_) => break, + }; + + drop(data_ch_tx.send(buffer[..n].to_vec()).await); + } + })); + + let five_tuple = FiveTuple { + src_addr, + dst_addr: turn_socket.local_addr().unwrap(), + ..Default::default() + }; + let mut m = create_manager(); + _ = m + .create_allocation( + five_tuple, + Arc::new(turn_socket), + 0, + DEFAULT_LIFETIME, + Username::new(String::from("user")).unwrap(), + true, + ) + .await + .unwrap(); + let a = m.get_alloc(&five_tuple).unwrap(); + + let peer_listener1 = UdpSocket::bind("127.0.0.1:0").await.unwrap(); + let peer_listener2 = UdpSocket::bind("127.0.0.1:0").await.unwrap(); + + let port = { + a.add_permission(peer_listener1.local_addr().unwrap().ip()).await; + a.add_channel_bind( + ChannelNumber::MIN, + peer_listener2.local_addr().unwrap(), + DEFAULT_LIFETIME, + ) + .await + .unwrap(); + + a.relay_socket.local_addr().unwrap().port() + }; + + let relay_addr_with_host_str = format!("127.0.0.1:{port}"); + let relay_addr_with_host = + SocketAddr::from_str(&relay_addr_with_host_str).unwrap(); + + let target_text = "permission"; + let _ = peer_listener1 + .send_to(target_text.as_bytes(), relay_addr_with_host) + .await + .unwrap(); + let data = data_ch_rx.recv().await.unwrap(); + + let msg = MessageDecoder::::new() + .decode_from_bytes(&data) + .unwrap() + .unwrap(); + let msg_data = msg.get_attribute::().unwrap().data().to_vec(); + + assert_eq!( + target_text.as_bytes(), + &msg_data, + "get message doesn't equal target text", + ); + + let target_text2 = "channel bind"; + let _ = peer_listener2 + .send_to(target_text2.as_bytes(), relay_addr_with_host) + .await + .unwrap(); + let data = data_ch_rx.recv().await.unwrap(); + + assert!(ChannelData::is_channel_data(&data), "should be channel data"); + + let channel_data = ChannelData::decode(data).unwrap(); + + assert_eq!( + ChannelNumber::MIN, + channel_data.num(), + "get channel data's number is invalid", + ); + assert_eq!( + target_text2.as_bytes(), + &channel_data.data(), + "get data doesn't equal target text", + ); + } + + #[tokio::test] + async fn errors_on_duplicate_five_tuple() { + let turn_socket: DynTransport = + Arc::new(UdpSocket::bind("0.0.0.0:0").await.unwrap()); + + let mut m = create_manager(); + let five_tuple = random_five_tuple(); + _ = m + .create_allocation( + five_tuple, + DynTransport::clone(&turn_socket), + 0, + DEFAULT_LIFETIME, + Username::new(String::from("user")).unwrap(), + true, + ) + .await + .unwrap(); + + let res = m + .create_allocation( + five_tuple, + DynTransport::clone(&turn_socket), + 0, + DEFAULT_LIFETIME, + Username::new(String::from("user")).unwrap(), + true, + ) + .await; + + assert_eq!(res, Err(Error::DupeFiveTuple)); + } + + #[tokio::test] + async fn deletes_allocation() { + let turn_socket: DynTransport = + Arc::new(UdpSocket::bind("0.0.0.0:0").await.unwrap()); + + let mut m = create_manager(); + let five_tuple = random_five_tuple(); + _ = m + .create_allocation( + five_tuple, + DynTransport::clone(&turn_socket), + 0, + DEFAULT_LIFETIME, + Username::new(String::from("user")).unwrap(), + true, + ) + .await + .unwrap(); + + assert!( + m.get_alloc(&five_tuple).is_some(), + "cannot to get `Allocation` right after creation", + ); + + m.delete_allocation(&five_tuple); + + assert!( + !m.get_alloc(&five_tuple).is_some(), + "`Allocation` of `{five_tuple}` was not deleted", + ); + } + + #[tokio::test] + async fn allocations_timeout() { + let turn_socket: DynTransport = + Arc::new(UdpSocket::bind("0.0.0.0:0").await.unwrap()); + + let mut m = create_manager(); + let mut allocations = vec![]; + let lifetime = Duration::from_millis(100); + for _ in 0..5 { + let five_tuple = random_five_tuple(); + + _ = m + .create_allocation( + five_tuple, + DynTransport::clone(&turn_socket), + 0, + lifetime, + Username::new(String::from("user")).unwrap(), + true, + ) + .await + .unwrap(); + + allocations.push(five_tuple); + } + + let mut count = 0; + 'outer: loop { + count += 1; + + if count >= 10 { + panic!("`Allocation`s didn't timeout"); + } + + sleep(lifetime + Duration::from_millis(100)).await; + + let any_outstanding = false; + + for a in &allocations { + if m.get_alloc(a).is_some() { + continue 'outer; + } + } + + if !any_outstanding { + return; + } + } + } + + #[tokio::test] + async fn deletes_allocation_by_username() { + let turn_socket: DynTransport = + Arc::new(UdpSocket::bind("0.0.0.0:0").await.unwrap()); + + let mut m = create_manager(); + let five_tuple1 = random_five_tuple(); + let five_tuple2 = random_five_tuple(); + let five_tuple3 = random_five_tuple(); + _ = m + .create_allocation( + five_tuple1, + DynTransport::clone(&turn_socket), + 0, + DEFAULT_LIFETIME, + Username::new(String::from("user")).unwrap(), + true, + ) + .await + .unwrap(); + _ = m + .create_allocation( + five_tuple2, + DynTransport::clone(&turn_socket), + 0, + DEFAULT_LIFETIME, + Username::new(String::from("user")).unwrap(), + true, + ) + .await + .unwrap(); + _ = m + .create_allocation( + five_tuple3, + DynTransport::clone(&turn_socket), + 0, + DEFAULT_LIFETIME, + Username::new(String::from("user2")).unwrap(), + true, + ) + .await + .unwrap(); + + assert_eq!( + m.allocations.len(), + 3, + "wrong number of created `Allocation`s", + ); + + m.delete_allocations_by_username("user"); + + assert_eq!( + m.allocations.len(), + 1, + "wrong number of left `Allocation`s", + ); + + assert!( + m.get_alloc(&five_tuple1).is_none(), + "first allocation is not deleted", + ); + assert!( + m.get_alloc(&five_tuple2).is_none(), + "second allocation is not deleted", + ); + assert!( + m.get_alloc(&five_tuple3).is_some(), + "third allocation is deleted", + ); + } +} diff --git a/src/allocation/mod.rs b/src/allocation/mod.rs index 42ed53f44..1e5407e8d 100644 --- a/src/allocation/mod.rs +++ b/src/allocation/mod.rs @@ -1,24 +1,24 @@ -//! TURN server [allocation]. +//! [Allocation] definitions. //! -//! [allocation]: https://datatracker.ietf.org/doc/html/rfc5766#section-5 +//! [Allocation]: https://tools.ietf.org/html/rfc5766#section-5 -mod allocation_manager; mod channel_bind; +mod manager; mod permission; use std::{ collections::HashMap, - fmt, marker::{Send, Sync}, mem, net::{IpAddr, SocketAddr}, sync::{ atomic::{AtomicUsize, Ordering}, - Arc, Mutex as SyncMutex, + Arc, }, }; use bytecodec::EncodeExt; +use derive_more::Display; use rand::random; use stun_codec::{ rfc5766::methods::DATA, Message, MessageClass, MessageEncoder, @@ -26,11 +26,7 @@ use stun_codec::{ }; use tokio::{ net::UdpSocket, - sync::{ - mpsc, - oneshot::{self, Sender}, - Mutex, - }, + sync::{mpsc, Mutex}, time::{sleep, Duration, Instant}, }; @@ -38,149 +34,167 @@ use crate::{ allocation::permission::PERMISSION_LIFETIME, attr::{Attribute, Data, Username, XorPeerAddress}, chandata::ChannelData, - con::Conn, server::INBOUND_MTU, - Error, + transport, Error, Transport, }; use self::{channel_bind::ChannelBind, permission::Permission}; -pub(crate) use allocation_manager::{Manager, ManagerConfig}; +pub(crate) use self::manager::{Config as ManagerConfig, Manager}; -/// [`Allocation`]s storage. -pub(crate) type AllocationMap = - Arc>>>; +/// Shortcut for a [`Transport`] trait object. +type DynTransport = Arc; -/// Information about an allocation. -#[derive(Debug, Clone)] -pub struct AllocInfo { - /// [`FiveTuple`] of this allocation. +/// 5-tuple uniquely identifying a UDP/TCP session. +/// +/// Consists of: +/// 1. source IP address +/// 2. source port +/// 3. destination IP address +/// 4. destination port +/// 5. transport protocol +#[derive(Clone, Copy, Debug, Display, Eq, Hash, PartialEq)] +#[display("{protocol}_{src_addr}_{dst_addr}")] +pub struct FiveTuple { + /// Number of the transport protocol according to [IANA]. + /// + /// [IANA]: https://tinyurl.com/iana-protocol-numbers + pub protocol: u8, + + /// Source address. + pub src_addr: SocketAddr, + + /// Destination address. + pub dst_addr: SocketAddr, +} + +/// Information about an [allocation]. +/// +/// [allocation]: https://tools.ietf.org/html/rfc5766#section-5 +#[derive(Clone, Debug)] +pub struct Info { + /// [`FiveTuple`] of the [allocation]. + /// + /// [allocation]: https://tools.ietf.org/html/rfc5766#section-5 pub five_tuple: FiveTuple, - /// Username of this allocation. - pub username: String, + /// [`Username`] of the [allocation]. + /// + /// [allocation]: https://tools.ietf.org/html/rfc5766#section-5 + pub username: Username, - /// Relayed bytes with this allocation. + /// Relayed bytes through the [allocation]. + /// + /// [allocation]: https://tools.ietf.org/html/rfc5766#section-5 pub relayed_bytes: usize, } -impl AllocInfo { - /// Creates a new [`AllocInfo`]. +impl Info { + /// Creates a new [`Info`] out of the provided parameters. #[must_use] pub const fn new( five_tuple: FiveTuple, - username: String, + username: Username, relayed_bytes: usize, ) -> Self { Self { five_tuple, username, relayed_bytes } } } -/// The tuple (source IP address, source port, destination IP -/// address, destination port, transport protocol). A 5-tuple -/// uniquely identifies a UDP/TCP session. -#[derive(PartialEq, Eq, Clone, Copy, Debug, Hash)] -pub struct FiveTuple { - /// Transport protocol according to [IANA] protocol numbers. - /// - /// [IANA]: https://tinyurl.com/iana-protocol-numbers - pub protocol: u8, - - /// Packet source address. - pub src_addr: SocketAddr, - - /// Packet target address. - pub dst_addr: SocketAddr, -} - -impl fmt::Display for FiveTuple { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}_{}_{}", self.protocol, self.src_addr, self.dst_addr) - } -} - -/// TURN server [Allocation]. +/// Representation of an [allocation]. /// -/// [Allocation]:https://datatracker.ietf.org/doc/html/rfc5766#section-5 +/// [allocation]: https://tools.ietf.org/html/rfc5766#section-5 +#[derive(Debug)] pub(crate) struct Allocation { - /// [`Conn`] used to create this [`Allocation`]. - turn_socket: Arc, - - /// Relay socket address. + /// Relay [`SocketAddr`]. relay_addr: SocketAddr, - /// Allocated relay socket. + /// Allocated relay [`UdpSocket`]. relay_socket: Arc, - /// [`FiveTuple`] this allocation is created with. + /// [`FiveTuple`] this [`Allocation`] is created for. five_tuple: FiveTuple, - /// Remote user ICE [`Username`]. + /// [`Username`] of the remote [ICE] user. + /// + /// [ICE]: https://webrtcglossary.com/ice username: Username, - /// List of [`Permission`]s for this [`Allocation`] + /// List of [`Permission`]s for this [`Allocation`]. permissions: Arc>>, - /// This [`Allocation`] [`ChannelBind`]ings. + /// [`ChannelBind`]s of this [`Allocation`]. channel_bindings: Arc>>, - /// All [`Allocation`]s storage. - allocations: Option, + /// [`mpsc::Sender`] to the internal loop of this [`Allocation`], used to + /// update its lifetime or stop it. + refresh_tx: mpsc::Sender, - /// Channel to the internal loop used to update lifetime or drop - /// allocation. - reset_tx: SyncMutex>>, - - /// Total number of relayed bytes. + /// Total number of relayed bytes through this [`Allocation`]. relayed_bytes: AtomicUsize, - /// Channel to the packet handler loop used to stop it. - drop_tx: Option>, - - /// Injected into allocations to notify when allocation is closed. - alloc_close_notify: Option>, + /// [`mpsc::Sender`] for notifying when this [`Allocation`] is closed. + alloc_close_notify: Option>, } impl Allocation { - /// Creates a new [`Allocation`]. + /// Creates a new [`Allocation`] out of the provided parameters. pub(crate) fn new( - turn_socket: Arc, + turn_socket: Arc, relay_socket: Arc, relay_addr: SocketAddr, five_tuple: FiveTuple, + lifetime: Duration, username: Username, - alloc_close_notify: Option>, + alloc_close_notify: Option>, ) -> Self { - Self { - turn_socket, + let (refresh_tx, refresh_rx) = mpsc::channel(1); + + let this = Self { relay_addr, relay_socket, five_tuple, username, permissions: Arc::new(Mutex::new(HashMap::new())), channel_bindings: Arc::new(Mutex::new(HashMap::new())), - allocations: None, - reset_tx: SyncMutex::new(None), + refresh_tx, relayed_bytes: AtomicUsize::default(), - drop_tx: None, alloc_close_notify, - } + }; + + this.spawn_relay_handler(refresh_rx, lifetime, turn_socket); + + this } - /// Send the given data via associated relay socket. + /// Indicates whether the underlying relay socket and transmission loop is + /// alive. + pub(crate) fn is_alive(&self) -> bool { + !self.refresh_tx.is_closed() + } + + /// Send the provided `data` via the associated relay socket. + /// + /// # Errors + /// + /// - With an [`Error::NoAllocationFound`] if this [`Allocation`] is dead. + /// - With a [`transport::Error`] if failed to send the `data`. pub(crate) async fn relay( &self, data: &[u8], to: SocketAddr, ) -> Result<(), Error> { - match self.relay_socket.send_to(data, to).await { - Ok(n) => { - _ = self.relayed_bytes.fetch_add(n, Ordering::AcqRel); - - Ok(()) - } - Err(err) => Err(Error::from(err)), + if !self.is_alive() { + return Err(Error::NoAllocationFound); } + + let n = self + .relay_socket + .send_to(data, to) + .await + .map_err(transport::Error::from)?; + _ = self.relayed_bytes.fetch_add(n, Ordering::AcqRel); + Ok(()) } /// Returns [`SocketAddr`] of the associated relay socket. @@ -188,35 +202,50 @@ impl Allocation { self.relay_addr } - /// Checks the Permission for the `addr`. + /// Checks the [`Permission`] for the provided [`SocketAddr`]. pub(crate) async fn has_permission(&self, addr: &SocketAddr) -> bool { + if !self.is_alive() { + return false; + } + self.permissions.lock().await.get(&addr.ip()).is_some() } /// Adds a new [`Permission`] to this [`Allocation`]. pub(crate) async fn add_permission(&self, ip: IpAddr) { + if !self.is_alive() { + return; + } + let mut permissions = self.permissions.lock().await; if let Some(existed_permission) = permissions.get(&ip) { existed_permission.refresh(PERMISSION_LIFETIME).await; } else { - let mut p = Permission::new(ip); - p.start(Arc::clone(&self.permissions), PERMISSION_LIFETIME); + let p = Permission::new( + ip, + Arc::clone(&self.permissions), + PERMISSION_LIFETIME, + ); drop(permissions.insert(p.ip(), p)); } } - /// Adds a new [`ChannelBind`] to this [`Allocation`], it also updates the - /// permissions needed for this [`ChannelBind`]. - #[allow(clippy::significant_drop_tightening)] // false-positive + /// Adds a new [`ChannelBind`] to this [`Allocation`], also updating the + /// [`Permission`]s needed for this [`ChannelBind`]. + #[allow(clippy::significant_drop_tightening)] // false positive pub(crate) async fn add_channel_bind( &self, number: u16, peer_addr: SocketAddr, lifetime: Duration, ) -> Result<(), Error> { - // The channel number is not currently bound to a different transport - // address (same transport address is OK); + if !self.is_alive() { + return Err(Error::NoAllocationFound); + } + + // The `ChannelNumber` is not currently bound to a different transport + // address (same transport address is OK). if let Some(addr) = self.get_channel_addr(&number).await { if addr != peer_addr { return Err(Error::SameChannelDifferentPeer); @@ -224,7 +253,7 @@ impl Allocation { } // The transport address is not currently bound to a different - // channel number. + // `ChannelNumber`. if let Some(n) = self.get_channel_number(&peer_addr).await { if number != n { return Err(Error::SamePeerDifferentChannel); @@ -237,33 +266,45 @@ impl Allocation { cb.refresh(lifetime).await; - // Channel binds also refresh permissions. + // `ChannelBind`s also refresh `Permission`s. self.add_permission(cb.peer().ip()).await; } else { - let mut bind = ChannelBind::new(number, peer_addr); - bind.start(Arc::clone(&self.channel_bindings), lifetime); + let bind = ChannelBind::new( + number, + peer_addr, + Arc::clone(&self.channel_bindings), + lifetime, + ); drop(channel_bindings.insert(number, bind)); - // Channel binds also refresh permissions. + // `ChannelBind`s also refresh `Permission`s. self.add_permission(peer_addr.ip()).await; } Ok(()) } - /// Gets the [`ChannelBind`]'s address by `number`. + /// Returns the [`ChannelBind`]'s address by the provided `number`. pub(crate) async fn get_channel_addr( &self, number: &u16, ) -> Option { + if !self.is_alive() { + return None; + } + self.channel_bindings.lock().await.get(number).map(ChannelBind::peer) } - /// Gets the [`ChannelBind`]'s number from this [`Allocation`] by `addr`. + /// Returns the [`ChannelBind`]'s number from this [`Allocation`] by its + /// `addr`ess. pub(crate) async fn get_channel_number( &self, addr: &SocketAddr, ) -> Option { + if !self.is_alive() { + return None; + } self.channel_bindings .lock() .await @@ -271,189 +312,130 @@ impl Allocation { .find_map(|b| (b.peer() == *addr).then_some(b.num())) } - /// Closes the [`Allocation`]. - pub(crate) async fn close(&self) -> Result<(), Error> { - #[allow(clippy::unwrap_used)] - if self.reset_tx.lock().unwrap().take().is_none() { - return Err(Error::Closed); - } - - drop(mem::take(&mut *self.permissions.lock().await)); - drop(mem::take(&mut *self.channel_bindings.lock().await)); - - log::trace!("allocation with {} closed!", self.five_tuple); - - drop(self.relay_socket.close().await); - - if let Some(notify_tx) = &self.alloc_close_notify { - drop( - notify_tx - .send(AllocInfo { - five_tuple: self.five_tuple, - username: self.username.name().to_owned(), - relayed_bytes: self - .relayed_bytes - .load(Ordering::Acquire), - }) - .await, - ); - } - - Ok(()) - } - - /// Starts the internal lifetime watching loop. - pub(crate) fn start(&self, lifetime: Duration) { - let (reset_tx, mut reset_rx) = mpsc::channel(1); - #[allow(clippy::unwrap_used)] - drop(self.reset_tx.lock().unwrap().replace(reset_tx)); - - let allocations = self.allocations.clone(); - let five_tuple = self.five_tuple; - - drop(tokio::spawn(async move { - let timer = sleep(lifetime); - tokio::pin!(timer); - - loop { - tokio::select! { - () = &mut timer => { - if let Some(allocs) = &allocations{ - #[allow(clippy::unwrap_used)] - let alloc = allocs - .lock() - .unwrap() - .remove(&five_tuple); - - if let Some(a) = alloc { - drop(a.close().await); - } - } - break; - }, - result = reset_rx.recv() => { - if let Some(d) = result { - timer.as_mut().reset(Instant::now() + d); - } else { - break; - } - }, - } - } - })); - } - - /// Updates the allocations lifetime. + /// Updates the `lifetime` of this [`Allocation`]. pub(crate) async fn refresh(&self, lifetime: Duration) { - #[allow(clippy::unwrap_used)] - let reset_tx = self.reset_tx.lock().unwrap().clone(); - - if let Some(tx) = reset_tx { - _ = tx.send(lifetime).await; - } + _ = self.refresh_tx.send(lifetime).await; } - /// When the server receives a UDP datagram at a currently allocated - /// relayed transport address, the server looks up the allocation - /// associated with the relayed transport address. The server then - /// checks to see whether the set of permissions for the allocation allow - /// the relaying of the UDP datagram as described in Section 8. + /// [`spawn`]s a relay handler of this [`Allocation`]. /// - /// If relaying is permitted, then the server checks if there is a - /// channel bound to the peer that sent the UDP datagram (see - /// Section 11). If a channel is bound, then processing proceeds as - /// described in Section 11.7. + /// See [Section 10.3][1]: + /// > When the server receives a UDP datagram at a currently allocated + /// > relayed transport address, the server looks up the allocation + /// > associated with the relayed transport address. The server then + /// > checks to see whether the set of permissions for the allocation allow + /// > the relaying of the UDP datagram as described in [Section 8]. + /// > + /// > If relaying is permitted, then the server checks if there is a + /// > channel bound to the peer that sent the UDP datagram (see + /// > [Section 11]). If a channel is bound, then processing proceeds as + /// > described in [Section 11.7][2]. + /// > + /// > If relaying is permitted but no channel is bound to the peer, then + /// > the server forms and sends a Data indication. The Data indication + /// > MUST contain both an XOR-PEER-ADDRESS and a DATA attribute. The DATA + /// > attribute is set to the value of the 'data octets' field from the + /// > datagram, and the XOR-PEER-ADDRESS attribute is set to the source + /// > transport address of the received UDP datagram. The Data indication + /// > is then sent on the 5-tuple associated with the allocation. /// - /// If relaying is permitted but no channel is bound to the peer, then - /// the server forms and sends a Data indication. The Data indication - /// MUST contain both an XOR-PEER-ADDRESS and a DATA attribute. The DATA - /// attribute is set to the value of the 'data octets' field from the - /// datagram, and the XOR-PEER-ADDRESS attribute is set to the source - /// transport address of the received UDP datagram. The Data indication - /// is then sent on the 5-tuple associated with the allocation. + /// [`spawn`]: tokio::spawn() + /// [1]: https://tools.ietf.org/html/rfc5766#section-10.3 + /// [2]: https://tools.ietf.org/html/rfc5766#section-11.7 + /// [Section 8]: https://tools.ietf.org/html/rfc5766#section-8 + /// [Section 11]: https://tools.ietf.org/html/rfc5766#section-11 #[allow(clippy::too_many_lines)] - fn packet_handler(&mut self) { + fn spawn_relay_handler( + &self, + mut refresh_rx: mpsc::Receiver, + lifetime: Duration, + turn_socket: Arc, + ) { let five_tuple = self.five_tuple; let relay_addr = self.relay_addr; let relay_socket = Arc::clone(&self.relay_socket); - let turn_socket = Arc::clone(&self.turn_socket); - let allocations = self.allocations.clone(); let channel_bindings = Arc::clone(&self.channel_bindings); let permissions = Arc::clone(&self.permissions); - let (drop_tx, drop_rx) = oneshot::channel::(); - self.drop_tx = Some(drop_tx); drop(tokio::spawn(async move { + log::trace!("Listening on relay addr: {relay_addr}"); + + let expired = sleep(lifetime); + tokio::pin!(expired); let mut buffer = vec![0u8; INBOUND_MTU]; - tokio::pin!(drop_rx); loop { - let (n, src_addr) = tokio::select! { + let (data, src_addr) = tokio::select! { result = relay_socket.recv_from(&mut buffer) => { - if let Ok((data, src_addr)) = result { - (data, src_addr) + if let Ok((n, src_addr)) = result { + (&buffer[..n], src_addr) } else { - if let Some(allocs) = &allocations { - #[allow(clippy::unwrap_used)] - drop( - allocs.lock().unwrap().remove(&five_tuple) - ); - } break; } } - _ = drop_rx.as_mut() => { - log::trace!("allocation has stopped, \ - stop packet_handler. five_tuple: {:?}", - five_tuple); + () = &mut expired => { break; - } - }; - - let cb_number = { - let mut cb_number = None; - #[allow( - clippy::iter_over_hash_type, - clippy::significant_drop_in_scrutinee - )] - for cb in channel_bindings.lock().await.values() { - if cb.peer() == src_addr { - cb_number = Some(cb.num()); - break; + }, + refresh = refresh_rx.recv() => { + match refresh { + Some(lf) => { + if lf == Duration::ZERO { + break; + } + expired.as_mut().reset(Instant::now() + lf); + continue; + } + None => { + break; + } } - } - cb_number + }, }; + let cb_number = channel_bindings + .lock() + .await + .iter() + .find(|(_, cb)| cb.peer() == src_addr) + .map(|(cn, _)| *cn); + if let Some(number) = cb_number { - match ChannelData::encode(buffer[..n].to_vec(), number) { + match ChannelData::encode(data, number) { Ok(data) => { - if let Err(err) = turn_socket + if let Err(e) = turn_socket .send_to(data, five_tuple.src_addr) .await { - log::error!( - "Failed to send ChannelData from \ - allocation {src_addr}: {err}", - ); + match e { + transport::Error::TransportIsDead => { + break; + } + transport::Error::Decode(_) + | transport::Error::ChannelData(_) + | transport::Error::Io(_) => { + log::warn!( + "Failed to send `ChannelData` from \ + `Allocation(scr: {src_addr}`: {e}", + ); + } + } } } - Err(err) => { - log::error!( - "Failed to send ChannelData from allocation \ - {src_addr}: {err}" + Err(e) => { + log::warn!( + "Failed to send `ChannelData` from \ + `Allocation(src: {src_addr})`: {e}", ); } }; } else { - let exist = - permissions.lock().await.get(&src_addr.ip()).is_some(); + let has_permission = + permissions.lock().await.contains_key(&src_addr.ip()); - if exist { + if has_permission { log::trace!( - "relaying message from {} to client at {}", - src_addr, - five_tuple.src_addr + "Relaying message from {src_addr} to client at {}", + five_tuple.src_addr, ); let mut msg: Message = Message::new( @@ -462,57 +444,82 @@ impl Allocation { TransactionId::new(random()), ); msg.add_attribute(XorPeerAddress::new(src_addr)); - let Ok(data) = Data::new(buffer[..n].to_vec()) else { - log::error!("DataIndication is too long"); + let Ok(data) = Data::new(data.to_vec()) else { + log::error!("`DataIndication` is too long"); continue; }; msg.add_attribute(data); match MessageEncoder::new().encode_into_bytes(msg) { Ok(encoded) => { - if let Err(err) = turn_socket + if let Err(e) = turn_socket .send_to(encoded, five_tuple.src_addr) .await { log::error!( - "Failed to send DataIndication from \ - allocation {} {}", - src_addr, - err + "Failed to send `DataIndication` from \ + `Allocation(src: {src_addr})`: {e}", ); } } Err(e) => { - log::error!("DataIndication encode err: {e}"); + log::error!( + "`DataIndication` encoding failed: {e}", + ); } } } else { log::info!( - "No Permission or Channel exists for {} on \ - allocation {}", - src_addr, - relay_addr + "No `Permission` or `ChannelBind` exists for \ + `{src_addr}` on `Allocation(relay: {relay_addr})`", ); } } } + drop(mem::take(&mut *channel_bindings.lock().await)); + drop(mem::take(&mut *permissions.lock().await)); + + log::trace!( + "`Allocation(five_tuple: {five_tuple})` stopped, stop \ + `relay_handler`", + ); })); } } +impl Drop for Allocation { + fn drop(&mut self) { + if let Some(notify_tx) = self.alloc_close_notify.take() { + let info = Info { + five_tuple: self.five_tuple, + username: self.username.clone(), + relayed_bytes: self.relayed_bytes.load(Ordering::Acquire), + }; + + drop(tokio::spawn(async move { + drop(notify_tx.send(info).await); + })); + } + } +} + #[cfg(test)] -mod allocation_test { - use std::{net::Ipv4Addr, str::FromStr}; +mod spec { + use std::{ + net::{Ipv4Addr, SocketAddr}, + str::FromStr, + sync::Arc, + }; use tokio::net::UdpSocket; - use super::*; - use crate::{ - attr::{ChannelNumber, PROTO_UDP}, + attr::{ChannelNumber, Username, PROTO_UDP}, server::DEFAULT_LIFETIME, }; + use super::{Allocation, FiveTuple}; + impl Default for FiveTuple { fn default() -> Self { FiveTuple { @@ -524,7 +531,7 @@ mod allocation_test { } #[tokio::test] - async fn test_has_permission() { + async fn has_permission() { let turn_socket = Arc::new(UdpSocket::bind("0.0.0.0:0").await.unwrap()); let relay_socket = Arc::clone(&turn_socket); let relay_addr = relay_socket.local_addr().unwrap(); @@ -533,6 +540,7 @@ mod allocation_test { relay_socket, relay_addr, FiveTuple::default(), + DEFAULT_LIFETIME, Username::new(String::from("user")).unwrap(), None, ); @@ -546,17 +554,17 @@ mod allocation_test { a.add_permission(addr3.ip()).await; let found_p1 = a.has_permission(&addr1).await; - assert!(found_p1, "Should keep the first one."); + assert!(found_p1, "should keep the first one"); let found_p2 = a.has_permission(&addr2).await; - assert!(found_p2, "Second one should be ignored."); + assert!(found_p2, "second one should be ignored"); let found_p3 = a.has_permission(&addr3).await; - assert!(found_p3, "Permission with another IP should be found"); + assert!(found_p3, "`Permission` with another IP should be found"); } #[tokio::test] - async fn test_add_permission() { + async fn add_permission() { let turn_socket = Arc::new(UdpSocket::bind("0.0.0.0:0").await.unwrap()); let relay_socket = Arc::clone(&turn_socket); let relay_addr = relay_socket.local_addr().unwrap(); @@ -565,6 +573,7 @@ mod allocation_test { relay_socket, relay_addr, FiveTuple::default(), + DEFAULT_LIFETIME, Username::new(String::from("user")).unwrap(), None, ); @@ -573,11 +582,11 @@ mod allocation_test { a.add_permission(addr.ip()).await; let found_p = a.has_permission(&addr).await; - assert!(found_p, "Should keep the first one."); + assert!(found_p, "should keep the first one"); } #[tokio::test] - async fn test_get_channel_by_number() { + async fn get_channel_by_number() { let turn_socket = Arc::new(UdpSocket::bind("0.0.0.0:0").await.unwrap()); let relay_socket = Arc::clone(&turn_socket); let relay_addr = relay_socket.local_addr().unwrap(); @@ -586,6 +595,7 @@ mod allocation_test { relay_socket, relay_addr, FiveTuple::default(), + DEFAULT_LIFETIME, Username::new(String::from("user")).unwrap(), None, ); @@ -602,14 +612,11 @@ mod allocation_test { let not_exist_channel = a.get_channel_addr(&(ChannelNumber::MIN + 1)).await; - assert!( - not_exist_channel.is_none(), - "should be nil for not existed channel." - ); + assert!(not_exist_channel.is_none(), "found, but shouldn't"); } #[tokio::test] - async fn test_get_channel_by_addr() { + async fn get_channel_by_addr() { let turn_socket = Arc::new(UdpSocket::bind("0.0.0.0:0").await.unwrap()); let relay_socket = Arc::clone(&turn_socket); let relay_addr = relay_socket.local_addr().unwrap(); @@ -618,6 +625,7 @@ mod allocation_test { relay_socket, relay_addr, FiveTuple::default(), + DEFAULT_LIFETIME, Username::new(String::from("user")).unwrap(), None, ); @@ -633,14 +641,11 @@ mod allocation_test { assert_eq!(ChannelNumber::MIN, exist_channel_number); let not_exist_channel = a.get_channel_number(&addr2).await; - assert!( - not_exist_channel.is_none(), - "should be nil for not existed channel." - ); + assert!(not_exist_channel.is_none(), "found, but shouldn't"); } #[tokio::test] - async fn test_allocation_close() { + async fn closing() { let turn_socket = Arc::new(UdpSocket::bind("0.0.0.0:0").await.unwrap()); let relay_socket = Arc::clone(&turn_socket); let relay_addr = relay_socket.local_addr().unwrap(); @@ -649,29 +654,21 @@ mod allocation_test { relay_socket, relay_addr, FiveTuple::default(), + DEFAULT_LIFETIME, Username::new(String::from("user")).unwrap(), None, ); - // add mock lifetimeTimer - a.start(DEFAULT_LIFETIME); - - // add channel let addr = SocketAddr::from_str("127.0.0.1:3478").unwrap(); - a.add_channel_bind(ChannelNumber::MIN, addr, DEFAULT_LIFETIME) .await .unwrap(); - - // add permission a.add_permission(addr.ip()).await; - - a.close().await.unwrap(); } } #[cfg(test)] -mod five_tuple_test { +mod five_tuple_spec { use std::net::SocketAddr; use crate::{ @@ -680,7 +677,7 @@ mod five_tuple_test { }; #[test] - fn test_five_tuple_equal() { + fn equality() { let src_addr1: SocketAddr = "0.0.0.0:3478".parse::().unwrap(); let src_addr2: SocketAddr = @@ -691,7 +688,7 @@ mod five_tuple_test { let dst_addr2: SocketAddr = "0.0.0.0:3481".parse::().unwrap(); - let tests = vec![ + let tests = [ ( "Equal", true, @@ -749,12 +746,11 @@ mod five_tuple_test { }, ), ]; - for (name, expect, a, b) in tests { let fact = a == b; assert_eq!( expect, fact, - "{name}: {a}, {b} equal check should be {expect}, but {fact}" + "{name}: {a}, {b} equal check should be {expect}, but {fact}", ); } } diff --git a/src/allocation/permission.rs b/src/allocation/permission.rs index 4507a8acf..178dfbed2 100644 --- a/src/allocation/permission.rs +++ b/src/allocation/permission.rs @@ -1,7 +1,7 @@ -//! TURN [Allocation] [Permission]. +//! [Allocation] [permission] definitions. //! -//! [Allocation]: https://datatracker.ietf.org/doc/html/rfc5766#section-2.2 -//! [Permission]: https://datatracker.ietf.org/doc/html/rfc5766#section-8 +//! [Allocation]: https://tools.ietf.org/html/rfc5766#section-2.2 +//! [permission]: https://tools.ietf.org/html/rfc5766#section-8 use std::{collections::HashMap, net::IpAddr, sync::Arc}; @@ -10,40 +10,38 @@ use tokio::{ time::{sleep, Duration, Instant}, }; -/// The Permission Lifetime MUST be 300 seconds (= 5 minutes)[1]. +/// [Lifetime][1] of a [`Permission`]. /// -/// [1]: https://datatracker.ietf.org/doc/html/rfc5766#section-8 +/// > The Permission Lifetime MUST be 300 seconds (= 5 minutes). +/// +/// [1]: https://tools.ietf.org/html/rfc5766#section-8 pub(crate) const PERMISSION_LIFETIME: Duration = Duration::from_secs(5 * 60); -/// TURN [Allocation] [Permission]. +/// Representation of an [allocation] [permission]. /// -/// [Allocation]: https://datatracker.ietf.org/doc/html/rfc5766#section-2.2 -/// [Permission]: https://datatracker.ietf.org/doc/html/rfc5766#section-8 +/// [allocation]: https://tools.ietf.org/html/rfc5766#section-2.2 +/// [permission]: https://tools.ietf.org/html/rfc5766#section-8 +#[derive(Debug)] pub(crate) struct Permission { - /// [`IpAddr`] of this permission that is matched with the source IP + /// [`IpAddr`] of this [`Permission`] that is matched with the source IP /// address of the datagram received. ip: IpAddr, - /// Channel to the inner lifetime watching loop. - reset_tx: Option>, + /// [`mpsc::Sender`] to the inner lifetime watching loop. + reset_tx: mpsc::Sender, } impl Permission { - /// Creates a new [`Permission`]. - pub(crate) const fn new(ip: IpAddr) -> Self { - Self { ip, reset_tx: None } - } - - /// Starts [`Permission`]'s internal lifetime watching loop. - pub(crate) fn start( - &mut self, + /// Creates a new [`Permission`] and [`spawn`]s a loop watching its + /// lifetime. + /// + /// [`spawn`]: tokio::spawn() + pub(crate) fn new( + ip: IpAddr, permissions: Arc>>, lifetime: Duration, - ) { + ) -> Self { let (reset_tx, mut reset_rx) = mpsc::channel(1); - self.reset_tx = Some(reset_tx); - - let ip = self.ip; drop(tokio::spawn(async move { let timer = sleep(lifetime); @@ -65,17 +63,17 @@ impl Permission { } } })); + + Self { ip, reset_tx } } - /// Returns [`IpAddr`] of this [`Permission`]. + /// Returns the [`IpAddr`] of this [`Permission`]. pub(crate) const fn ip(&self) -> IpAddr { self.ip } - /// Updates [`Permission`]'s lifetime. + /// Updates the `lifetime` of this [`Permission`]. pub(crate) async fn refresh(&self, lifetime: Duration) { - if let Some(tx) = &self.reset_tx { - _ = tx.send(lifetime).await; - } + _ = self.reset_tx.send(lifetime).await; } } diff --git a/src/attr.rs b/src/attr.rs index 9c2b4c66a..20795f13f 100644 --- a/src/attr.rs +++ b/src/attr.rs @@ -1,4 +1,8 @@ -//! STUN and TURN attributes used by the server. +//! [STUN] and [TURN] attributes used by a [`Server`]. +//! +//! [`Server`]: crate::Server +//! [STUN]: https://en.wikipedia.org/wiki/STUN +//! [TURN]: https://en.wikipedia.org/wiki/TURN use stun_codec::define_attribute_enums; diff --git a/src/chandata.rs b/src/chandata.rs index f5ddf6aab..a976d4d30 100644 --- a/src/chandata.rs +++ b/src/chandata.rs @@ -1,6 +1,10 @@ -//! [`ChannelData`] message implementation. +//! [TURN ChannelData Message][1] implementation. +//! +//! [1]: https://tools.ietf.org/html/rfc5766#section-11.4 -use crate::{attr::ChannelNumber, Error}; +use derive_more::{Display, Error}; + +use crate::attr::ChannelNumber; /// [`ChannelData`] message MUST be padded to a multiple of four bytes in order /// to ensure the alignment of subsequent messages. @@ -8,74 +12,82 @@ const PADDING: usize = 4; /// [Channel Number] field size. /// -/// [Channel Number]: https://datatracker.ietf.org/doc/html/rfc5766#section-11.4 -const CHANNEL_DATA_NUMBER_SIZE: usize = 2; +/// [Channel Number]: https://tools.ietf.org/html/rfc5766#section-11.4 +const NUMBER_SIZE: usize = 2; /// [Length] field size. /// -/// [Length]: https://datatracker.ietf.org/doc/html/rfc5766#section-11.4 -const CHANNEL_DATA_LENGTH_SIZE: usize = 2; +/// [Length]: https://tools.ietf.org/html/rfc5766#section-11.4 +const LENGTH_SIZE: usize = 2; -/// [ChannelData] message header size. +/// [ChannelData Message][1] header size. /// -/// [ChannelData]: https://datatracker.ietf.org/doc/html/rfc5766#section-11.4 -const CHANNEL_DATA_HEADER_SIZE: usize = - CHANNEL_DATA_LENGTH_SIZE + CHANNEL_DATA_NUMBER_SIZE; +/// [1]: https://tools.ietf.org/html/rfc5766#section-11.4 +const HEADER_SIZE: usize = LENGTH_SIZE + NUMBER_SIZE; -/// [`ChannelData`] represents the `ChannelData` Message defined in -/// [RFC 5766](https://www.rfc-editor.org/rfc/rfc5766#section-11.4). +/// Representation of [TURN ChannelData Message][1] defined in [RFC 5766]. +/// +/// [1]: https://tools.ietf.org/html/rfc5766#section-11.4 +/// [RFC 5766]: https://tools.ietf.org/html/rfc5766 #[derive(Debug)] -pub(crate) struct ChannelData { - /// Parsed [`ChannelData`] [Channel Number][1]. +pub struct ChannelData { + /// Parsed [Channel Number][1]. /// - /// [1]: https://datatracker.ietf.org/doc/html/rfc5766#section-11.4 + /// [1]: https://tools.ietf.org/html/rfc5766#section-11.4 number: u16, - /// Parsed [`ChannelData`] payload. + /// Parsed payload. data: Vec, } impl ChannelData { - /// Returns `true` if `buf` looks like the `ChannelData` Message. - #[allow(clippy::missing_asserts_for_indexing)] // Length is checked - pub(crate) fn is_channel_data(buf: &[u8]) -> bool { - if buf.len() < CHANNEL_DATA_HEADER_SIZE { + /// Checks whether the provided `data` represents a [`ChannelData`] message. + pub(crate) fn is_channel_data(data: &[u8]) -> bool { + // PANIC: Indexing is OK here, since the length is checked with the + // first `if` expression. + #![allow(clippy::missing_asserts_for_indexing)] // false positive + + if data.len() < HEADER_SIZE { return false; } let len = usize::from(u16::from_be_bytes([ - buf[CHANNEL_DATA_NUMBER_SIZE], - buf[CHANNEL_DATA_NUMBER_SIZE + 1], + data[NUMBER_SIZE], + data[NUMBER_SIZE + 1], ])); - if len > buf[CHANNEL_DATA_HEADER_SIZE..].len() { + if len > data[HEADER_SIZE..].len() { return false; } - ChannelNumber::new(u16::from_be_bytes([buf[0], buf[1]])).is_ok() + ChannelNumber::new(u16::from_be_bytes([data[0], data[1]])).is_ok() } - /// Decodes the given raw message as [`ChannelData`]. - pub(crate) fn decode(mut raw: Vec) -> Result { - if raw.len() < CHANNEL_DATA_HEADER_SIZE { - return Err(Error::UnexpectedEof); + /// Decodes the provided `raw` message as a [`ChannelData`] message. + /// + /// # Errors + /// + /// See the [`FormatError`] for details. + pub(crate) fn decode(mut raw: Vec) -> Result { + if raw.len() < HEADER_SIZE { + return Err(FormatError::BadChannelDataLength); } let number = u16::from_be_bytes([raw[0], raw[1]]); if ChannelNumber::new(number).is_err() { - return Err(Error::InvalidChannelNumber); + return Err(FormatError::InvalidChannelNumber); } let l = usize::from(u16::from_be_bytes([ - raw[CHANNEL_DATA_NUMBER_SIZE], - raw[CHANNEL_DATA_NUMBER_SIZE + 1], + raw[NUMBER_SIZE], + raw[NUMBER_SIZE + 1], ])); - if l > raw[CHANNEL_DATA_HEADER_SIZE..].len() { - return Err(Error::BadChannelDataLength); + if l > raw[HEADER_SIZE..].len() { + return Err(FormatError::BadChannelDataLength); } // Discard header and padding. - drop(raw.drain(0..CHANNEL_DATA_HEADER_SIZE)); + drop(raw.drain(0..HEADER_SIZE)); if l != raw.len() { raw.truncate(l); } @@ -83,45 +95,44 @@ impl ChannelData { Ok(Self { data: raw, number }) } - /// Returns [`ChannelData`] [Channel Number][1]. + /// Returns payload of this [`ChannelData`] message. + pub(crate) fn data(self) -> Vec { + self.data + } + + /// Returns [Channel Number][1] of this [`ChannelData`] message. /// - /// [1]: https://datatracker.ietf.org/doc/html/rfc5766#section-11.4 + /// [1]: https://tools.ietf.org/html/rfc5766#section-11.4 pub(crate) const fn num(&self) -> u16 { self.number } - /// Encodes the provided [`ChannelData`] payload and channel number to - /// bytes. + /// Encodes the provided `payload` and [Channel Number][1] as + /// [`ChannelData`] message bytes. + /// + /// [1]: https://tools.ietf.org/html/rfc5766#section-11.4 pub(crate) fn encode( - mut data: Vec, + payload: &[u8], chan_num: u16, - ) -> Result, Error> { - #[allow(clippy::map_err_ignore)] - let len = u16::try_from(data.len()) - .map_err(|_| Error::BadChannelDataLength)?; - for i in len.to_be_bytes().into_iter().rev() { - data.insert(0, i); - } - for i in chan_num.to_be_bytes().into_iter().rev() { - data.insert(0, i); - } + ) -> Result, FormatError> { + let length = HEADER_SIZE + payload.len(); + let padded_length = nearest_padded_value_length(length); - let padded = nearest_padded_value_length(data.len()); - let bytes_to_add = padded - data.len(); - if bytes_to_add > 0 { - data.extend_from_slice(&vec![0; bytes_to_add]); - } + #[allow(clippy::map_err_ignore)] // intentional + let len = u16::try_from(payload.len()) + .map_err(|_| FormatError::BadChannelDataLength)?; - Ok(data) - } + let mut encoded = vec![0u8; padded_length]; - /// Returns [`ChannelData`] payload. - pub(crate) fn data(self) -> Vec { - self.data + encoded[..NUMBER_SIZE].copy_from_slice(&chan_num.to_be_bytes()); + encoded[NUMBER_SIZE..HEADER_SIZE].copy_from_slice(&len.to_be_bytes()); + encoded[HEADER_SIZE..length].copy_from_slice(payload); + + Ok(encoded) } } -/// Calculates nearest padded length for the [`ChannelData`]. +/// Calculates a nearest padded length for a [`ChannelData`] message. pub(crate) const fn nearest_padded_value_length(l: usize) -> usize { let mut n = PADDING * (l / PADDING); if n < l { @@ -130,29 +141,43 @@ pub(crate) const fn nearest_padded_value_length(l: usize) -> usize { n } +/// Possible errors of a [`ChannelData`] message format. +#[derive(Clone, Copy, Debug, Display, Error, Eq, PartialEq)] +pub enum FormatError { + /// [Channel Number][1] is incorrect. + /// + /// [1]: https://tools.ietf.org/html/rfc5766#section-11.4 + #[display("Channel Number not in [0x4000, 0x7FFF]")] + InvalidChannelNumber, + + /// Incorrect message length. + #[display("Invalid `ChannelData` length")] + BadChannelDataLength, +} + #[cfg(test)] -mod chandata_test { - use super::*; +mod spec { + use crate::attr::ChannelNumber; + + use super::{ChannelData, FormatError}; #[test] - fn test_channel_data_encode() { + fn encodes() { let encoded = - ChannelData::encode(vec![1, 2, 3, 4], ChannelNumber::MIN + 1) - .unwrap(); + ChannelData::encode(&[1, 2, 3, 4], ChannelNumber::MIN + 1).unwrap(); let decoded = ChannelData::decode(encoded.clone()).unwrap(); assert!( ChannelData::is_channel_data(&encoded), - "unexpected IsChannelData" + "wrong `is_channel_data`", ); - - assert_eq!(vec![1, 2, 3, 4], decoded.data, "not equal"); - assert_eq!(ChannelNumber::MIN + 1, decoded.number, "not equal"); + assert_eq!(vec![1, 2, 3, 4], decoded.data, "wrong decoded data"); + assert_eq!(ChannelNumber::MIN + 1, decoded.number, "wrong number"); } #[test] - fn test_channel_data_equal() { - let tests = vec![ + fn encoded_equality() { + let tests = [ ( "equal", ChannelData { number: ChannelNumber::MIN, data: vec![1, 2, 3] }, @@ -186,55 +211,52 @@ mod chandata_test { ]; for (name, a, b, r) in tests { - let v = ChannelData::encode(a.data.clone(), a.number) - == ChannelData::encode(b.data.clone(), b.number); - assert_eq!(v, r, "unexpected: ({name}) {r} != {r}"); + let v = ChannelData::encode(&a.data, a.number) + == ChannelData::encode(&b.data, b.number); + + assert_eq!(v, r, "wrong equality of {name}"); } } #[test] - fn test_channel_data_decode() { - let tests = vec![ - ("small", vec![1, 2, 3], Error::UnexpectedEof), + fn fails_decoding_correctly() { + let tests = [ + ("small", vec![1, 2, 3], FormatError::BadChannelDataLength), ( "zeroes", vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - Error::InvalidChannelNumber, + FormatError::InvalidChannelNumber, ), ( "bad chan number", vec![63, 255, 0, 0, 0, 4, 0, 0, 1, 2, 3, 4], - Error::InvalidChannelNumber, + FormatError::InvalidChannelNumber, ), ( "bad length", vec![0x40, 0x40, 0x02, 0x23, 0x16, 0, 0, 0, 0, 0, 0, 0], - Error::BadChannelDataLength, + FormatError::BadChannelDataLength, ), ]; - for (name, buf, want_err) in tests { - if let Err(err) = ChannelData::decode(buf) { - assert_eq!( - want_err, err, - "unexpected: ({name}) {want_err} != {err}" - ); + if let Err(e) = ChannelData::decode(buf) { + assert_eq!(want_err, e, "wrong error of {name}"); } else { - panic!("expected error, but got ok"); + panic!("expected `Err`, but got `Ok` in {name}"); } } } #[test] - fn test_is_channel_data() { - let tests = vec![ + fn is_channel_data_detects_correctly() { + let tests = [ ("small", vec![1, 2, 3, 4], false), ("zeroes", vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], false), ]; - for (name, buf, r) in tests { let v = ChannelData::is_channel_data(&buf); - assert_eq!(v, r, "unexpected: ({name}) {r} != {v}"); + + assert_eq!(v, r, "wrong result in {name}"); } } @@ -261,11 +283,11 @@ mod chandata_test { ]; #[test] - fn test_chrome_channel_data() { + fn chrome_channel_data() { let mut data = vec![]; let mut messages = vec![]; - // Decoding hex data into binary. + // Decoding HEX data into binary. for h in &CHANDATA_TEST_HEX { let b = match hex::decode(h) { Ok(b) => b, @@ -274,21 +296,20 @@ mod chandata_test { data.push(b); } - // All hex streams decoded to raw binary format and stored in data + // All HEX streams decoded to raw binary format and stored in the `data` // slice. Decoding packets to messages. for packet in data { let m = ChannelData::decode(packet.clone()).unwrap(); - let encoded = - ChannelData::encode(m.data.clone(), m.number).unwrap(); + let encoded = ChannelData::encode(&m.data, m.number).unwrap(); let decoded = ChannelData::decode(encoded.clone()).unwrap(); - assert_eq!(m.data, decoded.data, "should be equal"); - assert_eq!(m.number, decoded.number, "should be equal"); + assert_eq!(m.data, decoded.data, "wrong payload"); + assert_eq!(m.number, decoded.number, "wrong number"); messages.push(m); } - assert_eq!(messages.len(), 2, "unexpected message slice list"); + assert_eq!(messages.len(), 2, "wrong number of messages"); } } diff --git a/src/con/mod.rs b/src/con/mod.rs deleted file mode 100644 index 08c5d69d8..000000000 --- a/src/con/mod.rs +++ /dev/null @@ -1,137 +0,0 @@ -//! Main STUN/TURN transport implementation. - -mod tcp; - -use std::io; - -use std::net::SocketAddr; - -use async_trait::async_trait; - -use tokio::{ - net, - net::{ToSocketAddrs, UdpSocket}, -}; - -use crate::{attr::PROTO_UDP, server::INBOUND_MTU, Error}; - -pub use tcp::TcpServer; - -/// Abstracting over transport implementation. -#[async_trait] -pub trait Conn { - async fn recv_from(&self) -> Result<(Vec, SocketAddr), Error>; - async fn send_to( - &self, - buf: Vec, - target: SocketAddr, - ) -> Result; - - /// Returns the local transport address. - fn local_addr(&self) -> SocketAddr; - - /// Return the transport protocol according to [IANA]. - /// - /// [IANA]: https://tinyurl.com/iana-protocol-numbers - fn proto(&self) -> u8; - - /// Closes the underlying transport. - async fn close(&self) -> Result<(), Error>; -} - -/// Performs a DNS resolution. -pub(crate) async fn lookup_host( - use_ipv4: bool, - host: T, -) -> Result -where - T: ToSocketAddrs, -{ - for remote_addr in net::lookup_host(host).await? { - if (use_ipv4 && remote_addr.is_ipv4()) - || (!use_ipv4 && remote_addr.is_ipv6()) - { - return Ok(remote_addr); - } - } - - Err(io::Error::new( - io::ErrorKind::Other, - format!( - "No available {} IP address found!", - if use_ipv4 { "ipv4" } else { "ipv6" }, - ), - ) - .into()) -} - -#[async_trait] -impl Conn for UdpSocket { - async fn recv_from(&self) -> Result<(Vec, SocketAddr), Error> { - let mut buf = vec![0u8; INBOUND_MTU]; - let (len, addr) = self.recv_from(&mut buf).await?; - buf.truncate(len); - - Ok((buf, addr)) - } - - async fn send_to( - &self, - data: Vec, - target: SocketAddr, - ) -> Result { - Ok(self.send_to(&data, target).await?) - } - - fn local_addr(&self) -> SocketAddr { - #[allow(clippy::unwrap_used)] - self.local_addr().unwrap() - } - - fn proto(&self) -> u8 { - PROTO_UDP - } - - async fn close(&self) -> Result<(), Error> { - Ok(()) - } -} - -#[cfg(test)] -mod conn_test { - use super::*; - - #[tokio::test] - async fn test_conn_lookup_host() { - let stun_serv_addr = "stun1.l.google.com:19302"; - - if let Ok(ipv4_addr) = lookup_host(true, stun_serv_addr).await { - assert!( - ipv4_addr.is_ipv4(), - "expected ipv4 but got ipv6: {ipv4_addr}" - ); - } - - if let Ok(ipv6_addr) = lookup_host(false, stun_serv_addr).await { - assert!( - ipv6_addr.is_ipv6(), - "expected ipv6 but got ipv4: {ipv6_addr}" - ); - } - } -} - -#[cfg(test)] -mod net_test { - use super::*; - - #[tokio::test] - async fn test_net_native_resolve_addr() { - let udp_addr = lookup_host(true, "localhost:1234").await.unwrap(); - assert_eq!(udp_addr.ip().to_string(), "127.0.0.1", "should match"); - assert_eq!(udp_addr.port(), 1234, "should match"); - - let result = lookup_host(false, "127.0.0.1:1234").await; - assert!(result.is_err(), "should not match"); - } -} diff --git a/src/lib.rs b/src/lib.rs index b1d83732a..1453ff189 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,4 @@ -//! A pure Rust implementation of TURN. - +#![doc = include_str!("../README.md")] #![deny( macro_use_extern_crate, nonstandard_style, @@ -149,25 +148,26 @@ mod allocation; mod attr; mod chandata; -mod con; -mod relay; +pub mod relay; mod server; +pub mod transport; -use std::{io, net::SocketAddr}; +use std::{net::SocketAddr, sync::Arc}; -use thiserror::Error; +use derive_more::{Display, Error as StdError, From}; +#[cfg(test)] +pub(crate) use self::allocation::Allocation; +pub(crate) use self::transport::Transport; pub use self::{ - allocation::{AllocInfo, FiveTuple}, - con::TcpServer, - relay::RelayAllocator, - server::{Config, ConnConfig, Server}, + allocation::{FiveTuple, Info as AllocationInfo}, + server::{Config as ServerConfig, Server}, }; -/// External authentication handler. +/// Authentication handler. pub trait AuthHandler { - /// Perform authentication of the given user data returning ICE password - /// on success. + /// Performs authentication of the specified user, returning its ICE + /// password on success. /// /// # Errors /// @@ -180,154 +180,134 @@ pub trait AuthHandler { ) -> Result, Error>; } -/// TURN server errors. -#[derive(Debug, Error, PartialEq)] +impl AuthHandler for Arc { + fn auth_handle( + &self, + username: &str, + realm: &str, + src_addr: SocketAddr, + ) -> Result, Error> { + (**self).auth_handle(username, realm, src_addr) + } +} + +/// Possible errors of a [STUN]/[TURN] [`Server`]. +/// +/// [STUN]: https://en.wikipedia.org/wiki/STUN +/// [TURN]: https://en.wikipedia.org/wiki/TURN +#[derive(Debug, Display, Eq, From, PartialEq, StdError)] #[non_exhaustive] #[allow(variant_size_differences)] pub enum Error { - /// Failed to allocate new relay connection sine maximum retires count + /// Failed to allocate new relay connection, since maximum retires count /// exceeded. - #[error("turn: max retries exceeded")] + #[display("turn: max retries exceeded")] MaxRetriesExceeded, - /// Failed to handle channel data since channel number is incorrect. - #[error("channel number not in [0x4000, 0x7FFF]")] - InvalidChannelNumber, - - /// Failed to handle channel data cause of incorrect message length. - #[error("channelData length != len(Data)")] - BadChannelDataLength, - - /// Failed to handle message since it's shorter than expected. - #[error("unexpected EOF")] - UnexpectedEof, - - /// A peer address is part of a different address family than that of the + /// {eer address is part of a different address family than that of the /// relayed transport address of the allocation. - #[error("error code 443: peer address family mismatch")] + #[display("error code 443: peer address family mismatch")] PeerAddressFamilyMismatch, /// Error when trying to perform action after closing server. - #[error("use of closed network connection")] + #[display("use of closed network connection")] Closed, - /// Channel binding request failed since channel number is currently bound + /// Channel binding request failed, since channel number is currently bound /// to a different transport address. - #[error("you cannot use the same channel number with different peer")] + #[display("cannot use the same channel number with different peer")] SameChannelDifferentPeer, - /// Channel binding request failed since the transport address is currently + /// Channel binding request failed, since the transport address is currently /// bound to a different channel number. - #[error("you cannot use the same peer number with different channel")] + #[display("cannot use the same peer number with different channel")] SamePeerDifferentChannel, /// Cannot create allocation with zero lifetime. - #[error("allocations must not be created with a lifetime of 0")] + #[display("allocations must not be created with a lifetime of 0")] LifetimeZero, /// Cannot create allocation for the same five-tuple. - #[error("allocation attempt created with duplicate FiveTuple")] + #[display("allocation attempt created with duplicate 5-TUPLE")] DupeFiveTuple, - /// The given nonce is wrong or already been used. - #[error("duplicated Nonce generated, discarding request")] - RequestReplay, - /// Authentication error. - #[error("no such user exists")] + #[display("no such user exists")] NoSuchUser, /// Unsupported request class. - #[error("unexpected class")] + #[display("unexpected class")] UnexpectedClass, - /// Allocate request failed since allocation already exists for the given - /// five-tuple. - #[error("relay already allocated for 5-TUPLE")] + /// Allocation request failed, since allocation already exists for the + /// provided [`FiveTuple`]. + #[display("relay already allocated for 5-TUPLE")] RelayAlreadyAllocatedForFiveTuple, - /// STUN message does not have a required attribute. - #[error("requested attribute not found")] + /// [STUN] message doesn't have a required attribute. + /// + /// [STUN]: https://en.wikipedia.org/wiki/STUN + #[display("requested attribute not found")] AttributeNotFound, - /// STUN message contains wrong message integrity. - #[error("message integrity mismatch")] + /// [STUN] message contains wrong [`MessageIntegrity`]. + /// + /// [`MessageIntegrity`]: attr::MessageIntegrity + /// [STUN]: https://en.wikipedia.org/wiki/STUN + #[display("message integrity mismatch")] IntegrityMismatch, /// [DONT-FRAGMENT][1] attribute is not supported. /// - /// [1]: https://datatracker.ietf.org/doc/html/rfc5766#section-14.8 - #[error("no support for DONT-FRAGMENT")] + /// [1]: https://tools.ietf.org/html/rfc5766#section-14.8 + #[display("no support for DONT-FRAGMENT")] NoDontFragmentSupport, - /// Allocate request cannot have both [RESERVATION-TOKEN][1] and - /// [EVEN-PORT]. + /// Allocation request cannot have both [RESERVATION-TOKEN][1] and + /// [EVEN-PORT][2]. /// - /// [1]: https://datatracker.ietf.org/doc/html/rfc5766#section-14.9 - /// [EVEN-PORT]: https://datatracker.ietf.org/doc/html/rfc5766#section-14.6 - #[error("Request must not contain RESERVATION-TOKEN and EVEN-PORT")] + /// [1]: https://tools.ietf.org/html/rfc5766#section-14.9 + /// [2]: https://tools.ietf.org/html/rfc5766#section-14.6 + #[display("Request must not contain RESERVATION-TOKEN and EVEN-PORT")] RequestWithReservationTokenAndEvenPort, /// Allocation request cannot contain both [RESERVATION-TOKEN][1] and /// [REQUESTED-ADDRESS-FAMILY][2]. /// - /// [1]: https://datatracker.ietf.org/doc/html/rfc5766#section-14.9 - /// [2]: https://www.rfc-editor.org/rfc/rfc6156#section-4.1.1 - #[error( + /// [1]: https://tools.ietf.org/html/rfc5766#section-14.9 + /// [2]: https://tools.ietf.org/html/rfc6156#section-4.1.1 + #[display( "Request must not contain RESERVATION-TOKEN \ and REQUESTED-ADDRESS-FAMILY" )] RequestWithReservationTokenAndReqAddressFamily, - /// No allocation for the given five-tuple. - #[error("no allocation found")] + /// No allocation for the provided [`FiveTuple`]. + #[display("no allocation found")] NoAllocationFound, /// The specified protocol is not supported. - #[error("allocation requested unsupported proto")] + #[display("allocation requested unsupported proto")] UnsupportedRelayProto, - /// Failed to handle send indication since there is no permission for the - /// given address. - #[error("unable to handle send-indication, no permission added")] + /// Failed to handle [Send Indication][1], since there is no permission for + /// the provided address. + /// + /// [1]: https://tools.ietf.org/html/rfc5766#section-10.2 + #[display("unable to handle send-indication, no permission added")] NoPermission, - /// Failed to handle channel data since ther is no binding for the given - /// channel. - #[error("no such channel bind")] + /// Failed to handle channel data, since there is no binding for the + /// provided channel. + #[display("no such channel bind")] NoSuchChannelBind, - /// Failed to decode message. - #[error("Failed to decode STUN/TURN message: {0:?}")] - Decode(bytecodec::ErrorKind), - /// Failed to encode message. - #[error("Failed to encode STUN/TURN message: {0:?}")] - Encode(bytecodec::ErrorKind), - - /// Tried to use dead transport. - #[error("Underlying TCP/UDP transport is dead")] - TransportIsDead, - - /// Error for transport. - #[error("{0}")] - Io(#[source] IoError), -} - -/// [`io::Error`] wrapper. -#[derive(Debug, Error)] -#[error("io error: {0}")] -pub struct IoError(#[from] pub io::Error); - -// Workaround for wanting PartialEq for io::Error. -impl PartialEq for IoError { - fn eq(&self, other: &Self) -> bool { - self.0.kind() == other.0.kind() - } -} + #[display("Failed to encode STUN/TURN message: {_0:?}")] + #[from(ignore)] + Encode(#[error(not(source))] bytecodec::ErrorKind), -impl From for Error { - fn from(e: io::Error) -> Self { - Self::Io(IoError(e)) - } + /// Failed to send message. + #[display("Transport error: {_0}")] + Transport(transport::Error), } diff --git a/src/relay.rs b/src/relay.rs index 79d60502b..7e1d09bf9 100644 --- a/src/relay.rs +++ b/src/relay.rs @@ -1,7 +1,4 @@ -//! [`RelayAllocator`] is used to create relay transports wit the given -//! configuration. - -#![allow(clippy::module_name_repetitions)] +//! Relay definitions. use std::{ net::{IpAddr, SocketAddr}, @@ -9,39 +6,39 @@ use std::{ }; use tokio::net::UdpSocket; -use crate::{con, Error}; +use crate::{transport, Error}; -/// [`RelayAllocator`] is used to generate a Relay Address when creating an -/// allocation. -#[derive(Debug)] -pub struct RelayAllocator { - /// `relay_address` is the IP returned to the user when the relay is - /// created. +/// Generator of relay addresses when creating an [allocation]. +/// +/// [allocation]: https://tools.ietf.org/html/rfc5766#section-5 +#[derive(Clone, Debug)] +pub struct Allocator { + /// [`IpAddr`] returned to the user when a relay is created. pub relay_address: IpAddr, - /// `min_port` the minimum port to allocate. + /// Minimum (inclusive) port to allocate. pub min_port: u16, - /// `max_port` the maximum (inclusive) port to allocate. + /// Maximum (inclusive) port to allocate. pub max_port: u16, - /// `max_retries` the amount of tries to allocate a random port in the - /// defined range. + /// Amount of tries to allocate a random port in the allowed range. pub max_retries: u16, - /// `address` is passed to Listen/ListenPacket when creating the Relay. + /// Address passed when creating a relay. pub address: String, } -impl RelayAllocator { +impl Allocator { /// Allocates a new relay connection. /// /// # Errors /// - /// With [`Error::MaxRetriesExceeded`] if the requested port is `0` and - /// failed to find a free port in the specified maximum retries. + /// - With an [`Error::MaxRetriesExceeded`] if the requested port is `0` and + /// failed to find a free port in the specified [`max_retries`]. + /// - With an [`Error::Transport`] if failed to bind to the specified port. /// - /// With [`Error::Io`] if failed to bind to the specified port. + /// [`max_retries`]: Allocator::max_retries pub async fn allocate_conn( &self, use_ipv4: bool, @@ -55,29 +52,33 @@ impl RelayAllocator { let port = self.min_port + rand::random::() % (self.max_port - self.min_port + 1); - let addr = con::lookup_host( + let addr = transport::lookup_host( use_ipv4, - &format!("{}:{}", self.address, port), + &format!("{}:{port}", self.address), ) .await?; let Ok(conn) = UdpSocket::bind(addr).await else { continue; }; - let mut relay_addr = conn.local_addr()?; + let mut relay_addr = + conn.local_addr().map_err(transport::Error::from)?; relay_addr.set_ip(self.relay_address); return Ok((Arc::new(conn), relay_addr)); } Err(Error::MaxRetriesExceeded) } else { - let addr = con::lookup_host( + let addr = transport::lookup_host( use_ipv4, - &format!("{}:{}", self.address, requested_port), + &format!("{}:{requested_port}", self.address), ) .await?; - let conn = Arc::new(UdpSocket::bind(addr).await?); - let mut relay_addr = conn.local_addr()?; + let conn = Arc::new( + UdpSocket::bind(addr).await.map_err(transport::Error::from)?, + ); + let mut relay_addr = + conn.local_addr().map_err(transport::Error::from)?; relay_addr.set_ip(self.relay_address); Ok((conn, relay_addr)) diff --git a/src/server/config.rs b/src/server/config.rs deleted file mode 100644 index 57c991803..000000000 --- a/src/server/config.rs +++ /dev/null @@ -1,82 +0,0 @@ -//! TURN server configuration. - -#![allow(clippy::module_name_repetitions)] - -use std::{fmt, sync::Arc}; - -use tokio::{sync::mpsc, time::Duration}; - -use crate::{ - allocation::AllocInfo, con::Conn, relay::RelayAllocator, AuthHandler, -}; - -/// Main STUN/TURN socket configuration. -pub struct ConnConfig { - /// STUN socket. - pub conn: Arc, - - /// Relay connections allocator. - pub relay_addr_generator: RelayAllocator, -} - -impl ConnConfig { - /// Creates a new [`ConnConfig`]. - /// - /// # Panics - /// - /// If the configured min port or max port is `0`. - /// If the configured min port is greater than max port. - /// If the configured address is an empty string. - pub fn new(conn: Arc, gen: RelayAllocator) -> Self { - assert!(gen.min_port > 0, "min_port must be greater than 0"); - assert!(gen.max_port > 0, "max_port must be greater than 0"); - assert!( - gen.min_port > gen.max_port, - "max_port must be greater than min_port" - ); - assert!(gen.address.is_empty(), "address must not be an empty string"); - - Self { conn, relay_addr_generator: gen } - } -} - -impl fmt::Debug for ConnConfig { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("ConnConfig") - .field("relay_addr_generator", &self.relay_addr_generator) - .field("conn", &self.conn.local_addr()) - .finish() - } -} - -/// [`Config`] configures the TURN Server. -pub struct Config { - /// `conn_configs` are a list of all the turn listeners. - /// Each listener can have custom behavior around the creation of Relays. - pub conn_configs: Vec, - - /// `realm` sets the realm for this server - pub realm: String, - - /// `auth_handler` is a callback used to handle incoming auth requests, - /// allowing users to customize Pion TURN with custom behavior. - pub auth_handler: Arc, - - /// Sets the lifetime of channel binding. - pub channel_bind_lifetime: Duration, - - /// To receive notify on allocation close event, with metrics data. - pub alloc_close_notify: Option>, -} - -impl fmt::Debug for Config { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("Config") - .field("conn_configs", &self.conn_configs) - .field("realm", &self.realm) - .field("channel_bind_lifetime", &self.channel_bind_lifetime) - .field("alloc_close_notify", &self.alloc_close_notify) - .field("auth_handler", &"dyn AuthHandler") - .finish() - } -} diff --git a/src/server/mod.rs b/src/server/mod.rs index 43f06116f..3a2e64617 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -1,125 +1,227 @@ -//! TURN server implementation. +//! [STUN]/[TURN] server implementation. +//! +//! [STUN]: https://en.wikipedia.org/wiki/STUN +//! [TURN]: https://en.wikipedia.org/wiki/TURN -mod config; mod request; -use std::{collections::HashMap, fmt, sync::Arc}; +use std::{collections::HashMap, sync::Arc}; +use derive_more::Debug; use tokio::{ sync::{ broadcast::{ error::RecvError, {self}, }, - mpsc, oneshot, Mutex, + mpsc, oneshot, }, - time::{Duration, Instant}, + time::Duration, }; +#[cfg(doc)] +use crate::allocation::Allocation; use crate::{ - allocation::{AllocInfo, FiveTuple, Manager, ManagerConfig}, - con::Conn, + allocation::{FiveTuple, Info, Manager, ManagerConfig}, + relay, + transport::Transport, AuthHandler, Error, }; -pub use self::config::{Config, ConnConfig}; - -/// `DEFAULT_LIFETIME` in RFC 5766 is 10 minutes. +/// Default lifetime of an [allocation][1] (10 minutes) as defined in +/// [RFC 5766 Section 2.2][1]. /// -/// [RFC 5766 Section 2.2](https://www.rfc-editor.org/rfc/rfc5766#section-2.2) +/// [1]: https://tools.ietf.org/html/rfc5766#section-2.2 pub(crate) const DEFAULT_LIFETIME: Duration = Duration::from_secs(10 * 60); -/// MTU used for UDP connections. +/// [MTU] of UDP connections. +/// +/// [MTU]: https://en.wikipedia.org/wiki/Maximum_transmission_unit pub(crate) const INBOUND_MTU: usize = 1500; -/// Server is an instance of the TURN Server -pub struct Server { - /// [`AuthHandler`] used to authenticate certain types of requests. - auth_handler: Arc, +/// Configuration of a [`Server`]. +#[derive(Debug)] +pub struct Config { + /// List of all [STUN]/[TURN] connections listeners. + /// + /// Each listener may have a custom behavior around the creation of + /// [`relay`]s. + /// + /// [STUN]: https://en.wikipedia.org/wiki/STUN + /// [TURN]: https://en.wikipedia.org/wiki/TURN + #[debug("{:?}", connections.iter() + .map(|c| (c.local_addr(), c.proto())) + .collect::>())] + pub connections: Vec>, + + /// [`Allocator`] of [`relay`] connections. + /// + /// [`Allocator`]: relay::Allocator + pub relay_addr_generator: relay::Allocator, + + /// [Realm][1] of the [`Server`]. + /// + /// > A string used to describe the server or a context within the server. + /// > The realm tells the client which username and password combination to + /// > use to authenticate requests. + /// + /// [1]: https://tools.ietf.org/html/rfc5766#section-3 + pub realm: String, - /// A string used to describe the server or a context within the server. - realm: String, + /// Callback for handling incoming authentication requests, allowing users + /// to customize it with custom behavior. + pub auth_handler: Arc, - /// [Channel binding][1] lifetime. + /// Lifetime of a [channel bindings][1]. /// - /// [1]: https://datatracker.ietf.org/doc/html/rfc5766#section-11 - channel_bind_lifetime: Duration, + /// [1]: https://tools.ietf.org/html/rfc5766#section-2.5 + pub channel_bind_lifetime: Duration, - /// Nonces generated by server. - pub(crate) nonces: Arc>>, + /// [`mpsc::Sender`] receiving notify on [allocation][1] close event, along + /// with metrics data. + /// + /// [1]: https://tools.ietf.org/html/rfc5766#section-2.2 + pub alloc_close_notify: Option>, +} - /// Channel to [`Server`]'s internal loop. - command_tx: Mutex>>, +/// Instance of a [STUN]/[TURN] server. +/// +/// [STUN]: https://en.wikipedia.org/wiki/STUN +/// [TURN]: https://en.wikipedia.org/wiki/TURN +#[derive(Debug)] +pub struct Server { + /// [`broadcast::Sender`] to this [`Server`]'s internal loop. + command_tx: broadcast::Sender, } impl Server { - /// creates a new TURN server + /// Creates a new [`Server`] according to the provided [`Config`], and + /// [`spawn`]s its internal loop. + /// + /// [`spawn`]: tokio::spawn() #[must_use] - pub fn new(config: Config) -> Self { + pub fn new(config: Config) -> Self + where + A: AuthHandler + Send + Sync + 'static, + { let (command_tx, _) = broadcast::channel(16); - let mut this = Self { - auth_handler: config.auth_handler, - realm: config.realm, - channel_bind_lifetime: config.channel_bind_lifetime, - nonces: Arc::new(Mutex::new(HashMap::new())), - command_tx: Mutex::new(Some(command_tx.clone())), - }; - if this.channel_bind_lifetime == Duration::from_secs(0) { - this.channel_bind_lifetime = DEFAULT_LIFETIME; - } - for p in config.conn_configs { - let nonces = Arc::clone(&this.nonces); - let auth_handler = Arc::clone(&this.auth_handler); - let realm = this.realm.clone(); - let channel_bind_lifetime = this.channel_bind_lifetime; - let handle_rx = command_tx.subscribe(); - let conn = p.conn; - let allocation_manager = Arc::new(Manager::new(ManagerConfig { - relay_addr_generator: p.relay_addr_generator, + let this = Self { command_tx: command_tx.clone() }; + let channel_bind_lifetime = + if config.channel_bind_lifetime == Duration::from_secs(0) { + DEFAULT_LIFETIME + } else { + config.channel_bind_lifetime + }; + + for conn in config.connections { + let auth_handler = Arc::clone(&config.auth_handler); + let realm = config.realm.clone(); + let mut nonces = HashMap::new(); + let mut handle_rx = command_tx.subscribe(); + let mut allocation_manager = Manager::new(ManagerConfig { + relay_addr_generator: config.relay_addr_generator.clone(), alloc_close_notify: config.alloc_close_notify.clone(), - })); + }); - Self::spawn_read_loop( - conn, - allocation_manager, - nonces, - auth_handler, - realm, - channel_bind_lifetime, - handle_rx, - ); + let (mut close_tx, mut close_rx) = oneshot::channel::<()>(); + drop(tokio::spawn(async move { + let local_con_addr = conn.local_addr(); + let protocol = conn.proto(); + + loop { + let (msg, src_addr) = tokio::select! { + cmd = handle_rx.recv() => { + match cmd { + Ok(Command::DeleteAllocations( + name, + completion, + )) => { + allocation_manager + .delete_allocations_by_username( + &name, + ); + drop(completion); + } + Ok(Command::GetAllocationsInfo( + five_tuples, + tx, + )) => { + let infos = allocation_manager + .get_allocations_info(&five_tuples); + drop(tx.send(infos).await); + } + Err(RecvError::Closed) => { + close_rx.close(); + break; + } + Err(RecvError::Lagged(n)) => { + log::warn!( + "`Server` has lagged by {n} messages", + ); + } + } + continue; + }, + v = conn.recv_from() => { + match v { + Ok(v) => v, + Err(e) => { + log::debug!("Exit read loop on error: {e}"); + break; + } + } + }, + () = close_tx.closed() => break + }; + + let handle = request::handle( + msg, + &conn, + FiveTuple { + src_addr, + dst_addr: local_con_addr, + protocol, + }, + &realm, + channel_bind_lifetime, + &mut allocation_manager, + &mut nonces, + &auth_handler, + ); + if let Err(e) = handle.await { + log::warn!("Error when handling `Request`: {e}"); + } + } + })); } this } - /// Deletes all existing allocations by the provided `username`. + /// Deletes all existing [allocations][1] with the provided `username`. /// /// # Errors /// - /// With [`Error::Closed`] if the [`Server`] was closed already. + /// With an [`Error::Closed`] if the [`Server`] was closed already. + /// + /// [1]: https://tools.ietf.org/html/rfc5766#section-2.2 pub async fn delete_allocations_by_username( &self, username: String, ) -> Result<(), Error> { - let tx = self.command_tx.lock().await.clone(); - - #[allow(clippy::map_err_ignore)] - if let Some(tx) = tx { - let (closed_tx, closed_rx) = mpsc::channel(1); - _ = tx - .send(Command::DeleteAllocations(username, Arc::new(closed_rx))) - .map_err(|_| Error::Closed)?; + let (closed_tx, closed_rx) = mpsc::channel(1); + #[allow(clippy::map_err_ignore)] // intentional + let _: usize = self + .command_tx + .send(Command::DeleteAllocations(username, Arc::new(closed_rx))) + .map_err(|_| Error::Closed)?; - closed_tx.closed().await; + closed_tx.closed().await; - Ok(()) - } else { - Err(Error::Closed) - } + Ok(()) } - /// Returns [`AllocInfo`]s by specified [`FiveTuple`]s. + /// Returns [`Info`]s for the provided [`FiveTuple`]s. /// /// If `five_tuples` is: /// - [`None`]: It returns information about the all @@ -130,177 +232,43 @@ impl Server { /// /// # Errors /// - /// With [`Error::Closed`] if the [`Server`] was closed already. + /// With an [`Error::Closed`] if the [`Server`] was closed already. pub async fn get_allocations_info( &self, five_tuples: Option>, - ) -> Result, Error> { + ) -> Result, Error> { if let Some(five_tuples) = &five_tuples { if five_tuples.is_empty() { return Ok(HashMap::new()); } } - let tx = self.command_tx.lock().await.clone(); - #[allow(clippy::map_err_ignore)] - if let Some(tx) = tx { - let (infos_tx, mut infos_rx) = mpsc::channel(1); - - _ = tx - .send(Command::GetAllocationsInfo(five_tuples, infos_tx)) - .map_err(|_| Error::Closed)?; - - let mut info: HashMap = HashMap::new(); - - for _ in 0..tx.receiver_count() { - info.extend(infos_rx.recv().await.ok_or(Error::Closed)?); - } - - Ok(info) - } else { - Err(Error::Closed) - } - } - - /// Spawns a message handler task for the given [`Conn`]. - fn spawn_read_loop( - conn: Arc, - allocation_manager: Arc, - nonces: Arc>>, - auth_handler: Arc, - realm: String, - channel_bind_lifetime: Duration, - mut handle_rx: broadcast::Receiver, - ) { - let (mut close_tx, mut close_rx) = oneshot::channel::<()>(); - - drop(tokio::spawn({ - let allocation_manager = Arc::clone(&allocation_manager); - - async move { - loop { - match handle_rx.recv().await { - Ok(Command::DeleteAllocations(name, completion)) => { - allocation_manager - .delete_allocations_by_username(name.as_str()) - .await; - drop(completion); - continue; - } - Ok(Command::GetAllocationsInfo(five_tuples, tx)) => { - let infos = allocation_manager - .get_allocations_info(&five_tuples); - drop(tx.send(infos).await); - - continue; - } - Err(RecvError::Closed) => { - close_rx.close(); - break; - } - Ok(Command::Close(completion)) => { - close_rx.close(); - drop(completion); - break; - } - Err(RecvError::Lagged(n)) => { - log::warn!( - "Turn server has lagged by {} messages", - n - ); - continue; - } - } - } - } - })); - - drop(tokio::spawn(async move { - let local_con_addr = conn.local_addr(); - let protocol = conn.proto(); + let (infos_tx, mut infos_rx) = mpsc::channel(1); - loop { - let (msg, src_addr) = tokio::select! { - v = conn.recv_from() => { - match v { - Ok(v) => v, - Err(err) => { - log::debug!("exit read loop on error: {}", err); - break; - } - } - }, - () = close_tx.closed() => break - }; + #[allow(clippy::map_err_ignore)] // intentional + let _: usize = self + .command_tx + .send(Command::GetAllocationsInfo(five_tuples, infos_tx)) + .map_err(|_| Error::Closed)?; - let handle = request::handle_message( - msg, - &conn, - FiveTuple { src_addr, dst_addr: local_con_addr, protocol }, - realm.as_str(), - channel_bind_lifetime, - &allocation_manager, - &nonces, - &auth_handler, - ); - - if let Err(err) = handle.await { - log::error!("error when handling datagram: {}", err); - } - } - - drop(allocation_manager.close().await); - drop(conn.close().await); - })); - } - - /// Close stops the TURN Server. It cleans up any associated state and - /// closes all connections it is managing. - pub async fn close(&self) { - let tx = self.command_tx.lock().await.take(); - - if let Some(tx) = tx { - if tx.receiver_count() == 0 { - return; - } - - let (closed_tx, closed_rx) = mpsc::channel(1); - drop(tx.send(Command::Close(Arc::new(closed_rx)))); - closed_tx.closed().await; + let mut info: HashMap = HashMap::new(); + for _ in 0..self.command_tx.receiver_count() { + info.extend(infos_rx.recv().await.ok_or(Error::Closed)?); } + Ok(info) } } -impl fmt::Debug for Server { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("Server") - .field("realm", &self.realm) - .field("channel_bind_lifetime", &self.channel_bind_lifetime) - .field("nonces", &self.nonces) - .field("command_tx", &self.command_tx) - .field("auth_handler", &"dyn AuthHandler") - .finish() - } -} - -/// The protocol to communicate between the [`Server`]'s public methods -/// and the tasks spawned in the [`Server::spawn_read_loop`] method. +/// Commands for communication between [`Server`]'s public methods and the tasks +/// spawned in its inner loop. #[derive(Clone)] enum Command { - /// Command to delete [`Allocation`][`Allocation`] by provided `username`. - /// - /// [`Allocation`]: `crate::allocation::Allocation` + /// Delete [`Allocation`] by the provided `username`. DeleteAllocations(String, Arc>), - /// Command to get information of [`Allocation`][`Allocation`]s by provided - /// [`FiveTuple`]s. - /// - /// [`Allocation`]: `crate::allocation::Allocation` + /// Return information about [`Allocation`] for the provided [`FiveTuple`]s. GetAllocationsInfo( Option>, - mpsc::Sender>, + mpsc::Sender>, ), - - /// Command to close the [`Server`]. - Close(Arc>), } diff --git a/src/server/request.rs b/src/server/request.rs index b8d4e2168..de30fcd63 100644 --- a/src/server/request.rs +++ b/src/server/request.rs @@ -1,10 +1,9 @@ -//! Ingress STUN/TURN messages handlers. +//! Ingress [`Request`] handling. -use bytecodec::{DecodeExt, EncodeExt}; +use bytecodec::EncodeExt; use std::{ collections::HashMap, marker::{Send, Sync}, - mem, net::SocketAddr, sync::Arc, }; @@ -23,14 +22,12 @@ use stun_codec::{ methods::{ALLOCATE, CHANNEL_BIND, CREATE_PERMISSION, REFRESH, SEND}, }, rfc8656::errors::{AddressFamilyNotSupported, PeerAddressFamilyMismatch}, - Attribute as _, Message, MessageClass, MessageDecoder, MessageEncoder, - TransactionId, -}; -use tokio::{ - sync::Mutex, - time::{Duration, Instant}, + Attribute as _, Message, MessageClass, MessageEncoder, TransactionId, }; +use tokio::time::{Duration, Instant}; +#[cfg(doc)] +use crate::allocation::Allocation; use crate::{ allocation::{FiveTuple, Manager}, attr::{ @@ -41,132 +38,146 @@ use crate::{ XorRelayAddress, PROTO_UDP, }, chandata::ChannelData, - con::Conn, server::DEFAULT_LIFETIME, + transport, + transport::{Request, Transport}, AuthHandler, Error, }; -/// It is RECOMMENDED that the server use a maximum allowed lifetime value of no -/// more than 3600 seconds (1 hour). +/// Maximum allowed lifetime of an [allocation][1]. +/// +/// See [RFC 5766 Section 6.2][2]: +/// > It is RECOMMENDED that the server use a maximum allowed lifetime value of +/// > no more than 3600 seconds (1 hour). +/// +/// [1]: https://tools.ietf.org/html/rfc5766#section-2.2 +/// [2]: https://tools.ietf.org/html/rfc5766#section-6.2 const MAXIMUM_ALLOCATION_LIFETIME: Duration = Duration::from_secs(3600); -/// Lifetime of the NONCE sent by server. +/// Lifetime of a [`Nonce`] sent by a server. const NONCE_LIFETIME: Duration = Duration::from_secs(3600); -/// Handles the given STUN/TURN message according to [spec]. +/// Handles the provided [`Request`] according to [RFC 5389 Section 7.3][1]. /// -/// [spec]: https://datatracker.ietf.org/doc/html/rfc5389#section-7.3 -#[allow(clippy::too_many_arguments)] -pub(crate) async fn handle_message( - mut raw: Vec, - conn: &Arc, +/// # Errors +/// +/// See the [`Error`] for details. +/// +/// [1]: https://tools.ietf.org/html/rfc5389#section-7.3 +#[allow(clippy::too_many_arguments)] // TODO: refactor +pub(crate) async fn handle( + msg: Request, + conn: &Arc, five_tuple: FiveTuple, server_realm: &str, channel_bind_lifetime: Duration, - allocs: &Arc, - nonces: &Arc>>, - auth_handler: &Arc, + allocs: &mut Manager, + nonces: &mut HashMap, + auth_handler: &(impl AuthHandler + Send + Sync), ) -> Result<(), Error> { - if ChannelData::is_channel_data(&raw) { - let data = ChannelData::decode(mem::take(&mut raw))?; - - handle_data_packet(data, five_tuple, allocs).await - } else { - use stun_codec::MessageClass::{Indication, Request}; - - let msg = MessageDecoder::::new() - .decode_from_bytes(&raw) - .map_err(|e| Error::Decode(*e.kind()))? - .map_err(|e| Error::Decode(*e.error().kind()))?; - - let auth = match (msg.method(), msg.class()) { - ( - ALLOCATE | REFRESH | CREATE_PERMISSION | CHANNEL_BIND, - Request, - ) => { - authenticate_request( - &msg, - auth_handler, - conn, - nonces, - five_tuple, - server_realm, - ) - .await? - } - _ => None, - }; - - match (msg.method(), msg.class()) { - (ALLOCATE, Request) => { - if let Some((uname, realm, pass)) = auth { - handle_allocate_request( - msg, conn, allocs, five_tuple, uname, realm, pass, + match msg { + Request::ChannelData(data) => { + handle_data_packet(data, five_tuple, allocs).await + } + Request::Message(msg) => { + use stun_codec::MessageClass::{Indication, Request}; + + let auth = match (msg.method(), msg.class()) { + ( + ALLOCATE | REFRESH | CREATE_PERMISSION | CHANNEL_BIND, + Request, + ) => { + authenticate_request( + &msg, + auth_handler, + conn, + nonces, + five_tuple, + server_realm, ) - .await - } else { - Ok(()) + .await? } - } - (REFRESH, Request) => { - if let Some((uname, realm, pass)) = auth { - handle_refresh_request( - msg, conn, allocs, five_tuple, uname, realm, pass, - ) - .await - } else { - Ok(()) + _ => None, + }; + + match (msg.method(), msg.class()) { + (ALLOCATE, Request) => { + if let Some((uname, realm, pass)) = auth { + handle_allocate_request( + msg, conn, allocs, five_tuple, uname, realm, pass, + ) + .await + } else { + Ok(()) + } } - } - (CREATE_PERMISSION, Request) => { - if let Some((uname, realm, pass)) = auth { - handle_create_permission_request( - msg, conn, allocs, five_tuple, uname, realm, pass, - ) - .await - } else { - Ok(()) + (REFRESH, Request) => { + if let Some((uname, realm, pass)) = auth { + handle_refresh_request( + msg, conn, allocs, five_tuple, uname, realm, pass, + ) + .await + } else { + Ok(()) + } } - } - (CHANNEL_BIND, Request) => { - if let Some((uname, realm, pass)) = auth { - handle_channel_bind_request( - msg, - conn, - allocs, - five_tuple, - channel_bind_lifetime, - uname, - realm, - pass, - ) - .await - } else { - Ok(()) + (CREATE_PERMISSION, Request) => { + if let Some((uname, realm, pass)) = auth { + handle_create_permission_request( + msg, conn, allocs, five_tuple, uname, realm, pass, + ) + .await + } else { + Ok(()) + } } + (CHANNEL_BIND, Request) => { + if let Some((uname, realm, pass)) = auth { + handle_channel_bind_request( + msg, + conn, + allocs, + five_tuple, + channel_bind_lifetime, + uname, + realm, + pass, + ) + .await + } else { + Ok(()) + } + } + (BINDING, Request) => { + handle_binding_request(conn, five_tuple).await + } + (SEND, Indication) => { + handle_send_indication(msg, allocs, five_tuple).await + } + (_, _) => Err(Error::UnexpectedClass), } - (BINDING, Request) => { - handle_binding_request(conn, five_tuple).await - } - (SEND, Indication) => { - handle_send_indication(msg, allocs, five_tuple).await - } - (_, _) => Err(Error::UnexpectedClass), } } } -/// Relays the given [`ChannelData`]. +/// Relays the provided [`ChannelData`]. +/// +/// # Errors +/// +/// - With an [`Error::NoAllocationFound`] if there is no [`Allocation`] found +/// for the provided [`FiveTuple`]. +/// - With an [`Error::NoSuchChannelBind`] if the there is no channel in the +/// [`Allocation`] for the provided [`ChannelData::num()`]. +/// - Or if fails to relay the provided `data`. async fn handle_data_packet( data: ChannelData, five_tuple: FiveTuple, - allocs: &Arc, + allocs: &mut Manager, ) -> Result<(), Error> { if let Some(alloc) = allocs.get_alloc(&five_tuple) { let channel = alloc.get_channel_addr(&data.num()).await; if let Some(peer) = channel { alloc.relay(&data.data(), peer).await?; - Ok(()) } else { Err(Error::NoSuchChannelBind) @@ -176,14 +187,19 @@ async fn handle_data_packet( } } -/// Handles the given STUN [`Message`] as an [AllocateRequest]. +/// Handles the provided [STUN] [`Message`] as an [Allocate request][1]. +/// +/// # Errors +/// +/// See the [`Error`] for details. /// -/// [AllocateRequest]: https://datatracker.ietf.org/doc/html/rfc5766#section-6.2 -#[allow(clippy::too_many_lines)] +/// [1]: https://tools.ietf.org/html/rfc5766#section-6.2 +/// [STUN]: https://en.wikipedia.org/wiki/STUN +#[allow(clippy::too_many_lines)] // TODO: refactor async fn handle_allocate_request( msg: Message, - conn: &Arc, - allocs: &Arc, + conn: &Arc, + allocs: &mut Manager, five_tuple: FiveTuple, uname: Username, realm: Realm, @@ -196,21 +212,14 @@ async fn handle_allocate_request( // some procedure outside the scope of this document. let mut requested_port = 0; - let mut reservation_token: Option = None; let mut use_ipv4 = true; // 2. The server checks if the 5-tuple is currently in use by an existing // allocation. If yes, the server rejects the request with a 437 // (Allocation Mismatch) error. - if allocs.has_alloc(&five_tuple) { - let mut msg = Message::new( - MessageClass::ErrorResponse, - ALLOCATE, - msg.transaction_id(), - ); - msg.add_attribute(ErrorCode::from(AllocationMismatch)); - - answer_with_err(conn, five_tuple.src_addr, msg).await?; + if allocs.get_alloc(&five_tuple).is_some() { + respond_with_err(&msg, AllocationMismatch, conn, five_tuple.src_addr) + .await?; return Err(Error::RelayAlreadyAllocatedForFiveTuple); } @@ -225,27 +234,19 @@ async fn handle_allocate_request( .get_attribute::() .map(RequestedTransport::protocol) else { - let mut msg = Message::new( - MessageClass::ErrorResponse, - ALLOCATE, - msg.transaction_id(), - ); - msg.add_attribute(ErrorCode::from(BadRequest)); - - answer_with_err(conn, five_tuple.src_addr, msg).await?; + respond_with_err(&msg, BadRequest, conn, five_tuple.src_addr).await?; return Err(Error::AttributeNotFound); }; if requested_proto != PROTO_UDP { - let mut msg = Message::new( - MessageClass::ErrorResponse, - ALLOCATE, - msg.transaction_id(), - ); - msg.add_attribute(ErrorCode::from(UnsupportedTransportProtocol)); - - answer_with_err(conn, five_tuple.src_addr, msg).await?; + respond_with_err( + &msg, + UnsupportedTransportProtocol, + conn, + five_tuple.src_addr, + ) + .await?; return Err(Error::UnsupportedRelayProto); } @@ -266,7 +267,7 @@ async fn handle_allocate_request( vec![DontFragment.get_type()], )); - answer_with_err(conn, five_tuple.src_addr, msg).await?; + send_to(msg, conn, five_tuple.src_addr).await?; return Err(Error::NoDontFragmentSupport); } @@ -284,14 +285,7 @@ async fn handle_allocate_request( let even_port = msg.get_attribute::(); if has_reservation_token && even_port.is_some() { - let mut msg = Message::new( - MessageClass::ErrorResponse, - ALLOCATE, - msg.transaction_id(), - ); - msg.add_attribute(ErrorCode::from(BadRequest)); - - answer_with_err(conn, five_tuple.src_addr, msg).await?; + respond_with_err(&msg, BadRequest, conn, five_tuple.src_addr).await?; return Err(Error::RequestWithReservationTokenAndEvenPort); } @@ -311,14 +305,13 @@ async fn handle_allocate_request( .map(RequestedAddressFamily::address_family) { if has_reservation_token { - let mut msg = Message::new( - MessageClass::ErrorResponse, - ALLOCATE, - msg.transaction_id(), - ); - msg.add_attribute(ErrorCode::from(AddressFamilyNotSupported)); - - answer_with_err(conn, five_tuple.src_addr, msg).await?; + respond_with_err( + &msg, + AddressFamilyNotSupported, + conn, + five_tuple.src_addr, + ) + .await?; return Err(Error::RequestWithReservationTokenAndReqAddressFamily); } @@ -340,14 +333,13 @@ async fn handle_allocate_request( random_port = match allocs.get_random_even_port().await { Ok(port) => port, Err(err) => { - let mut msg = Message::new( - MessageClass::ErrorResponse, - ALLOCATE, - msg.transaction_id(), - ); - msg.add_attribute(ErrorCode::from(InsufficientCapacity)); - - answer_with_err(conn, five_tuple.src_addr, msg).await?; + respond_with_err( + &msg, + InsufficientCapacity, + conn, + five_tuple.src_addr, + ) + .await?; return Err(err); } @@ -355,7 +347,6 @@ async fn handle_allocate_request( } requested_port = random_port; - reservation_token = Some(random()); } // 7. At any point, the server MAY choose to reject the request with a 486 @@ -370,7 +361,7 @@ async fn handle_allocate_request( // different server. The use of this error code and attribute follow the // specification in [RFC5389]. let lifetime_duration = get_lifetime(&msg); - let a = match allocs + let relay_addr = match allocs .create_allocation( five_tuple, Arc::clone(conn), @@ -383,14 +374,13 @@ async fn handle_allocate_request( { Ok(a) => a, Err(err) => { - let mut msg = Message::new( - MessageClass::ErrorResponse, - ALLOCATE, - msg.transaction_id(), - ); - msg.add_attribute(ErrorCode::from(InsufficientCapacity)); - - answer_with_err(conn, five_tuple.src_addr, msg).await?; + respond_with_err( + &msg, + InsufficientCapacity, + conn, + five_tuple.src_addr, + ) + .await?; return Err(err); } @@ -406,29 +396,20 @@ async fn handle_allocate_request( // was reserved). // * An XOR-MAPPED-ADDRESS attribute containing the client's IP address // and port (from the 5-tuple). - let msg = { - if let Some(token) = reservation_token { - allocs.create_reservation(token, a.relay_addr().port()).await; - } - let mut msg = Message::new( MessageClass::SuccessResponse, ALLOCATE, msg.transaction_id(), ); - msg.add_attribute(XorRelayAddress::new(a.relay_addr())); + msg.add_attribute(XorRelayAddress::new(relay_addr)); msg.add_attribute( Lifetime::new(lifetime_duration) .map_err(|e| Error::Encode(*e.kind()))?, ); msg.add_attribute(XorMappedAddress::new(five_tuple.src_addr)); - if let Some(token) = reservation_token { - msg.add_attribute(ReservationToken::new(token)); - } - let integrity = MessageIntegrity::new_long_term_credential( &msg, &uname, &realm, &pass, ) @@ -438,15 +419,19 @@ async fn handle_allocate_request( msg }; - build_and_send(conn, five_tuple.src_addr, msg).await + send_to(msg, conn, five_tuple.src_addr).await } -/// Authenticates the given [`Message`]. +/// Authenticates the provided [`Message`]. +/// +/// # Errors +/// +/// See the [`Error`] for details. async fn authenticate_request( msg: &Message, - auth_handler: &Arc, - conn: &Arc, - nonces: &Arc>>, + auth_handler: &(impl AuthHandler + Send + Sync), + conn: &Arc, + nonces: &mut HashMap, five_tuple: FiveTuple, realm: &str, ) -> Result)>, Error> { @@ -463,23 +448,14 @@ async fn authenticate_request( return Ok(None); }; - let mut bad_request_msg = Message::new( - MessageClass::ErrorResponse, - msg.method(), - msg.transaction_id(), - ); - bad_request_msg.add_attribute(ErrorCode::from(BadRequest)); - let Some(nonce_attr) = &msg.get_attribute::() else { - answer_with_err(conn, five_tuple.src_addr, bad_request_msg).await?; + respond_with_err(msg, BadRequest, conn, five_tuple.src_addr).await?; return Err(Error::AttributeNotFound); }; - let to_be_deleted = { - // Assert Nonce exists and is not expired - let mut nonces = nonces.lock().await; - - let to_be_deleted = nonces.get(nonce_attr.value()).map_or( + let stale_nonce = { + // Assert that the nonce exists and is not yet expired. + let stale_nonce = nonces.get(nonce_attr.value()).map_or( true, |nonce_creation_time| { Instant::now() @@ -489,13 +465,13 @@ async fn authenticate_request( }, ); - if to_be_deleted { + if stale_nonce { _ = nonces.remove(nonce_attr.value()); } - to_be_deleted + stale_nonce }; - if to_be_deleted { + if stale_nonce { respond_with_nonce( msg, ErrorCode::from(StaleNonce), @@ -509,11 +485,11 @@ async fn authenticate_request( } let Some(uname_attr) = msg.get_attribute::() else { - answer_with_err(conn, five_tuple.src_addr, bad_request_msg).await?; + respond_with_err(msg, BadRequest, conn, five_tuple.src_addr).await?; return Err(Error::AttributeNotFound); }; let Some(realm_attr) = msg.get_attribute::() else { - answer_with_err(conn, five_tuple.src_addr, bad_request_msg).await?; + respond_with_err(msg, BadRequest, conn, five_tuple.src_addr).await?; return Err(Error::AttributeNotFound); }; @@ -522,21 +498,14 @@ async fn authenticate_request( realm_attr.text(), five_tuple.src_addr, ) else { - answer_with_err(conn, five_tuple.src_addr, bad_request_msg).await?; + respond_with_err(msg, BadRequest, conn, five_tuple.src_addr).await?; return Err(Error::NoSuchUser); }; if let Err(err) = integrity.check_long_term_credential(uname_attr, realm_attr, &pass) { - let mut unauthorized_msg = Message::new( - MessageClass::ErrorResponse, - msg.method(), - msg.transaction_id(), - ); - unauthorized_msg.add_attribute(err); - - answer_with_err(conn, five_tuple.src_addr, unauthorized_msg).await?; + respond_with_err(msg, err, conn, five_tuple.src_addr).await?; Err(Error::IntegrityMismatch) } else { @@ -545,12 +514,16 @@ async fn authenticate_request( } /// Sends a [`MessageClass::SuccessResponse`] message with a -/// [`XorMappedAddress`] attribute to the given [`Conn`]. +/// [`XorMappedAddress`] attribute to the provided [`Transport`]. +/// +/// # Errors +/// +/// See the [`Error`] for details. async fn handle_binding_request( - conn: &Arc, + conn: &Arc, five_tuple: FiveTuple, ) -> Result<(), Error> { - log::trace!("received BindingRequest from {}", five_tuple.src_addr); + log::trace!("Received `BindingRequest` from {}", five_tuple.src_addr); let mut msg = Message::new( MessageClass::SuccessResponse, @@ -562,26 +535,30 @@ async fn handle_binding_request( Fingerprint::new(&msg).map_err(|e| Error::Encode(*e.kind()))?; msg.add_attribute(fingerprint); - build_and_send(conn, five_tuple.src_addr, msg).await + send_to(msg, conn, five_tuple.src_addr).await } -/// Handle the given [`Message`] as [Refresh Request]. +/// Handles the provided [`Message`] as a [Refresh Request][1]. /// -/// [Refresh Request]: https://datatracker.ietf.org/doc/html/rfc5766#section-7.2 +/// # Errors +/// +/// See the [`Error`] for details. +/// +/// [1]: https://tools.ietf.org/html/rfc5766#section-7.2 async fn handle_refresh_request( msg: Message, - conn: &Arc, - allocs: &Arc, + conn: &Arc, + allocs: &mut Manager, five_tuple: FiveTuple, uname: Username, realm: Realm, pass: Box, ) -> Result<(), Error> { - log::trace!("received RefreshRequest from {}", five_tuple.src_addr); + log::trace!("Received `RefreshRequest` from {}", five_tuple.src_addr); let lifetime_duration = get_lifetime(&msg); if lifetime_duration == Duration::from_secs(0) { - allocs.delete_allocation(&five_tuple).await; + allocs.delete_allocation(&five_tuple); } else if let Some(a) = allocs.get_alloc(&five_tuple) { // If a server receives a Refresh Request with a // REQUESTED-ADDRESS-FAMILY attribute, and the @@ -596,14 +573,13 @@ async fn handle_refresh_request( if (family == AddressFamily::V6 && !a.relay_addr().is_ipv6()) || (family == AddressFamily::V4 && !a.relay_addr().is_ipv4()) { - let mut msg = Message::new( - MessageClass::ErrorResponse, - REFRESH, - msg.transaction_id(), - ); - msg.add_attribute(ErrorCode::from(PeerAddressFamilyMismatch)); - - answer_with_err(conn, five_tuple.src_addr, msg).await?; + respond_with_err( + &msg, + PeerAddressFamilyMismatch, + conn, + five_tuple.src_addr, + ) + .await?; return Err(Error::PeerAddressFamilyMismatch); } @@ -627,22 +603,26 @@ async fn handle_refresh_request( .map_err(|e| Error::Encode(*e.kind()))?; msg.add_attribute(integrity); - build_and_send(conn, five_tuple.src_addr, msg).await + send_to(msg, conn, five_tuple.src_addr).await } -/// Handles the given [`Message`] as a [CreatePermission Request][1]. +/// Handles the provided [`Message`] as a [CreatePermission Request][1]. +/// +/// # Errors +/// +/// See the [`Error`] for details. /// -/// [1]: https://datatracker.ietf.org/doc/html/rfc5766#section-9.2 +/// [1]: https://tools.ietf.org/html/rfc5766#section-9.2 async fn handle_create_permission_request( msg: Message, - conn: &Arc, - allocs: &Arc, + conn: &Arc, + allocs: &mut Manager, five_tuple: FiveTuple, uname: Username, realm: Realm, pass: Box, ) -> Result<(), Error> { - log::trace!("received CreatePermission from {}", five_tuple.src_addr); + log::trace!("Received `CreatePermission` from {}", five_tuple.src_addr); let Some(alloc) = allocs.get_alloc(&five_tuple) else { return Err(Error::NoAllocationFound); @@ -656,27 +636,25 @@ async fn handle_create_permission_request( }; let addr = attr.address(); - // If an XOR-PEER-ADDRESS attribute contains an address of an - // address family different than that of the relayed transport - // address for the allocation, the server MUST generate an error - // response with the 443 (Peer Address Family Mismatch) response - // code. [RFC 6156, Section 6.2] + // If an XOR-PEER-ADDRESS attribute contains an address of an address + // family different than that of the relayed transport address for the + // allocation, the server MUST generate an error response with the 443 + // (Peer Address Family Mismatch) response code. [RFC 6156, Section 6.2] if (addr.is_ipv4() && !alloc.relay_addr().is_ipv4()) || (addr.is_ipv6() && !alloc.relay_addr().is_ipv6()) { - let mut msg = Message::new( - MessageClass::ErrorResponse, - CREATE_PERMISSION, - msg.transaction_id(), - ); - msg.add_attribute(ErrorCode::from(PeerAddressFamilyMismatch)); - - answer_with_err(conn, five_tuple.src_addr, msg).await?; + respond_with_err( + &msg, + PeerAddressFamilyMismatch, + conn, + five_tuple.src_addr, + ) + .await?; return Err(Error::PeerAddressFamilyMismatch); } - log::trace!("adding permission for {}", addr); + log::trace!("Adding permission for {addr}"); alloc.add_permission(addr.ip()).await; add_count += 1; @@ -700,18 +678,22 @@ async fn handle_create_permission_request( msg }; - build_and_send(conn, five_tuple.src_addr, msg).await + send_to(msg, conn, five_tuple.src_addr).await } -/// Handles the given [`Message`] as a [Send Indication][1]. +/// Handles the provided [`Message`] as a [Send Indication][1]. +/// +/// # Errors +/// +/// See the [`Error`] for details. /// -/// [1]: https://datatracker.ietf.org/doc/html/rfc5766#section-10.2 +/// [1]: https://tools.ietf.org/html/rfc5766#section-10.2 async fn handle_send_indication( msg: Message, - allocs: &Arc, + allocs: &mut Manager, five_tuple: FiveTuple, ) -> Result<(), Error> { - log::trace!("received SendIndication from {}", five_tuple.src_addr); + log::trace!("Received `SendIndication` from {}", five_tuple.src_addr); let Some(a) = allocs.get_alloc(&five_tuple) else { return Err(Error::NoAllocationFound); @@ -724,23 +706,25 @@ async fn handle_send_indication( .map(XorPeerAddress::address) .ok_or(Error::AttributeNotFound)?; - let has_perm = a.has_permission(&peer_address).await; - if !has_perm { + if !a.has_permission(&peer_address).await { return Err(Error::NoPermission); } - // TODO: dont clone a.relay(data_attr.data(), peer_address).await.map_err(Into::into) } -/// Handles the given [`Message`] as a [ChannelBind Request][1]. +/// Handles the provided [`Message`] as a [ChannelBind Request][1]. /// -/// [1]: https://datatracker.ietf.org/doc/html/rfc5766#section-11.2 -#[allow(clippy::too_many_arguments)] +/// # Errors +/// +/// See the [`Error`] for details. +/// +/// [1]: https://tools.ietf.org/html/rfc5766#section-11.2 +#[allow(clippy::too_many_arguments)] // TODO: refactor async fn handle_channel_bind_request( msg: Message, - conn: &Arc, - allocs: &Arc, + conn: &Arc, + allocs: &mut Manager, five_tuple: FiveTuple, channel_bind_lifetime: Duration, uname: Username, @@ -748,64 +732,50 @@ async fn handle_channel_bind_request( pass: Box, ) -> Result<(), Error> { if let Some(alloc) = allocs.get_alloc(&five_tuple) { - let mut bad_request_msg = Message::new( - MessageClass::ErrorResponse, - CHANNEL_BIND, - msg.transaction_id(), - ); - bad_request_msg.add_attribute(ErrorCode::from(BadRequest)); - let Some(ch_num) = msg.get_attribute::().map(|a| a.value()) else { - answer_with_err(conn, five_tuple.src_addr, bad_request_msg).await?; + respond_with_err(&msg, BadRequest, conn, five_tuple.src_addr) + .await?; return Err(Error::AttributeNotFound); }; let Some(peer_addr) = msg.get_attribute::().map(XorPeerAddress::address) else { - answer_with_err(conn, five_tuple.src_addr, bad_request_msg).await?; + respond_with_err(&msg, BadRequest, conn, five_tuple.src_addr) + .await?; return Err(Error::AttributeNotFound); }; - // If the XOR-PEER-ADDRESS attribute contains an address of - // an address family different than that - // of the relayed transport address for the - // allocation, the server MUST generate an error response - // with the 443 (Peer Address Family - // Mismatch) response code. [RFC 6156, Section 7.2] + // If the XOR-PEER-ADDRESS attribute contains an address of an address + // family different than that of the relayed transport address for the + // allocation, the server MUST generate an error response with the 443 + // (Peer Address Family Mismatch) response code. [RFC 6156, Section 7.2] if (peer_addr.is_ipv4() && !alloc.relay_addr().is_ipv4()) || (peer_addr.is_ipv6() && !alloc.relay_addr().is_ipv6()) { - let mut peer_address_family_mismatch_msg = Message::new( - MessageClass::ErrorResponse, - CHANNEL_BIND, - msg.transaction_id(), - ); - peer_address_family_mismatch_msg - .add_attribute(ErrorCode::from(PeerAddressFamilyMismatch)); - - answer_with_err( + respond_with_err( + &msg, + PeerAddressFamilyMismatch, conn, five_tuple.src_addr, - peer_address_family_mismatch_msg, ) .await?; return Err(Error::PeerAddressFamilyMismatch); } - log::trace!("binding channel {ch_num} to {peer_addr}",); + log::trace!("Binding channel {ch_num} to {peer_addr}"); - if let Err(err) = alloc + if let Err(e) = alloc .add_channel_bind(ch_num, peer_addr, channel_bind_lifetime) .await { - answer_with_err(conn, five_tuple.src_addr, bad_request_msg).await?; - - return Err(err); + respond_with_err(&msg, BadRequest, conn, five_tuple.src_addr) + .await?; + return Err(e); } let mut msg = Message::new( @@ -820,21 +790,25 @@ async fn handle_channel_bind_request( .map_err(|e| Error::Encode(*e.kind()))?; msg.add_attribute(integrity); - build_and_send(conn, five_tuple.src_addr, msg).await + send_to(msg, conn, five_tuple.src_addr).await } else { Err(Error::NoAllocationFound) } } -/// Responds the given [`Message`] with a [`MessageClass::ErrorResponse`] with -/// a new random nonce. +/// Responds to the provided [`Message`] with a [`MessageClass::ErrorResponse`] +/// with a new random nonce. +/// +/// # Errors +/// +/// See the [`Error`] for details. async fn respond_with_nonce( msg: &Message, response_code: ErrorCode, - conn: &Arc, + conn: &Arc, realm: &str, five_tuple: FiveTuple, - nonces: &Arc>>, + nonces: &mut HashMap, ) -> Result<(), Error> { let nonce: String = rand::thread_rng() .sample_iter(&Alphanumeric) @@ -842,14 +816,7 @@ async fn respond_with_nonce( .map(char::from) .collect(); - { - // Nonce has already been taken - let mut nonces = nonces.lock().await; - if nonces.contains_key(&nonce) { - return Err(Error::RequestReplay); - } - _ = nonces.insert(nonce.clone(), Instant::now()); - } + _ = nonces.insert(nonce.clone(), Instant::now()); let mut msg = Message::new( MessageClass::ErrorResponse, @@ -862,37 +829,56 @@ async fn respond_with_nonce( Realm::new(realm.to_owned()).map_err(|e| Error::Encode(*e.kind()))?, ); - build_and_send(conn, five_tuple.src_addr, msg).await + send_to(msg, conn, five_tuple.src_addr).await } -/// Encodes and sends the provided [`Message`] to the given [`SocketAddr`] -/// via given [`Conn`]. -async fn build_and_send( - conn: &Arc, - dst: SocketAddr, +/// Encodes and sends the provided [`Message`] to the provided [`SocketAddr`] +/// via the provided [`Transport`]. +/// +/// # Errors +/// +/// See the [`Error`] for details. +async fn send_to( msg: Message, + conn: &Arc, + dst: SocketAddr, ) -> Result<(), Error> { let bytes = MessageEncoder::new() .encode_into_bytes(msg) .map_err(|e| Error::Encode(*e.kind()))?; - _ = conn.send_to(bytes, dst).await?; - Ok(()) + match conn.send_to(bytes, dst).await { + Ok(()) | Err(transport::Error::TransportIsDead) => Ok(()), + Err(err) => Err(Error::from(err)), + } } -/// Send a STUN packet and return the original error to the caller -async fn answer_with_err( - conn: &Arc, +/// Sends a [`MessageClass::ErrorResponse`] to the client and returns the +/// original error to the caller. +/// +/// # Errors +/// +/// See the [`Error`] for details. +async fn respond_with_err( + req: &Message, + err: impl Into, + conn: &Arc, dst: SocketAddr, - msg: Message, ) -> Result<(), Error> { - build_and_send(conn, dst, msg).await?; + let mut err_msg = Message::new( + MessageClass::ErrorResponse, + req.method(), + req.transaction_id(), + ); + err_msg.add_attribute(err.into()); + + send_to(err_msg, conn, dst).await?; Ok(()) } -/// Calculates a [`Lifetime`] fetching it from the given [`Message`] and -/// ensuring that it is not greater than configured +/// Calculates a [`Lifetime`], fetching it from the provided [`Message`] and +/// ensuring that it's not greater than the configured /// [`MAXIMUM_ALLOCATION_LIFETIME`]. fn get_lifetime(m: &Message) -> Duration { m.get_attribute::().map(Lifetime::lifetime).map_or( @@ -908,22 +894,19 @@ fn get_lifetime(m: &Message) -> Duration { } #[cfg(test)] -mod request_test { - use std::{net::IpAddr, str::FromStr}; +mod get_lifetime_spec { + use std::time::Duration; - use tokio::{ - net::UdpSocket, - time::{Duration, Instant}, + use crate::attr::Lifetime; + use rand::random; + use stun_codec::{ + rfc5766::methods::ALLOCATE, Message, MessageClass, TransactionId, }; - use crate::{allocation::ManagerConfig, relay::RelayAllocator}; - - use super::*; - - const STATIC_KEY: &str = "ABC"; + use super::{get_lifetime, DEFAULT_LIFETIME, MAXIMUM_ALLOCATION_LIFETIME}; #[tokio::test] - async fn test_allocation_lifetime_parsing() { + async fn lifetime_parsing() { let lifetime = Lifetime::new(Duration::from_secs(5)).unwrap(); let mut m = Message::new( @@ -935,7 +918,7 @@ mod request_test { assert_eq!( lifetime_duration, DEFAULT_LIFETIME, - "Allocation lifetime should be default time duration" + "allocation lifetime should be default time duration", ); m.add_attribute(lifetime.clone()); @@ -944,13 +927,12 @@ mod request_test { assert_eq!( lifetime_duration, lifetime.lifetime(), - "Expect lifetime_duration is {lifetime:?}, but \ - {lifetime_duration:?}" + "wrong lifetime duration", ); } #[tokio::test] - async fn test_allocation_lifetime_overflow() { + async fn lifetime_overflow() { let lifetime = Lifetime::new(MAXIMUM_ALLOCATION_LIFETIME * 2).unwrap(); let mut m2 = Message::new( @@ -963,12 +945,43 @@ mod request_test { let lifetime_duration = get_lifetime(&m2); assert_eq!( lifetime_duration, DEFAULT_LIFETIME, - "Expect lifetime_duration is {DEFAULT_LIFETIME:?}, \ - but {lifetime_duration:?}" + "wrong lifetime duration", ); } +} + +#[cfg(test)] +mod handle_spec { + use std::{ + collections::HashMap, + net::{IpAddr, SocketAddr}, + str::FromStr, + sync::Arc, + }; + + use rand::random; + use stun_codec::{ + rfc5766::methods::REFRESH, Message, MessageClass, TransactionId, + }; + use tokio::{ + net::UdpSocket, + time::{Duration, Instant}, + }; + + use crate::{ + allocation, + attr::{Attribute, Lifetime, MessageIntegrity, Nonce, Realm, Username}, + relay, + transport::Request, + AuthHandler, Error, FiveTuple, Transport, + }; + + use super::handle; + + const STATIC_KEY: &str = "ABC"; struct TestAuthHandler; + impl AuthHandler for TestAuthHandler { fn auth_handle( &self, @@ -981,20 +994,21 @@ mod request_test { } #[tokio::test] - async fn test_allocation_lifetime_deletion_zero_lifetime() { - let conn: Arc = + async fn lifetime_deletion_zero_lifetime() { + let conn: Arc = Arc::new(UdpSocket::bind("0.0.0.0:0").await.unwrap()); - let allocation_manager = Arc::new(Manager::new(ManagerConfig { - relay_addr_generator: RelayAllocator { - relay_address: IpAddr::from([127, 0, 0, 1]), - min_port: 49152, - max_port: 65535, - max_retries: 10, - address: String::from("127.0.0.1"), - }, - alloc_close_notify: None, - })); + let mut allocation_manager = + allocation::Manager::new(allocation::ManagerConfig { + relay_addr_generator: relay::Allocator { + relay_address: IpAddr::from([127, 0, 0, 1]), + min_port: 49152, + max_port: 65535, + max_retries: 10, + address: String::from("127.0.0.1"), + }, + alloc_close_notify: None, + }); let socket = SocketAddr::new(IpAddr::from_str("127.0.0.1").unwrap(), 5000); @@ -1003,11 +1017,11 @@ mod request_test { dst_addr: conn.local_addr(), protocol: conn.proto(), }; - let nonces = Arc::new(Mutex::new(HashMap::new())); + let mut nonces = HashMap::new(); - nonces.lock().await.insert(STATIC_KEY.to_owned(), Instant::now()); + _ = nonces.insert(STATIC_KEY.to_owned(), Instant::now()); - allocation_manager + _ = allocation_manager .create_allocation( five_tuple, Arc::clone(&conn), @@ -1039,18 +1053,16 @@ mod request_test { .unwrap(); m.add_attribute(integrity); - let bytes = MessageEncoder::new().encode_into_bytes(m).unwrap(); - let auth: Arc = Arc::new(TestAuthHandler {}); - handle_message( - bytes, + handle( + Request::Message(m), &conn, five_tuple, STATIC_KEY, Duration::from_secs(60), - &allocation_manager, - &nonces, + &mut allocation_manager, + &mut nonces, &auth, ) .await diff --git a/src/transport/mod.rs b/src/transport/mod.rs new file mode 100644 index 000000000..d76d7c9e0 --- /dev/null +++ b/src/transport/mod.rs @@ -0,0 +1,218 @@ +//! [STUN]/[TURN] transport definitions. +//! +//! [STUN]: https://en.wikipedia.org/wiki/STUN +//! [TURN]: https://en.wikipedia.org/wiki/TURN + +mod tcp; + +use std::io; + +use std::net::SocketAddr; + +use async_trait::async_trait; +use bytecodec::DecodeExt; +use derive_more::{Display, Error as StdError, From}; +use stun_codec::{Message, MessageDecoder}; +use tokio::net::{self, ToSocketAddrs}; + +use crate::{ + attr::{Attribute, PROTO_UDP}, + chandata, + chandata::ChannelData, + server::INBOUND_MTU, +}; + +pub use tokio::net::UdpSocket; + +pub use self::tcp::Server as TcpServer; + +/// Parsed ingress [STUN]/[TURN] message. +/// +/// [STUN]: https://en.wikipedia.org/wiki/STUN +/// [TURN]: https://en.wikipedia.org/wiki/TURN +#[derive(Debug)] +pub enum Request { + /// [STUN Message]. + /// + /// [STUN Message]: https://tools.ietf.org/html/rfc5389#section-6 + Message(Message), + + /// [TURN ChannelData Message][1]. + /// + /// [1]: https://tools.ietf.org/html/rfc5766#section-11.4 + ChannelData(ChannelData), +} + +/// Abstraction of [STUN]/[TURN] transport implementation. +/// +/// [STUN]: https://en.wikipedia.org/wiki/STUN +/// [TURN]: https://en.wikipedia.org/wiki/TURN +#[async_trait] +pub trait Transport { + /// Receives a [`Request`] datagram message. + /// + /// # Errors + /// + /// See the [`Error`] for details. + async fn recv_from(&self) -> Result<(Request, SocketAddr), Error>; + + /// Sends `data` to the provided [`SocketAddr`]. + /// + /// # Errors + /// + /// See the [`Error`] for details. + async fn send_to( + &self, + data: Vec, + target: SocketAddr, + ) -> Result<(), Error>; + + /// Returns the local [`SocketAddr`] of this [`Transport`]. + fn local_addr(&self) -> SocketAddr; + + /// Returns the protocol number of this [`Transport`] according to [IANA]. + /// + /// [IANA]: https://tinyurl.com/iana-protocol-numbers + fn proto(&self) -> u8; +} + +#[async_trait] +impl Transport for UdpSocket { + async fn recv_from(&self) -> Result<(Request, SocketAddr), Error> { + let mut buf = vec![0u8; INBOUND_MTU]; + let (n, addr) = self.recv_from(&mut buf).await?; + + let msg = if ChannelData::is_channel_data(&buf[0..n]) { + buf.truncate(n); + let data = ChannelData::decode(buf)?; + + Request::ChannelData(data) + } else { + let msg = MessageDecoder::::new() + .decode_from_bytes(&buf[0..n]) + .map_err(|e| Error::Decode(*e.kind()))? + .map_err(|e| Error::Decode(*e.error().kind()))?; + + Request::Message(msg) + }; + + Ok((msg, addr)) + } + + async fn send_to( + &self, + data: Vec, + target: SocketAddr, + ) -> Result<(), Error> { + Ok(self.send_to(&data, target).await.map(|_| ())?) + } + + fn local_addr(&self) -> SocketAddr { + // PANIC: Unwrapping is OK here, as this function is intended to be + // called on the bound `UdpSocket` only. + #[allow(clippy::unwrap_used)] // intentional + self.local_addr().unwrap() + } + + fn proto(&self) -> u8 { + PROTO_UDP + } +} + +/// Performs a DNS resolution of the provided `host`. +/// +/// # Errors +/// +/// If the provided `host` cannot be resolved to IP address. +pub(crate) async fn lookup_host( + use_ipv4: bool, + host: impl ToSocketAddrs, +) -> Result { + for remote_addr in net::lookup_host(host).await? { + if (use_ipv4 && remote_addr.is_ipv4()) + || (!use_ipv4 && remote_addr.is_ipv6()) + { + return Ok(remote_addr); + } + } + + Err(io::Error::other(format!( + "No available {} IP address found!", + if use_ipv4 { "ipv4" } else { "ipv6" }, + )) + .into()) +} + +/// Possible errors of a [`Transport`]. +#[derive(Debug, Display, From, Eq, PartialEq, StdError)] +#[allow(variant_size_differences)] +pub enum Error { + /// Tried to use a dead [`Transport`]. + #[display("Underlying TCP/UDP transport is dead")] + TransportIsDead, + + /// Failed to decode message. + #[display("Failed to decode STUN/TURN message: {_0:?}")] + Decode(#[error(not(source))] bytecodec::ErrorKind), + + /// [TURN ChannelData][1] format error. + /// + /// [1]: https://tools.ietf.org/html/rfc5766#section-11.4 + #[from(chandata::FormatError)] + ChannelData(chandata::FormatError), + + /// I/O error of the [`Transport`]. + #[display("I/O error: {_0}")] + #[from(io::Error, IoError)] + Io(IoError), +} + +/// [`io::Error`] implementing [`Eq`] and [`PartialEq`] by its [`kind`]. +/// +/// [`kind`]: io::Error::kind() +#[derive(Debug, Display, From, StdError)] +pub struct IoError(pub io::Error); + +impl Eq for IoError {} + +impl PartialEq for IoError { + fn eq(&self, other: &Self) -> bool { + self.0.kind() == other.0.kind() + } +} + +#[cfg(test)] +mod lookup_host_spec { + use super::lookup_host; + + #[tokio::test] + async fn considers_ip_version() { + let stun_serv_addr = "stun1.l.google.com:19302"; + + if let Ok(ipv4_addr) = lookup_host(true, stun_serv_addr).await { + assert!( + ipv4_addr.is_ipv4(), + "expected ipv4 but got ipv6: {ipv4_addr}", + ); + } + + if let Ok(ipv6_addr) = lookup_host(false, stun_serv_addr).await { + assert!( + ipv6_addr.is_ipv6(), + "expected ipv6 but got ipv4: {ipv6_addr}", + ); + } + } + + #[tokio::test] + async fn resolves_localhost() { + let udp_addr = lookup_host(true, "localhost:1234").await.unwrap(); + + assert_eq!(udp_addr.ip().to_string(), "127.0.0.1"); + assert_eq!(udp_addr.port(), 1234); + + let res = lookup_host(false, "127.0.0.1:1234").await; + + assert!(res.is_err(), "expected error, got: {res:?}"); + } +} diff --git a/src/con/tcp.rs b/src/transport/tcp.rs similarity index 54% rename from src/con/tcp.rs rename to src/transport/tcp.rs index e8c4f222b..eb13512eb 100644 --- a/src/con/tcp.rs +++ b/src/transport/tcp.rs @@ -1,6 +1,7 @@ -//! STUN/TURN TCP server connection implementation. - -#![allow(clippy::module_name_repetitions)] +//! [STUN]/[TURN] TCP-based [`Transport`] implementation. +//! +//! [STUN]: https://en.wikipedia.org/wiki/STUN +//! [TURN]: https://en.wikipedia.org/wiki/TURN use std::{ collections::{hash_map::Entry, HashMap}, @@ -9,8 +10,10 @@ use std::{ }; use async_trait::async_trait; +use bytecodec::DecodeExt; use bytes::BytesMut; use futures::StreamExt; +use stun_codec::MessageDecoder; use tokio::{ io::AsyncWriteExt as _, net::{TcpListener, TcpStream}, @@ -19,37 +22,41 @@ use tokio::{ use tokio_util::codec::{Decoder, FramedRead}; use crate::{ - attr::PROTO_TCP, - chandata::nearest_padded_value_length, - con::{Conn, Error}, + attr::{Attribute, PROTO_TCP}, + chandata::{nearest_padded_value_length, ChannelData}, }; -/// Channels to the active TCP sessions. +use super::{Error, Request, Transport}; + +/// Shortcut for a [`HashMap`] of active TCP sessions. type TcpWritersMap = Arc< Mutex< HashMap< SocketAddr, - mpsc::Sender<(Vec, oneshot::Sender>)>, + mpsc::Sender<(Vec, oneshot::Sender>)>, >, >, >; -/// TURN TCP transport. +/// Server implementing [STUN]/[TURN] TCP-based [`Transport`]. +/// +/// [STUN]: https://en.wikipedia.org/wiki/STUN +/// [TURN]: https://en.wikipedia.org/wiki/TURN #[derive(Debug)] -pub struct TcpServer { - /// Ingress messages receiver. - ingress_rx: Mutex, SocketAddr)>>, +pub struct Server { + /// [`mpsc::Receiver`] of [`Request`]s. + ingress_rx: Mutex>, - /// Local [`TcpListener`] address. + /// Local [`SocketAddr`] of the [`TcpListener`]. local_addr: SocketAddr, - /// Channels to all active TCP sessions. + /// Active TCP sessions. writers: TcpWritersMap, } #[async_trait] -impl Conn for TcpServer { - async fn recv_from(&self) -> Result<(Vec, SocketAddr), Error> { +impl Transport for Server { + async fn recv_from(&self) -> Result<(Request, SocketAddr), Error> { if let Some((data, addr)) = self.ingress_rx.lock().await.recv().await { Ok((data, addr)) } else { @@ -57,22 +64,23 @@ impl Conn for TcpServer { } } - #[allow(clippy::significant_drop_in_scrutinee)] async fn send_to( &self, data: Vec, target: SocketAddr, - ) -> Result { + ) -> Result<(), Error> { let mut writers = self.writers.lock().await; + #[allow(clippy::significant_drop_in_scrutinee)] // intentional match writers.entry(target) { Entry::Occupied(mut e) => { let (res_tx, res_rx) = oneshot::channel(); if e.get_mut().send((data, res_tx)).await.is_err() { // Underlying TCP stream is dead. drop(e.remove_entry()); + Err(Error::TransportIsDead) } else { - #[allow(clippy::map_err_ignore)] + #[allow(clippy::map_err_ignore)] // intentional res_rx.await.map_err(|_| Error::TransportIsDead)? } } @@ -87,19 +95,17 @@ impl Conn for TcpServer { fn proto(&self) -> u8 { PROTO_TCP } - - async fn close(&self) -> Result<(), Error> { - Ok(()) - } } -impl TcpServer { - /// Creates a new [`TcpServer`]. +impl Server { + /// Creates and [`spawn`]s a new [`Server`] on the provided [`TcpListener`]. /// /// # Errors /// - /// With [`enum@Error`] if failed to receive local [`SocketAddr`] for the - /// provided [`TcpListener`]. + /// If fails to receive the local [`SocketAddr`] of the provided + /// [`TcpListener`]. + /// + /// [`spawn`]: tokio::spawn() pub fn new(listener: TcpListener) -> Result { let local_addr = listener.local_addr()?; let (ingress_tx, ingress_rx) = mpsc::channel(256); @@ -109,60 +115,65 @@ impl TcpServer { let writers = Arc::clone(&writers); async move { loop { - let Ok((stream, remote)) = listener.accept().await else { - log::debug!("Closing TCP listener at {local_addr}"); - break; - }; - if ingress_tx.is_closed() { - break; + tokio::select! { + stream = listener.accept() => { + match stream { + Ok((stream, remote)) => { + Self::spawn_stream_handler( + stream, + local_addr, + remote, + ingress_tx.clone(), + Arc::clone(&writers), + ); + }, + Err(_) => { + break; + } + } + } + () = ingress_tx.closed() => { + break; + } } - - Self::spawn_stream_handler( - stream, - local_addr, - remote, - ingress_tx.clone(), - Arc::clone(&writers), - ); } + log::debug!("Closing `TcpListener` at {local_addr}"); } })); Ok(Self { ingress_rx: Mutex::new(ingress_rx), local_addr, writers }) } - /// Spawns a handler task for the given [`TcpStream`] + /// [`spawn`]s a handler for the provided [`TcpStream`]. + /// + /// [`spawn`]: tokio::spawn() fn spawn_stream_handler( mut stream: TcpStream, - local_addr: SocketAddr, + local: SocketAddr, remote: SocketAddr, - ingress_tx: mpsc::Sender<(Vec, SocketAddr)>, + ingress_tx: mpsc::Sender<(Request, SocketAddr)>, writers: TcpWritersMap, ) { drop(tokio::spawn(async move { let (egress_tx, mut egress_rx) = mpsc::channel::<( Vec, - oneshot::Sender>, + oneshot::Sender>, )>(256); drop(writers.lock().await.insert(remote, egress_tx)); let (reader, mut writer) = stream.split(); - let mut reader = FramedRead::new(reader, StunTcpCodec::default()); + let mut reader = FramedRead::new(reader, Codec::default()); loop { tokio::select! { msg = egress_rx.recv() => { if let Some((msg, tx)) = msg { - let len = msg.len(); let res = writer.write_all(msg.as_slice()).await - .map(|()| len) .map_err(Error::from); drop(tx.send(res)); } else { - log::debug!("Closing TCP {local_addr} <=> \ - {remote}"); - + log::debug!("Closing TCP {local} <=> {remote}"); break; } }, @@ -172,23 +183,23 @@ impl TcpServer { match ingress_tx.try_send((msg, remote)) { Ok(()) => {}, Err(TrySendError::Full(_)) => { - log::debug!("Dropped ingress message \ - from TCP {local_addr} <=> {remote}"); + log::debug!( + "Dropped ingress message from TCP \ + {local} <=> {remote}", + ); } Err(TrySendError::Closed(_)) => { - log::debug!("Closing TCP \ - {local_addr} <=> {remote}"); - + log::debug!( + "Closing TCP {local} <=> {remote}", + ); break; } } } Some(Err(_)) => {}, None => { - log::debug!("Closing TCP \ - {local_addr} <=> {remote}"); - + log::debug!("Closing TCP {local} <=> {remote}"); break; } } @@ -199,10 +210,12 @@ impl TcpServer { } } -#[derive(Debug, Clone, Copy)] -enum StunMessageKind { - /// STUN method. +/// Kind of a [`Request`] message. +#[derive(Clone, Copy, Debug)] +enum RequestKind { + /// [STUN Message]. /// + /// ```ascii /// 0 1 2 3 /// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ @@ -214,10 +227,14 @@ enum StunMessageKind { /// | Transaction ID (96 bits) | /// | | /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - Method(usize), + /// ``` + /// + /// [STUN Message]: https://tools.ietf.org/html/rfc5389#section-6 + Message(usize), - /// TURN [ChannelData][1]. + /// [TURN ChannelData Message][1]. /// + /// ```ascii /// 0 1 2 3 /// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ @@ -230,13 +247,15 @@ enum StunMessageKind { /// | +-------------------------------+ /// | | /// +-------------------------------+ + /// ``` /// - /// [1]: https://datatracker.ietf.org/doc/html/rfc5766#section-11.4 + /// [1]: https://tools.ietf.org/html/rfc5766#section-11.4 ChannelData(usize), } -impl StunMessageKind { - /// Detects [`StunMessageKind`] from the given 4 bytes. +impl RequestKind { + /// Detects a [`RequestKind`] from the provided first 4 bytes of a + /// [`Request`]. fn detect_kind(first_4_bytes: [u8; 4]) -> Self { let size = usize::from(u16::from_be_bytes([ first_4_bytes[2], @@ -245,48 +264,72 @@ impl StunMessageKind { // If the first two bits are zeroes, then this is a STUN method. if first_4_bytes[0] & 0b1100_0000 == 0 { - Self::Method(nearest_padded_value_length(size + 20)) + Self::Message(nearest_padded_value_length(size + 20)) } else { Self::ChannelData(nearest_padded_value_length(size + 4)) } } - /// Returns the expected length of the message. + /// Returns the expected length of the [`Request`] message. const fn length(&self) -> usize { *match self { - Self::Method(l) | Self::ChannelData(l) => l, + Self::Message(l) | Self::ChannelData(l) => l, } } } -/// [`Decoder`] that splits STUN/TURN stream into frames. +/// [`Decoder`] splitting a [STUN]/[TURN] stream into frames. +/// +/// [STUN]: https://en.wikipedia.org/wiki/STUN +/// [TURN]: https://en.wikipedia.org/wiki/TURN #[derive(Default)] -struct StunTcpCodec { - /// Current message kind. - current: Option, +struct Codec { + /// Current [`RequestKind`]. + current: Option, + + /// [STUN Message] decoder. + /// + /// [STUN Message]: https://tools.ietf.org/html/rfc5389#section-6 + msg_decoder: MessageDecoder, } -impl Decoder for StunTcpCodec { +impl Decoder for Codec { + type Item = Request; type Error = Error; - type Item = Vec; - #[allow(clippy::unwrap_in_result, clippy::missing_asserts_for_indexing)] fn decode( &mut self, buf: &mut BytesMut, ) -> Result, Self::Error> { + // PANIC: Indexing is OK below, since we guard it with `if` condition. + #![allow(clippy::missing_asserts_for_indexing)] // false positive + if self.current.is_none() && buf.len() >= 4 { - self.current = Some(StunMessageKind::detect_kind([ + self.current = Some(RequestKind::detect_kind([ buf[0], buf[1], buf[2], buf[3], ])); } - if let Some(pending) = self.current { - if buf.len() >= pending.length() { - #[allow(clippy::unwrap_used)] - return Ok(Some( - buf.split_to(self.current.take().unwrap().length()) - .to_vec(), - )); + + if let Some(current) = self.current { + if buf.len() >= current.length() { + _ = self.current.take(); + + let raw = buf.split_to(current.length()); + let msg = match current { + RequestKind::Message(_) => { + let msg = self + .msg_decoder + .decode_from_bytes(&raw) + .map_err(|e| Error::Decode(*e.kind()))? + .map_err(|e| Error::Decode(*e.error().kind()))?; + + Request::Message(msg) + } + RequestKind::ChannelData(_) => { + Request::ChannelData(ChannelData::decode(raw.to_vec())?) + } + }; + return Ok(Some(msg)); } }