Skip to content

Commit

Permalink
feat: implement WaitGroup
Browse files Browse the repository at this point in the history
Signed-off-by: tison <wander4096@gmail.com>
  • Loading branch information
tisonkun committed Oct 25, 2024
1 parent 0c87843 commit 4189862
Show file tree
Hide file tree
Showing 5 changed files with 346 additions and 4 deletions.
18 changes: 14 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,22 @@

[package]
name = "mea"
version = "0.0.1"

description = "Async Rust utilities that are runtime agnostic."
edition = "2021"
rust-version = "1.75.0"

description = "Async Rust utilities that are runtime agnostic."
documentation = "https://docs.rs/mea"
homepage = "https://github.com/tisonkun/mea"
license = "Apache-2.0"
readme = "README.md"
rust-version = "1.75.0"
version = "0.0.1"
repository = "https://github.com/tisonkun/mea"

[package.metadata.docs.rs]
all-features = true
rustdoc-args = ["--cfg", "docsrs"]

[dependencies]
[dev-dependencies]
pollster = { version = "0.3.0" }
tokio = { version = "1.41.0", features = ["full"] }
104 changes: 104 additions & 0 deletions src/internal.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
// Copyright 2024 tison <wander4096@gmail.com>
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use std::mem::take;
use std::task::Waker;

pub(crate) struct Mutex<T>(std::sync::Mutex<T>);

impl<T> Mutex<T> {
#[must_use]
#[inline]
pub(crate) const fn new(t: T) -> Self {
Self(std::sync::Mutex::new(t))
}

pub(crate) fn lock(&self) -> std::sync::MutexGuard<'_, T> {
self.0
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
}
}

pub(crate) struct Waiters {
data: Vec<Waiter>,
next: usize,
}

enum Waiter {
Occupied(Waker),
Vacant(usize),
}

impl Waiters {
pub(crate) const fn new() -> Self {
Self {
data: Vec::new(),
next: 0,
}
}

pub(crate) fn upsert(&mut self, id: &mut Option<usize>, waker: &Waker) {
match *id {
Some(key) => match self.data.get_mut(key) {
Some(Waiter::Occupied(w)) => {
if !w.will_wake(waker) {
*w = waker.clone()
}
}
_ => unreachable!("update non-existent waker"),
},
None => {
let key = self.next;

if self.data.len() == key {
self.data.push(Waiter::Occupied(waker.clone()));
self.next = key + 1;
} else if let Some(&Waiter::Vacant(n)) = self.data.get(key) {
self.data[key] = Waiter::Occupied(waker.clone());
self.next = n;
} else {
unreachable!();
}

*id = Some(key);
}
}
}

pub(crate) fn remove(&mut self, id: &mut Option<usize>) {
if let Some(key) = id.take() {
if let Some(waiter) = self.data.get_mut(key) {
if let Waiter::Occupied(_) = waiter {
*waiter = Waiter::Vacant(self.next);
self.next = key;
}
}
}
}

pub(crate) fn wake_all(mutex: &Mutex<Self>) {
let waiters = {
let mut lock = mutex.lock();
lock.next = 0;
take(&mut lock.data)
};

for waiter in waiters {
if let Waiter::Occupied(w) = waiter {
w.wake();
}
}
}
}
18 changes: 18 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// Copyright 2024 tison <wander4096@gmail.com>
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

mod internal;

pub mod timeout;
pub mod waitgroup;
21 changes: 21 additions & 0 deletions src/timeout.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// Copyright 2024 tison <wander4096@gmail.com>
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#[derive(Debug, PartialEq, Eq, Copy, Clone)]
pub enum MaybeTimedOut<T> {
/// Action completed before timeout.
Completed,
/// Action timed out. `T` contains the asynchronous timer result.
TimedOut(T),
}
189 changes: 189 additions & 0 deletions src/waitgroup.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
// Copyright 2024 tison <wander4096@gmail.com>
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use std::future::Future;
use std::future::IntoFuture;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::Weak;
use std::task::Context;
use std::task::Poll;

use crate::internal::Mutex;
use crate::internal::Waiters;
use crate::timeout::MaybeTimedOut;

#[derive(Clone)]
pub struct WaitGroup {
inner: Arc<Inner>,
}

impl WaitGroup {
pub fn new() -> Self {
Self {
inner: Arc::new(Inner {
waiters: Mutex::new(Waiters::new()),
}),
}
}

pub fn wait(self) -> WaitGroupFuture {
self.into_future()
}

pub fn wait_timeout<T>(self, timer: T) -> WaitGroupTimeoutFuture<T> {
WaitGroupTimeoutFuture {
id: None,
inner: Arc::downgrade(&self.inner),
timer,
}
}
}

impl Default for WaitGroup {
fn default() -> Self {
Self::new()
}
}

impl IntoFuture for WaitGroup {
type Output = ();

type IntoFuture = WaitGroupFuture;

fn into_future(self) -> Self::IntoFuture {
WaitGroupFuture {
id: None,
inner: Arc::downgrade(&self.inner),
}
}
}

struct Inner {
waiters: Mutex<Waiters>,
}

impl Drop for Inner {
fn drop(&mut self) {
Waiters::wake_all(&self.waiters);
}
}

pub struct WaitGroupFuture {
id: Option<usize>,
inner: Weak<Inner>,
}

impl Future for WaitGroupFuture {
type Output = ();

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let Self { id, inner } = self.get_mut();
match inner.upgrade() {
Some(inner) => {
let mut lock = inner.waiters.lock();
lock.upsert(id, cx.waker());
Poll::Pending
}
None => Poll::Ready(()),
}
}
}

pub struct WaitGroupTimeoutFuture<T> {
id: Option<usize>,
inner: Weak<Inner>,
timer: T,
}

impl<T: Future> Future for WaitGroupTimeoutFuture<T> {
type Output = MaybeTimedOut<T::Output>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
// SAFETY: `id` and `inner` are Unpin; `timer` must be pinned before poll.
let Self { id, inner, timer } = unsafe { self.get_unchecked_mut() };
let timer = unsafe { Pin::new_unchecked(timer) };

match inner.upgrade() {
Some(inner) => match timer.poll(cx) {
Poll::Ready(o) => {
let mut lock = inner.waiters.lock();
lock.remove(id);
Poll::Ready(MaybeTimedOut::TimedOut(o))
}
Poll::Pending => {
let mut lock = inner.waiters.lock();
lock.upsert(id, cx.waker());
Poll::Pending
}
},
None => Poll::Ready(MaybeTimedOut::Completed),
}
}
}

#[cfg(test)]
mod test {
use std::time::Duration;

use super::*;

fn test_runtime() -> tokio::runtime::Runtime {
tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap()
}

#[test]
fn test_wait_group_drop() {
let test_runtime = test_runtime();

let wg = WaitGroup::new();
for _i in 0..100 {
let w = wg.clone();
test_runtime.spawn(async move {
drop(w);
});
}
pollster::block_on(wg.into_future());
}

#[test]
fn test_wait_group_await() {
let test_runtime = test_runtime();

let wg = WaitGroup::new();
for _i in 0..100 {
let w = wg.clone();
test_runtime.spawn(async move {
w.await;
});
}
pollster::block_on(wg.into_future());
}

#[test]
fn test_wait_group_timeout() {
let test_runtime = test_runtime();

let wg = WaitGroup::new();
let _wg_clone = wg.clone();
let out = test_runtime.block_on(async move {
let timer = tokio::time::sleep(Duration::from_millis(50));
wg.wait_timeout(timer).await
});
assert_eq!(out, MaybeTimedOut::TimedOut(()));
}
}

0 comments on commit 4189862

Please sign in to comment.