diff --git a/src/model.rs b/src/model.rs index a2ad8a3cb..1f1d4f7f8 100644 --- a/src/model.rs +++ b/src/model.rs @@ -1,6 +1,7 @@ use std::collections::HashMap; -use crate::{updates::observations::observation_update, utils::function_pointer::FnType}; +use crate::utils::function_pointer::FnType; use crate::utils::set_sequence::set_update_sequence; +use crate::utils::beliefs_propagation::belief_propagation; use crate::utils::function_pointer::get_func_map; use pyo3::types::PyTuple; use pyo3::{prelude::*, types::{PyList, PyDict}}; @@ -129,33 +130,6 @@ impl Network { self.update_sequence = set_update_sequence(self); } - /// Single time slice belief propagation. - /// - /// # Arguments - /// * `observations` - A vector of values, each value is one new observation associated - /// with one node. - pub fn belief_propagation(&mut self, observations_set: Vec) { - - let predictions = self.update_sequence.predictions.clone(); - let updates = self.update_sequence.updates.clone(); - - // 1. prediction steps - for (idx, step) in predictions.iter() { - step(self, *idx); - } - - // 2. observation steps - for (i, observations) in observations_set.iter().enumerate() { - let idx = self.inputs[i]; - observation_update(self, idx, *observations); - } - - // 3. update steps - for (idx, step) in updates.iter() { - step(self, *idx); - } - } - /// Add a sequence of observations. /// /// # Arguments @@ -164,6 +138,8 @@ impl Network { pub fn input_data(&mut self, input_data: Vec) { let n_time = input_data.len(); + let predictions = self.update_sequence.predictions.clone(); + let updates = self.update_sequence.updates.clone(); // initialize the belief trajectories result struture let mut node_trajectories = NodeTrajectories {floats: HashMap::new(), vectors: HashMap::new()}; @@ -193,7 +169,7 @@ impl Network { for observation in input_data { // 1. belief propagation for one time slice - self.belief_propagation(vec![observation]); + belief_propagation(self, vec![observation], &predictions, &updates); // 2. append the new beliefs in the trajectories structure // iterate over the float hashmap diff --git a/src/utils/beliefs_propagation.rs b/src/utils/beliefs_propagation.rs new file mode 100644 index 000000000..789887643 --- /dev/null +++ b/src/utils/beliefs_propagation.rs @@ -0,0 +1,25 @@ +use crate::{utils::function_pointer::FnType, model::Network, updates::observations::observation_update}; + +/// Single time slice belief propagation. +/// +/// # Arguments +/// * `observations` - A vector of values, each value is one new observation associated +/// with one node. +pub fn belief_propagation(network: &mut Network, observations_set: Vec, predictions: & Vec<(usize, FnType)>, updates: & Vec<(usize, FnType)>) { + + // 1. prediction steps + for (idx, step) in predictions.iter() { + step(network, *idx); + } + + // 2. observation steps + for (i, observations) in observations_set.iter().enumerate() { + let idx = network.inputs[i]; + observation_update(network, idx, *observations); + } + + // 3. update steps + for (idx, step) in updates.iter() { + step(network, *idx); + } +} diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 1918ac547..6c65a8867 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -1,2 +1,3 @@ pub mod set_sequence; -pub mod function_pointer; \ No newline at end of file +pub mod function_pointer; +pub mod beliefs_propagation; \ No newline at end of file