From 2e72a8f2b98a539afb59e0e83e40e70c59085d53 Mon Sep 17 00:00:00 2001 From: yngrtc Date: Sat, 16 Mar 2024 09:17:09 -0700 Subject: [PATCH] fix client_test.rs --- rtc-shared/src/util.rs | 17 +++ rtc-turn/src/client/client_test.rs | 183 +++++++++++++++--------- rtc-turn/src/client/mod.rs | 33 +++-- rtc-turn/src/client/relay.rs | 13 +- rtc-turn/src/client/relay/relay_test.rs | 56 -------- 5 files changed, 158 insertions(+), 144 deletions(-) delete mode 100644 rtc-turn/src/client/relay/relay_test.rs diff --git a/rtc-shared/src/util.rs b/rtc-shared/src/util.rs index bff0b2e..2887cef 100644 --- a/rtc-shared/src/util.rs +++ b/rtc-shared/src/util.rs @@ -1,3 +1,6 @@ +use crate::error::{Error, Result}; +use std::net::{SocketAddr, ToSocketAddrs}; + // match_range is a MatchFunc that accepts packets with the first byte in [lower..upper] fn match_range(lower: u8, upper: u8) -> impl Fn(&[u8]) -> bool { move |buf: &[u8]| -> bool { @@ -53,3 +56,17 @@ pub fn match_srtp(buf: &[u8]) -> bool { pub fn match_srtcp(buf: &[u8]) -> bool { match_srtp_or_srtcp(buf) && is_rtcp(buf) } + +/// lookup host to SocketAddr +pub fn lookup_host(use_ipv4: bool, host: T) -> Result +where + T: ToSocketAddrs, +{ + for remote_addr in host.to_socket_addrs()? { + if (use_ipv4 && remote_addr.is_ipv4()) || (!use_ipv4 && remote_addr.is_ipv6()) { + return Ok(remote_addr); + } + } + + Err(Error::ErrAddressParseFailed) +} diff --git a/rtc-turn/src/client/client_test.rs b/rtc-turn/src/client/client_test.rs index 9f08f6b..188bb10 100644 --- a/rtc-turn/src/client/client_test.rs +++ b/rtc-turn/src/client/client_test.rs @@ -1,120 +1,167 @@ -use tokio::net::UdpSocket; - use super::*; -use crate::auth::*; +use std::collections::HashSet; +use std::net::UdpSocket; -async fn create_listening_test_client(rto_in_ms: u16) -> Result { - let conn = UdpSocket::bind("0.0.0.0:0").await?; +fn create_listening_test_client(rto_in_ms: u64) -> Result<(UdpSocket, Client)> { + let udp_socket = UdpSocket::bind("0.0.0.0:0")?; - let c = Client::new(ClientConfig { + let client = Client::new(ClientConfig { stun_serv_addr: String::new(), turn_serv_addr: String::new(), + local_addr: udp_socket.local_addr()?, + protocol: Protocol::UDP, username: String::new(), password: String::new(), realm: String::new(), software: "TEST SOFTWARE".to_owned(), rto_in_ms, - conn: Arc::new(conn), - vnet: None, - }) - .await?; - - c.listen().await?; + })?; - Ok(c) + Ok((udp_socket, client)) } -async fn create_listening_test_client_with_stun_serv() -> Result { - let conn = UdpSocket::bind("0.0.0.0:0").await?; +fn create_listening_test_client_with_stun_serv() -> Result<(UdpSocket, Client)> { + let udp_socket = UdpSocket::bind("0.0.0.0:0")?; - let c = Client::new(ClientConfig { + let client = Client::new(ClientConfig { stun_serv_addr: "stun1.l.google.com:19302".to_owned(), turn_serv_addr: String::new(), + local_addr: udp_socket.local_addr()?, + protocol: Protocol::UDP, username: String::new(), password: String::new(), realm: String::new(), software: "TEST SOFTWARE".to_owned(), rto_in_ms: 0, - conn: Arc::new(conn), - vnet: None, - }) - .await?; - - c.listen().await?; + })?; - Ok(c) + Ok((udp_socket, client)) } -#[tokio::test] -async fn test_client_with_stun_send_binding_request() -> Result<()> { +#[test] +fn test_client_with_stun_send_binding_request() -> Result<()> { //env_logger::init(); - let c = create_listening_test_client_with_stun_serv().await?; + let (conn, mut client) = create_listening_test_client_with_stun_serv()?; + let local_addr = conn.local_addr()?; + + let tid = client.send_binding_request()?; + + while let Some(transmit) = client.poll_transmit() { + conn.send_to(&transmit.message, transmit.transport.peer_addr)?; + } - let resp = c.send_binding_request().await?; - log::debug!("mapped-addr: {}", resp); - { - let ci = c.client_internal.lock().await; - let tm = ci.tr_map.lock().await; - assert_eq!(0, tm.size(), "should be no transaction left"); + let mut buffer = vec![0u8; 2048]; + let (n, peer_addr) = conn.recv_from(&mut buffer)?; + client.handle_transmit(Transmit { + now: Instant::now(), + transport: TransportContext { + local_addr, + peer_addr, + protocol: Protocol::UDP, + ecn: None, + }, + message: BytesMut::from(&buffer[..n]), + })?; + + if let Some(event) = client.poll_event() { + match event { + Event::BindingResponse(id, refl_addr) => { + assert_eq!(tid, id); + log::debug!("mapped-addr: {}", refl_addr); + } + _ => assert!(false), + } + } else { + assert!(false); } - c.close().await?; + assert_eq!(0, client.tr_map.size(), "should be no transaction left"); + + client.close(); Ok(()) } -#[tokio::test] -async fn test_client_with_stun_send_binding_request_to_parallel() -> Result<()> { - env_logger::init(); - - let c1 = create_listening_test_client(0).await?; - let c2 = c1.clone(); +#[test] +fn test_client_with_stun_send_binding_request_to_parallel() -> Result<()> { + //env_logger::init(); - let (stared_tx, mut started_rx) = mpsc::channel::<()>(1); - let (finished_tx, mut finished_rx) = mpsc::channel::<()>(1); + let (conn, mut client) = create_listening_test_client(0)?; + let local_addr = conn.local_addr()?; - let to = lookup_host(true, "stun1.l.google.com:19302").await?; + let to = lookup_host(true, "stun1.l.google.com:19302")?; - tokio::spawn(async move { - drop(stared_tx); - if let Ok(resp) = c2.send_binding_request_to(&to.to_string()).await { - log::debug!("mapped-addr: {}", resp); - } - drop(finished_tx); - }); + let tid1 = client.send_binding_request_to(to)?; + let tid2 = client.send_binding_request_to(to)?; + while let Some(transmit) = client.poll_transmit() { + conn.send_to(&transmit.message, transmit.transport.peer_addr)?; + } - let _ = started_rx.recv().await; + let mut buffer = vec![0u8; 2048]; + for _ in 0..2 { + let (n, peer_addr) = conn.recv_from(&mut buffer)?; + client.handle_transmit(Transmit { + now: Instant::now(), + transport: TransportContext { + local_addr, + peer_addr, + protocol: Protocol::UDP, + ecn: None, + }, + message: BytesMut::from(&buffer[..n]), + })?; + } - let resp = c1.send_binding_request_to(&to.to_string()).await?; - log::debug!("mapped-addr: {}", resp); + let mut tids = HashSet::new(); + while let Some(event) = client.poll_event() { + match event { + Event::BindingResponse(tid, refl_addr) => { + tids.insert(tid); + log::debug!("mapped-addr: {}", refl_addr); + } + _ => {} + } + } - let _ = finished_rx.recv().await; + assert_eq!(2, tids.len()); + assert!(tids.contains(&tid1)); + assert!(tids.contains(&tid2)); - c1.close().await?; + client.close(); Ok(()) } -#[tokio::test] -async fn test_client_with_stun_send_binding_request_to_timeout() -> Result<()> { +#[test] +fn test_client_with_stun_send_binding_request_to_timeout() -> Result<()> { //env_logger::init(); - let c = create_listening_test_client(10).await?; + let (conn, mut client) = create_listening_test_client(10)?; - let to = lookup_host(true, "127.0.0.1:9").await?; + let to = lookup_host(true, "127.0.0.1:9")?; - let result = c.send_binding_request_to(&to.to_string()).await; - assert!(result.is_err(), "expected error, but got ok"); - - c.close().await?; + let tid = client.send_binding_request_to(to)?; + while let Some(transmit) = client.poll_transmit() { + conn.send_to(&transmit.message, transmit.transport.peer_addr)?; + } - Ok(()) -} + while let Some(to) = client.poll_timout() { + client.handle_timeout(to); + } -struct TestAuthHandler; -impl AuthHandler for TestAuthHandler { - fn auth_handle(&self, username: &str, realm: &str, _src_addr: SocketAddr) -> Result> { - Ok(generate_auth_key(username, realm, "pass")) + if let Some(event) = client.poll_event() { + match event { + Event::TransactionTimeout(id) => { + assert_eq!(tid, id); + } + _ => assert!(false), + } + } else { + assert!(false); } + + client.close(); + + Ok(()) } diff --git a/rtc-turn/src/client/mod.rs b/rtc-turn/src/client/mod.rs index 6daaf60..43f5102 100644 --- a/rtc-turn/src/client/mod.rs +++ b/rtc-turn/src/client/mod.rs @@ -1,6 +1,6 @@ -/*TODO:#[cfg(test)] +#[cfg(test)] mod client_test; -*/ + pub mod binding; pub mod permission; pub mod relay; @@ -9,7 +9,6 @@ pub mod transaction; use bytes::BytesMut; use std::collections::{HashMap, VecDeque}; use std::net::SocketAddr; -use std::str::FromStr; use std::time::Instant; use stun::attributes::*; @@ -31,6 +30,7 @@ use crate::proto::relayaddr::RelayedAddress; use crate::proto::reqtrans::RequestedTransport; use crate::proto::{PROTO_TCP, PROTO_UDP}; use shared::error::{Error, Result}; +use shared::util::lookup_host; use shared::{Protocol, Transmit, TransportContext}; use stun::error_code::ErrorCodeAttribute; use stun::fingerprint::FINGERPRINT; @@ -89,7 +89,7 @@ pub struct ClientConfig { /// Client is a STUN client pub struct Client { stun_serv_addr: Option, - turn_serv_addr: SocketAddr, + turn_serv_addr: Option, local_addr: SocketAddr, protocol: Protocol, username: Username, @@ -112,13 +112,19 @@ impl Client { let stun_serv_addr = if config.stun_serv_addr.is_empty() { None } else { - Some(SocketAddr::from_str(config.stun_serv_addr.as_str())?) + Some(lookup_host( + config.local_addr.is_ipv4(), + config.stun_serv_addr.as_str(), + )?) }; let turn_serv_addr = if config.turn_serv_addr.is_empty() { - return Err(Error::ErrNilTurnSocket); + None } else { - SocketAddr::from_str(config.turn_serv_addr.as_str())? + Some(lookup_host( + config.local_addr.is_ipv4(), + config.turn_serv_addr.as_str(), + )?) }; Ok(Client { @@ -500,8 +506,11 @@ impl Client { ])?; log::debug!("client.Allocate call PerformTransaction 1"); - let tid = - self.perform_transaction(&msg, self.turn_serv_addr, TransactionType::AllocateAttempt); + let tid = self.perform_transaction( + &msg, + self.turn_server_addr()?, + TransactionType::AllocateAttempt, + ); Ok(tid) } @@ -558,7 +567,7 @@ impl Client { log::debug!("client.Allocate call PerformTransaction 2"); self.perform_transaction( &msg, - self.turn_serv_addr, + self.turn_server_addr()?, TransactionType::AllocateRequest(nonce), ); } @@ -599,8 +608,8 @@ impl Client { } /// turn_server_addr return the TURN server address - fn turn_server_addr(&self) -> SocketAddr { - self.turn_serv_addr + fn turn_server_addr(&self) -> Result { + self.turn_serv_addr.ok_or(Error::ErrNilTurnSocket) } /// username returns username diff --git a/rtc-turn/src/client/relay.rs b/rtc-turn/src/client/relay.rs index 046ce4e..e146e06 100644 --- a/rtc-turn/src/client/relay.rs +++ b/rtc-turn/src/client/relay.rs @@ -1,6 +1,3 @@ -//TODO #[cfg(test)] -//mod relay_test; - use log::{debug, warn}; use std::collections::HashMap; use std::net::SocketAddr; @@ -206,7 +203,7 @@ impl<'a> Relay<'a> { // indication has no transaction (fire-and-forget) self.client - .write_to(&msg.raw, self.client.turn_server_addr()); + .write_to(&msg.raw, self.client.turn_server_addr()?); return Ok(()); } @@ -270,7 +267,7 @@ impl<'a> Relay<'a> { let _ = self.client.perform_transaction( &msg, - self.client.turn_server_addr(), + self.client.turn_server_addr()?, TransactionType::CreatePermissionRequest(self.relayed_addr, peer_addr_opt), ); @@ -335,7 +332,7 @@ impl<'a> Relay<'a> { let _ = self.client.perform_transaction( &msg, - self.client.turn_server_addr(), + self.client.turn_server_addr()?, TransactionType::RefreshRequest(self.relayed_addr), ); @@ -415,7 +412,7 @@ impl<'a> Relay<'a> { let mut msg = Message::new(); msg.build(&setters)?; - (msg, self.client.turn_server_addr()) + (msg, self.client.turn_server_addr()?) }; debug!("UDPConn.bind call PerformTransaction 1"); @@ -480,7 +477,7 @@ impl<'a> Relay<'a> { ch_data.encode(); self.client - .write_to(&ch_data.raw, self.client.turn_server_addr()); + .write_to(&ch_data.raw, self.client.turn_server_addr()?); Ok(()) } diff --git a/rtc-turn/src/client/relay/relay_test.rs b/rtc-turn/src/client/relay/relay_test.rs deleted file mode 100644 index 04876bc..0000000 --- a/rtc-turn/src/client/relay/relay_test.rs +++ /dev/null @@ -1,56 +0,0 @@ -use std::net::Ipv4Addr; - -use super::*; -use shared::error::Result; - -struct DummyRelayConnObserver { - turn_server_addr: String, - username: Username, - realm: Realm, -} - -#[test] -fn test_relay() -> Result<()> { - let obs = DummyRelayConnObserver { - turn_server_addr: String::new(), - username: Username::new(ATTR_USERNAME, "username".to_owned()), - realm: Realm::new(ATTR_REALM, "realm".to_owned()), - }; - - let (_read_ch_tx, read_ch_rx) = mpsc::channel(100); - - let config = RelayConfig { - relayed_addr: SocketAddr::new(Ipv4Addr::new(0, 0, 0, 0).into(), 0), - integrity: MessageIntegrity::default(), - nonce: Nonce::new(ATTR_NONCE, "nonce".to_owned()), - lifetime: Duration::from_secs(0), - binding_mgr: Arc::new(Mutex::new(BindingManager::new())), - read_ch_rx: Arc::new(Mutex::new(read_ch_rx)), - }; - - let rc = Relay::new(Arc::new(Mutex::new(obs)), config).await; - - let rci = rc.relay_conn.lock().await; - let (bind_addr, bind_number) = { - let mut bm = rci.binding_mgr.lock().await; - let b = bm - .create(SocketAddr::new(Ipv4Addr::new(127, 0, 0, 1).into(), 1234)) - .unwrap(); - (b.addr, b.number) - }; - - //let binding_mgr = Arc::clone(&rci.binding_mgr); - let rc_obs = Arc::clone(&rci.obs); - let nonce = rci.nonce.clone(); - let integrity = rci.integrity.clone(); - - if let Err(err) = - RelayConnInternal::bind(rc_obs, bind_addr, bind_number, nonce, integrity).await - { - assert!(Error::ErrUnexpectedResponse != err); - } else { - panic!("should fail"); - } - - Ok(()) -}