Skip to content

Commit

Permalink
perf
Browse files Browse the repository at this point in the history
  • Loading branch information
LegrandNico committed Nov 11, 2024
1 parent 78c5f7f commit d98ffe8
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 30 deletions.
34 changes: 5 additions & 29 deletions src/model.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::collections::HashMap;
use crate::{updates::observations::observation_update, utils::function_pointer::FnType};
use crate::utils::function_pointer::FnType;
use crate::utils::set_sequence::set_update_sequence;
use crate::utils::beliefs_propagation::belief_propagation;
use crate::utils::function_pointer::get_func_map;
use pyo3::types::PyTuple;
use pyo3::{prelude::*, types::{PyList, PyDict}};
Expand Down Expand Up @@ -129,33 +130,6 @@ impl Network {
self.update_sequence = set_update_sequence(self);
}

/// Single time slice belief propagation.
///
/// # Arguments
/// * `observations` - A vector of values, each value is one new observation associated
/// with one node.
pub fn belief_propagation(&mut self, observations_set: Vec<f64>) {

let predictions = self.update_sequence.predictions.clone();
let updates = self.update_sequence.updates.clone();

// 1. prediction steps
for (idx, step) in predictions.iter() {
step(self, *idx);
}

// 2. observation steps
for (i, observations) in observations_set.iter().enumerate() {
let idx = self.inputs[i];
observation_update(self, idx, *observations);
}

// 3. update steps
for (idx, step) in updates.iter() {
step(self, *idx);
}
}

/// Add a sequence of observations.
///
/// # Arguments
Expand All @@ -164,6 +138,8 @@ impl Network {
pub fn input_data(&mut self, input_data: Vec<f64>) {

let n_time = input_data.len();
let predictions = self.update_sequence.predictions.clone();
let updates = self.update_sequence.updates.clone();

// initialize the belief trajectories result struture
let mut node_trajectories = NodeTrajectories {floats: HashMap::new(), vectors: HashMap::new()};
Expand Down Expand Up @@ -193,7 +169,7 @@ impl Network {
for observation in input_data {

// 1. belief propagation for one time slice
self.belief_propagation(vec![observation]);
belief_propagation(self, vec![observation], &predictions, &updates);

// 2. append the new beliefs in the trajectories structure
// iterate over the float hashmap
Expand Down
25 changes: 25 additions & 0 deletions src/utils/beliefs_propagation.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
use crate::{utils::function_pointer::FnType, model::Network, updates::observations::observation_update};

/// Single time slice belief propagation.
///
/// # Arguments
/// * `observations` - A vector of values, each value is one new observation associated
/// with one node.
pub fn belief_propagation(network: &mut Network, observations_set: Vec<f64>, predictions: & Vec<(usize, FnType)>, updates: & Vec<(usize, FnType)>) {

// 1. prediction steps
for (idx, step) in predictions.iter() {
step(network, *idx);
}

// 2. observation steps
for (i, observations) in observations_set.iter().enumerate() {
let idx = network.inputs[i];
observation_update(network, idx, *observations);
}

// 3. update steps
for (idx, step) in updates.iter() {
step(network, *idx);
}
}
3 changes: 2 additions & 1 deletion src/utils/mod.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
pub mod set_sequence;
pub mod function_pointer;
pub mod function_pointer;
pub mod beliefs_propagation;

0 comments on commit d98ffe8

Please sign in to comment.