diff --git a/mea/src/barrier/mod.rs b/mea/src/barrier/mod.rs index a099347..a4a5d3f 100644 --- a/mea/src/barrier/mod.rs +++ b/mea/src/barrier/mod.rs @@ -34,8 +34,8 @@ //! let barrier = barrier.clone(); //! handles.push(tokio::spawn(async move { //! println!("Task {} before barrier", i); -//! let is_leader = barrier.wait().await; -//! println!("Task {} after barrier (leader: {})", i, is_leader); +//! let result = barrier.wait().await; +//! println!("Task {} after barrier (leader: {})", i, result.is_leader()); //! })); //! } //! @@ -48,8 +48,8 @@ //! let barrier = barrier.clone(); //! handles.push(tokio::spawn(async move { //! println!("Task {} before barrier", i); -//! let is_leader = barrier.wait().await; -//! println!("Task {} after barrier (leader: {})", i, is_leader); +//! let result = barrier.wait().await; +//! println!("Task {} after barrier (leader: {})", i, result.is_leader()); //! })); //! } //! @@ -95,8 +95,8 @@ mod tests; /// let barrier = barrier.clone(); /// handles.push(tokio::spawn(async move { /// println!("Task {} before barrier", i); -/// let is_leader = barrier.wait().await; -/// println!("Task {} after barrier (leader: {})", i, is_leader); +/// let result = barrier.wait().await; +/// println!("Task {} after barrier (leader: {})", i, result.is_leader()); /// })); /// } /// @@ -109,8 +109,8 @@ mod tests; /// let barrier = barrier.clone(); /// handles.push(tokio::spawn(async move { /// println!("Task {} before barrier", i); -/// let is_leader = barrier.wait().await; -/// println!("Task {} after barrier (leader: {})", i, is_leader); +/// let result = barrier.wait().await; +/// println!("Task {} after barrier (leader: {})", i, result.is_leader()); /// })); /// } /// @@ -142,6 +142,54 @@ impl fmt::Debug for BarrierState { } } +/// A `BarrierWaitResult` is returned by [`Barrier::wait()`] when all threads +/// in the [`Barrier`] have rendezvoused. +/// +/// # Examples +/// +/// ``` +/// # #[tokio::main] +/// # async fn main() { +/// use mea::barrier::Barrier; +/// +/// let barrier = Barrier::new(1); +/// let barrier_wait_result = barrier.wait().await; +/// # } +/// ``` +pub struct BarrierWaitResult(bool); + +impl fmt::Debug for BarrierWaitResult { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("BarrierWaitResult") + .field("is_leader", &self.is_leader()) + .finish() + } +} + +impl BarrierWaitResult { + /// Returns `true` if this worker is the "leader" for the call to [`Barrier::wait()`]. + /// + /// Only one worker will have `true` returned from their result, all other + /// workers will have `false` returned. + /// + /// # Examples + /// + /// ``` + /// # #[tokio::main] + /// # async fn main() { + /// use mea::barrier::Barrier; + /// + /// let barrier = Barrier::new(1); + /// let barrier_wait_result = barrier.wait().await; + /// println!("{:?}", barrier_wait_result.is_leader()); + /// # } + /// ``` + #[must_use] + pub fn is_leader(&self) -> bool { + self.0 + } +} + impl Barrier { /// Creates a new barrier that can block the specified number of tasks. /// @@ -199,16 +247,16 @@ impl Barrier { /// let barrier2 = barrier.clone(); /// /// let handle = tokio::spawn(async move { - /// let is_leader = barrier2.wait().await; - /// println!("Task 1: leader = {}", is_leader); + /// let result = barrier2.wait().await; + /// println!("Task 1: leader = {}", result.is_leader()); /// }); /// - /// let is_leader = barrier.wait().await; - /// println!("Task 2: leader = {}", is_leader); + /// let result = barrier.wait().await; + /// println!("Task 2: leader = {}", result.is_leader()); /// handle.await.unwrap(); /// # } /// ``` - pub async fn wait(&self) -> bool { + pub async fn wait(&self) -> BarrierWaitResult { let generation = { let mut state = self.state.lock(); let generation = state.generation; @@ -220,7 +268,7 @@ impl Barrier { state.arrived = 0; state.generation += 1; state.waiters.wake_all(); - return true; + return BarrierWaitResult(true); } generation @@ -232,7 +280,7 @@ impl Barrier { barrier: self, }; fut.await; - false + BarrierWaitResult(false) } } diff --git a/mea/src/barrier/tests.rs b/mea/src/barrier/tests.rs index 2b407f2..76802b4 100644 --- a/mea/src/barrier/tests.rs +++ b/mea/src/barrier/tests.rs @@ -24,12 +24,12 @@ fn zero_does_not_block() { { let mut f = spawn(b.wait()); let leader = assert_ready!(f.poll()); - assert!(leader); + assert!(leader.is_leader()); } { let mut f = spawn(b.wait()); let leader = assert_ready!(f.poll()); - assert!(leader); + assert!(leader.is_leader()); } } @@ -39,17 +39,17 @@ fn single() { { let mut f = spawn(b.wait()); let leader = assert_ready!(f.poll()); - assert!(leader); + assert!(leader.is_leader()); } { let mut f = spawn(b.wait()); let leader = assert_ready!(f.poll()); - assert!(leader); + assert!(leader.is_leader()); } { let mut f = spawn(b.wait()); let leader = assert_ready!(f.poll()); - assert!(leader); + assert!(leader.is_leader()); } } @@ -61,8 +61,8 @@ fn tango() { assert_pending!(f1.poll()); let mut f2 = spawn(b.wait()); - let f2_leader = assert_ready!(f2.poll()); - let f1_leader = assert_ready!(f1.poll()); + let f2_leader = assert_ready!(f2.poll()).is_leader(); + let f1_leader = assert_ready!(f1.poll()).is_leader(); assert!(f1_leader || f2_leader); assert!(!(f1_leader && f2_leader)); @@ -86,10 +86,10 @@ fn lots() { // pass the barrier let mut f = spawn(b.wait()); - let mut found_leader = assert_ready!(f.poll()); + let mut found_leader = assert_ready!(f.poll()).is_leader(); for mut f in wait { let leader = assert_ready!(f.poll()); - if leader { + if leader.is_leader() { assert!(!found_leader); found_leader = true; } diff --git a/mea/src/internal/semaphore.rs b/mea/src/internal/semaphore.rs index ec18807..ef62693 100644 --- a/mea/src/internal/semaphore.rs +++ b/mea/src/internal/semaphore.rs @@ -98,12 +98,14 @@ impl Semaphore { } /// Acquires `n` permits from the semaphore. - pub(crate) fn acquire(&self, n: u32) -> Acquire<'_> { - Acquire { + pub(crate) async fn acquire(&self, n: u32) { + let fut = Acquire { permits: n, index: None, semaphore: self, - } + done: false, + }; + fut.await } /// Adds `n` new permits to the semaphore. @@ -162,6 +164,7 @@ pub(crate) struct Acquire<'a> { permits: u32, index: Option, semaphore: &'a Semaphore, + done: bool, } impl Drop for Acquire<'_> { @@ -190,8 +193,13 @@ impl Future for Acquire<'_> { permits, index, semaphore, + done, } = self.get_mut(); + if *done { + return Poll::Ready(()); + } + match index { Some(idx) => { let mut waiters = semaphore.waiters.lock(); @@ -214,6 +222,7 @@ impl Future for Acquire<'_> { if ready { *index = None; + *done = true; return Poll::Ready(()); } } @@ -256,6 +265,7 @@ impl Future for Acquire<'_> { Ok(_) => { acquired += acq; if remaining == 0 { + *done = true; return Poll::Ready(()); } break lock.expect("lock not acquired"); diff --git a/mea/src/latch/mod.rs b/mea/src/latch/mod.rs index 738e14f..3ee54b8 100644 --- a/mea/src/latch/mod.rs +++ b/mea/src/latch/mod.rs @@ -235,11 +235,12 @@ impl Latch { /// handle.await.unwrap(); /// # } /// ``` - pub fn wait(&self) -> LatchWait<'_> { - LatchWait { + pub async fn wait(&self) { + let fut = LatchWait { idx: None, latch: self, - } + }; + fut.await } } diff --git a/mea/src/mutex/mod.rs b/mea/src/mutex/mod.rs index 0b28a5e..571cc8b 100644 --- a/mea/src/mutex/mod.rs +++ b/mea/src/mutex/mod.rs @@ -147,11 +147,8 @@ impl Mutex { /// # } /// ``` pub async fn lock(&self) -> MutexGuard<'_, T> { - let fut = async { - self.s.acquire(1).await; - MutexGuard { lock: self } - }; - fut.await + self.s.acquire(1).await; + MutexGuard { lock: self } } /// Attempts to acquire the lock, and returns `None` if the lock is currently held somewhere