From 35506a317a95896608ac56e41457bded875a84ca Mon Sep 17 00:00:00 2001 From: asonix Date: Wed, 10 Apr 2024 10:48:02 -0500 Subject: [PATCH 1/2] Add failing test that demonstrates how closed receivers break sender error assumptions --- tests/async.rs | 32 +++++++++++++++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/tests/async.rs b/tests/async.rs index 1d40dc0..0ff0df6 100644 --- a/tests/async.rs +++ b/tests/async.rs @@ -1,4 +1,9 @@ -use std::time::{Duration, Instant}; +use std::{ + future::Future, + pin::Pin, + task::Context, + time::{Duration, Instant}, +}; use futures::{future::join_all, FutureExt}; use loole::{bounded, RecvError, SendError}; @@ -11,6 +16,31 @@ async fn async_sleep(ms: u64) { tokio::time::sleep(Duration::from_millis(ms)).await } +#[test] +fn ordered_deques() { + let (tx, rx) = bounded(0); + + let mut send_future_1 = tx.send_async(1); + let mut send_future_2 = tx.send_async(2); + + let mut cx = Context::from_waker(futures::task::noop_waker_ref()); + let cx = &mut cx; + + assert!(Pin::new(&mut send_future_1).poll(cx).is_pending()); + assert!(Pin::new(&mut send_future_2).poll(cx).is_pending()); + + drop(rx); + + assert_eq!( + Pin::new(&mut send_future_2).poll(cx), + std::task::Poll::Ready(Err(SendError(2))) + ); + assert_eq!( + Pin::new(&mut send_future_1).poll(cx), + std::task::Poll::Ready(Err(SendError(1))) + ); +} + #[tokio::test] async fn async_send_before_recv_buffer_0() { let (tx, rx) = bounded(0); From a2517b71dde170c8a74a131d8458f172c6b82bb3 Mon Sep 17 00:00:00 2001 From: asonix Date: Wed, 10 Apr 2024 11:23:48 -0500 Subject: [PATCH 2/2] Remove correct messsage from sender queue When the receiver is dropped, the sender retrieves its value from the sender queue to return as an error. Due to the concurrent nature of futures, the retrieval is not guaranteed to occur in the same order as the insertion, and therefore we cannot rely on .dequeue() for this, and must instead use the more costly .remove() --- src/lib.rs | 30 +++++++++++++++++++++++------- 1 file changed, 23 insertions(+), 7 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 954b621..ebe16a3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -402,7 +402,20 @@ fn try_recv(mut guard: MutexGuard<'_, SharedState>) -> TryRecvResult { #[derive(Debug)] pub struct SendFuture { shared_state: Arc>>, - msg: Option, + msg: MessageOrId, +} + +#[derive(Debug)] +enum MessageOrId { + Message(T), + Id(usize), + Invalid, +} + +impl MessageOrId { + fn take(&mut self) -> Self { + std::mem::replace(self, Self::Invalid) + } } impl std::marker::Unpin for SendFuture {} @@ -412,16 +425,17 @@ impl Future for SendFuture { fn poll(mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let m = match self.msg.take() { - Some(m) => m, - None => { + MessageOrId::Message(m) => m, + MessageOrId::Id(id) => { let mut guard = self.shared_state.lock(); if guard.closed { - if let Some((_, (m, Some(_)))) = guard.pending_sends.dequeue() { + if let Some((_, (m, Some(_)))) = guard.pending_sends.remove(id) { return Poll::Ready(Err(SendError(m))); } } return Poll::Ready(Ok(())); } + MessageOrId::Invalid => panic!("Future polled after completion"), }; let mut guard = self.shared_state.lock(); let id = guard.get_next_id(); @@ -433,10 +447,12 @@ impl Future for SendFuture { guard .pending_sends .enqueue(id, (m, Some(cx.waker().clone().into()))); - if let Some((_, s)) = guard.pending_recvs.dequeue() { - drop(guard); + let opt = guard.pending_recvs.dequeue(); + drop(guard); + if let Some((_, s)) = opt { s.wake(); } + self.msg = MessageOrId::Id(id); Poll::Pending } } @@ -542,7 +558,7 @@ impl Sender { pub fn send_async(&self, m: T) -> SendFuture { SendFuture { shared_state: Arc::clone(&self.shared_state), - msg: Some(m), + msg: MessageOrId::Message(m), } }