Skip to content

Commit

Permalink
Merge pull request #3 from asonix/asonix/wrong-send-error-fix
Browse files Browse the repository at this point in the history
Fix pending async senders erroring with mismatched values
  • Loading branch information
mahdi-shojaee authored Apr 15, 2024
2 parents e38433d + a2517b7 commit 2b740ee
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 8 deletions.
30 changes: 23 additions & 7 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,20 @@ fn try_recv<T>(mut guard: MutexGuard<'_, SharedState<T>>) -> TryRecvResult<T> {
#[derive(Debug)]
pub struct SendFuture<T> {
shared_state: Arc<Mutex<SharedState<T>>>,
msg: Option<T>,
msg: MessageOrId<T>,
}

#[derive(Debug)]
enum MessageOrId<T> {
Message(T),
Id(usize),
Invalid,
}

impl<T> MessageOrId<T> {
fn take(&mut self) -> Self {
std::mem::replace(self, Self::Invalid)
}
}

impl<T> std::marker::Unpin for SendFuture<T> {}
Expand All @@ -412,16 +425,17 @@ impl<T> Future for SendFuture<T> {

fn poll(mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let m = match self.msg.take() {
Some(m) => m,
None => {
MessageOrId::Message(m) => m,
MessageOrId::Id(id) => {
let mut guard = self.shared_state.lock();
if guard.closed {
if let Some((_, (m, Some(_)))) = guard.pending_sends.dequeue() {
if let Some((_, (m, Some(_)))) = guard.pending_sends.remove(id) {
return Poll::Ready(Err(SendError(m)));
}
}
return Poll::Ready(Ok(()));
}
MessageOrId::Invalid => panic!("Future polled after completion"),
};
let mut guard = self.shared_state.lock();
let id = guard.get_next_id();
Expand All @@ -433,10 +447,12 @@ impl<T> Future for SendFuture<T> {
guard
.pending_sends
.enqueue(id, (m, Some(cx.waker().clone().into())));
if let Some((_, s)) = guard.pending_recvs.dequeue() {
drop(guard);
let opt = guard.pending_recvs.dequeue();
drop(guard);
if let Some((_, s)) = opt {
s.wake();
}
self.msg = MessageOrId::Id(id);
Poll::Pending
}
}
Expand Down Expand Up @@ -542,7 +558,7 @@ impl<T> Sender<T> {
pub fn send_async(&self, m: T) -> SendFuture<T> {
SendFuture {
shared_state: Arc::clone(&self.shared_state),
msg: Some(m),
msg: MessageOrId::Message(m),
}
}

Expand Down
32 changes: 31 additions & 1 deletion tests/async.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
use std::time::{Duration, Instant};
use std::{
future::Future,
pin::Pin,
task::Context,
time::{Duration, Instant},
};

use futures::{future::join_all, FutureExt};
use loole::{bounded, RecvError, SendError};
Expand All @@ -11,6 +16,31 @@ async fn async_sleep(ms: u64) {
tokio::time::sleep(Duration::from_millis(ms)).await
}

#[test]
fn ordered_deques() {
let (tx, rx) = bounded(0);

let mut send_future_1 = tx.send_async(1);
let mut send_future_2 = tx.send_async(2);

let mut cx = Context::from_waker(futures::task::noop_waker_ref());
let cx = &mut cx;

assert!(Pin::new(&mut send_future_1).poll(cx).is_pending());
assert!(Pin::new(&mut send_future_2).poll(cx).is_pending());

drop(rx);

assert_eq!(
Pin::new(&mut send_future_2).poll(cx),
std::task::Poll::Ready(Err(SendError(2)))
);
assert_eq!(
Pin::new(&mut send_future_1).poll(cx),
std::task::Poll::Ready(Err(SendError(1)))
);
}

#[tokio::test]
async fn async_send_before_recv_buffer_0() {
let (tx, rx) = bounded(0);
Expand Down

0 comments on commit 2b740ee

Please sign in to comment.