diff --git a/src/lib.rs b/src/lib.rs index 504294410..55ad98cd0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,4 @@ -pub mod network; +pub mod model; pub mod utils; pub mod math; pub mod updates; \ No newline at end of file diff --git a/src/network.rs b/src/model.rs similarity index 66% rename from src/network.rs rename to src/model.rs index 61408539e..2b5845f80 100644 --- a/src/network.rs +++ b/src/model.rs @@ -1,18 +1,19 @@ use std::collections::HashMap; -use crate::updates::posterior; +use crate::updates::observations::observation_update; +use crate::utils::get_update_sequence; use pyo3::{prelude::*, types::{PyList, PyDict}}; #[derive(Debug)] #[pyclass] pub struct AdjacencyLists{ #[pyo3(get, set)] - pub value_parents: Option, + pub value_parents: Option>, #[pyo3(get, set)] - pub value_children: Option, + pub value_children: Option>, #[pyo3(get, set)] - pub volatility_parents: Option, + pub volatility_parents: Option>, #[pyo3(get, set)] - pub volatility_children: Option, + pub volatility_children: Option>, } #[derive(Debug, Clone)] pub struct ContinuousStateNode{ @@ -38,12 +39,22 @@ pub enum Node { Exponential(ExponentialFamiliyStateNode), } +// Create a default signature for update functions +pub type FnType = fn(&mut Network, usize); + +#[derive(Debug)] +pub struct UpdateSequence { + pub predictions: Vec<(usize, FnType)>, + pub updates: Vec<(usize, FnType)>, +} + #[derive(Debug)] #[pyclass] pub struct Network{ pub nodes: HashMap, pub edges: Vec, pub inputs: Vec, + pub update_sequence: UpdateSequence, } #[pymethods] @@ -56,6 +67,7 @@ impl Network { nodes: HashMap::new(), edges: Vec::new(), inputs: Vec::new(), + update_sequence: UpdateSequence {predictions: Vec::new(), updates: Vec::new()} } } @@ -68,12 +80,17 @@ impl Network { /// * `volatility_children` - The indexes of the node's volatility children. /// * `volatility_parents` - The indexes of the node's volatility parents. #[pyo3(signature = (kind="continuous-state", value_parents=None, value_children=None, volatility_children=None, volatility_parents=None))] - pub fn add_nodes(&mut self, kind: &str, value_parents: Option, - value_children: Option, volatility_children: Option, - volatility_parents: Option) { - + pub fn add_nodes(&mut self, kind: &str, value_parents: Option>, + value_children: Option>, volatility_children: Option>, + volatility_parents: Option>) { + // the node ID is equal to the number of nodes already in the network let node_id: usize = self.nodes.len(); + + // if this node has no children, this is an input node + if (value_children == None) & (volatility_children == None) { + self.inputs.push(node_id); + } let edges = AdjacencyLists{ value_parents: value_children, @@ -92,10 +109,6 @@ impl Network { self.nodes.insert(node_id, node); self.edges.push(edges); - // if this node has no children, this is an input node - if (value_children == None) & (volatility_children == None) { - self.inputs.push(node_id); - } } else if kind == "exponential-state" { let exponential_node: ExponentialFamiliyStateNode = ExponentialFamiliyStateNode{ mean: 0.0, expected_mean: 0.0, nus: 0.0, xis: [0.0, 0.0] @@ -103,78 +116,41 @@ impl Network { let node = Node::Exponential(exponential_node); self.nodes.insert(node_id, node); self.edges.push(edges); - - // if this node has no children, this is an input node - if (value_children == None) & (volatility_children == None) { - self.inputs.push(node_id); - } } else { println!("Invalid type of node provided ({}).", kind); } } - pub fn prediction_error(&mut self, node_idx: usize) { - - // get the observation value - let mean; - match self.nodes[&node_idx] { - Node::Continuous(ref node) => { - mean = node.mean; - } - Node::Exponential(ref node) => { - mean = node.mean; - } - } - - let value_parent_idx = &self.edges[node_idx].value_parents; - match value_parent_idx { - Some(idx) => { - match self.nodes.get_mut(idx) { - Some(Node::Continuous(ref mut parent)) => { - parent.mean = mean - } - Some(Node::Exponential(ref mut parent)) => { - parent.mean = mean - } - None => println!("No prediction error for this type of node."), - } - } - None => println!("No value parent"), - } - } - - pub fn posterior_update(&mut self, node_idx: usize, observation: f64) { - - match self.nodes.get_mut(&node_idx) { - Some(Node::Continuous(ref mut node)) => { - node.mean = observation - } - Some(Node::Exponential(ref mut node)) => { - posterior::exponential::posterior_update_exponential(node) - } - None => println!("No posterior update for this type of node.") - } + pub fn get_update_sequence(&mut self) { + self.update_sequence = get_update_sequence(self); } - /// One time step belief propagation. + /// Single time slice belief propagation. /// /// # Arguments /// * `observations` - A vector of values, each value is one new observation associated /// with one node. - pub fn belief_propagation(&mut self, observations: Vec) { + pub fn belief_propagation(&mut self, observations_set: Vec) { - // 1. prediction propagation + let predictions = self.update_sequence.predictions.clone(); + let updates = self.update_sequence.updates.clone(); - for i in 0..observations.len() { - - let input_node_idx = self.inputs[i]; - // 2. inject the observations into the input nodes - self.posterior_update(input_node_idx, observations[i]); - // 3. posterior update - prediction errors propagation - self.prediction_error(input_node_idx); + // 1. prediction steps + for (idx, step) in predictions.iter() { + step(self, *idx); + } + + // 2. observation steps + for (i, observations) in observations_set.iter().enumerate() { + let idx = self.inputs[i]; + observation_update(self, idx, *observations); + } + + // 3. update steps + for (idx, step) in updates.iter() { + step(self, *idx); } - } /// Add a sequence of observations. @@ -203,17 +179,16 @@ impl Network { for s in &self.edges { // Create a new Python dictionary for each MyStruct let py_dict = PyDict::new(py); - py_dict.set_item("value_parents", s.value_parents)?; - py_dict.set_item("value_children", s.value_children)?; - py_dict.set_item("volatility_parents", s.volatility_parents)?; - py_dict.set_item("volatility_children", s.volatility_children)?; + py_dict.set_item("value_parents", &s.value_parents)?; + py_dict.set_item("value_children", &s.value_children)?; + py_dict.set_item("volatility_parents", &s.volatility_parents)?; + py_dict.set_item("volatility_children", &s.volatility_children)?; // Add the dictionary to the list py_list.append(py_dict)?; } Ok(py_list) } - } // Create a module to expose the class to Python diff --git a/src/updates/mod.rs b/src/updates/mod.rs index dd8404066..28479cc79 100644 --- a/src/updates/mod.rs +++ b/src/updates/mod.rs @@ -1,3 +1,4 @@ pub mod posterior; pub mod prediction; -pub mod prediction_error; \ No newline at end of file +pub mod prediction_error; +pub mod observations; \ No newline at end of file diff --git a/src/updates/observations.rs b/src/updates/observations.rs new file mode 100644 index 000000000..1d64ec7a9 --- /dev/null +++ b/src/updates/observations.rs @@ -0,0 +1,21 @@ +use crate::model::{Network, Node}; + + +/// Inject new observations into an input node +/// +/// # Arguments +/// * `network` - The main network containing the node. +/// * `node_idx` - The input node index. +/// * `observations` - The new observations. +/// +/// # Returns +/// * `network` - The network after message passing. +pub fn observation_update(network: &mut Network, node_idx: usize, observations: f64) { + + match network.nodes.get_mut(&node_idx) { + Some(Node::Exponential(ref mut node)) => { + node.mean = observations; + } + _ => (), + } +} \ No newline at end of file diff --git a/src/updates/posterior/continuous.rs b/src/updates/posterior/continuous.rs index e69de29bb..b1941915d 100644 --- a/src/updates/posterior/continuous.rs +++ b/src/updates/posterior/continuous.rs @@ -0,0 +1,13 @@ +use crate::model::Network; + +/// Posterior update from a continuous state node +/// +/// # Arguments +/// * `network` - The main network containing the node. +/// * `node_idx` - The node index. +/// +/// # Returns +/// * `network` - The network after message passing. +pub fn posterior_update_continuous_state_node(network: &mut Network, node_idx: usize) { + let a = 1; +} \ No newline at end of file diff --git a/src/updates/posterior/exponential.rs b/src/updates/posterior/exponential.rs index 97b07e8cf..20dcf2d15 100644 --- a/src/updates/posterior/exponential.rs +++ b/src/updates/posterior/exponential.rs @@ -1,9 +1,23 @@ -use crate::network::ExponentialFamiliyStateNode; +use crate::model::{Network, Node}; use crate::math::sufficient_statistics; -pub fn posterior_update_exponential(node: &mut ExponentialFamiliyStateNode) { - let suf_stats = sufficient_statistics(&node.mean); - for i in 0..suf_stats.len() { - node.xis[i] = node.xis[i] + (1.0 / (1.0 + node.nus)) * (suf_stats[i] - node.xis[i]); +/// Updating an exponential family state node +/// +/// # Arguments +/// * `network` - The main network containing the node. +/// * `node_idx` - The node index. +/// +/// # Returns +/// * `network` - The network after message passing. +pub fn posterior_update_exponential_state_node(network: &mut Network, node_idx: usize) { + + match network.nodes.get_mut(&node_idx) { + Some(Node::Exponential(ref mut node)) => { + let suf_stats = sufficient_statistics(&node.mean); + for i in 0..suf_stats.len() { + node.xis[i] = node.xis[i] + (1.0 / (1.0 + node.nus)) * (suf_stats[i] - node.xis[i]); + } + } + _ => (), } } \ No newline at end of file diff --git a/src/updates/prediction/continuous.rs b/src/updates/prediction/continuous.rs new file mode 100644 index 000000000..f40d2c35e --- /dev/null +++ b/src/updates/prediction/continuous.rs @@ -0,0 +1,13 @@ +use crate::model::Network; + +/// Prediction from a continuous state node +/// +/// # Arguments +/// * `network` - The main network containing the node. +/// * `node_idx` - The node index. +/// +/// # Returns +/// * `network` - The network after message passing. +pub fn prediction_continuous_state_node(network: &mut Network, node_idx: usize) { + let a = 1; +} \ No newline at end of file diff --git a/src/updates/prediction/mod.rs b/src/updates/prediction/mod.rs index e69de29bb..6817d49e7 100644 --- a/src/updates/prediction/mod.rs +++ b/src/updates/prediction/mod.rs @@ -0,0 +1 @@ +pub mod continuous; \ No newline at end of file diff --git a/src/updates/prediction_error/continuous.rs b/src/updates/prediction_error/continuous.rs new file mode 100644 index 000000000..96e8990ca --- /dev/null +++ b/src/updates/prediction_error/continuous.rs @@ -0,0 +1,13 @@ +use crate::model::Network; + +/// Prediction error from a continuous state node +/// +/// # Arguments +/// * `network` - The main network containing the node. +/// * `node_idx` - The node index. +/// +/// # Returns +/// * `network` - The network after message passing. +pub fn prediction_error_continuous_state_node(network: &mut Network, node_idx: usize) { + let a = 1; +} \ No newline at end of file diff --git a/src/updates/prediction_error/mod.rs b/src/updates/prediction_error/mod.rs index 69e3fea51..c53f4e7ae 100644 --- a/src/updates/prediction_error/mod.rs +++ b/src/updates/prediction_error/mod.rs @@ -1 +1 @@ -pub mod nodes; +pub mod continuous; diff --git a/src/updates/prediction_error/nodes/continuous.rs b/src/updates/prediction_error/nodes/continuous.rs deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/updates/prediction_error/nodes/mod.rs b/src/updates/prediction_error/nodes/mod.rs deleted file mode 100644 index 6817d49e7..000000000 --- a/src/updates/prediction_error/nodes/mod.rs +++ /dev/null @@ -1 +0,0 @@ -pub mod continuous; \ No newline at end of file diff --git a/src/utils.rs b/src/utils.rs index ab5da35a3..178d1ba6a 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,51 +1,198 @@ -use crate::network::Network; +use crate::{model::{FnType, Network, Node, UpdateSequence}, updates::{posterior::{continuous::posterior_update_continuous_state_node, exponential::posterior_update_exponential_state_node}, prediction::continuous::prediction_continuous_state_node, prediction_error::continuous::prediction_error_continuous_state_node}}; +pub fn get_update_sequence(network: &Network) -> UpdateSequence { + let predictions = get_predictions_sequence(network); + let updates = get_updates_sequence(network); -pub fn get_update_order(network: Network) -> Vec { - - let mut update_list = Vec::new(); + // return the update sequence + let update_sequence = UpdateSequence {predictions: predictions, updates: updates}; + update_sequence +} - // list all nodes availables in the network - let mut nodes_idxs: Vec = network.nodes.keys().cloned().collect(); - // remove the input nodes - nodes_idxs.retain(|x| !network.inputs.contains(x)); +pub fn get_predictions_sequence(network: &Network) -> Vec<(usize, FnType)> { - let mut remaining = nodes_idxs.len(); + let mut predictions : Vec<(usize, FnType)> = Vec::new(); + + // 1. get prediction sequence ------------------------------------------------------ + + // list all nodes availables in the network + let mut nodes_idxs: Vec = network.nodes.keys().cloned().collect(); - while remaining > 0 { + // iterate over all nodes and add the prediction step if all criteria are met + let mut n_remaining = nodes_idxs.len(); + while n_remaining > 0 { + + // were we able to add an update step in the list on that iteration? let mut has_update = false; - // loop over all available + + // loop over all the remaining nodes for i in 0..nodes_idxs.len() { let idx = nodes_idxs[i]; - let value_children_idxs = network.edges[idx].value_children; - - // check if there is any element in value children - // that is found in the to-be-updated list of nodes - let contains_common = value_children_idxs.iter().any(|&item| nodes_idxs.contains(&item)); - + + // list the node's parents + let value_parents_idxs = &network.edges[idx].value_parents; + let volatility_parents_idxs = &network.edges[idx].volatility_parents; + + let parents_idxs = match (value_parents_idxs, volatility_parents_idxs) { + // If both are Some, merge the vectors + (Some(ref vec1), Some(ref vec2)) => { + // Create a new vector by merging the two + let merged_vec: Vec = vec1.iter().chain(vec2.iter()).cloned().collect(); + Some(merged_vec) // Return the merged vector wrapped in Some + } + // If one is Some and the other is None, return the one that's Some + (Some(vec), None) | (None, Some(vec)) => Some(vec.clone()), + // If both are None, return None + (None, None) => None, + }; + + // check if there is any parent node that is still found in the to-be-updated list + let contains_common = match parents_idxs { + Some(vec) => vec.iter().any(|item| nodes_idxs.contains(item)), + None => false + }; + + // if all parents have processed their prediction, this one can be added if !(contains_common) { // add the node in the update list - update_list.push(idx); + match network.nodes.get(&idx) { + Some(Node::Continuous(_)) => { + predictions.push((idx, prediction_continuous_state_node)); + } + Some(Node::Exponential(_)) => (), + None => () + + } - // remove the parent from the availables nodes list + // remove the node from the to-be-updated list nodes_idxs.retain(|&x| x != idx); - - remaining -= 1; + n_remaining -= 1; has_update = true; break; } } + // 2. get update sequence ------------------------------------------------------ + if !(has_update) { break; } } - update_list + predictions + } +pub fn get_updates_sequence(network: &Network) -> Vec<(usize, FnType)> { + + let mut updates : Vec<(usize, FnType)> = Vec::new(); + + // 1. get update sequence ---------------------------------------------------------- + + // list all nodes availables in the network + let mut pe_nodes_idxs: Vec = network.nodes.keys().cloned().collect(); + let mut po_nodes_idxs: Vec = network.nodes.keys().cloned().collect(); + + // remove the input nodes from the to-be-visited nodes for posterior updates + po_nodes_idxs.retain(|x| !network.inputs.contains(x)); + + // iterate over all nodes and add the prediction step if all criteria are met + let mut n_remaining = 2 * pe_nodes_idxs.len(); // posterior updates + prediction errors + + while n_remaining > 0 { + + // were we able to add an update step in the list on that iteration? + let mut has_update = false; + + // loop over all the remaining nodes for prediction errors --------------------- + for i in 0..pe_nodes_idxs.len() { + + let idx = pe_nodes_idxs[i]; + + // to send a prediction error, this node should have been updated first + if !(po_nodes_idxs.contains(&idx)) { + + // add the node in the update list + match network.nodes.get(&idx) { + Some(Node::Continuous(_)) => { + updates.push((idx, prediction_error_continuous_state_node)); + } + Some(Node::Exponential(_)) => (), + None => () + + } + + // remove the node from the to-be-updated list + pe_nodes_idxs.retain(|&x| x != idx); + n_remaining -= 1; + has_update = true; + break; + } + } + + // loop over all the remaining nodes for posterior updates --------------------- + for i in 0..po_nodes_idxs.len() { + + let idx = po_nodes_idxs[i]; + + // to start a posterior update, all children should have sent prediction errors + // 1. get a list of all children + let value_children_idxs = &network.edges[idx].value_children; + let volatility_children_idxs = &network.edges[idx].volatility_children; + + let children_idxs = match (value_children_idxs, volatility_children_idxs) { + // If both are Some, merge the vectors + (Some(ref vec1), Some(ref vec2)) => { + // Create a new vector by merging the two + let merged_vec: Vec = vec1.iter().chain(vec2.iter()).cloned().collect(); + Some(merged_vec) // Return the merged vector wrapped in Some + } + // If one is Some and the other is None, return the one that's Some + (Some(vec), None) | (None, Some(vec)) => Some(vec.clone()), + // If both are None, return None + (None, None) => None, + }; + + // 2. check if any of the children is still on the to-be-visited list for prediction errors + // check if there is any parent node that is still found in the to-be-updated list + let missing_pe = match children_idxs { + Some(vec) => vec.iter().any(|item| pe_nodes_idxs.contains(item)), + None => false + }; + + // 3. if false, add the posterior update to the list + if !(missing_pe) { + + // add the node in the update list + match network.nodes.get(&idx) { + Some(Node::Continuous(_)) => { + updates.push((idx, posterior_update_continuous_state_node)); + } + Some(Node::Exponential(_)) => { + updates.push((idx, posterior_update_exponential_state_node)); + } + None => () + + } + + // remove the node from the to-be-updated list + po_nodes_idxs.retain(|&x| x != idx); + n_remaining -= 1; + has_update = true; + break; + } + } + // 2. get update sequence ---------------------------------------------------------- + + if !(has_update) { + break; + } + } + updates + +} // Tests module for unit tests #[cfg(test)] // Only compile and include this module when running tests @@ -61,15 +208,15 @@ mod tests { // create a network network.add_nodes( "continuous-state", - Some(1), + Some(vec![1]), None, None, - Some(2), + Some(vec![2]), ); network.add_nodes( "continuous-state", None, - Some(0), + Some(vec![0]), None, None, ); @@ -77,7 +224,7 @@ mod tests { "continuous-state", None, None, - Some(0), + Some(vec![0]), None, ); network.add_nodes( @@ -89,7 +236,7 @@ mod tests { ); println!("Network: {:?}", network); - println!("Update order: {:?}", get_update_order(network)); + println!("Update order: {:?}", get_update_sequence(&network)); } }