From 41898626e9801a1894c938107b18061c83ba74d0 Mon Sep 17 00:00:00 2001 From: tison Date: Sat, 26 Oct 2024 02:27:45 +0800 Subject: [PATCH] feat: implement WaitGroup Signed-off-by: tison --- Cargo.toml | 18 ++++- src/internal.rs | 104 ++++++++++++++++++++++++++ src/lib.rs | 18 +++++ src/timeout.rs | 21 ++++++ src/waitgroup.rs | 189 +++++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 346 insertions(+), 4 deletions(-) create mode 100644 src/internal.rs create mode 100644 src/timeout.rs create mode 100644 src/waitgroup.rs diff --git a/Cargo.toml b/Cargo.toml index eb73cd1..c423a37 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,12 +14,22 @@ [package] name = "mea" +version = "0.0.1" -description = "Async Rust utilities that are runtime agnostic." edition = "2021" +rust-version = "1.75.0" + +description = "Async Rust utilities that are runtime agnostic." +documentation = "https://docs.rs/mea" +homepage = "https://github.com/tisonkun/mea" license = "Apache-2.0" readme = "README.md" -rust-version = "1.75.0" -version = "0.0.1" +repository = "https://github.com/tisonkun/mea" + +[package.metadata.docs.rs] +all-features = true +rustdoc-args = ["--cfg", "docsrs"] -[dependencies] +[dev-dependencies] +pollster = { version = "0.3.0" } +tokio = { version = "1.41.0", features = ["full"] } diff --git a/src/internal.rs b/src/internal.rs new file mode 100644 index 0000000..e10dfe2 --- /dev/null +++ b/src/internal.rs @@ -0,0 +1,104 @@ +// Copyright 2024 tison +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::mem::take; +use std::task::Waker; + +pub(crate) struct Mutex(std::sync::Mutex); + +impl Mutex { + #[must_use] + #[inline] + pub(crate) const fn new(t: T) -> Self { + Self(std::sync::Mutex::new(t)) + } + + pub(crate) fn lock(&self) -> std::sync::MutexGuard<'_, T> { + self.0 + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + } +} + +pub(crate) struct Waiters { + data: Vec, + next: usize, +} + +enum Waiter { + Occupied(Waker), + Vacant(usize), +} + +impl Waiters { + pub(crate) const fn new() -> Self { + Self { + data: Vec::new(), + next: 0, + } + } + + pub(crate) fn upsert(&mut self, id: &mut Option, waker: &Waker) { + match *id { + Some(key) => match self.data.get_mut(key) { + Some(Waiter::Occupied(w)) => { + if !w.will_wake(waker) { + *w = waker.clone() + } + } + _ => unreachable!("update non-existent waker"), + }, + None => { + let key = self.next; + + if self.data.len() == key { + self.data.push(Waiter::Occupied(waker.clone())); + self.next = key + 1; + } else if let Some(&Waiter::Vacant(n)) = self.data.get(key) { + self.data[key] = Waiter::Occupied(waker.clone()); + self.next = n; + } else { + unreachable!(); + } + + *id = Some(key); + } + } + } + + pub(crate) fn remove(&mut self, id: &mut Option) { + if let Some(key) = id.take() { + if let Some(waiter) = self.data.get_mut(key) { + if let Waiter::Occupied(_) = waiter { + *waiter = Waiter::Vacant(self.next); + self.next = key; + } + } + } + } + + pub(crate) fn wake_all(mutex: &Mutex) { + let waiters = { + let mut lock = mutex.lock(); + lock.next = 0; + take(&mut lock.data) + }; + + for waiter in waiters { + if let Waiter::Occupied(w) = waiter { + w.wake(); + } + } + } +} diff --git a/src/lib.rs b/src/lib.rs index e69de29..f6dfeb6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -0,0 +1,18 @@ +// Copyright 2024 tison +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +mod internal; + +pub mod timeout; +pub mod waitgroup; diff --git a/src/timeout.rs b/src/timeout.rs new file mode 100644 index 0000000..57450c1 --- /dev/null +++ b/src/timeout.rs @@ -0,0 +1,21 @@ +// Copyright 2024 tison +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#[derive(Debug, PartialEq, Eq, Copy, Clone)] +pub enum MaybeTimedOut { + /// Action completed before timeout. + Completed, + /// Action timed out. `T` contains the asynchronous timer result. + TimedOut(T), +} diff --git a/src/waitgroup.rs b/src/waitgroup.rs new file mode 100644 index 0000000..1324ba2 --- /dev/null +++ b/src/waitgroup.rs @@ -0,0 +1,189 @@ +// Copyright 2024 tison +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::future::Future; +use std::future::IntoFuture; +use std::pin::Pin; +use std::sync::Arc; +use std::sync::Weak; +use std::task::Context; +use std::task::Poll; + +use crate::internal::Mutex; +use crate::internal::Waiters; +use crate::timeout::MaybeTimedOut; + +#[derive(Clone)] +pub struct WaitGroup { + inner: Arc, +} + +impl WaitGroup { + pub fn new() -> Self { + Self { + inner: Arc::new(Inner { + waiters: Mutex::new(Waiters::new()), + }), + } + } + + pub fn wait(self) -> WaitGroupFuture { + self.into_future() + } + + pub fn wait_timeout(self, timer: T) -> WaitGroupTimeoutFuture { + WaitGroupTimeoutFuture { + id: None, + inner: Arc::downgrade(&self.inner), + timer, + } + } +} + +impl Default for WaitGroup { + fn default() -> Self { + Self::new() + } +} + +impl IntoFuture for WaitGroup { + type Output = (); + + type IntoFuture = WaitGroupFuture; + + fn into_future(self) -> Self::IntoFuture { + WaitGroupFuture { + id: None, + inner: Arc::downgrade(&self.inner), + } + } +} + +struct Inner { + waiters: Mutex, +} + +impl Drop for Inner { + fn drop(&mut self) { + Waiters::wake_all(&self.waiters); + } +} + +pub struct WaitGroupFuture { + id: Option, + inner: Weak, +} + +impl Future for WaitGroupFuture { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let Self { id, inner } = self.get_mut(); + match inner.upgrade() { + Some(inner) => { + let mut lock = inner.waiters.lock(); + lock.upsert(id, cx.waker()); + Poll::Pending + } + None => Poll::Ready(()), + } + } +} + +pub struct WaitGroupTimeoutFuture { + id: Option, + inner: Weak, + timer: T, +} + +impl Future for WaitGroupTimeoutFuture { + type Output = MaybeTimedOut; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + // SAFETY: `id` and `inner` are Unpin; `timer` must be pinned before poll. + let Self { id, inner, timer } = unsafe { self.get_unchecked_mut() }; + let timer = unsafe { Pin::new_unchecked(timer) }; + + match inner.upgrade() { + Some(inner) => match timer.poll(cx) { + Poll::Ready(o) => { + let mut lock = inner.waiters.lock(); + lock.remove(id); + Poll::Ready(MaybeTimedOut::TimedOut(o)) + } + Poll::Pending => { + let mut lock = inner.waiters.lock(); + lock.upsert(id, cx.waker()); + Poll::Pending + } + }, + None => Poll::Ready(MaybeTimedOut::Completed), + } + } +} + +#[cfg(test)] +mod test { + use std::time::Duration; + + use super::*; + + fn test_runtime() -> tokio::runtime::Runtime { + tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .unwrap() + } + + #[test] + fn test_wait_group_drop() { + let test_runtime = test_runtime(); + + let wg = WaitGroup::new(); + for _i in 0..100 { + let w = wg.clone(); + test_runtime.spawn(async move { + drop(w); + }); + } + pollster::block_on(wg.into_future()); + } + + #[test] + fn test_wait_group_await() { + let test_runtime = test_runtime(); + + let wg = WaitGroup::new(); + for _i in 0..100 { + let w = wg.clone(); + test_runtime.spawn(async move { + w.await; + }); + } + pollster::block_on(wg.into_future()); + } + + #[test] + fn test_wait_group_timeout() { + let test_runtime = test_runtime(); + + let wg = WaitGroup::new(); + let _wg_clone = wg.clone(); + let out = test_runtime.block_on(async move { + let timer = tokio::time::sleep(Duration::from_millis(50)); + wg.wait_timeout(timer).await + }); + assert_eq!(out, MaybeTimedOut::TimedOut(())); + } +}