Skip to content

Commit

Permalink
Merge pull request #95 from j5ik2o/refactor-2024-10-04
Browse files Browse the repository at this point in the history
refactor: remove ThrottleCallback struct and simplify callback usage
  • Loading branch information
j5ik2o authored Oct 4, 2024
2 parents 840f9f6 + 55f9ceb commit 1376302
Show file tree
Hide file tree
Showing 10 changed files with 127 additions and 125 deletions.
3 changes: 2 additions & 1 deletion core/src/actor/context/actor_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,8 @@ impl ActorContext {
.await
.get_process_registry()
.await
.remove_process(&self.get_self_opt().await.unwrap());
.remove_process(&self.get_self_opt().await.unwrap())
.await;
let result = self
.invoke_user_message(MessageHandle::new(AutoReceiveMessage::PostStop))
.await;
Expand Down
7 changes: 3 additions & 4 deletions core/src/actor/dispatch/dead_letter_process.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use crate::actor::metrics::metrics_impl::{Metrics, EXTENSION_ID};
use crate::actor::process::{Process, ProcessHandle};
use crate::generated::actor::{DeadLetterResponse, Terminated};

use crate::actor::dispatch::throttler::{Throttle, ThrottleCallback, Valve};
use crate::actor::dispatch::throttler::{Throttle, Valve};
use crate::metrics::ActorMetrics;
use async_trait::async_trait;
use nexus_actor_message_derive_rs::Message;
Expand All @@ -36,9 +36,8 @@ impl DeadLetterProcess {
.await
.dead_letter_throttle_interval
.clone();
let func = ThrottleCallback::new(move |i: usize| async move {
tracing::info!("DeadLetterProcess: Throttling dead letters, count: {}", i)
});
let func =
move |i: usize| async move { tracing::info!("DeadLetterProcess: Throttling dead letters, count: {}", i) };
let dispatcher = myself.actor_system.get_config().await.system_dispatcher.clone();
let throttle = Throttle::new(
dispatcher,
Expand Down
31 changes: 7 additions & 24 deletions core/src/actor/dispatch/throttler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;

use crate::actor::dispatch::{Dispatcher, Runnable};
use futures::future::BoxFuture;
use tokio::sync::Mutex;
use tokio::time::{interval, Duration};

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
Expand All @@ -20,31 +18,16 @@ pub struct Throttle {
max_events_in_period: usize,
}

pub struct ThrottleCallback(Arc<Mutex<dyn FnMut(usize) -> BoxFuture<'static, ()> + Send + 'static>>);

impl ThrottleCallback {
pub fn new<F, Fut>(f: F) -> Self
where
F: Fn(usize) -> Fut + Send + 'static,
Fut: Future<Output = ()> + Send + 'static, {
Self(Arc::new(Mutex::new(move |size: usize| {
Box::pin(f(size)) as BoxFuture<'static, ()>
})))
}

pub async fn run(&self, times_called: usize) {
let mut f = self.0.lock().await;
f(times_called).await;
}
}

impl Throttle {
pub async fn new(
pub async fn new<F, Fut>(
dispatcher: Arc<dyn Dispatcher>,
max_events_in_period: usize,
period: Duration,
throttled_callback: ThrottleCallback,
) -> Arc<Self> {
mut throttled_callback: F,
) -> Arc<Self>
where
F: FnMut(usize) -> Fut + Send + 'static,
Fut: Future<Output = ()> + Send + 'static, {
let throttle = Arc::new(Self {
current_events: Arc::new(AtomicUsize::new(0)),
max_events_in_period,
Expand All @@ -59,7 +42,7 @@ impl Throttle {
interval.tick().await;
let times_called = throttle_clone.current_events.swap(0, Ordering::SeqCst);
if times_called > max_events_in_period {
throttled_callback.run(times_called - max_events_in_period).await;
throttled_callback(times_called - max_events_in_period).await;
}
}
}))
Expand Down
21 changes: 8 additions & 13 deletions core/src/actor/dispatch/throttler_test.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#[cfg(test)]
mod tests {
use crate::actor::dispatch::throttler::{Throttle, ThrottleCallback, Valve};
use crate::actor::dispatch::throttler::{Throttle, Valve};
use crate::actor::dispatch::TokioRuntimeContextDispatcher;
use std::sync::Arc;
use std::time::Duration;
Expand All @@ -12,18 +12,13 @@ mod tests {
let callback_called_clone = Arc::clone(&callback_called);
let dispatcher = Arc::new(TokioRuntimeContextDispatcher::new().unwrap());

let throttle = Throttle::new(
dispatcher,
10,
Duration::from_millis(100),
ThrottleCallback::new(move |_| {
let callback_called = callback_called_clone.clone();
async move {
let mut called = callback_called.lock().await;
*called = true;
}
}),
)
let throttle = Throttle::new(dispatcher, 10, Duration::from_millis(100), move |_| {
let callback_called = callback_called_clone.clone();
async move {
let mut called = callback_called.lock().await;
*called = true;
}
})
.await;

assert_eq!(throttle.should_throttle(), Valve::Open);
Expand Down
1 change: 1 addition & 0 deletions core/src/actor/message/continuation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ impl Message for Continuation {
}
}

#[allow(clippy::type_complexity)]
#[derive(Clone)]
pub struct ContinuationCallback(Arc<dyn Fn() -> BoxFuture<'static, ()> + Send + Sync + 'static>);

Expand Down
17 changes: 11 additions & 6 deletions core/src/actor/process/process_registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ impl AddressResolver {

impl ProcessRegistry {
pub fn new(actor_system: ActorSystem) -> Self {
ProcessRegistry {
Self {
sequence_id: Arc::new(AtomicU64::new(0)),
actor_system,
address: Arc::new(RwLock::new(LOCAL_ADDRESS.to_string())),
Expand Down Expand Up @@ -133,7 +133,7 @@ impl ProcessRegistry {
(pid, inserted)
}

pub fn remove_process(&self, pid: &ExtendedPid) {
pub async fn remove_process(&self, pid: &ExtendedPid) {
let bucket = self.local_pids.get_bucket(pid.id());
if let Some((_, process)) = bucket.remove(pid.id()) {
if let Some(actor_process) = process.as_any().downcast_ref::<ActorProcess>() {
Expand Down Expand Up @@ -168,11 +168,14 @@ impl ProcessRegistry {
}
}

pub(crate) fn uint64_to_id(u: u64) -> String {
const DIGITS: &[u8] = b"0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ~+";
const DIGITS: &[u8; 64] = b"0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ~+";

pub fn uint64_to_id(u: u64) -> String {
// 最大13文字 (62進数で11桁 + '$' + 潜在的な追加の1文字)
let mut buf = [0u8; 13];
let mut i = 12;
let mut i = buf.len() - 1;
let mut u = u;

while u >= 64 {
buf[i] = DIGITS[(u & 0x3f) as usize];
u >>= 6;
Expand All @@ -181,5 +184,7 @@ pub(crate) fn uint64_to_id(u: u64) -> String {
buf[i] = DIGITS[u as usize];
i -= 1;
buf[i] = b'$';
String::from_utf8(buf[i..].to_vec()).unwrap()

// 使用された部分のスライスを文字列に変換
unsafe { std::str::from_utf8_unchecked(&buf[i..]).to_string() }
}
20 changes: 10 additions & 10 deletions utils/src/collections/queue/mpsc_bounded_channel_queue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::collections::element::Element;
use crate::collections::{QueueBase, QueueError, QueueReader, QueueSize, QueueWriter};
use async_trait::async_trait;
use tokio::sync::mpsc::error::{SendError, TryRecvError};
use tokio::sync::{mpsc, Mutex};
use tokio::sync::{mpsc, RwLock};

#[derive(Debug)]
struct MpscBoundedQueueInner<E> {
Expand All @@ -18,15 +18,15 @@ struct MpscBoundedQueueInner<E> {
#[derive(Debug, Clone)]
pub struct MpscBoundedChannelQueue<E> {
sender: mpsc::Sender<E>,
inner: Arc<Mutex<MpscBoundedQueueInner<E>>>,
inner: Arc<RwLock<MpscBoundedQueueInner<E>>>,
}

impl<T> MpscBoundedChannelQueue<T> {
pub fn new(buffer: usize) -> Self {
let (sender, receiver) = mpsc::channel(buffer);
Self {
sender,
inner: Arc::new(Mutex::new(MpscBoundedQueueInner {
inner: Arc::new(RwLock::new(MpscBoundedQueueInner {
receiver,
count: 0,
capacity: buffer,
Expand All @@ -36,15 +36,15 @@ impl<T> MpscBoundedChannelQueue<T> {
}

async fn try_recv(&self) -> Result<T, TryRecvError> {
let mut inner_mg = self.inner.lock().await;
let mut inner_mg = self.inner.write().await;
if inner_mg.is_closed {
return Err(TryRecvError::Disconnected);
}
inner_mg.receiver.try_recv()
}

async fn try_send(&self, element: T) -> Result<(), SendError<T>> {
let inner_mg = self.inner.lock().await;
let inner_mg = self.inner.read().await;
if inner_mg.is_closed {
return Err(SendError(element));
}
Expand All @@ -60,25 +60,25 @@ impl<T> MpscBoundedChannelQueue<T> {
}

async fn increment_count(&self) {
let mut inner_mg = self.inner.lock().await;
let mut inner_mg = self.inner.write().await;
inner_mg.count += 1;
}

async fn decrement_count(&self) {
let mut inner_mg = self.inner.lock().await;
let mut inner_mg = self.inner.write().await;
inner_mg.count = inner_mg.count.saturating_sub(1);
}
}

#[async_trait]
impl<E: Element> QueueBase<E> for MpscBoundedChannelQueue<E> {
async fn len(&self) -> QueueSize {
let inner_mg = self.inner.lock().await;
let inner_mg = self.inner.read().await;
QueueSize::Limited(inner_mg.count)
}

async fn capacity(&self) -> QueueSize {
let inner_mg = self.inner.lock().await;
let inner_mg = self.inner.read().await;
QueueSize::Limited(inner_mg.capacity)
}
}
Expand Down Expand Up @@ -110,7 +110,7 @@ impl<E: Element> QueueReader<E> for MpscBoundedChannelQueue<E> {
}

async fn clean_up(&mut self) {
let mut inner_mg = self.inner.lock().await;
let mut inner_mg = self.inner.write().await;
inner_mg.count = 0;
inner_mg.receiver.close();
inner_mg.is_closed = true;
Expand Down
18 changes: 9 additions & 9 deletions utils/src/collections/queue/mpsc_unbounded_channel_queue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::collections::element::Element;
use crate::collections::{QueueBase, QueueError, QueueReader, QueueSize, QueueWriter};
use async_trait::async_trait;
use tokio::sync::mpsc::error::{SendError, TryRecvError};
use tokio::sync::{mpsc, Mutex};
use tokio::sync::{mpsc, RwLock};

#[derive(Debug)]
struct MpscUnboundedChannelQueueInner<E> {
Expand All @@ -17,15 +17,15 @@ struct MpscUnboundedChannelQueueInner<E> {
#[derive(Debug, Clone)]
pub struct MpscUnboundedChannelQueue<E> {
sender: mpsc::UnboundedSender<E>,
inner: Arc<Mutex<MpscUnboundedChannelQueueInner<E>>>,
inner: Arc<RwLock<MpscUnboundedChannelQueueInner<E>>>,
}

impl<T> MpscUnboundedChannelQueue<T> {
pub fn new() -> Self {
let (sender, receiver) = mpsc::unbounded_channel();
Self {
sender,
inner: Arc::new(Mutex::new(MpscUnboundedChannelQueueInner {
inner: Arc::new(RwLock::new(MpscUnboundedChannelQueueInner {
receiver,
count: 0,
is_closed: false,
Expand All @@ -34,15 +34,15 @@ impl<T> MpscUnboundedChannelQueue<T> {
}

async fn try_recv(&self) -> Result<T, TryRecvError> {
let mut inner_mg = self.inner.lock().await;
let mut inner_mg = self.inner.write().await;
if inner_mg.is_closed {
return Err(TryRecvError::Disconnected);
}
inner_mg.receiver.try_recv()
}

async fn send(&self, element: T) -> Result<(), SendError<T>> {
let inner_mg = self.inner.lock().await;
let inner_mg = self.inner.read().await;
if inner_mg.is_closed {
return Err(SendError(element));
}
Expand All @@ -51,20 +51,20 @@ impl<T> MpscUnboundedChannelQueue<T> {
}

async fn increment_count(&self) {
let mut inner_mg = self.inner.lock().await;
let mut inner_mg = self.inner.write().await;
inner_mg.count += 1;
}

async fn decrement_count(&self) {
let mut inner_mg = self.inner.lock().await;
let mut inner_mg = self.inner.write().await;
inner_mg.count = inner_mg.count.saturating_sub(1);
}
}

#[async_trait]
impl<E: Element> QueueBase<E> for MpscUnboundedChannelQueue<E> {
async fn len(&self) -> QueueSize {
let inner_mg = self.inner.lock().await;
let inner_mg = self.inner.read().await;
QueueSize::Limited(inner_mg.count)
}

Expand Down Expand Up @@ -106,7 +106,7 @@ impl<E: Element> QueueReader<E> for MpscUnboundedChannelQueue<E> {
}

async fn clean_up(&mut self) {
let mut inner_mg = self.inner.lock().await;
let mut inner_mg = self.inner.write().await;
inner_mg.count = 0;
inner_mg.receiver.close();
inner_mg.is_closed = true;
Expand Down
Loading

0 comments on commit 1376302

Please sign in to comment.