From 8606c17cea0f36da9b8bcb65feb82d2dc9e13d9e Mon Sep 17 00:00:00 2001 From: Raphael Hetzel Date: Mon, 11 Sep 2023 19:35:05 +0000 Subject: [PATCH] Fix controller resource handling. Add basic controller tests. --- edgeless_api/src/orc/mod.rs | 2 +- edgeless_api/src/workflow_instance/mod.rs | 2 +- edgeless_con/src/controller.rs | 344 +++++++++++++------- edgeless_con/src/controller/test.rs | 370 ++++++++++++++++++++++ edgeless_con/src/lib.rs | 2 +- 5 files changed, 594 insertions(+), 126 deletions(-) create mode 100644 edgeless_con/src/controller/test.rs diff --git a/edgeless_api/src/orc/mod.rs b/edgeless_api/src/orc/mod.rs index 1b03cdd1..8363f660 100644 --- a/edgeless_api/src/orc/mod.rs +++ b/edgeless_api/src/orc/mod.rs @@ -1,3 +1,3 @@ -pub trait OrchestratorAPI { +pub trait OrchestratorAPI: Send { fn function_instance_api(&mut self) -> Box; } diff --git a/edgeless_api/src/workflow_instance/mod.rs b/edgeless_api/src/workflow_instance/mod.rs index 8bce0cc2..bca7af57 100644 --- a/edgeless_api/src/workflow_instance/mod.rs +++ b/edgeless_api/src/workflow_instance/mod.rs @@ -4,7 +4,7 @@ use crate::function_instance::InstanceId; const WORKFLOW_ID_NONE: uuid::Uuid = uuid::uuid!("00000000-0000-0000-0000-ffff00000000"); -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct WorkflowId { pub workflow_id: uuid::Uuid, } diff --git a/edgeless_con/src/controller.rs b/edgeless_con/src/controller.rs index bc9cf9e7..3e14f45f 100644 --- a/edgeless_con/src/controller.rs +++ b/edgeless_con/src/controller.rs @@ -1,11 +1,13 @@ -use std::str::FromStr; +use std::{collections::HashSet, str::FromStr}; use edgeless_api::workflow_instance::WorkflowInstance; use futures::{Future, SinkExt, StreamExt}; +#[cfg(test)] +pub mod test; + pub struct Controller { sender: futures::channel::mpsc::UnboundedSender, - // controller_settings: crate::EdgelessConSettings, } enum ControllerRequest { @@ -26,29 +28,40 @@ struct ResourceHandle { config_api: Box, } -impl Controller { - pub fn new(controller_settings: crate::EdgelessConSettings) -> (Self, std::pin::Pin + Send>>) { - let (sender, receiver) = futures::channel::mpsc::unbounded(); - - let main_task = Box::pin(async move { - Self::main_task(receiver, controller_settings).await; - }); +#[derive(Clone)] +struct ActiveWorkflow { + _desired_state: edgeless_api::workflow_instance::SpawnWorkflowRequest, + function_instances: std::collections::HashMap>, + resource_instances: std::collections::HashMap>, +} - (Controller { sender }, main_task) +impl ActiveWorkflow { + fn instances(&self, alias: &str) -> Vec { + let mut all_instances = Vec::new(); + if let Some(function_instances) = self.function_instances.get(alias) { + all_instances.append(&mut function_instances.clone()); + } + if let Some(resource_instances) = self.resource_instances.get(alias) { + all_instances.extend(&mut resource_instances.iter().map(|(_provider, id)| id.clone())); + } + return all_instances; } +} + +impl Controller { - async fn main_task(receiver: futures::channel::mpsc::UnboundedReceiver, settings: crate::EdgelessConSettings) { - let mut orc_clients = std::collections::HashMap::>::new(); + pub async fn new_from_config(controller_settings: crate::EdgelessConSettings) -> (Self, std::pin::Pin + Send>>) { + let mut orc_clients = std::collections::HashMap::>::new(); let mut resources = std::collections::HashMap::::new(); - for orc in &settings.orchestrators { + for orc in &controller_settings.orchestrators { orc_clients.insert( orc.domain_id.to_string(), Box::new(edgeless_api::grpc_impl::orc::OrchestratorAPIClient::new(&orc.orchestrator_url).await), ); } - for resource in &settings.resources { + for resource in &controller_settings.resources { resources.insert( resource.resource_provider_id.clone(), ResourceHandle { @@ -61,8 +74,31 @@ impl Controller { ); } + Self::new(orc_clients, resources) + } + + fn new( + orchestrators: std::collections::HashMap>, + resources: std::collections::HashMap, + ) -> (Self, std::pin::Pin + Send>>) { + let (sender, receiver) = futures::channel::mpsc::unbounded(); + + let main_task = Box::pin(async move { + Self::main_task(receiver, orchestrators, resources).await; + }); + + (Controller { sender }, main_task) + } + + async fn main_task( + receiver: futures::channel::mpsc::UnboundedReceiver, + orchestrators: std::collections::HashMap>, + resources: std::collections::HashMap, + ) { + let mut resources = resources; + let mut receiver = receiver; - let mut client = match orc_clients.into_values().next() { + let mut client = match orchestrators.into_values().next() { Some(c) => c, None => { return; @@ -70,121 +106,154 @@ impl Controller { }; let mut fn_client = client.function_instance_api(); - let mut active_workflows = std::collections::HashMap::>::new(); + let mut active_workflows = std::collections::HashMap::::new(); + while let Some(req) = receiver.next().await { match req { ControllerRequest::START(spawn_workflow_request, reply_sender) => { - let mut f_ids = std::collections::HashMap::::new(); - let mut to_patch = Vec::::new(); - for fun in spawn_workflow_request.workflow_functions.clone() { - let outputs: std::collections::HashMap = fun - .output_callback_definitions - .iter() - .filter_map(|(output_id, output_alias)| match f_ids.get(output_alias) { - Some(mapping) => Some((output_id.to_string(), mapping.instances[0].clone())), - None => None, - }) - .collect(); - if outputs.len() != fun.output_callback_definitions.len() { - to_patch.push(fun.function_alias.clone()); - } + let mut wf = ActiveWorkflow { + _desired_state: spawn_workflow_request.clone(), + function_instances: std::collections::HashMap::new(), + resource_instances: std::collections::HashMap::new(), + }; - let state_id = match fun.function_alias.as_str() { - "pinger" => uuid::Uuid::from_str("86699b23-6c24-4ca2-a2a0-b843b7c5e193").unwrap(), - "ponger" => uuid::Uuid::from_str("7dd076cc-2606-40ae-b46b-97628e0094be").unwrap(), - _ => uuid::Uuid::new_v4(), - }; - - let f_id = fn_client - .start(edgeless_api::function_instance::SpawnFunctionRequest { - instance_id: None, - code: fun.function_class_specification, - annotations: fun.function_annotations, - output_callback_definitions: outputs, - state_specification: edgeless_api::function_instance::StateSpecification { - state_id: state_id, - state_policy: edgeless_api::function_instance::StatePolicy::NodeLocal, - }, - }) - .await; - if let Ok(id) = f_id { - f_ids.insert( - fun.function_alias.clone(), - edgeless_api::workflow_instance::WorkflowFunctionMapping { - function_alias: fun.function_alias.clone(), - instances: vec![id], - }, - ); - } - } - for resource in spawn_workflow_request.workflow_resources { - if let Some((provider_id, handle)) = resources - .iter_mut() - .find(|(_id, spec)| spec.resource_type == resource.resource_class_type) - { - let wf_id = handle - .config_api - .start(edgeless_api::resource_configuration::ResourceInstanceSpecification { - provider_id: provider_id.clone(), - output_callback_definitions: resource - .output_callback_definitions - .iter() - .map(|(callback, alias)| (callback.to_string(), f_ids.get(alias).unwrap().instances[0].clone())) - .collect(), - configuration: resource.configurations.clone(), - }) - .await; - if let Ok(id) = wf_id { - log::info!("Insert {}", resource.alias.clone()); - f_ids.insert( - resource.alias.clone(), - edgeless_api::workflow_instance::WorkflowFunctionMapping { - function_alias: resource.alias.clone(), - instances: vec![id], - }, - ); - } + let mut to_upsert = HashSet::::new(); + to_upsert.extend(spawn_workflow_request.workflow_functions.iter().map(|f| f.function_alias.to_string())); + to_upsert.extend(spawn_workflow_request.workflow_resources.iter().map(|w| w.alias.to_string())); + + let mut iteration_count = 100; + + loop { + if iteration_count == 0 || to_upsert.len() == 0 { + break; } - } - for workflow_fid_alias in to_patch { - if let Some(mapping) = f_ids.get(&workflow_fid_alias) { - if let Some(config) = spawn_workflow_request - .workflow_functions - .iter() - .filter(|fun| fun.function_alias == workflow_fid_alias) - .next() - { - for instance in &mapping.instances { - let res = fn_client - .update_links(edgeless_api::function_instance::UpdateFunctionLinksRequest { - instance_id: Some(instance.clone()), - output_callback_definitions: config - .output_callback_definitions - .iter() - .filter_map(|(output_id, output_alias)| match f_ids.get(output_alias) { - Some(peer_function) => Some((output_id.to_string(), peer_function.instances[0].clone())), - None => None, - }) - .collect(), + iteration_count = iteration_count - 1; + + for fun in &spawn_workflow_request.workflow_functions { + if to_upsert.contains(&fun.function_alias) { + let outputs: std::collections::HashMap = fun + .output_callback_definitions + .iter() + .filter_map(|(output_id, output_alias)| { + let instances = wf.instances(&output_alias); + if instances.len() > 0 { + Some((output_id.to_string(), instances[0].clone())) + } else { + None + } + }) + .collect(); + + let all_outputs_mapped = outputs.len() == fun.output_callback_definitions.len(); + + let state_id = match fun.function_alias.as_str() { + "pinger" => uuid::Uuid::from_str("86699b23-6c24-4ca2-a2a0-b843b7c5e193").unwrap(), + "ponger" => uuid::Uuid::from_str("7dd076cc-2606-40ae-b46b-97628e0094be").unwrap(), + _ => uuid::Uuid::new_v4(), + }; + + // Update spawned instance + if let Some(existing_instances) = wf.function_instances.get(&fun.function_alias) { + for instance in existing_instances { + let res = fn_client + .update_links(edgeless_api::function_instance::UpdateFunctionLinksRequest { + instance_id: Some(instance.clone()), + output_callback_definitions: outputs.clone(), + }) + .await; + match res { + Ok(_) => { + if all_outputs_mapped { + to_upsert.remove(&fun.function_alias); + } + } + Err(err) => { + log::error!("Unhandled exception during update: {:?}", err); + } + } + } + } else { + // Create new instance + let f_id = fn_client + .start(edgeless_api::function_instance::SpawnFunctionRequest { + instance_id: None, + code: fun.function_class_specification.clone(), + annotations: fun.function_annotations.clone(), + output_callback_definitions: outputs.clone(), + state_specification: edgeless_api::function_instance::StateSpecification { + state_id: state_id, + state_policy: edgeless_api::function_instance::StatePolicy::NodeLocal, + }, }) .await; - match res { - Ok(_) => {} - Err(err) => { - log::error!("Unhandled: {:?}", err); + + if let Ok(id) = f_id { + wf.function_instances.insert(fun.function_alias.clone(), vec![id]); + if all_outputs_mapped { + to_upsert.remove(&fun.function_alias); + } + } + } + } + } + + for resource in &spawn_workflow_request.workflow_resources { + if to_upsert.contains(&resource.alias) { + let output_mapping: std::collections::HashMap = resource + .output_callback_definitions + .iter() + .map(|(callback, alias)| (callback.to_string(), wf.function_instances.get(alias).unwrap()[0].clone())) + .collect(); + + // Update resource instance + if let Some(_instances) = wf.resource_instances.get(&resource.alias) { + // resources currently don't have an update function. + todo!(); + } else { + // Create new resource instance + if let Some((provider_id, handle)) = resources + .iter_mut() + .find(|(_id, spec)| spec.resource_type == resource.resource_class_type) + { + let wf_id = handle + .config_api + .start(edgeless_api::resource_configuration::ResourceInstanceSpecification { + provider_id: provider_id.clone(), + output_callback_definitions: output_mapping.clone(), + configuration: resource.configurations.clone(), + }) + .await; + + if let Ok(id) = wf_id { + wf.resource_instances.insert(resource.alias.clone(), vec![(provider_id.clone(), id)]); + if output_mapping.len() == resource.output_callback_definitions.len() { + to_upsert.remove(&resource.alias); + } } } } } } } - active_workflows.insert( - spawn_workflow_request.workflow_id.workflow_id.to_string(), - f_ids.clone().into_values().collect(), - ); + + // Everything should be mapped now. + // Fails if there is invalid mappings or large dependency loops. + if to_upsert.len() > 0 { + reply_sender.send(Err(anyhow::anyhow!("Failed to resolve alias-links."))).unwrap(); + continue; + } + + active_workflows.insert(spawn_workflow_request.workflow_id.clone(), wf.clone()); match reply_sender.send(Ok(edgeless_api::workflow_instance::WorkflowInstance { workflow_id: spawn_workflow_request.workflow_id, - functions: f_ids.into_values().collect(), + functions: wf + .function_instances + .iter() + .map(|(alias, instances)| edgeless_api::workflow_instance::WorkflowFunctionMapping { + function_alias: alias.to_string(), + instances: instances.clone(), + }) + .collect(), })) { Ok(_) => {} Err(err) => { @@ -193,9 +262,9 @@ impl Controller { } } ControllerRequest::STOP(workflow_id) => { - if let Some(workflow_functions) = active_workflows.remove(&workflow_id.workflow_id.to_string()) { - for mapping in workflow_functions { - for f_id in mapping.instances { + if let Some(workflow_to_remove) = active_workflows.remove(&workflow_id) { + for (_alias, instances) in workflow_to_remove.function_instances { + for f_id in instances { match fn_client.stop(f_id).await { Ok(_) => {} Err(err) => { @@ -204,6 +273,21 @@ impl Controller { } } } + for (_alias, instances) in workflow_to_remove.resource_instances { + for (provider, instance_id) in instances { + match resources.get_mut(&provider) { + Some(provider) => match provider.config_api.stop(instance_id).await { + Ok(()) => {} + Err(err) => { + log::warn!("Stop resource failed: {:?}", err); + } + }, + None => { + log::warn!("Provider for previously spawned resource does not exist (anymore)."); + } + } + } + } } else { log::warn!("cannot stop non-existing workflow: {:?}", workflow_id); } @@ -211,18 +295,32 @@ impl Controller { ControllerRequest::LIST(workflow_id, reply_sender) => { let mut ret: Vec = vec![]; if let Some(w_id) = workflow_id.is_valid() { - if let Some(f_ids) = active_workflows.get(&w_id.to_string()) { + if let Some(wf) = active_workflows.get(&w_id) { ret = vec![WorkflowInstance { workflow_id: w_id.clone(), - functions: f_ids.clone(), + functions: wf + .function_instances + .iter() + .map(|(alias, instances)| edgeless_api::workflow_instance::WorkflowFunctionMapping { + function_alias: alias.to_string(), + instances: instances.clone(), + }) + .collect(), }]; } } else { ret = active_workflows .iter() - .map(|(w_id, f_ids)| WorkflowInstance { - workflow_id: edgeless_api::workflow_instance::WorkflowId::from_string(w_id.as_str()), - functions: f_ids.clone(), + .map(|(w_id, wf)| WorkflowInstance { + workflow_id: w_id.clone(), + functions: wf + .function_instances + .iter() + .map(|(alias, instances)| edgeless_api::workflow_instance::WorkflowFunctionMapping { + function_alias: alias.to_string(), + instances: instances.clone(), + }) + .collect(), }) .collect(); } diff --git a/edgeless_con/src/controller/test.rs b/edgeless_con/src/controller/test.rs new file mode 100644 index 00000000..26797e70 --- /dev/null +++ b/edgeless_con/src/controller/test.rs @@ -0,0 +1,370 @@ +use super::*; + +enum MockFunctionInstanceEvent { + Start( + ( + // this is the id passed from the orchestrator to the controller + edgeless_api::function_instance::InstanceId, + edgeless_api::function_instance::SpawnFunctionRequest, + ), + ), + Stop(edgeless_api::function_instance::InstanceId), + Update(edgeless_api::function_instance::UpdateFunctionLinksRequest), +} + +struct MockOrchestrator { + node_id: uuid::Uuid, + sender: futures::channel::mpsc::UnboundedSender, +} + +impl edgeless_api::orc::OrchestratorAPI for MockOrchestrator { + fn function_instance_api(&mut self) -> Box { + Box::new(MockFunctionInstanceAPI { + node_id: self.node_id.clone(), + sender: self.sender.clone(), + }) + } +} + +#[derive(Clone)] +struct MockFunctionInstanceAPI { + node_id: uuid::Uuid, + sender: futures::channel::mpsc::UnboundedSender, +} + +#[async_trait::async_trait] +impl edgeless_api::function_instance::FunctionInstanceAPI for MockFunctionInstanceAPI { + async fn start( + &mut self, + spawn_request: edgeless_api::function_instance::SpawnFunctionRequest, + ) -> anyhow::Result { + let new_id = edgeless_api::function_instance::InstanceId::new(self.node_id); + self.sender + .send(MockFunctionInstanceEvent::Start((new_id.clone(), spawn_request))) + .await + .unwrap(); + Ok(new_id) + } + async fn stop(&mut self, id: edgeless_api::function_instance::InstanceId) -> anyhow::Result<()> { + self.sender.send(MockFunctionInstanceEvent::Stop(id)).await.unwrap(); + Ok(()) + } + async fn update_links(&mut self, update: edgeless_api::function_instance::UpdateFunctionLinksRequest) -> anyhow::Result<()> { + self.sender.send(MockFunctionInstanceEvent::Update(update)).await.unwrap(); + Ok(()) + } +} + +enum MockResourceEvent { + Start( + ( + // this is the id passed from the orchestrator to the controller + edgeless_api::function_instance::InstanceId, + edgeless_api::resource_configuration::ResourceInstanceSpecification, + ), + ), + Stop(edgeless_api::function_instance::InstanceId), +} + +struct MockResourceProvider { + node_id: uuid::Uuid, + sender: futures::channel::mpsc::UnboundedSender, +} + +#[async_trait::async_trait] +impl edgeless_api::resource_configuration::ResourceConfigurationAPI for MockResourceProvider { + async fn start( + &mut self, + instance_specification: edgeless_api::resource_configuration::ResourceInstanceSpecification, + ) -> anyhow::Result { + let new_id = edgeless_api::function_instance::InstanceId::new(self.node_id); + self.sender + .send(MockResourceEvent::Start((new_id.clone(), instance_specification))) + .await + .unwrap(); + Ok(new_id) + } + async fn stop(&mut self, resource_id: edgeless_api::function_instance::InstanceId) -> anyhow::Result<()> { + self.sender.send(MockResourceEvent::Stop(resource_id)).await.unwrap(); + Ok(()) + } +} + +async fn test_setup() -> ( + Box, + futures::channel::mpsc::UnboundedReceiver, + futures::channel::mpsc::UnboundedReceiver, +) { + let (mock_orc_sender, mock_orc_receiver) = futures::channel::mpsc::unbounded::(); + let mock_orc = MockOrchestrator { + node_id: uuid::Uuid::new_v4(), + sender: mock_orc_sender, + }; + + let (mock_res_sender, mock_res_receiver) = futures::channel::mpsc::unbounded::(); + let mock_res = MockResourceProvider { + node_id: uuid::Uuid::new_v4(), + sender: mock_res_sender, + }; + + let orc_clients = std::collections::HashMap::>::from([( + "domain-1".to_string(), + Box::new(mock_orc) as Box, + )]); + let resources = std::collections::HashMap::::from([( + "resource-1".to_string(), + ResourceHandle { + resource_type: "test-res".to_string(), + _output_callback_declarations: vec!["test_out".to_string()], + config_api: Box::new(mock_res) as Box, + }, + )]); + + let (mut controller, controller_task) = Controller::new(orc_clients, resources); + tokio::spawn(controller_task); + let mut client = controller.get_api_client(); + let wf_client = client.workflow_instance_api(); + + (wf_client, mock_orc_receiver, mock_res_receiver) +} + +#[tokio::test] +async fn single_function_start_stop() { + let (mut wf_client, mut mock_orc_receiver, mut mock_res_receiver) = test_setup().await; + + assert!(mock_orc_receiver.try_next().is_err()); + assert!(mock_res_receiver.try_next().is_err()); + + let wf_id = edgeless_api::workflow_instance::WorkflowId { + workflow_id: uuid::Uuid::new_v4(), + }; + + let returned_id = wf_client + .start(edgeless_api::workflow_instance::SpawnWorkflowRequest { + workflow_id: wf_id.clone(), + workflow_functions: vec![edgeless_api::workflow_instance::WorkflowFunction { + function_alias: "f1".to_string(), + function_class_specification: edgeless_api::function_instance::FunctionClassSpecification { + function_class_id: "fc1".to_string(), + function_class_type: "RUST_WASM".to_string(), + function_class_version: "0.1".to_string(), + function_class_inlude_code: vec![], + output_callback_declarations: vec![], + }, + output_callback_definitions: std::collections::HashMap::new(), + function_annotations: std::collections::HashMap::new(), + }], + workflow_resources: vec![], + workflow_annotations: std::collections::HashMap::new(), + }) + .await + .unwrap(); + + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + let res = mock_orc_receiver.try_next().unwrap().unwrap(); + if let MockFunctionInstanceEvent::Start((id, _spawn_req)) = res { + assert_eq!(returned_id.functions[0].instances[0], id); + } else { + panic!(); + } + assert!(mock_res_receiver.try_next().is_err()); + + assert!(mock_orc_receiver.try_next().is_err()); + + wf_client.stop(wf_id).await.unwrap(); + + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + let stop_res = mock_orc_receiver.try_next().unwrap().unwrap(); + if let MockFunctionInstanceEvent::Stop(id) = stop_res { + assert_eq!(returned_id.functions[0].instances[0], id); + } else { + panic!(); + } + + assert!(mock_orc_receiver.try_next().is_err()); + assert!(mock_res_receiver.try_next().is_err()); +} + +#[tokio::test] +async fn resource_to_function_start_stop() { + let (mut wf_client, mut mock_orc_receiver, mut mock_res_receiver) = test_setup().await; + + assert!(mock_orc_receiver.try_next().is_err()); + assert!(mock_res_receiver.try_next().is_err()); + + let wf_id = edgeless_api::workflow_instance::WorkflowId { + workflow_id: uuid::Uuid::new_v4(), + }; + + let returned_id = wf_client + .start(edgeless_api::workflow_instance::SpawnWorkflowRequest { + workflow_id: wf_id.clone(), + workflow_functions: vec![edgeless_api::workflow_instance::WorkflowFunction { + function_alias: "f1".to_string(), + function_class_specification: edgeless_api::function_instance::FunctionClassSpecification { + function_class_id: "fc1".to_string(), + function_class_type: "RUST_WASM".to_string(), + function_class_version: "0.1".to_string(), + function_class_inlude_code: vec![], + output_callback_declarations: vec![], + }, + output_callback_definitions: std::collections::HashMap::new(), + function_annotations: std::collections::HashMap::new(), + }], + workflow_resources: vec![edgeless_api::workflow_instance::WorkflowResource { + alias: "r1".to_string(), + resource_class_type: "test-res".to_string(), + output_callback_definitions: std::collections::HashMap::from([("test_out".to_string(), "f1".to_string())]), + configurations: std::collections::HashMap::new(), + }], + workflow_annotations: std::collections::HashMap::new(), + }) + .await + .unwrap(); + + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + let res = mock_orc_receiver.try_next().unwrap().unwrap(); + if let MockFunctionInstanceEvent::Start((id, _spawn_req)) = res { + assert_eq!(returned_id.functions[0].instances[0], id); + } else { + panic!(); + } + + let resource_res = mock_res_receiver.try_next().unwrap().unwrap(); + if let MockResourceEvent::Start((_id, spawn_req)) = resource_res { + assert_eq!( + *spawn_req.output_callback_definitions.get("test_out").unwrap(), + returned_id.functions[0].instances[0] + ); + } else { + panic!(); + } + + assert!(mock_orc_receiver.try_next().is_err()); + assert!(mock_res_receiver.try_next().is_err()); + + wf_client.stop(wf_id).await.unwrap(); + + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + let stop_res = mock_orc_receiver.try_next().unwrap().unwrap(); + if let MockFunctionInstanceEvent::Stop(id) = stop_res { + assert_eq!(returned_id.functions[0].instances[0], id); + } else { + panic!(); + } + + let stop_resource_res = mock_res_receiver.try_next().unwrap().unwrap(); + if let MockResourceEvent::Stop(_id) = stop_resource_res { + } else { + panic!(); + } + + assert!(mock_orc_receiver.try_next().is_err()); + assert!(mock_res_receiver.try_next().is_err()); +} + +#[tokio::test] +async fn function_link_loop_start_stop() { + let (mut wf_client, mut mock_orc_receiver, mut mock_res_receiver) = test_setup().await; + + assert!(mock_orc_receiver.try_next().is_err()); + assert!(mock_res_receiver.try_next().is_err()); + + let wf_id = edgeless_api::workflow_instance::WorkflowId { + workflow_id: uuid::Uuid::new_v4(), + }; + + let returned_wf_state = wf_client + .start(edgeless_api::workflow_instance::SpawnWorkflowRequest { + workflow_id: wf_id.clone(), + workflow_functions: vec![ + edgeless_api::workflow_instance::WorkflowFunction { + function_alias: "f1".to_string(), + function_class_specification: edgeless_api::function_instance::FunctionClassSpecification { + function_class_id: "fc1".to_string(), + function_class_type: "RUST_WASM".to_string(), + function_class_version: "0.1".to_string(), + function_class_inlude_code: vec![], + output_callback_declarations: vec!["output-1".to_string()], + }, + output_callback_definitions: std::collections::HashMap::from([("output-1".to_string(), "f2".to_string())]), + function_annotations: std::collections::HashMap::new(), + }, + edgeless_api::workflow_instance::WorkflowFunction { + function_alias: "f2".to_string(), + function_class_specification: edgeless_api::function_instance::FunctionClassSpecification { + function_class_id: "fc2".to_string(), + function_class_type: "RUST_WASM".to_string(), + function_class_version: "0.1".to_string(), + function_class_inlude_code: vec![], + output_callback_declarations: vec!["output-2".to_string()], + }, + output_callback_definitions: std::collections::HashMap::from([("output-2".to_string(), "f1".to_string())]), + function_annotations: std::collections::HashMap::new(), + }, + ], + workflow_resources: vec![], + workflow_annotations: std::collections::HashMap::new(), + }) + .await + .unwrap(); + + let fids: std::collections::HashSet<_> = returned_wf_state + .functions + .iter() + .flat_map(|instances| instances.instances.clone()) + .collect(); + let to_patch: Option; + + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + let res = mock_orc_receiver.try_next().unwrap().unwrap(); + if let MockFunctionInstanceEvent::Start((id, spawn_req)) = res { + assert!(fids.contains(&id)); + assert_eq!(spawn_req.output_callback_definitions.len(), 0); + to_patch = Some(id); + } else { + panic!(); + } + let res2 = mock_orc_receiver.try_next().unwrap().unwrap(); + if let MockFunctionInstanceEvent::Start((id, spawn_req)) = res2 { + assert!(fids.contains(&id)); + assert_eq!(spawn_req.output_callback_definitions.len(), 1); + } else { + panic!(); + } + let res3 = mock_orc_receiver.try_next().unwrap().unwrap(); + if let MockFunctionInstanceEvent::Update(update_req) = res3 { + assert_eq!(update_req.instance_id.unwrap(), to_patch.unwrap()); + assert_eq!(update_req.output_callback_definitions.len(), 1); + } else { + panic!(); + } + + assert!(mock_res_receiver.try_next().is_err()); + assert!(mock_orc_receiver.try_next().is_err()); + + wf_client.stop(wf_id).await.unwrap(); + + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + let stop_res = mock_orc_receiver.try_next().unwrap().unwrap(); + if let MockFunctionInstanceEvent::Stop(id) = stop_res { + assert!(fids.contains(&id)); + } else { + panic!(); + } + + let stop_res2 = mock_orc_receiver.try_next().unwrap().unwrap(); + if let MockFunctionInstanceEvent::Stop(id) = stop_res2 { + assert!(fids.contains(&id)); + } else { + panic!(); + } + + assert!(mock_res_receiver.try_next().is_err()); +} diff --git a/edgeless_con/src/lib.rs b/edgeless_con/src/lib.rs index 90ca29f7..b8a83f84 100644 --- a/edgeless_con/src/lib.rs +++ b/edgeless_con/src/lib.rs @@ -25,7 +25,7 @@ pub async fn edgeless_con_main(settings: EdgelessConSettings) { log::info!("Starting Edgeless Controller at {}", settings.controller_url); log::debug!("Settings: {:?}", settings); - let (mut controller, controller_task) = controller::Controller::new(settings.clone()); + let (mut controller, controller_task) = controller::Controller::new_from_config(settings.clone()).await; let server_task = edgeless_api::grpc_impl::controller::WorkflowInstanceAPIServer::run(controller.get_api_client(), settings.controller_url.clone());