diff --git a/edgeless_con/src/controller.rs b/edgeless_con/src/controller.rs index cf754bf9..52a0116b 100644 --- a/edgeless_con/src/controller.rs +++ b/edgeless_con/src/controller.rs @@ -49,7 +49,10 @@ pub(crate) enum DomainRegisterRequest { } pub(crate) enum InternalRequest { - Poll(), + Refresh( + // Reply Channel + tokio::sync::oneshot::Sender<()>, + ), } #[derive(Clone)] @@ -73,10 +76,11 @@ impl Controller { let refresh_task = Box::pin(async move { let mut sender = internal_sender; - let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(1)); loop { - interval.tick().await; - let _ = sender.send(InternalRequest::Poll()).await; + let (reply_sender, reply_receiver) = tokio::sync::oneshot::channel::<()>(); + let _ = sender.send(InternalRequest::Refresh(reply_sender)).await; + let _ = reply_receiver.await; + tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; } }); diff --git a/edgeless_con/src/controller/controller_task.rs b/edgeless_con/src/controller/controller_task.rs index 83eab60b..138e1129 100644 --- a/edgeless_con/src/controller/controller_task.rs +++ b/edgeless_con/src/controller/controller_task.rs @@ -113,8 +113,9 @@ impl ControllerTask { } Some(req) = self.internal_receiver.next() => { match req { - super::InternalRequest::Poll() => { + super::InternalRequest::Refresh(reply_sender) => { self.check_domains().await; + let _ = reply_sender.send(()); } } } diff --git a/edgeless_orc/src/node_register.rs b/edgeless_orc/src/node_register.rs index 44d9f7bd..9c0a9e15 100644 --- a/edgeless_orc/src/node_register.rs +++ b/edgeless_orc/src/node_register.rs @@ -18,7 +18,10 @@ pub enum NodeRegisterRequest { } pub(crate) enum InternalRequest { - Poll(), + Refresh( + // Reply Channel + tokio::sync::oneshot::Sender<()>, + ), } struct NodeRegisterEntry { @@ -44,10 +47,11 @@ impl NodeRegister { let refresh_task = Box::pin(async move { let mut sender = internal_sender; - let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(1)); loop { - interval.tick().await; - let _ = sender.send(InternalRequest::Poll()).await; + let (reply_sender, reply_receiver) = tokio::sync::oneshot::channel::<()>(); + let _ = sender.send(InternalRequest::Refresh(reply_sender)).await; + let _ = reply_receiver.await; + tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; } }); @@ -71,7 +75,7 @@ impl NodeRegister { tokio::select! { Some(req) = internal_receiver.next() => { match req { - InternalRequest::Poll() => { + InternalRequest::Refresh(reply_sender) => { // Find all nodes that are stale, i.e., which have not been // refreshed by their own indicated deadline. let mut stale_nodes = vec![]; @@ -88,6 +92,8 @@ impl NodeRegister { let _ = orchestrator_sender.send(super::orchestrator::OrchestratorRequest::DelNode(stale_node)).await; } + + let _ = reply_sender.send(()); } } }, diff --git a/edgeless_orc/src/orchestrator.rs b/edgeless_orc/src/orchestrator.rs index bd5e23f5..94d2270b 100644 --- a/edgeless_orc/src/orchestrator.rs +++ b/edgeless_orc/src/orchestrator.rs @@ -52,7 +52,10 @@ pub enum OrchestratorRequest { Vec, ), DelNode(uuid::Uuid), - Refresh(), + Refresh( + // Reply Channel + tokio::sync::oneshot::Sender<()>, + ), } pub struct OrchestratorClient { @@ -112,10 +115,11 @@ impl Orchestrator { let refresh_sender = sender.clone(); let refresh_task = Box::pin(async move { let mut refresh_sender = refresh_sender; - let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(1)); loop { - interval.tick().await; - let _ = refresh_sender.send(OrchestratorRequest::Refresh()).await; + let (reply_sender, reply_receiver) = tokio::sync::oneshot::channel::<()>(); + let _ = refresh_sender.send(OrchestratorRequest::Refresh(reply_sender)).await; + let _ = reply_receiver.await; + tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; } }); diff --git a/edgeless_orc/src/orchestrator/test.rs b/edgeless_orc/src/orchestrator/test.rs index fcc3a5e2..1fda5cc0 100644 --- a/edgeless_orc/src/orchestrator/test.rs +++ b/edgeless_orc/src/orchestrator/test.rs @@ -1251,7 +1251,10 @@ async fn test_recreate_fun_after_disconnect() { assert_eq!(Some(&1), num_events.get("patch-function")); for _ in 0..5 { - let _ = orc_sender.send(OrchestratorRequest::Refresh()).await; + let (reply_sender, reply_receiver) = tokio::sync::oneshot::channel::<()>(); + let _ = orc_sender.send(OrchestratorRequest::Refresh(reply_sender)).await; + let _ = reply_receiver.await; + let mut num_events = std::collections::HashMap::new(); loop { if let Some((_node_id, event)) = wait_for_events_if_any(&mut nodes).await { diff --git a/edgeless_orc/src/orchestrator_task.rs b/edgeless_orc/src/orchestrator_task.rs index f5a82405..d8db8578 100644 --- a/edgeless_orc/src/orchestrator_task.rs +++ b/edgeless_orc/src/orchestrator_task.rs @@ -128,8 +128,9 @@ impl OrchestratorTask { self.update_domain().await; self.refresh().await; } - crate::orchestrator::OrchestratorRequest::Refresh() => { + crate::orchestrator::OrchestratorRequest::Refresh(reply_sender) => { self.refresh().await; + let _ = reply_sender.send(()); } } }