From 7df1201ee172bc5d12d55b5aeeecce6ae0920695 Mon Sep 17 00:00:00 2001 From: Cameron Garnham Date: Wed, 19 Jun 2024 15:44:02 +0200 Subject: [PATCH] dev: use stream for udp requests --- .cargo/config.toml | 1 - cSpell.json | 2 + src/lib.rs | 2 +- src/servers/udp/handlers.rs | 8 +- src/servers/udp/server.rs | 481 +++++++++++++------ src/shared/bit_torrent/tracker/udp/client.rs | 57 ++- tests/servers/health_check_api/contract.rs | 9 + tests/servers/udp/contract.rs | 75 ++- tests/servers/udp/environment.rs | 11 +- 9 files changed, 447 insertions(+), 199 deletions(-) diff --git a/.cargo/config.toml b/.cargo/config.toml index 34d6230b..a88db5f3 100644 --- a/.cargo/config.toml +++ b/.cargo/config.toml @@ -23,4 +23,3 @@ rustflags = [ "-D", "unused", ] - diff --git a/cSpell.json b/cSpell.json index ef807f03..6a9da032 100644 --- a/cSpell.json +++ b/cSpell.json @@ -34,10 +34,12 @@ "codecov", "codegen", "completei", + "Condvar", "connectionless", "Containerfile", "conv", "curr", + "cvar", "Cyberneering", "dashmap", "datagram", diff --git a/src/lib.rs b/src/lib.rs index cf283441..bb6826dd 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -494,7 +494,7 @@ pub mod bootstrap; pub mod console; pub mod core; pub mod servers; -pub mod shared; +pub mod shared; #[macro_use] extern crate lazy_static; diff --git a/src/servers/udp/handlers.rs b/src/servers/udp/handlers.rs index 36825f08..f7e3aac6 100644 --- a/src/servers/udp/handlers.rs +++ b/src/servers/udp/handlers.rs @@ -10,7 +10,6 @@ use aquatic_udp_protocol::{ ErrorResponse, Ipv4AddrBytes, Ipv6AddrBytes, NumberOfDownloads, NumberOfPeers, Port, Request, Response, ResponsePeer, ScrapeRequest, ScrapeResponse, TorrentScrapeStatistics, TransactionId, }; -use tokio::net::UdpSocket; use torrust_tracker_located_error::DynError; use torrust_tracker_primitives::info_hash::InfoHash; use tracing::debug; @@ -34,13 +33,12 @@ use crate::shared::bit_torrent::common::MAX_SCRAPE_TORRENTS; /// - Delegating the request to the correct handler depending on the request type. /// /// It will return an `Error` response if the request is invalid. -pub(crate) async fn handle_packet(udp_request: UdpRequest, tracker: &Arc, socket: Arc) -> Response { +pub(crate) async fn handle_packet(udp_request: UdpRequest, tracker: &Arc, addr: SocketAddr) -> Response { debug!("Handling Packets: {udp_request:?}"); let start_time = Instant::now(); let request_id = RequestId::make(&udp_request); - let server_socket_addr = socket.local_addr().expect("Could not get local_addr for socket."); match Request::parse_bytes(&udp_request.payload[..udp_request.payload.len()], MAX_SCRAPE_TORRENTS).map_err(|e| { Error::InternalServer { @@ -49,7 +47,7 @@ pub(crate) async fn handle_packet(udp_request: UdpRequest, tracker: &Arc { - log_request(&request, &request_id, &server_socket_addr); + log_request(&request, &request_id, &addr); let transaction_id = match &request { Request::Connect(connect_request) => connect_request.transaction_id, @@ -64,7 +62,7 @@ pub(crate) async fn handle_packet(udp_request: UdpRequest, tracker: &Arc { /// /// It panics if unable to receive the bound socket address from service. /// - pub async fn start(self, tracker: Arc, form: ServiceRegistrationForm) -> Result, Error> { + pub async fn start(self, tracker: Arc, form: ServiceRegistrationForm) -> Result, std::io::Error> { let (tx_start, rx_start) = tokio::sync::oneshot::channel::(); let (tx_halt, rx_halt) = tokio::sync::oneshot::channel::(); @@ -129,6 +139,7 @@ impl UdpServer { let task = self.state.launcher.start(tracker, tx_start, rx_halt); let binding = rx_start.await.expect("it should be able to start the service").address; + let local_addr = format!("udp://{binding}"); form.send(ServiceRegistration::new(binding, Udp::check)) .expect("it should be able to send service registration"); @@ -141,7 +152,7 @@ impl UdpServer { }, }; - trace!("Running UDP Tracker on Socket: {}", running_udp_server.state.binding); + tracing::trace!(target: "UDP TRACKER: UdpServer::start", local_addr, "(running)"); Ok(running_udp_server) } @@ -159,13 +170,13 @@ impl UdpServer { /// # Panics /// /// It panics if unable to shutdown service. - pub async fn stop(self) -> Result, Error> { + pub async fn stop(self) -> Result, UdpError> { self.state .halt_task .send(Halted::Normal) - .map_err(|e| Error::Error(e.to_string()))?; + .map_err(|e| UdpError::Error(e.to_string()))?; - let launcher = self.state.task.await.expect("unable to shutdown service"); + let launcher = self.state.task.await.expect("it should shutdown service"); let stopped_api_server: UdpServer = UdpServer { state: Stopped { launcher }, @@ -200,23 +211,12 @@ impl Launcher { } } +/// Ring-Buffer of Active Requests +#[derive(Default)] struct ActiveRequests { rb: StaticRb, // the number of requests we handle at the same time. } -impl ActiveRequests { - /// Creates a new [`ActiveRequests`] filled with finished tasks. - async fn new() -> Self { - let mut rb = StaticRb::default(); - - let () = while rb.try_push(tokio::task::spawn_blocking(|| ()).abort_handle()).is_ok() {}; - - task::yield_now().await; - - Self { rb } - } -} - impl std::fmt::Debug for ActiveRequests { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let (left, right) = &self.rb.as_slices(); @@ -235,6 +235,86 @@ impl Drop for ActiveRequests { } } +/// Wrapper for Tokio [`UdpSocket`][`tokio::net::UdpSocket`] that can be canceled +struct Socket { + socket: Arc, + recv: (Mutex>>, Condvar), +} + +impl Socket { + async fn new(addr: SocketAddr) -> Result> { + let socket = tokio::net::UdpSocket::bind(addr).await; + + let socket = match socket { + Ok(socket) => socket, + Err(e) => Err(e)?, + }; + + let local_addr = format!("udp://{addr}"); + tracing::debug!(target: "UDP TRACKER: UdpSocket::new", local_addr, "(bound)"); + + Ok(Self { + socket: Arc::new(socket), + recv: (Mutex::default(), Condvar::default()), + }) + } +} + +impl Deref for Socket { + type Target = tokio::net::UdpSocket; + + fn deref(&self) -> &Self::Target { + &self.socket + } +} + +impl Debug for Socket { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let local_addr = match self.socket.local_addr() { + Ok(socket) => format!("Receiving From: {socket}"), + Err(err) => format!("Socket Broken: {err}"), + }; + + f.debug_struct("UdpSocket").field("addr", &local_addr).finish_non_exhaustive() + } +} + +struct Receiver { + socket: Arc, + tracker: Arc, + data: RefCell<[u8; MAX_PACKET_SIZE]>, +} + +impl Stream for Receiver { + type Item = std::io::Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut buf = *self.data.borrow_mut(); + let mut buf = tokio::io::ReadBuf::new(&mut buf); + + let Poll::Ready(ready) = self.socket.poll_recv_from(cx, &mut buf) else { + return Poll::Pending; + }; + + let res = match ready { + Ok(from) => { + let payload = buf.filled().to_vec(); + let request = UdpRequest { payload, from }; + + Some(Ok(tokio::task::spawn(Udp::process_request( + request, + self.tracker.clone(), + self.socket.clone(), + )) + .abort_handle())) + } + Err(err) => Some(Err(err)), + }; + + Poll::Ready(res) + } +} + /// A UDP server instance launcher. #[derive(Constructor)] pub struct Udp; @@ -252,124 +332,178 @@ impl Udp { tx_start: oneshot::Sender, rx_halt: oneshot::Receiver, ) { - let socket = Arc::new( - UdpSocket::bind(bind_to) - .await - .unwrap_or_else(|_| panic!("Could not bind to {bind_to}.")), - ); - let address = socket - .local_addr() - .unwrap_or_else(|_| panic!("Could not get local_addr from {bind_to}.")); - let halt = shutdown_signal_with_message(rx_halt, format!("Halting Http Service Bound to Socket: {address}")); - - info!(target: "UDP TRACKER", "Starting on: udp://{}", address); - - let running = tokio::task::spawn(async move { - debug!(target: "UDP TRACKER", "Started: Waiting for packets on socket address: udp://{address} ..."); - Self::run_udp_server(tracker, socket).await; - }); + let halt_task = tokio::task::spawn(shutdown_signal_with_message( + rx_halt, + format!("Halting Http Service Bound to Socket: {bind_to}"), + )); + + let socket = tokio::time::timeout(Duration::from_millis(5000), Socket::new(bind_to)) + .await + .expect("it should bind to the socket within five seconds"); + + let socket = match socket { + Ok(socket) => socket, + Err(e) => { + tracing::error!(target: "UDP TRACKER: Udp::run_with_graceful_shutdown", addr = %bind_to, err = %e, "panic! (error when building socket)" ); + panic!("could not bind to socket!"); + } + }; + + let address = socket.local_addr().expect("it should get the locally bound address"); + let local_addr = format!("udp://{address}"); + + // note: this log message is parsed by our container. i.e: + // + // `[UDP TRACKER][INFO] Starting on: udp://` + // + tracing::info!(target: "UDP TRACKER", "Starting on: {local_addr}"); + + let socket = socket.socket; + + let direct = Receiver { + socket, + tracker, + data: RefCell::new([0; MAX_PACKET_SIZE]), + }; + + tracing::trace!(target: "UDP TRACKER: Udp::run_with_graceful_shutdown", local_addr, "(spawning main loop)"); + let running = { + let local_addr = local_addr.clone(); + tokio::task::spawn(async move { + tracing::debug!(target: "UDP TRACKER: Udp::run_with_graceful_shutdown::task", local_addr, "(listening...)"); + let () = Self::run_udp_server_main(direct).await; + }) + }; tx_start .send(Started { address }) .expect("the UDP Tracker service should not be dropped"); - info!(target: "UDP TRACKER", "Started on: udp://{}", address); + tracing::debug!(target: "UDP TRACKER: Udp::run_with_graceful_shutdown", local_addr, "(started)"); let stop = running.abort_handle(); select! { - _ = running => { debug!(target: "UDP TRACKER", "Socket listener stopped on address: udp://{address}"); }, - () = halt => { debug!(target: "UDP TRACKER", "Halt signal spawned task stopped on address: udp://{address}"); } + _ = running => { tracing::debug!(target: "UDP TRACKER: Udp::run_with_graceful_shutdown", local_addr, "(stopped)"); }, + _ = halt_task => { tracing::debug!(target: "UDP TRACKER: Udp::run_with_graceful_shutdown",local_addr, "(halting)"); } } stop.abort(); - task::yield_now().await; // lets allow the other threads to complete. + tokio::task::yield_now().await; // lets allow the other threads to complete. } - async fn run_udp_server(tracker: Arc, socket: Arc) { - let tracker = tracker.clone(); - let socket = socket.clone(); + async fn run_udp_server_main(mut direct: Receiver) { + let reqs = &mut ActiveRequests::default(); - let reqs = &mut ActiveRequests::new().await; + let addr = direct.socket.local_addr().expect("it should get local address"); + let local_addr = format!("udp://{addr}"); loop { - task::yield_now().await; - for h in reqs.rb.iter_mut() { - if h.is_finished() { - std::mem::swap( - h, - &mut Self::spawn_request_processor( - Self::receive_request(socket.clone()).await, - tracker.clone(), - socket.clone(), - ) - .abort_handle(), - ); - } else { - // the task is still running, lets yield and give it a chance to flush. + if let Some(req) = { + tracing::trace!(target: "UDP TRACKER: Udp::run_udp_server", local_addr, "(wait for request)"); + direct.next().await + } { + tracing::trace!(target: "UDP TRACKER: Udp::run_udp_server::loop", local_addr, "(in)"); + + let req = match req { + Ok(req) => req, + Err(e) => { + if e.kind() == std::io::ErrorKind::Interrupted { + tracing::warn!(target: "UDP TRACKER: Udp::run_udp_server::loop", local_addr, err = %e, "(interrupted)"); + return; + } + tracing::error!(target: "UDP TRACKER: Udp::run_udp_server::loop", local_addr, err = %e, "break: (got error)"); + break; + } + }; + + if req.is_finished() { + continue; + } + + // fill buffer with requests + let Err(req) = reqs.rb.try_push(req) else { + continue; + }; + + let mut finished: u64 = 0; + let mut unfinished_task = None; + // buffer is full.. lets make some space. + for h in reqs.rb.pop_iter() { + // remove some finished tasks + if h.is_finished() { + finished += 1; + continue; + } + + // task is unfinished.. give it another chance. tokio::task::yield_now().await; - h.abort(); + // if now finished, we continue. + if h.is_finished() { + finished += 1; + continue; + } - let server_socket_addr = socket.local_addr().expect("Could not get local_addr for socket."); + tracing::debug!(target: "UDP TRACKER: Udp::run_udp_server::loop", local_addr, removed_count = finished, "(got unfinished task)"); - tracing::span!( - target: "UDP TRACKER", - tracing::Level::WARN, "request-aborted", server_socket_addr = %server_socket_addr); + if finished == 0 { + // we have _no_ finished tasks.. will abort the unfinished task to make space... + h.abort(); - // force-break a single thread, then loop again. - break; - } - } - } - } + tracing::warn!(target: "UDP TRACKER: Udp::run_udp_server::loop", local_addr, "aborting request: (no finished tasks)"); + break; + } - async fn receive_request(socket: Arc) -> Result> { - // Wait for the socket to be readable - socket.readable().await?; + // we have space, return unfinished task for re-entry. + unfinished_task = Some(h); + } - let mut buf = Vec::with_capacity(MAX_PACKET_SIZE); + // re-insert the previous unfinished task. + if let Some(h) = unfinished_task { + reqs.rb.try_push(h).expect("it was previously inserted"); + } - match socket.recv_buf_from(&mut buf).await { - Ok((n, from)) => { - Vec::truncate(&mut buf, n); - trace!("GOT {buf:?}"); - Ok(UdpRequest { payload: buf, from }) + // insert the new task. + if !req.is_finished() { + reqs.rb.try_push(req).expect("it should remove at least one element."); + } + } else { + tokio::task::yield_now().await; + // the request iterator returned `None`. + tracing::error!(target: "UDP TRACKER: Udp::run_udp_server", local_addr, "breaking: (ran dry, should not happen in production!)"); + break; } - - Err(e) => Err(Box::new(e)), } } - fn spawn_request_processor( - result: Result>, - tracker: Arc, - socket: Arc, - ) -> JoinHandle<()> { - tokio::task::spawn(Self::process_request(result, tracker, socket)) - } - - async fn process_request(result: Result>, tracker: Arc, socket: Arc) { - match result { - Ok(udp_request) => { - trace!("Received Request from: {}", udp_request.from); - Self::process_valid_request(tracker.clone(), socket.clone(), udp_request).await; - } - Err(error) => { - debug!("error: {error}"); - } - } + async fn process_request(request: UdpRequest, tracker: Arc, socket: Arc) { + tracing::trace!(target: "UDP TRACKER: Udp::process_request", request = %request.from, "(receiving)"); + Self::process_valid_request(tracker, socket, request).await; } async fn process_valid_request(tracker: Arc, socket: Arc, udp_request: UdpRequest) { - trace!("Making Response to {udp_request:?}"); + tracing::trace!(target: "UDP TRACKER: Udp::process_valid_request", "Making Response to {udp_request:?}"); let from = udp_request.from; - let response = handlers::handle_packet(udp_request, &tracker.clone(), socket.clone()).await; + let response = handlers::handle_packet( + udp_request, + &tracker.clone(), + socket.local_addr().expect("it should get the local address"), + ) + .await; Self::send_response(&socket.clone(), from, response).await; } async fn send_response(socket: &Arc, to: SocketAddr, response: Response) { - trace!("Sending Response: {response:?} to: {to:?}"); + let response_type = match &response { + Response::Connect(_) => "Connect".to_string(), + Response::AnnounceIpv4(_) => "AnnounceIpv4".to_string(), + Response::AnnounceIpv6(_) => "AnnounceIpv6".to_string(), + Response::Scrape(_) => "Scrape".to_string(), + Response::Error(e) => format!("Error: {e:?}"), + }; + + tracing::debug!(target: "UDP TRACKER: Udp::send_response", target = ?to, response_type, "(sending)"); let buffer = vec![0u8; MAX_PACKET_SIZE]; let mut cursor = Cursor::new(buffer); @@ -380,22 +514,21 @@ impl Udp { let position = cursor.position() as usize; let inner = cursor.get_ref(); - debug!("Sending {} bytes ...", &inner[..position].len()); - debug!("To: {:?}", &to); - debug!("Payload: {:?}", &inner[..position]); + tracing::debug!(target: "UDP TRACKER: Udp::send_response", ?to, bytes_count = &inner[..position].len(), "(sending...)" ); + tracing::trace!(target: "UDP TRACKER: Udp::send_response", ?to, bytes_count = &inner[..position].len(), payload = ?&inner[..position], "(sending...)"); Self::send_packet(socket, &to, &inner[..position]).await; - debug!("{} bytes sent", &inner[..position].len()); + tracing::trace!(target: "UDP TRACKER: Udp::send_response", ?to, bytes_count = &inner[..position].len(), "(sent)"); } - Err(_) => { - error!("could not write response to bytes."); + Err(e) => { + tracing::error!(target: "UDP TRACKER: Udp::send_response", ?to, response_type, err = %e, "(error)"); } } } async fn send_packet(socket: &Arc, remote_addr: &SocketAddr, payload: &[u8]) { - trace!("Sending Packets: {payload:?} to: {remote_addr:?}"); + tracing::trace!(target: "UDP TRACKER: Udp::send_response", to = %remote_addr, ?payload, "(sending)"); // doesn't matter if it reaches or not drop(socket.send_to(payload, remote_addr).await); @@ -413,55 +546,46 @@ impl Udp { #[cfg(test)] mod tests { - use std::sync::Arc; - use std::time::Duration; + use std::{sync::Arc, time::Duration}; - use ringbuf::traits::{Consumer, Observer, RingBuffer}; use torrust_tracker_test_helpers::configuration::ephemeral_mode_public; - use super::ActiveRequests; - use crate::bootstrap::app::initialize_with_configuration; - use crate::servers::registar::Registar; - use crate::servers::udp::server::{Launcher, UdpServer}; + use crate::{ + bootstrap::app::initialize_with_configuration, + servers::{ + registar::Registar, + udp::server::{Launcher, UdpServer}, + }, + }; #[tokio::test] - async fn it_should_return_to_the_start_of_the_ring_buffer() { - let mut a_req = ActiveRequests::new().await; - - tokio::time::sleep(Duration::from_millis(10)).await; + async fn it_should_be_able_to_start_and_stop() { + let cfg = Arc::new(ephemeral_mode_public()); + let tracker = initialize_with_configuration(&cfg); + let udp_trackers = cfg.udp_trackers.clone().expect("missing UDP trackers configuration"); + let config = &udp_trackers[0]; + let bind_to = config.bind_address; + let register = &Registar::default(); - let mut count: usize = 0; - let cap: usize = a_req.rb.capacity().into(); + let stopped = UdpServer::new(Launcher::new(bind_to)); - // Add a single pending task to check that the ring-buffer is looping correctly. - a_req - .rb - .push_overwrite(tokio::task::spawn(std::future::pending::<()>()).abort_handle()); + let started = stopped + .start(tracker, register.give_form()) + .await + .expect("it should start the server"); - count += 1; + let stopped = started.stop().await.expect("it should stop the server"); - for _ in 0..2 { - for h in a_req.rb.iter() { - let first = count % cap; - println!("{count},{first},{}", h.is_finished()); - - if first == 0 { - assert!(!h.is_finished()); - } else { - assert!(h.is_finished()); - } + tokio::time::sleep(Duration::from_secs(1)).await; - count += 1; - } - } + assert_eq!(stopped.state.launcher.bind_to, bind_to); } #[tokio::test] - async fn it_should_be_able_to_start_and_stop() { + async fn it_should_be_able_to_start_and_stop_with_wait() { let cfg = Arc::new(ephemeral_mode_public()); let tracker = initialize_with_configuration(&cfg); - let udp_trackers = cfg.udp_trackers.clone().expect("missing UDP trackers configuration"); - let config = &udp_trackers[0]; + let config = &cfg.udp_trackers.as_ref().unwrap().first().unwrap(); let bind_to = config.bind_address; let register = &Registar::default(); @@ -472,6 +596,8 @@ mod tests { .await .expect("it should start the server"); + tokio::time::sleep(Duration::from_secs(1)).await; + let stopped = started.stop().await.expect("it should stop the server"); tokio::time::sleep(Duration::from_secs(1)).await; @@ -479,3 +605,68 @@ mod tests { assert_eq!(stopped.state.launcher.bind_to, bind_to); } } + +/// Todo: submit test to tokio documentation. +#[cfg(test)] +mod test_tokio { + use std::sync::Arc; + use std::time::Duration; + + use tokio::sync::Barrier; + use tokio::task::JoinSet; + + #[tokio::test] + async fn test_barrier_with_aborted_tasks() { + // Create a barrier that requires 10 tasks to proceed. + let barrier = Arc::new(Barrier::new(10)); + let mut tasks = JoinSet::default(); + let mut handles = Vec::default(); + + // Set Barrier to 9/10. + for _ in 0..9 { + let c = barrier.clone(); + handles.push(tasks.spawn(async move { + c.wait().await; + })); + } + + // Abort two tasks: Barrier: 7/10. + for _ in 0..2 { + if let Some(handle) = handles.pop() { + handle.abort(); + } + } + + // Spawn a single task: Barrier 8/10. + let c = barrier.clone(); + handles.push(tasks.spawn(async move { + c.wait().await; + })); + + // give a chance fro the barrier to release. + tokio::time::sleep(Duration::from_millis(50)).await; + + // assert that the barrier isn't removed, i.e. 8, not 10. + for h in &handles { + assert!(!h.is_finished()); + } + + // Spawn two more tasks to trigger the barrier release: Barrier 10/10. + for _ in 0..2 { + let c = barrier.clone(); + handles.push(tasks.spawn(async move { + c.wait().await; + })); + } + + // give a chance fro the barrier to release. + tokio::time::sleep(Duration::from_millis(50)).await; + + // assert that the barrier has been triggered + for h in &handles { + assert!(h.is_finished()); + } + + tasks.shutdown().await; + } +} diff --git a/src/shared/bit_torrent/tracker/udp/client.rs b/src/shared/bit_torrent/tracker/udp/client.rs index 45b51ad3..90054346 100644 --- a/src/shared/bit_torrent/tracker/udp/client.rs +++ b/src/shared/bit_torrent/tracker/udp/client.rs @@ -15,7 +15,7 @@ use crate::shared::bit_torrent::tracker::udp::{source_address, MAX_PACKET_SIZE}; /// Default timeout for sending and receiving packets. And waiting for sockets /// to be readable and writable. -const DEFAULT_TIMEOUT: Duration = Duration::from_secs(5); +pub const DEFAULT_TIMEOUT: Duration = Duration::from_secs(5); #[allow(clippy::module_name_repetitions)] #[derive(Debug)] @@ -37,7 +37,16 @@ impl UdpClient { .parse::() .context(format!("{local_address} is not a valid socket address"))?; - let socket = UdpSocket::bind(socket_addr).await?; + let socket = match time::timeout(DEFAULT_TIMEOUT, UdpSocket::bind(socket_addr)).await { + Ok(bind_result) => match bind_result { + Ok(socket) => { + debug!("Bound to socket: {socket_addr}"); + Ok(socket) + } + Err(e) => Err(anyhow!("Failed to bind to socket: {socket_addr}, error: {e:?}")), + }, + Err(e) => Err(anyhow!("Timeout waiting to bind to socket: {socket_addr}, error: {e:?}")), + }?; let udp_client = Self { socket: Arc::new(socket), @@ -54,12 +63,15 @@ impl UdpClient { .parse::() .context(format!("{remote_address} is not a valid socket address"))?; - match self.socket.connect(socket_addr).await { - Ok(()) => { - debug!("Connected successfully"); - Ok(()) - } - Err(e) => Err(anyhow!("Failed to connect: {e:?}")), + match time::timeout(self.timeout, self.socket.connect(socket_addr)).await { + Ok(connect_result) => match connect_result { + Ok(()) => { + debug!("Connected to socket {socket_addr}"); + Ok(()) + } + Err(e) => Err(anyhow!("Failed to connect to socket {socket_addr}: {e:?}")), + }, + Err(e) => Err(anyhow!("Timeout waiting to connect to socket {socket_addr}, error: {e:?}")), } } @@ -100,7 +112,9 @@ impl UdpClient { /// /// # Panics /// - pub async fn receive(&self, bytes: &mut [u8]) -> Result { + pub async fn receive(&self) -> Result> { + let mut response_buffer = [0u8; MAX_PACKET_SIZE]; + debug!(target: "UDP client", "receiving ..."); match time::timeout(self.timeout, self.socket.readable()).await { @@ -113,21 +127,20 @@ impl UdpClient { Err(e) => return Err(anyhow!("Timeout waiting for the socket to become readable: {e:?}")), }; - let size_result = match time::timeout(self.timeout, self.socket.recv(bytes)).await { + let size = match time::timeout(self.timeout, self.socket.recv(&mut response_buffer)).await { Ok(recv_result) => match recv_result { Ok(size) => Ok(size), Err(e) => Err(anyhow!("IO error during send: {e:?}")), }, Err(e) => Err(anyhow!("Receive operation timed out: {e:?}")), - }; + }?; - if size_result.is_ok() { - let size = size_result.as_ref().unwrap(); - debug!(target: "UDP client", "{size} bytes received {bytes:?}"); - size_result - } else { - size_result - } + let mut res: Vec = response_buffer.to_vec(); + Vec::truncate(&mut res, size); + + debug!(target: "UDP client", "{size} bytes received {res:?}"); + + Ok(res) } } @@ -181,13 +194,11 @@ impl UdpTrackerClient { /// /// Will return error if can't create response from the received payload (bytes buffer). pub async fn receive(&self) -> Result { - let mut response_buffer = [0u8; MAX_PACKET_SIZE]; - - let payload_size = self.udp_client.receive(&mut response_buffer).await?; + let payload = self.udp_client.receive().await?; - debug!(target: "UDP tracker client", "received {payload_size} bytes. Response {response_buffer:?}"); + debug!(target: "UDP tracker client", "received {} bytes. Response {payload:?}", payload.len()); - let response = Response::parse_bytes(&response_buffer[..payload_size], true)?; + let response = Response::parse_bytes(&payload, true)?; Ok(response) } diff --git a/tests/servers/health_check_api/contract.rs b/tests/servers/health_check_api/contract.rs index 3c3c1315..6164a516 100644 --- a/tests/servers/health_check_api/contract.rs +++ b/tests/servers/health_check_api/contract.rs @@ -245,6 +245,8 @@ mod udp { use crate::servers::health_check_api::Started; use crate::servers::udp; + //static INIT: std::sync::Once = std::sync::Once::new(); + #[tokio::test] pub(crate) async fn it_should_return_good_health_for_udp_service() { let configuration = Arc::new(configuration::ephemeral()); @@ -288,6 +290,13 @@ mod udp { #[tokio::test] pub(crate) async fn it_should_return_error_when_udp_service_was_stopped_after_registration() { + // INIT.call_once(|| { + // let () = tracing_subscriber::fmt() + // .compact() + // .with_max_level(tracing::Level::TRACE) + // .init(); + // }); + let configuration = Arc::new(configuration::ephemeral()); let service = udp::Started::new(&configuration).await; diff --git a/tests/servers/udp/contract.rs b/tests/servers/udp/contract.rs index 7abd6092..677aad3b 100644 --- a/tests/servers/udp/contract.rs +++ b/tests/servers/udp/contract.rs @@ -17,10 +17,6 @@ fn empty_udp_request() -> [u8; MAX_PACKET_SIZE] { [0; MAX_PACKET_SIZE] } -fn empty_buffer() -> [u8; MAX_PACKET_SIZE] { - [0; MAX_PACKET_SIZE] -} - async fn send_connection_request(transaction_id: TransactionId, client: &UdpTrackerClient) -> ConnectionId { let connect_request = ConnectRequest { transaction_id }; @@ -54,13 +50,12 @@ async fn should_return_a_bad_request_response_when_the_client_sends_an_empty_req Err(err) => panic!("{err}"), }; - let mut buffer = empty_buffer(); - match client.receive(&mut buffer).await { - Ok(_) => (), + let response = match client.receive().await { + Ok(response) => response, Err(err) => panic!("{err}"), }; - let response = Response::parse_bytes(&buffer, true).unwrap(); + let response = Response::parse_bytes(&response, true).unwrap(); assert!(is_error_response(&response, "bad request")); @@ -75,8 +70,17 @@ mod receiving_a_connection_request { use crate::servers::udp::asserts::is_connect_response; use crate::servers::udp::Started; + static INIT: std::sync::Once = std::sync::Once::new(); + #[tokio::test] async fn should_return_a_connect_response() { + INIT.call_once(|| { + let () = tracing_subscriber::fmt() + .compact() + .with_max_level(tracing::Level::TRACE) + .init(); + }); + let env = Started::new(&configuration::ephemeral().into()).await; let client = match new_udp_tracker_client_connected(&env.bind_address().to_string()).await { @@ -111,30 +115,20 @@ mod receiving_an_announce_request { AnnounceActionPlaceholder, AnnounceEvent, AnnounceRequest, ConnectionId, InfoHash, NumberOfBytes, NumberOfPeers, PeerId, PeerKey, Port, TransactionId, }; - use torrust_tracker::shared::bit_torrent::tracker::udp::client::new_udp_tracker_client_connected; + use torrust_tracker::shared::bit_torrent::tracker::udp::client::{new_udp_tracker_client_connected, UdpTrackerClient}; use torrust_tracker_test_helpers::configuration; use crate::servers::udp::asserts::is_ipv4_announce_response; use crate::servers::udp::contract::send_connection_request; use crate::servers::udp::Started; - #[tokio::test] - async fn should_return_an_announce_response() { - let env = Started::new(&configuration::ephemeral().into()).await; - - let client = match new_udp_tracker_client_connected(&env.bind_address().to_string()).await { - Ok(udp_tracker_client) => udp_tracker_client, - Err(err) => panic!("{err}"), - }; - - let connection_id = send_connection_request(TransactionId::new(123), &client).await; - + pub async fn send_and_get_announce(tx_id: TransactionId, c_id: ConnectionId, client: &UdpTrackerClient) { // Send announce request let announce_request = AnnounceRequest { - connection_id: ConnectionId(connection_id.0), + connection_id: ConnectionId(c_id.0), action_placeholder: AnnounceActionPlaceholder::default(), - transaction_id: TransactionId::new(123i32), + transaction_id: tx_id, info_hash: InfoHash([0u8; 20]), peer_id: PeerId([255u8; 20]), bytes_downloaded: NumberOfBytes(0i64.into()), @@ -160,6 +154,43 @@ mod receiving_an_announce_request { println!("test response {response:?}"); assert!(is_ipv4_announce_response(&response)); + } + + #[tokio::test] + async fn should_return_an_announce_response() { + let env = Started::new(&configuration::ephemeral().into()).await; + + let client = match new_udp_tracker_client_connected(&env.bind_address().to_string()).await { + Ok(udp_tracker_client) => udp_tracker_client, + Err(err) => panic!("{err}"), + }; + + let tx_id = TransactionId::new(123); + + let c_id = send_connection_request(tx_id, &client).await; + + send_and_get_announce(tx_id, c_id, &client).await; + + env.stop().await; + } + + #[tokio::test] + async fn should_return_many_announce_response() { + let env = Started::new(&configuration::ephemeral().into()).await; + + let client = match new_udp_tracker_client_connected(&env.bind_address().to_string()).await { + Ok(udp_tracker_client) => udp_tracker_client, + Err(err) => panic!("{err}"), + }; + + let tx_id = TransactionId::new(123); + + let c_id = send_connection_request(tx_id, &client).await; + + for x in 0..1000 { + tracing::info!("req no: {x}"); + send_and_get_announce(tx_id, c_id, &client).await; + } env.stop().await; } diff --git a/tests/servers/udp/environment.rs b/tests/servers/udp/environment.rs index 1ba038c7..7b21defc 100644 --- a/tests/servers/udp/environment.rs +++ b/tests/servers/udp/environment.rs @@ -5,6 +5,7 @@ use torrust_tracker::bootstrap::app::initialize_with_configuration; use torrust_tracker::core::Tracker; use torrust_tracker::servers::registar::Registar; use torrust_tracker::servers::udp::server::{Launcher, Running, Stopped, UdpServer}; +use torrust_tracker::shared::bit_torrent::tracker::udp::client::DEFAULT_TIMEOUT; use torrust_tracker_configuration::{Configuration, UdpTracker}; use torrust_tracker_primitives::info_hash::InfoHash; use torrust_tracker_primitives::peer; @@ -58,16 +59,22 @@ impl Environment { impl Environment { pub async fn new(configuration: &Arc) -> Self { - Environment::::new(configuration).start().await + tokio::time::timeout(DEFAULT_TIMEOUT, Environment::::new(configuration).start()) + .await + .expect("it should create an environment within the timeout") } #[allow(dead_code)] pub async fn stop(self) -> Environment { + let stopped = tokio::time::timeout(DEFAULT_TIMEOUT, self.server.stop()) + .await + .expect("it should stop the environment within the timeout"); + Environment { config: self.config, tracker: self.tracker, registar: Registar::default(), - server: self.server.stop().await.expect("it stop the udp tracker service"), + server: stopped.expect("it stop the udp tracker service"), } }