diff --git a/.cargo/config.toml b/.cargo/config.toml index 34d6230b..a88db5f3 100644 --- a/.cargo/config.toml +++ b/.cargo/config.toml @@ -23,4 +23,3 @@ rustflags = [ "-D", "unused", ] - diff --git a/Cargo.lock b/Cargo.lock index 523ea575..a41d6327 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3802,6 +3802,17 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-stream" +version = "0.1.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "267ac89e0bec6e691e5813911606935d77c476ff49024f98abcea3e7b15e37af" +dependencies = [ + "futures-core", + "pin-project-lite", + "tokio", +] + [[package]] name = "tokio-util" version = "0.7.11" @@ -3910,6 +3921,7 @@ dependencies = [ "rand", "reqwest", "ringbuf", + "rstest", "serde", "serde_bencode", "serde_bytes", @@ -3917,6 +3929,7 @@ dependencies = [ "serde_repr", "thiserror", "tokio", + "tokio-stream", "torrust-tracker-clock", "torrust-tracker-configuration", "torrust-tracker-contrib-bencode", diff --git a/Cargo.toml b/Cargo.toml index 5183c606..f95fc549 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -63,6 +63,7 @@ r2d2_sqlite = { version = "0", features = ["bundled"] } rand = "0" reqwest = { version = "0", features = ["json"] } ringbuf = "0" +rstest = "0" serde = { version = "1", features = ["derive"] } serde_bencode = "0" serde_bytes = "0" @@ -70,6 +71,7 @@ serde_json = { version = "1", features = ["preserve_order"] } serde_repr = "0" thiserror = "1" tokio = { version = "1", features = ["macros", "net", "rt-multi-thread", "signal", "sync"] } +tokio-stream = "0" torrust-tracker-clock = { version = "3.0.0-alpha.12-develop", path = "packages/clock" } torrust-tracker-configuration = { version = "3.0.0-alpha.12-develop", path = "packages/configuration" } torrust-tracker-contrib-bencode = { version = "3.0.0-alpha.12-develop", path = "contrib/bencode" } @@ -80,12 +82,13 @@ tower = { version = "0.4.13", features = ["timeout"] } tower-http = { version = "0", features = ["compression-full", "cors", "propagate-header", "request-id", "trace"] } trace = "0" tracing = "0" +#tracing-subscriber = { version = "0", features = ["json"] } url = "2" uuid = { version = "1", features = ["v4"] } zerocopy = "0.7.33" [package.metadata.cargo-machete] -ignored = ["crossbeam-skiplist", "dashmap", "figment", "parking_lot", "serde_bytes"] +ignored = ["crossbeam-skiplist", "dashmap", "figment", "parking_lot", "serde_bytes", "tokio-stream"] [dev-dependencies] local-ip-address = "0" diff --git a/cSpell.json b/cSpell.json index 2b5cf55b..cc1bc3da 100644 --- a/cSpell.json +++ b/cSpell.json @@ -34,10 +34,12 @@ "codecov", "codegen", "completei", + "Condvar", "connectionless", "Containerfile", "conv", "curr", + "cvar", "Cyberneering", "dashmap", "datagram", diff --git a/src/servers/udp/handlers.rs b/src/servers/udp/handlers.rs index fee00a0b..b540a7ed 100644 --- a/src/servers/udp/handlers.rs +++ b/src/servers/udp/handlers.rs @@ -11,7 +11,6 @@ use aquatic_udp_protocol::{ ScrapeRequest, ScrapeResponse, TorrentScrapeStatistics, TransactionId, }; use log::debug; -use tokio::net::UdpSocket; use torrust_tracker_located_error::DynError; use torrust_tracker_primitives::info_hash::InfoHash; use uuid::Uuid; @@ -35,13 +34,12 @@ use crate::shared::bit_torrent::common::MAX_SCRAPE_TORRENTS; /// type. /// /// It will return an `Error` response if the request is invalid. -pub(crate) async fn handle_packet(udp_request: UdpRequest, tracker: &Arc, socket: Arc) -> Response { +pub(crate) async fn handle_packet(udp_request: UdpRequest, tracker: &Arc, addr: SocketAddr) -> Response { debug!("Handling Packets: {udp_request:?}"); let start_time = Instant::now(); let request_id = RequestId::make(&udp_request); - let server_socket_addr = socket.local_addr().expect("Could not get local_addr for socket."); match Request::parse_bytes(&udp_request.payload[..udp_request.payload.len()], MAX_SCRAPE_TORRENTS).map_err(|e| { Error::InternalServer { @@ -50,7 +48,7 @@ pub(crate) async fn handle_packet(udp_request: UdpRequest, tracker: &Arc { - log_request(&request, &request_id, &server_socket_addr); + log_request(&request, &request_id, &addr); let transaction_id = match &request { Request::Connect(connect_request) => connect_request.transaction_id, @@ -65,7 +63,7 @@ pub(crate) async fn handle_packet(udp_request: UdpRequest, tracker: &Arc { /// /// It panics if unable to receive the bound socket address from service. /// - pub async fn start(self, tracker: Arc, form: ServiceRegistrationForm) -> Result, Error> { + pub async fn start(self, tracker: Arc, form: ServiceRegistrationForm) -> Result, std::io::Error> { let (tx_start, rx_start) = tokio::sync::oneshot::channel::(); let (tx_halt, rx_halt) = tokio::sync::oneshot::channel::(); @@ -143,7 +150,7 @@ impl UdpServer { }, }; - trace!("Running UDP Tracker on Socket: {}", running_udp_server.state.binding); + tracing::trace!("Running UDP Tracker on Socket: {}", running_udp_server.state.binding); Ok(running_udp_server) } @@ -161,11 +168,11 @@ impl UdpServer { /// # Panics /// /// It panics if unable to shutdown service. - pub async fn stop(self) -> Result, Error> { + pub async fn stop(self) -> Result, UdpError> { self.state .halt_task .send(Halted::Normal) - .map_err(|e| Error::Error(e.to_string()))?; + .map_err(|e| UdpError::Error(e.to_string()))?; let launcher = self.state.task.await.expect("unable to shutdown service"); @@ -202,23 +209,12 @@ impl Launcher { } } +/// Ring-Buffer of Active Requests +#[derive(Default)] struct ActiveRequests { rb: StaticRb, // the number of requests we handle at the same time. } -impl ActiveRequests { - /// Creates a new [`ActiveRequests`] filled with finished tasks. - async fn new() -> Self { - let mut rb = StaticRb::default(); - - let () = while rb.try_push(tokio::task::spawn_blocking(|| ()).abort_handle()).is_ok() {}; - - task::yield_now().await; - - Self { rb } - } -} - impl std::fmt::Debug for ActiveRequests { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let (left, right) = &self.rb.as_slices(); @@ -237,6 +233,411 @@ impl Drop for ActiveRequests { } } +/// Tokio Select Output for Waiting for Binding to the Socket +enum SocketBind { + Bind(std::io::Result), + Interrupted, +} + +/// Tokio Select Output for Waiting for the Socket to become Ready +enum SocketReady { + Ready(std::io::Result), + Interrupted, +} + +/// Tokio Select Output for Waiting for Reading from the Socket +enum SocketRead { + Read(std::io::Result<(usize, SocketAddr)>), + Interrupted, +} + +/// Wrapper for Tokio [`UdpSocket`][`tokio::net::UdpSocket`] that can be canceled +struct UdpSocket { + socket: tokio::net::UdpSocket, + halt: Arc, +} + +impl UdpSocket { + async fn new(addr: SocketAddr, halt: Arc) -> Result> { + let bind = tokio::net::UdpSocket::bind(addr); + let interrupt = halt.notified(); + + let bind = tokio::select! { + socket = bind => SocketBind::Bind(socket), + () = interrupt => SocketBind::Interrupted, + }; + + let socket = match bind { + SocketBind::Bind(socket) => socket, + SocketBind::Interrupted => { + tracing::info!(target: "UDP TRACKER", local_addr = %addr, "UdpSocket: interrupting wait for binding socket"); + return Err(std::io::Error::new(std::io::ErrorKind::Interrupted, "caught halt"))?; + } + }; + + let socket = match socket { + Ok(socket) => socket, + Err(e) => Err(e)?, + }; + + Ok(Self { socket, halt }) + } + + async fn receive_request(&self) -> Result> { + let local_addr = self.local_addr().expect("it should get the local address"); + + // wait for the socket to become readable + { + let wait_ready = self.ready(Interest::READABLE); + let interrupt = self.halt.notified(); + + let ready = tokio::select! { + ready = wait_ready => SocketReady::Ready(ready), + () = interrupt => SocketReady::Interrupted, + }; + + let ready = match ready { + SocketReady::Ready(ready) => ready, + SocketReady::Interrupted => { + tracing::info!(target: "UDP TRACKER", %local_addr, "UdpSocket: it was interrupted while waiting to become ready."); + return Err(std::io::Error::new(std::io::ErrorKind::Interrupted, "caught halt"))?; + } + }; + + let () = match ready { + Ok(ready) => assert!(ready.is_readable(), "it should be readable now"), + Err(e) => { + tracing::error!(target: "UDP TRACKER", %local_addr, err = %e, "UdpSocket: waiting for ready errored!"); + return Err(e)?; + } + }; + } + + // wait for request to be received + { + let mut buf = Vec::with_capacity(MAX_PACKET_SIZE); + + let wait_read = self.recv_buf_from(&mut buf); + let interrupt = self.halt.notified(); + + let read = tokio::select! { + read = wait_read => SocketRead::Read(read), + () = interrupt => SocketRead::Interrupted + }; + + let read = match read { + SocketRead::Read(read) => read, + SocketRead::Interrupted => { + tracing::info!(target: "UDP TRACKER", %local_addr, "UdpSocket: it was interrupted while waiting to read."); + return Err(std::io::Error::new(std::io::ErrorKind::Interrupted, "caught halt"))?; + } + }; + + match read { + Ok((n, from)) => { + Vec::truncate(&mut buf, n); + tracing::trace!(target: "UDP TRACKER", %local_addr, buffer = ?buf, "UdpSocket: read buffer from socket!"); + Ok(UdpRequest { payload: buf, from }) + } + + Err(e) => { + tracing::error!(target: "UDP TRACKER", %local_addr, err = %e, "UdpSocket: waiting for ready errored!"); + Err(e)? + } + } + } + } +} + +impl Deref for UdpSocket { + type Target = tokio::net::UdpSocket; + + fn deref(&self) -> &Self::Target { + &self.socket + } +} + +impl Debug for UdpSocket { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let local_addr = match self.socket.local_addr() { + Ok(socket) => format!("Receiving From: {socket}"), + Err(err) => format!("Socket Broken: {err}"), + }; + + f.debug_struct("UdpSocket").field("addr", &local_addr).finish_non_exhaustive() + } +} +/// Handle for a Working Udp Response Task +type UdpTaskHandleResult = Result>; + +/// Listens to a Socket and Receives Requests. +#[derive(Default, Debug)] +struct Receiver { + next: Arc<(Mutex>, Condvar)>, + waker: Arc>>, + task: Option<(std::thread::JoinHandle<()>, SocketAddr, Arc)>, + halted: Arc, +} + +impl Receiver { + fn listen(mut self, socket: Arc, tracker: Arc, handle: tokio::runtime::Handle) -> Result { + if self.halted.load(Ordering::SeqCst) { + tracing::error!(target: "UDP TRACKER", "was already halted"); + return Err(self); + } + + let local_addr = socket.local_addr().expect("it should get local address"); + let halt = socket.halt.clone(); + + let next = self.next.clone(); + let waker = self.waker.clone(); + let halted = self.halted.clone(); + + let task = std::thread::spawn(move || { + tracing::debug!(target: "UDP TRACKER", %local_addr, "main thread of receiving process started"); + let rt = tokio::runtime::Builder::new_current_thread() + .build() + .expect("it should make a runtime"); + + let (next, cvar) = &*next; + + loop { + tracing::trace!(target: "UDP TRACKER", %local_addr, "in main loop of receiving thread"); + let mut stop = false; + + let need_next = { + let next = next.lock().expect("it should get a lock"); + next.is_none() + }; + + if need_next { + tracing::trace!(target: "UDP TRACKER", %local_addr, "we do not have the next request, lest collect it"); + + let request = rt.block_on(socket.clone().receive_request()); + + let res = Some(match request { + Ok(request) => { + tracing::trace!(target: "UDP TRACKER", %local_addr, "got a request, will start processing it"); + Ok(handle + .spawn(Udp::process_request(request, tracker.clone(), socket.clone())) + .abort_handle()) + } + + Err(e) => { + if e.kind() == std::io::ErrorKind::Interrupted { + tracing::info!(target: "UDP TRACKER", err = %e, "interrupted"); + stop = true; + } else { + tracing::error!(target: "UDP TRACKER", err = %e, "request error"); + } + Err(e) + } + }); + + { + tracing::trace!(target: "UDP TRACKER", %local_addr, "now we have the processing request, lets set it"); + let mut next = next.lock().expect("it should get a lock"); + *next = res; + } + } + + tracing::trace!(target: "UDP TRACKER", %local_addr, "the next request is set, lets notify any waiters.."); + + if stop { + tracing::info!(target: "UDP TRACKER", %local_addr, "halting request receiver"); + halted.store(true, Ordering::SeqCst); + break; + } + + let mut next = next.lock().expect("it should get a lock"); + while next.is_some() { + tracing::trace!(target: "UDP TRACKER", "we have a waiting processing request"); + { + let waker = waker.lock().expect("it should get a lock"); + tracing::trace!(target: "UDP TRACKER", "waker is: {}", waker.is_some()); + + if let Some(waker) = waker.as_ref() { + waker.wake_by_ref(); + }; + } + + tracing::trace!(target: "UDP TRACKER", "wait until we need to refresh"); + let (update, timeout) = cvar + .wait_timeout(next, Duration::from_millis(10)) + .expect("it should get the new lock"); + + next = update; + tracing::trace!(target: "UDP TRACKER", ?timeout, "cvar updated lock"); + } + } + }); + + self.task = Some((task, local_addr, halt)); + + Ok(self) + } +} + +impl Drop for Receiver { + fn drop(&mut self) { + let was_halted = self.halted.swap(true, Ordering::SeqCst); + + let Some(task) = self.task.take() else { + tracing::error!(target: "UDP TRACKER", info = %self, was_halted, "halting: receiver: task not found!"); + return; + }; + + let (task, _, halt) = task; + + // notify waiting, and next that we are halting. + halt.notify_waiters(); + halt.notify_one(); + + let was_finished = task.is_finished(); + + tracing::info!(target: "UDP TRACKER", info = %self, was_halted, was_finished, "halting: receiver"); + + match task.join() { + Ok(()) => tracing::debug!(target: "UDP TRACKER", info = %self, "halting: receiver: finished"), + Err(e) => { + tracing::warn!(target: "UDP TRACKER", info = %self, err = ?e, "halting: receiver: finished with error"); + } + } + } +} + +impl std::fmt::Display for Receiver { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match &self.task { + Some(task) => { + if task.0.is_finished() { + f.write_fmt(format_args!("finished: {}", task.1)) + } else { + f.write_fmt(format_args!("running: {}", task.1)) + } + } + + None => f.write_fmt(format_args!("not_started")), + } + } +} + +impl Future for Receiver { + type Output = Result>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + if !self.task.as_ref().is_some_and(|(t, _, _)| !t.is_finished()) { + tracing::warn!(target: "UDP TRACKER", "it should be listening before polling"); + return Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::ConnectionRefused, + "polling but not listening", + ))?); + }; + + tracing::trace!(target: "UDP TRACKER", %self, "receiver was polled"); + + let (next, cvar) = &*self.next; + + let mut next = next.lock().expect("it should get a lock"); + tracing::trace!(target: "UDP TRACKER", %self,"next is: {}", next.is_some()); + + let mut waker = self.waker.lock().expect("it should get a lock"); + tracing::trace!(target: "UDP TRACKER", %self," waker is: {}", waker.is_some()); + + if let Some(res) = next.take() { + drop(waker.take()); // drop any waker + cvar.notify_all(); + tracing::trace!(target: "UDP TRACKER", %self,"returned waiting task"); + return Poll::Ready(res); + }; + + if let Some(mut waker) = waker.as_ref() { + tracing::trace!(target: "UDP TRACKER", %self,"updating waker"); + waker.clone_from(&cx.waker()); + } else { + tracing::trace!(target: "UDP TRACKER", %self, "setting new waker"); + *waker = Some(cx.waker().clone()); + } + + cvar.notify_all(); + + tracing::trace!(target: "UDP TRACKER", %self, "lets wait until we are awoken"); + Poll::Pending + } +} + +/// A Generic Trait that produces a Stream of [`UdpTaskHandleResult`] +trait Reqs: Stream + Debug + Display +where + T: Future + Display + Debug + Unpin, +{ +} + +impl Reqs for Requests where T: Future + Display + Debug + Unpin {} + +/// Provides a Stream of [`UdpTaskHandleResult`] +#[derive(Debug)] +struct Requests +where + T: Future + Display + Debug + Unpin, +{ + receiver: T, + errored: bool, +} + +impl Requests +where + T: Future + Display + Debug + Unpin, +{ + fn new(receiver: T) -> Self { + Self { + receiver, + errored: false, + } + } +} + +impl std::fmt::Display for Requests +where + T: Future + Display + Debug + Unpin, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_fmt(format_args!("T: {}", self.receiver)) + } +} + +impl Stream for Requests +where + T: Future + Display + Debug + Unpin, +{ + type Item = Result>; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if self.errored { + tracing::warn!(target: "UDP TRACKER", "stream errored: returning none."); + return std::task::Poll::Ready(None); + } + + match Pin::new(&mut self.receiver).poll(cx) { + Poll::Ready(h) => match h { + Ok(h) => { + tracing::trace!(target: "UDP TRACKER", "got handle"); + std::task::Poll::Ready(Some(Ok(h))) + } + Err(e) => { + self.errored = true; + tracing::trace!(target: "UDP TRACKER", "got error"); + std::task::Poll::Ready(Some(Err(e))) + } + }, + Poll::Pending => { + tracing::trace!(target: "UDP TRACKER", "not ready: pending"); + std::task::Poll::Pending + } + } + } +} + /// A UDP server instance launcher. #[derive(Constructor)] pub struct Udp; @@ -254,124 +655,164 @@ impl Udp { tx_start: oneshot::Sender, rx_halt: oneshot::Receiver, ) { - let socket = Arc::new( - UdpSocket::bind(bind_to) - .await - .unwrap_or_else(|_| panic!("Could not bind to {bind_to}.")), - ); - let address = socket - .local_addr() - .unwrap_or_else(|_| panic!("Could not get local_addr from {bind_to}.")); - let halt = shutdown_signal_with_message(rx_halt, format!("Halting Http Service Bound to Socket: {address}")); - - info!(target: "UDP TRACKER", "Starting on: udp://{}", address); + let halt = Arc::new(Notify::new()); + + let halt_tx = halt.clone(); + let halt_task = tokio::task::spawn(async move { + shutdown_signal_with_message(rx_halt, format!("Halting Http Service Bound to Socket: {bind_to}")).await; + halt_tx.notify_one(); + }); + + let socket = UdpSocket::new(bind_to, halt).await; + + let socket = match socket { + Ok(socket) => socket, + Err(e) => { + tracing::error!(target: "UDP TRACKER", addr = %bind_to, err = %e, "error when building socket" ); + panic!("could not bind to socket!"); + } + }; + let address = socket.local_addr().expect("it should get the locally bound address"); + tracing::info!(target: "UDP TRACKER", "Starting on: udp://{}", address); + + let receiver = Receiver::default(); + + let receiver = Receiver::listen(receiver, socket.into(), tracker, tokio::runtime::Handle::current()) + .expect("it should not be halted already"); + tracing::trace!(target: "UDP TRACKER", "receiver was created"); + let requests = Requests::new(receiver); + + tracing::trace!(target: "UDP TRACKER", "spawning udp server inner-loop"); let running = tokio::task::spawn(async move { - debug!(target: "UDP TRACKER", "Started: Waiting for packets on socket address: udp://{address} ..."); - Self::run_udp_server(tracker, socket).await; + tracing::debug!(target: "UDP TRACKER", "Started: Waiting for packets on socket address: udp://{address} ..."); + Self::run_udp_server(requests).await.expect("it should not run dry"); }); tx_start .send(Started { address }) .expect("the UDP Tracker service should not be dropped"); - debug!(target: "UDP TRACKER", "Started on: udp://{}", address); + tracing::debug!(target: "UDP TRACKER", "Started on: udp://{}", address); let stop = running.abort_handle(); select! { - _ = running => { debug!(target: "UDP TRACKER", "Socket listener stopped on address: udp://{address}"); }, - () = halt => { debug!(target: "UDP TRACKER", "Halt signal spawned task stopped on address: udp://{address}"); } + _ = running => { tracing::debug!(target: "UDP TRACKER", "Socket listener stopped on address: udp://{address}"); }, + _ = halt_task => { tracing::debug!(target: "UDP TRACKER", "Halt signal spawned task stopped on address: udp://{address}"); } } stop.abort(); task::yield_now().await; // lets allow the other threads to complete. } - async fn run_udp_server(tracker: Arc, socket: Arc) { - let tracker = tracker.clone(); - let socket = socket.clone(); - - let reqs = &mut ActiveRequests::new().await; + async fn run_udp_server<'a, R, T>(mut requests: R) -> Result<(), R> + where + R: Reqs + Unpin, + T: Future + Display + Debug + Unpin, + { + let reqs = &mut ActiveRequests::default(); loop { - task::yield_now().await; - for h in reqs.rb.iter_mut() { - if h.is_finished() { - std::mem::swap( - h, - &mut Self::spawn_request_processor( - Self::receive_request(socket.clone()).await, - tracker.clone(), - socket.clone(), - ) - .abort_handle(), - ); - } else { - // the task is still running, lets yield and give it a chance to flush. + if let Some(req) = { + tracing::trace!(target: "UDP TRACKER", info = %requests, "getting request"); + requests.next().await + } { + tracing::trace!(target: "UDP TRACKER", info = %requests, "processing request"); + + let req = match req { + Ok(req) => req, + Err(e) => { + if e.kind() == std::io::ErrorKind::Interrupted { + tracing::warn!(target: "UDP TRACKER", info = %requests, err = %e, "was interrupted"); + return Ok(()); + } + tracing::error!(target: "UDP TRACKER", info = %requests, err = %e, ""); + break; + } + }; + + if req.is_finished() { + continue; + } + + // fill buffer with requests + let Err(req) = reqs.rb.try_push(req) else { + continue; + }; + + let mut finished: u64 = 0; + let mut unfinished_task = None; + // buffer is full.. lets make some space. + for h in reqs.rb.pop_iter() { + // remove some finished tasks + if h.is_finished() { + finished += 1; + continue; + } + + // task is unfinished.. give it another chance. tokio::task::yield_now().await; - h.abort(); + // if now finished, we continue. + if h.is_finished() { + finished += 1; + continue; + } - let server_socket_addr = socket.local_addr().expect("Could not get local_addr for socket."); + if finished == 0 { + // we have _no_ finished tasks.. will abort the unfinished task to make space... + h.abort(); - tracing::span!( - target: "UDP TRACKER", - tracing::Level::WARN, "request-aborted", server_socket_addr = %server_socket_addr); + tracing::warn!(target: "UDP TRACKER", info = %requests, "request-aborted: removed no tasks, will abort to make space for new request"); + break; + } - // force-break a single thread, then loop again. - break; - } - } - } - } + tracing::debug!(target: "UDP TRACKER", info = %requests, removed_count = finished, "removed completed-requests"); - async fn receive_request(socket: Arc) -> Result> { - // Wait for the socket to be readable - socket.readable().await?; + // we have space, return unfinished task for re-entry. + unfinished_task = Some(h); + } - let mut buf = Vec::with_capacity(MAX_PACKET_SIZE); + // re-insert the previous unfinished task. + if let Some(h) = unfinished_task { + reqs.rb.try_push(h).expect("it was previously inserted"); + } - match socket.recv_buf_from(&mut buf).await { - Ok((n, from)) => { - Vec::truncate(&mut buf, n); - trace!("GOT {buf:?}"); - Ok(UdpRequest { payload: buf, from }) + // insert the new task. + if !req.is_finished() { + reqs.rb.try_push(req).expect("it should remove at least one element."); + } + } else { + tokio::task::yield_now().await; + // the request iterator returned `None`. + tracing::error!(target: "UDP TRACKER", info = %requests, "request socket ran dry! this should never happen in production!"); + break; } - - Err(e) => Err(Box::new(e)), } - } - fn spawn_request_processor( - result: Result>, - tracker: Arc, - socket: Arc, - ) -> JoinHandle<()> { - tokio::task::spawn(Self::process_request(result, tracker, socket)) + Err(requests) } - async fn process_request(result: Result>, tracker: Arc, socket: Arc) { - match result { - Ok(udp_request) => { - trace!("Received Request from: {}", udp_request.from); - Self::process_valid_request(tracker.clone(), socket.clone(), udp_request).await; - } - Err(error) => { - debug!("error: {error}"); - } - } + async fn process_request(request: UdpRequest, tracker: Arc, socket: Arc) { + tracing::trace!("Received Request from: {}", request.from); + Self::process_valid_request(tracker.clone(), socket.clone(), request).await; } async fn process_valid_request(tracker: Arc, socket: Arc, udp_request: UdpRequest) { - trace!("Making Response to {udp_request:?}"); + tracing::trace!("Making Response to {udp_request:?}"); let from = udp_request.from; - let response = handlers::handle_packet(udp_request, &tracker.clone(), socket.clone()).await; + let response = handlers::handle_packet( + udp_request, + &tracker.clone(), + socket.local_addr().expect("it should get the local address"), + ) + .await; Self::send_response(&socket.clone(), from, response).await; } async fn send_response(socket: &Arc, to: SocketAddr, response: Response) { - trace!("Sending Response: {response:?} to: {to:?}"); + tracing::trace!("Sending Response: {response:?} to: {to:?}"); let buffer = vec![0u8; MAX_PACKET_SIZE]; let mut cursor = Cursor::new(buffer); @@ -382,22 +823,22 @@ impl Udp { let position = cursor.position() as usize; let inner = cursor.get_ref(); - debug!("Sending {} bytes ...", &inner[..position].len()); - debug!("To: {:?}", &to); - debug!("Payload: {:?}", &inner[..position]); + tracing::debug!("Sending {} bytes ...", &inner[..position].len()); + tracing::debug!("To: {:?}", &to); + tracing::debug!("Payload: {:?}", &inner[..position]); Self::send_packet(socket, &to, &inner[..position]).await; - debug!("{} bytes sent", &inner[..position].len()); + tracing::debug!("{} bytes sent", &inner[..position].len()); } Err(_) => { - error!("could not write response to bytes."); + tracing::error!("could not write response to bytes."); } } } async fn send_packet(socket: &Arc, remote_addr: &SocketAddr, payload: &[u8]) { - trace!("Sending Packets: {payload:?} to: {remote_addr:?}"); + tracing::trace!("Sending Packets: {payload:?} to: {remote_addr:?}"); // doesn't matter if it reaches or not drop(socket.send_to(payload, remote_addr).await); @@ -415,51 +856,360 @@ impl Udp { #[cfg(test)] mod tests { - use std::sync::Arc; + use std::fmt::{Debug, Display, Formatter}; + use std::pin::Pin; + use std::sync::atomic::{AtomicBool, Ordering}; + use std::sync::{Arc, Mutex}; + use std::task::{Context, Poll, Waker}; use std::time::Duration; - use ringbuf::traits::{Consumer, Observer, RingBuffer}; + use futures::{Future, Stream}; + use rstest::{fixture, rstest}; + use tokio::sync::Barrier; + use tokio::task::AbortHandle; use torrust_tracker_test_helpers::configuration::ephemeral_mode_public; - use super::ActiveRequests; + use super::{Reqs, UdpTaskHandleResult}; use crate::bootstrap::app::initialize_with_configuration; use crate::servers::registar::Registar; - use crate::servers::udp::server::{Launcher, UdpServer}; + use crate::servers::udp::server::{Launcher, Requests, Udp, UdpServer}; + + // todo, this dose-not need to be multi-threaded + #[derive(Debug)] + struct MockReceiver { + waker: Option, + #[allow(clippy::type_complexity)] + next: Option, + barrier: Arc, + handle: tokio::runtime::Handle, + inner: MockInner, + } + + impl Display for MockReceiver { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_fmt(format_args!("{:?}", self.inner)) + } + } + + impl Future for MockReceiver { + type Output = Result>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + if self.inner.remaining == 0 { + return Poll::Ready(Err(Box::new(std::io::Error::new( + std::io::ErrorKind::NotFound, + "reached limit", + )))); + } + + { + let next = self.next.take(); + if let Some(next) = next { + drop(self.waker.take()); + return Poll::Ready(next); + }; + } + + // if we hae a waker + if let Some(ref mut waker) = self.waker { + waker.clone_from(cx.waker()); + } else { + self.inner.remaining -= 1; + + self.waker = Some(cx.waker().clone()); + + let barrier = self.barrier.clone(); + let status: Arc = Arc::default(); + let () = self.inner.completed.push(status.clone()); + + let task = self.handle.spawn(async move { + if barrier.wait().await.is_leader() { + tracing::info!("group processed"); + } + let () = status.store(true, Ordering::SeqCst); + }); + + self.inner.handles.push(task.abort_handle()); + + // we have a new request, lets start processing it + self.next = Some(Ok(task.abort_handle())); + + // our next task is ready, lets notify any waiters + cx.waker().wake_by_ref(); + } + + // wait until the waker is triggered + Poll::Pending + } + } + + #[derive(Debug)] + struct MockRequests { + inner: Arc>, + barrier: Arc, + guard: Option>, + } + + impl Display for MockRequests { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let inner = self.inner.lock().expect("it should get a lock"); + + f.write_fmt(format_args!("{inner:?}")) + } + } + + struct MockInner { + remaining: isize, + completed: Vec>, + handles: Vec, + } + + impl Debug for MockInner { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("MockRequestsInner") + .field("remaining", &self.remaining()) + .field("issued", &self.issued()) + .field("finished", &self.finished()) + .field("completed", &self.completed()) + .field("pending", &self.pending()) + .field("aborted", &self.aborted()) + .finish() + } + } + + impl MockInner { + fn remaining(&self) -> isize { + self.remaining + } + fn issued(&self) -> isize { + self.handles.len().try_into().expect("it_should_fit_into") + } + fn finished(&self) -> isize { + self.handles + .iter() + .filter(|h| h.is_finished()) + .count() + .try_into() + .expect("it_should_fit_into") + } + fn completed(&self) -> isize { + self.completed + .iter() + .filter(|a| a.load(Ordering::SeqCst)) + .count() + .try_into() + .expect("it_should_fit_into") + } + fn pending(&self) -> isize { + self.handles + .iter() + .filter(|h| !h.is_finished()) + .count() + .try_into() + .expect("it_should_fit_into") + } + fn aborted(&self) -> isize { + self.finished() - self.completed() + } + } + + impl Reqs for MockRequests {} + + impl Stream for MockRequests { + type Item = Result>; + + fn poll_next(self: std::pin::Pin<&mut Self>, _: &mut std::task::Context<'_>) -> std::task::Poll> { + tracing::debug!("polling for new task"); + + let _guard = match &self.guard { + Some(guard) => match guard.try_lock() { + Ok(guard) => Some(guard), + Err(e) => match e { + std::sync::TryLockError::Poisoned(e) => { + tracing::error!("lock is poisoned: {e}"); + panic!() + } + std::sync::TryLockError::WouldBlock => return std::task::Poll::Pending, + }, + }, + None => None, + }; + + let mut inner = self.inner.lock().expect("it should get a lock"); + + tracing::debug!("poll locked, {} remaining tasks to make", inner.remaining); + + let r = if inner.remaining == 0 { + std::task::Poll::Ready(None) + } else { + inner.remaining -= 1; + + let barrier = Arc::clone(&self.barrier); + let status: Arc = Arc::default(); // false + + let () = inner.completed.push(status.clone()); + + let remaining = inner.remaining; + let task = tokio::task::spawn(async move { + if barrier.wait().await.is_leader() { + tracing::info!("group_processed with remaining: {remaining}"); + } + status.store(true, Ordering::SeqCst); + }); + + inner.handles.push(task.abort_handle()); + + std::task::Poll::Ready(Some(Ok(task.abort_handle()))) + }; + + println!("inner: {inner:?}"); + + r + } + } + #[fixture] + fn tokio_single_thread() -> tokio::runtime::Runtime { + tokio::runtime::Builder::new_current_thread() + .build() + .expect("it should build runtime") + } + + #[fixture] + fn tokio_multi_thread() -> tokio::runtime::Runtime { + tokio::runtime::Builder::new_multi_thread() + .build() + .expect("it should build runtime") + } + + //static INIT: std::sync::Once = std::sync::Once::new(); + + #[rstest] #[tokio::test] - async fn it_should_return_to_the_start_of_the_ring_buffer() { - let mut a_req = ActiveRequests::new().await; + async fn it_should_process_many_requests_2(#[values(1, 20, 1000)] total: isize, #[values(1, 3, 20)] barrier: usize) { + // INIT.call_once(|| { + // let () = tracing_subscriber::fmt() + // .compact() + // .with_max_level(tracing::Level::TRACE) + // .init(); + // }); + + let mock_receiver = MockReceiver { + waker: None, + next: None, + barrier: Arc::new(Barrier::new(barrier)), + handle: tokio::runtime::Handle::current(), + inner: MockInner { + remaining: total, + completed: Vec::default(), + handles: Vec::default(), + }, + }; - tokio::time::sleep(Duration::from_millis(10)).await; + let requests = Requests::new(mock_receiver); + + let Err(requests) = Udp::run_udp_server(requests).await else { + panic!("it should return error from a empty stream"); + }; - let mut count: usize = 0; - let cap: usize = a_req.rb.capacity().into(); + let inner = requests.receiver.inner; - // Add a single pending task to check that the ring-buffer is looping correctly. - a_req - .rb - .push_overwrite(tokio::task::spawn(std::future::pending::<()>()).abort_handle()); + assert_eq!(inner.remaining(), 0); + assert_eq!(inner.issued(), total); + assert_eq!(inner.completed(), total - inner.aborted() - inner.pending()); + } - count += 1; + #[rstest] + #[case::rt_st(tokio_single_thread())] + #[case::rt_mt(tokio_multi_thread())] + fn it_should_process_many_requests( + #[case] rt: tokio::runtime::Runtime, + #[values(None, Some(Mutex::default()))] guard: Option>, + #[values(1, 20, 1000)] total: isize, + #[values(1, 3, 20)] barrier: usize, + ) { + let mock_requests = MockRequests { + guard, + inner: Arc::new(Mutex::new(MockInner { + remaining: total, + completed: Vec::default(), + handles: Vec::default(), + })), + barrier: Barrier::new(barrier).into(), + }; - for _ in 0..2 { - for h in a_req.rb.iter() { - let first = count % cap; - println!("{count},{first},{}", h.is_finished()); + let mock_requests = { + let rt = rt; - if first == 0 { - assert!(!h.is_finished()); - } else { - assert!(h.is_finished()); + let mock_requests = match rt.block_on(Udp::run_udp_server::(mock_requests)) { + Ok(()) => unreachable!("it should end with error"), + Err(e) => { + tracing::trace!("finished"); + e } + }; - count += 1; + { + let inner = mock_requests.inner.lock().expect("it should get a lock"); + + tracing::info!(inner = ?inner, "sleeping..."); + + let () = std::thread::sleep(Duration::from_millis(1000)); + + tracing::info!(inner = ?inner, "finished inner..."); } - } + mock_requests + }; + + let inner = mock_requests.inner.lock().expect("it should get a lock"); + + tracing::info!(inner = ?inner, "finished outer..."); + println!("Finished with:{inner:?}"); + + assert_eq!(inner.remaining(), 0); + assert_eq!(inner.issued(), total); + assert_eq!(inner.completed(), total - inner.aborted()); + assert_eq!(inner.pending(), 0); } #[tokio::test] async fn it_should_be_able_to_start_and_stop() { + // INIT.call_once(|| { + // let () = tracing_subscriber::fmt() + // .compact() + // .with_max_level(tracing::Level::TRACE) + // .init(); + // }); + + let cfg = Arc::new(ephemeral_mode_public()); + let tracker = initialize_with_configuration(&cfg); + let config = &cfg.udp_trackers[0]; + let bind_to = config.bind_address; + let register = &Registar::default(); + + let stopped = UdpServer::new(Launcher::new(bind_to)); + + let started = stopped + .start(tracker, register.give_form()) + .await + .expect("it should start the server"); + + let stopped = started.stop().await.expect("it should stop the server"); + + tokio::time::sleep(Duration::from_secs(1)).await; + + assert_eq!(stopped.state.launcher.bind_to, bind_to); + } + + #[tokio::test] + async fn it_should_be_able_to_start_and_stop_with_wait() { + // INIT.call_once(|| { + // let () = tracing_subscriber::fmt() + // .compact() + // .with_max_level(tracing::Level::TRACE) + // .init(); + // }); + let cfg = Arc::new(ephemeral_mode_public()); let tracker = initialize_with_configuration(&cfg); let config = &cfg.udp_trackers[0]; @@ -473,6 +1223,8 @@ mod tests { .await .expect("it should start the server"); + tokio::time::sleep(Duration::from_secs(1)).await; + let stopped = started.stop().await.expect("it should stop the server"); tokio::time::sleep(Duration::from_secs(1)).await; @@ -480,3 +1232,68 @@ mod tests { assert_eq!(stopped.state.launcher.bind_to, bind_to); } } + +/// Todo: submit test to tokio documentation. +#[cfg(test)] +mod test_tokio { + use std::sync::Arc; + use std::time::Duration; + + use tokio::sync::Barrier; + use tokio::task::JoinSet; + + #[tokio::test] + async fn test_barrier_with_aborted_tasks() { + // Create a barrier that requires 10 tasks to proceed. + let barrier = Arc::new(Barrier::new(10)); + let mut tasks = JoinSet::default(); + let mut handles = Vec::default(); + + // Set Barrier to 9/10. + for _ in 0..9 { + let c = barrier.clone(); + handles.push(tasks.spawn(async move { + c.wait().await; + })); + } + + // Abort two tasks: Barrier: 7/10. + for _ in 0..2 { + if let Some(handle) = handles.pop() { + handle.abort(); + } + } + + // Spawn a single task: Barrier 8/10. + let c = barrier.clone(); + handles.push(tasks.spawn(async move { + c.wait().await; + })); + + // give a chance fro the barrier to release. + tokio::time::sleep(Duration::from_millis(50)).await; + + // assert that the barrier isn't removed, i.e. 8, not 10. + for h in &handles { + assert!(!h.is_finished()); + } + + // Spawn two more tasks to trigger the barrier release: Barrier 10/10. + for _ in 0..2 { + let c = barrier.clone(); + handles.push(tasks.spawn(async move { + c.wait().await; + })); + } + + // give a chance fro the barrier to release. + tokio::time::sleep(Duration::from_millis(50)).await; + + // assert that the barrier has been triggered + for h in &handles { + assert!(h.is_finished()); + } + + tasks.shutdown().await; + } +} diff --git a/tests/servers/health_check_api/contract.rs b/tests/servers/health_check_api/contract.rs index 3c3c1315..6164a516 100644 --- a/tests/servers/health_check_api/contract.rs +++ b/tests/servers/health_check_api/contract.rs @@ -245,6 +245,8 @@ mod udp { use crate::servers::health_check_api::Started; use crate::servers::udp; + //static INIT: std::sync::Once = std::sync::Once::new(); + #[tokio::test] pub(crate) async fn it_should_return_good_health_for_udp_service() { let configuration = Arc::new(configuration::ephemeral()); @@ -288,6 +290,13 @@ mod udp { #[tokio::test] pub(crate) async fn it_should_return_error_when_udp_service_was_stopped_after_registration() { + // INIT.call_once(|| { + // let () = tracing_subscriber::fmt() + // .compact() + // .with_max_level(tracing::Level::TRACE) + // .init(); + // }); + let configuration = Arc::new(configuration::ephemeral()); let service = udp::Started::new(&configuration).await; diff --git a/tests/servers/udp/contract.rs b/tests/servers/udp/contract.rs index 7abd6092..3961fada 100644 --- a/tests/servers/udp/contract.rs +++ b/tests/servers/udp/contract.rs @@ -75,8 +75,17 @@ mod receiving_a_connection_request { use crate::servers::udp::asserts::is_connect_response; use crate::servers::udp::Started; + // static INIT: std::sync::Once = std::sync::Once::new(); + #[tokio::test] async fn should_return_a_connect_response() { + // INIT.call_once(|| { + // let () = tracing_subscriber::fmt() + // .compact() + // .with_max_level(tracing::Level::TRACE) + // .init(); + // }); + let env = Started::new(&configuration::ephemeral().into()).await; let client = match new_udp_tracker_client_connected(&env.bind_address().to_string()).await { @@ -111,30 +120,20 @@ mod receiving_an_announce_request { AnnounceActionPlaceholder, AnnounceEvent, AnnounceRequest, ConnectionId, InfoHash, NumberOfBytes, NumberOfPeers, PeerId, PeerKey, Port, TransactionId, }; - use torrust_tracker::shared::bit_torrent::tracker::udp::client::new_udp_tracker_client_connected; + use torrust_tracker::shared::bit_torrent::tracker::udp::client::{new_udp_tracker_client_connected, UdpTrackerClient}; use torrust_tracker_test_helpers::configuration; use crate::servers::udp::asserts::is_ipv4_announce_response; use crate::servers::udp::contract::send_connection_request; use crate::servers::udp::Started; - #[tokio::test] - async fn should_return_an_announce_response() { - let env = Started::new(&configuration::ephemeral().into()).await; - - let client = match new_udp_tracker_client_connected(&env.bind_address().to_string()).await { - Ok(udp_tracker_client) => udp_tracker_client, - Err(err) => panic!("{err}"), - }; - - let connection_id = send_connection_request(TransactionId::new(123), &client).await; - + pub async fn send_and_get_announce(tx_id: TransactionId, c_id: ConnectionId, client: &UdpTrackerClient) { // Send announce request let announce_request = AnnounceRequest { - connection_id: ConnectionId(connection_id.0), + connection_id: ConnectionId(c_id.0), action_placeholder: AnnounceActionPlaceholder::default(), - transaction_id: TransactionId::new(123i32), + transaction_id: tx_id, info_hash: InfoHash([0u8; 20]), peer_id: PeerId([255u8; 20]), bytes_downloaded: NumberOfBytes(0i64.into()), @@ -160,6 +159,43 @@ mod receiving_an_announce_request { println!("test response {response:?}"); assert!(is_ipv4_announce_response(&response)); + } + + #[tokio::test] + async fn should_return_an_announce_response() { + let env = Started::new(&configuration::ephemeral().into()).await; + + let client = match new_udp_tracker_client_connected(&env.bind_address().to_string()).await { + Ok(udp_tracker_client) => udp_tracker_client, + Err(err) => panic!("{err}"), + }; + + let tx_id = TransactionId::new(123); + + let c_id = send_connection_request(tx_id, &client).await; + + send_and_get_announce(tx_id, c_id, &client).await; + + env.stop().await; + } + + #[tokio::test] + async fn should_return_many_announce_response() { + let env = Started::new(&configuration::ephemeral().into()).await; + + let client = match new_udp_tracker_client_connected(&env.bind_address().to_string()).await { + Ok(udp_tracker_client) => udp_tracker_client, + Err(err) => panic!("{err}"), + }; + + let tx_id = TransactionId::new(123); + + let c_id = send_connection_request(tx_id, &client).await; + + for x in 0..1000 { + log::info!("req no: {x}"); + send_and_get_announce(tx_id, c_id, &client).await; + } env.stop().await; }