From 330ea21bfb59186fe0c9abf03a5b234344efd334 Mon Sep 17 00:00:00 2001 From: doyoubi Date: Wed, 4 Aug 2021 22:34:18 +0800 Subject: [PATCH] Optimize WaitableTask --- src/proxy/migration_backend.rs | 203 ++++++++++++++++----------------- 1 file changed, 98 insertions(+), 105 deletions(-) diff --git a/src/proxy/migration_backend.rs b/src/proxy/migration_backend.rs index af0743c..b4ac1d4 100644 --- a/src/proxy/migration_backend.rs +++ b/src/proxy/migration_backend.rs @@ -7,6 +7,7 @@ use crate::common::utils::{generate_lock_slot, pretty_print_bytes, RetryError, W use crate::migration::scan_migration::{pttl_to_restore_expire_time, PTTL_KEY_NOT_FOUND}; use crate::migration::stats::MigrationStats; use crate::protocol::{Array, BinSafeStr, BulkStr, RFunctor, Resp, RespVec, VFunctor}; +use arc_swap::ArcSwapOption; use atomic_option::AtomicOption; use dashmap::DashSet; use either::Either; @@ -17,38 +18,79 @@ use futures::channel::{ use futures::{select, Future, FutureExt, StreamExt}; use std::fmt; use std::pin::Pin; -use std::sync::atomic::Ordering; +use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::Arc; use std::time::Duration; const KEY_NOT_EXISTS: &str = "0"; const FAILED_TO_ACCESS_SOURCE: &str = "MIGRATION_FORWARD: failed to access source node"; +struct WaitRegistry { + pending: AtomicU64, + signal: ArcSwapOption>, +} + +impl WaitRegistry { + fn new() -> (Self, WaitHandle) { + let (s, r) = oneshot::channel(); + let registry = Self { + pending: AtomicU64::new(0), + signal: ArcSwapOption::new(Some(Arc::new(s))), + }; + (registry, r) + } +} + +impl WaitRegistry { + fn register(&self) { + self.pending.fetch_add(1, Ordering::Relaxed); + } + + fn unregister(&self) { + if self.pending.fetch_sub(1, Ordering::Relaxed) == 1 { + self.signal.swap(None); + } + } +} + type WaitHandle = oneshot::Receiver<()>; +struct AutoDropWaitRegistry { + inner: Arc, +} + +impl AutoDropWaitRegistry { + fn new(inner: Arc) -> Self { + Self { inner } + } +} + +impl Drop for AutoDropWaitRegistry { + fn drop(&mut self) { + self.inner.unregister(); + } +} + pub struct WaitableTask { inner: T, // We reply on the `drop` function of `sender` #[allow(dead_code)] - sender: Option>, + _registry: Option, } impl WaitableTask { - fn new_with_handle(inner: T) -> (Self, WaitHandle) { - let (sender, receiver) = oneshot::channel(); - ( - Self { - inner, - sender: Some(sender), - }, - receiver, - ) + fn new_with_registry(inner: T, registry: Arc) -> Self { + registry.register(); + Self { + inner, + _registry: Some(AutoDropWaitRegistry::new(registry)), + } } - fn new_without_handle(inner: T) -> Self { + fn new_without_registry(inner: T) -> Self { Self { inner, - sender: None, + _registry: None, } } @@ -127,7 +169,7 @@ impl MgrCmdStateExists { let resp = Self::gen_exists_resp(&key); let (cmd_task, reply_fut) = cmd_task_factory.create_with_ctx(inner_task.get_context(), resp); - let task = ReqTask::Simple(WaitableTask::new_without_handle(cmd_task)); + let task = ReqTask::Simple(WaitableTask::new_without_registry(cmd_task)); let state = Self { inner_task, key, @@ -158,40 +200,40 @@ struct MgrCmdStateForward { impl MgrCmdStateForward { fn from_state_exists( state: MgrCmdStateExists, - ) -> (Self, ReqTask>, WaitHandle) { - let (inner_task, wait_handle) = WaitableTask::new_with_handle(state.into_inner()); + registry: Arc, + ) -> (Self, ReqTask>) { + let inner_task = WaitableTask::new_with_registry(state.into_inner(), registry); ( MgrCmdStateForward { _lock_guard: None }, ReqTask::Simple(inner_task), - wait_handle, ) } fn from_state_dump_pttl( state: MgrCmdStateDumpPttl, - ) -> (Self, ReqTask>, WaitHandle) { - let (inner_task, wait_handle) = WaitableTask::new_with_handle(state.into_inner()); + registry: Arc, + ) -> (Self, ReqTask>) { + let inner_task = WaitableTask::new_with_registry(state.into_inner(), registry); ( MgrCmdStateForward { _lock_guard: None }, ReqTask::Simple(inner_task), - wait_handle, ) } fn from_state_umsync( state: MgrCmdStateUmSync, - ) -> (Self, ReqTask>, WaitHandle) { + registry: Arc, + ) -> (Self, ReqTask>) { let MgrCmdStateUmSync { inner_task, lock_guard, } = state; - let (inner_task, wait_handle) = WaitableTask::new_with_handle(inner_task); + let inner_task = WaitableTask::new_with_registry(inner_task, registry); ( MgrCmdStateForward { _lock_guard: Some(lock_guard), }, ReqTask::Simple(inner_task), - wait_handle, ) } } @@ -330,12 +372,8 @@ impl MgrCmdStateRestoreForward { state: MgrCmdStateDumpPttl, entry: DataEntry, cmd_task_factory: &F, - ) -> ( - Self, - ReqTask>, - ReplyFuture, - [WaitHandle; 2], - ) { + registry: Arc, + ) -> (Self, ReqTask>, ReplyFuture) { let MgrCmdStateDumpPttl { inner_task, key, @@ -347,9 +385,8 @@ impl MgrCmdStateRestoreForward { let (restore_cmd_task, restore_reply_fut) = cmd_task_factory.create_with_ctx(inner_task.get_context(), resp); - let (restore_cmd_task, restore_wait_handle) = - WaitableTask::new_with_handle(restore_cmd_task); - let (inner_task, inner_wait_handle) = WaitableTask::new_with_handle(inner_task); + let restore_cmd_task = WaitableTask::new_with_registry(restore_cmd_task, registry.clone()); + let inner_task = WaitableTask::new_with_registry(inner_task, registry); let task = ReqTask::Multi(vec![restore_cmd_task, inner_task]); ( @@ -360,7 +397,6 @@ impl MgrCmdStateRestoreForward { }, task, restore_reply_fut, - [restore_wait_handle, inner_wait_handle], ) } @@ -419,8 +455,6 @@ type UmSyncTaskSender = UnboundedSender<(MgrCmdStateUmSync, ReplyFuture)>; type UmSyncTaskReceiver = UnboundedReceiver<(MgrCmdStateUmSync, ReplyFuture)>; type DeleteKeyTaskSender = UnboundedSender; type DeleteKeyTaskReceiver = UnboundedReceiver; -type DstWaitHandleSender = UnboundedSender; -type DstWaitHandleReceiver = UnboundedReceiver; pub struct RestoreDataCmdTaskHandler where @@ -439,7 +473,6 @@ where pending_umsync_task_sender: PendingUmSyncTaskSender, umsync_task_sender: UmSyncTaskSender, del_task_sender: DeleteKeyTaskSender, - wait_handle_sender: DstWaitHandleSender, #[allow(clippy::type_complexity)] task_receivers: AtomicOption<( ExistsTaskReceiver, @@ -448,11 +481,12 @@ where PendingUmSyncTaskReceiver, UmSyncTaskReceiver, DeleteKeyTaskReceiver, - DstWaitHandleReceiver, + WaitHandle, )>, cmd_task_factory: Arc, key_lock: Arc, stats: Arc, + registry: Arc, } impl RestoreDataCmdTaskHandler @@ -479,7 +513,7 @@ where let (pending_umsync_task_sender, pending_umsync_task_receiver) = unbounded(); let (umsync_task_sender, umsync_task_receiver) = unbounded(); let (del_task_sender, del_task_receiver) = unbounded(); - let (wait_handle_sender, wait_handle_receiver) = unbounded(); + let (registry, wait_handle) = WaitRegistry::new(); let task_receivers = AtomicOption::new(Box::new(( exists_task_receiver, dump_pttl_task_receiver, @@ -487,7 +521,7 @@ where pending_umsync_task_receiver, umsync_task_receiver, del_task_receiver, - wait_handle_receiver, + wait_handle, ))); let key_lock = Arc::new(KeyLock::new(LOCK_SHARD_SIZE)); Self { @@ -500,11 +534,11 @@ where pending_umsync_task_sender, umsync_task_sender, del_task_sender, - wait_handle_sender, task_receivers, cmd_task_factory, key_lock, stats, + registry: Arc::new(registry), } } @@ -521,7 +555,6 @@ where let dump_pttl_task_sender = self.dump_pttl_task_sender.clone(); let umsync_task_sender = self.umsync_task_sender.clone(); let del_task_sender = self.del_task_sender.clone(); - let wait_handle_sender = self.wait_handle_sender.clone(); let src_sender = self.src_sender.clone(); let dst_sender = self.dst_sender.clone(); let src_proxy_sender = self.src_proxy_sender.clone(); @@ -537,7 +570,7 @@ where pending_umsync_task_receiver, umsync_task_receiver, del_task_receiver, - wait_handle_receiver, + wait_handle, ) = match receiver_opt { Some(r) => r, None => { @@ -550,12 +583,12 @@ where exists_task_sender, exists_task_receiver, dump_pttl_task_sender, - wait_handle_sender.clone(), src_sender.clone(), dst_sender.clone(), cmd_task_factory.clone(), key_lock.clone(), self.stats.clone(), + self.registry.clone(), ); let pending_umsync_task_handler = Self::handle_pending_umsync_task( @@ -563,19 +596,19 @@ where umsync_task_sender, src_proxy_sender, dst_sender.clone(), - wait_handle_sender.clone(), key_lock.clone(), cmd_task_factory.clone(), self.stats.clone(), + self.registry.clone(), ); let dump_pttl_task_handler = Self::handle_dump_pttl_task( dump_pttl_task_receiver, restore_task_sender, - wait_handle_sender.clone(), dst_sender.clone(), cmd_task_factory.clone(), self.stats.clone(), + self.registry.clone(), ); let restore_task_handler = Self::handle_restore( @@ -587,22 +620,20 @@ where let umsync_task_handler = Self::handle_umsync_task( umsync_task_receiver, - wait_handle_sender.clone(), dst_sender, self.stats.clone(), + self.registry.clone(), ); let del_task_handler = Self::handle_del_task(del_task_receiver); - let handle_wait_handler = Self::handle_wait_handle(wait_handle_receiver); - let mut exists_task_handler = Box::pin(exists_task_handler.fuse()); let mut pending_umsync_task_handler = Box::pin(pending_umsync_task_handler.fuse()); let mut dump_pttl_task_handler = Box::pin(dump_pttl_task_handler.fuse()); let mut restore_task_handler = Box::pin(restore_task_handler.fuse()); let mut umsync_task_handler = Box::pin(umsync_task_handler.fuse()); let mut del_task_handler = Box::pin(del_task_handler.fuse()); - let mut handle_wait_handler = Box::pin(handle_wait_handler.fuse()); + let mut handle_wait_handler = Box::pin(wait_handle.map(|_| ()).fuse()); select! { () = exists_task_handler => {}, @@ -663,7 +694,6 @@ where () = handle_wait_handler => {}, } - self.wait_handle_sender.close_channel(); info!("wait for wait_handle task"); handle_wait_handler.await; info!("All remaining tasks in migration backend are finished"); @@ -674,12 +704,12 @@ where exists_task_sender: ExistsTaskSender, mut exists_task_receiver: ExistsTaskReceiver, dump_pttl_task_sender: DumpPttlTaskSender, - wait_handle_sender: DstWaitHandleSender, src_sender: Arc, dst_sender: Arc, cmd_task_factory: Arc, key_lock: Arc, stats: Arc, + registry: Arc, ) { while let Some((state, reply_receiver)) = exists_task_receiver.next().await { let res = reply_receiver.await; @@ -693,16 +723,11 @@ where .importing_dst_key_existed .fetch_add(1, Ordering::Relaxed); - let (_state, req_task, wait_handle) = MgrCmdStateForward::from_state_exists(state); + let (_state, req_task) = + MgrCmdStateForward::from_state_exists(state, registry.clone()); if let Err(err) = dst_sender.send(req_task) { debug!("failed to forward: {:?}", err); } - if let Err(err) = wait_handle_sender.unbounded_send(wait_handle) { - warn!( - "failed to send wait_handle in handle_exists_task: {:?}", - err - ); - } continue; } stats @@ -756,17 +781,11 @@ where Err(()) => continue, }; if key_exists { - let (_state, req_task, wait_handle) = - MgrCmdStateForward::from_state_exists(state); + let (_state, req_task) = + MgrCmdStateForward::from_state_exists(state, registry.clone()); if let Err(err) = dst_sender.send(req_task) { debug!("failed to forward: {:?}", err); } - if let Err(err) = wait_handle_sender.unbounded_send(wait_handle) { - warn!( - "failed to send wait_handle in handle_exists_task: {:?}", - err - ); - } continue; } @@ -817,10 +836,10 @@ where async fn handle_dump_pttl_task( mut dump_pttl_task_receiver: DumpPttlTaskReceiver, restore_task_sender: RestoreTaskSender, - wait_handle_sender: DstWaitHandleSender, dst_sender: Arc, cmd_task_factory: Arc, stats: Arc, + registry: Arc, ) { while let Some((state, reply_fut)) = dump_pttl_task_receiver.next().await { let res = reply_fut.await; @@ -836,17 +855,11 @@ where .importing_src_key_not_existed .fetch_add(1, Ordering::Relaxed); // The key also does not exist in source node. - let (_state, req_task, wait_handle) = - MgrCmdStateForward::from_state_dump_pttl(state); + let (_state, req_task) = + MgrCmdStateForward::from_state_dump_pttl(state, registry.clone()); if let Err(err) = dst_sender.send(req_task) { debug!("failed to send forward: {:?}", err); } - if let Err(err) = wait_handle_sender.unbounded_send(wait_handle) { - warn!( - "failed to send wait_handle in handle_dump_pttl_task: {:?}", - err - ); - } continue; } Err(err) => { @@ -860,19 +873,17 @@ where } }; - let (state, req_task, reply_receiver, wait_handles) = - MgrCmdStateRestoreForward::from_state_dump(state, entry, &(*cmd_task_factory)); + let (state, req_task, reply_receiver) = MgrCmdStateRestoreForward::from_state_dump( + state, + entry, + &(*cmd_task_factory), + registry.clone(), + ); if let Err(err) = dst_sender.send(req_task) { debug!("failed to send restore and forward: {:?}", err); } - for wait_handle in std::array::IntoIter::new(wait_handles) { - if let Err(err) = wait_handle_sender.unbounded_send(wait_handle) { - warn!("failed to send wait_handle to queue: {:?}", err); - } - } - if let Err(err) = restore_task_sender.unbounded_send((state, reply_receiver)) { debug!("failed to send restore task to queue: {:?}", err); } @@ -932,10 +943,10 @@ where umsync_task_sender: UmSyncTaskSender, src_proxy_sender: Arc, dst_sender: Arc, - wait_handle_sender: DstWaitHandleSender, key_lock: Arc, cmd_task_factory: Arc, stats: Arc, + registry: Arc, ) { while let Some(pending_task) = pending_umsync_task_receiver.next().await { let PendingUmSyncTask { @@ -978,7 +989,7 @@ where // The queue has been closed so the migration should be done. // We can safely send it to the dst_sender. let MgrCmdStateUmSync { inner_task, .. } = err.into_inner().0; - let (task, wait_handle) = WaitableTask::new_with_handle(inner_task); + let task = WaitableTask::new_with_registry(inner_task, registry.clone()); if let Err(err) = dst_sender.send(ReqTask::Simple(task)) { error!( "failed to send to dst sender after migration is done: {:?}", @@ -986,21 +997,15 @@ where ); continue; } - if let Err(err) = wait_handle_sender.unbounded_send(wait_handle) { - warn!( - "failed to send wait_handle in handle_pending_umsync_task: {:?}", - err - ); - } } } } async fn handle_umsync_task( mut umsync_task_receiver: UmSyncTaskReceiver, - wait_handle_sender: DstWaitHandleSender, dst_sender: Arc, stats: Arc, + registry: Arc, ) { while let Some((state, reply_fut)) = umsync_task_receiver.next().await { // The DUMP and RESTORE has already been processed in the source proxy. @@ -1034,16 +1039,10 @@ where _ => (), }; - let (_state, req_task, wait_handle) = MgrCmdStateForward::from_state_umsync(state); + let (_state, req_task) = MgrCmdStateForward::from_state_umsync(state, registry.clone()); if let Err(err) = dst_sender.send(req_task) { debug!("failed to forward: {:?}", err); } - if let Err(err) = wait_handle_sender.unbounded_send(wait_handle) { - warn!( - "failed to send wait_handle in handle_umsync_task: {:?}", - err - ); - } } } @@ -1065,12 +1064,6 @@ where } } - async fn handle_wait_handle(mut wait_handle_receiver: DstWaitHandleReceiver) { - while let Some(wait_handle) = wait_handle_receiver.next().await { - wait_handle.await.unwrap_or(()); - } - } - pub fn handle_cmd_task(&self, cmd_task: F::Task) -> Result<(), RetryError> { let (key, lock_slot) = match cmd_task.get_key() { Some(key) => (key.to_vec(), generate_lock_slot(key)),