Skip to content

Commit

Permalink
Use Hashmaps, vector preallocation and avoid cloning [Rust] (#250)
Browse files Browse the repository at this point in the history
* get hashmap without match statements

* preallocate vectors

* avoid cloning
  • Loading branch information
LegrandNico authored Nov 11, 2024
1 parent e32e331 commit c4ab757
Show file tree
Hide file tree
Showing 7 changed files with 87 additions and 76 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[lib]
name = "rshgf"
crate-type = ["cdylib"]
crate-type = ["cdylib", "rlib"]
path = "src/lib.rs" # The source file of the target.

[dependencies]
Expand Down
22 changes: 22 additions & 0 deletions examples/exponential.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
use rshgf::model::Network;

fn main() {

// initialize network
let mut network = Network::new();

// create a network with two exponential family state nodes
network.add_nodes(
"exponential-state",
None,
None,
None,
None
);

// belief propagation
let input_data = vec![1.0, 1.3, 1.5, 1.7];
network.set_update_sequence();
network.input_data(input_data);

}
78 changes: 26 additions & 52 deletions src/model.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::collections::HashMap;
use crate::{updates::observations::observation_update, utils::function_pointer::FnType};
use crate::utils::function_pointer::FnType;
use crate::utils::set_sequence::set_update_sequence;
use crate::utils::beliefs_propagation::belief_propagation;
use crate::utils::function_pointer::get_func_map;
use pyo3::types::PyTuple;
use pyo3::{prelude::*, types::{PyList, PyDict}};
Expand Down Expand Up @@ -129,95 +130,68 @@ impl Network {
self.update_sequence = set_update_sequence(self);
}

/// Single time slice belief propagation.
///
/// # Arguments
/// * `observations` - A vector of values, each value is one new observation associated
/// with one node.
pub fn belief_propagation(&mut self, observations_set: Vec<f64>) {

let predictions = self.update_sequence.predictions.clone();
let updates = self.update_sequence.updates.clone();

// 1. prediction steps
for (idx, step) in predictions.iter() {
step(self, *idx);
}

// 2. observation steps
for (i, observations) in observations_set.iter().enumerate() {
let idx = self.inputs[i];
observation_update(self, idx, *observations);
}

// 3. update steps
for (idx, step) in updates.iter() {
step(self, *idx);
}
}

/// Add a sequence of observations.
///
/// # Arguments
/// * `input_data` - A vector of vectors. Each vector is a time series of observations
/// associated with one node.
pub fn input_data(&mut self, input_data: Vec<f64>) {

let n_time = input_data.len();
let predictions = self.update_sequence.predictions.clone();
let updates = self.update_sequence.updates.clone();

// initialize the belief trajectories result struture
let mut node_trajectories = NodeTrajectories {floats: HashMap::new(), vectors: HashMap::new()};

// add empty vectors in the floats hashmap
// preallocate empty vectors in the floats hashmap
for (node_idx, node) in &self.attributes.floats {
let new_map: HashMap<String, Vec<f64>> = HashMap::new();
node_trajectories.floats.insert(*node_idx, new_map);
if let Some(attr) = node_trajectories.floats.get_mut(node_idx) {
for key in node.keys() {
attr.insert(key.clone(), Vec::new());
}
let attr = node_trajectories.floats.get_mut(node_idx).expect("New map not found.");
for key in node.keys() {
attr.insert(key.clone(), Vec::with_capacity(n_time));
}
}
// add empty vectors in the vectors hashmap
}

// preallocate empty vectors in the vectors hashmap
for (node_idx, node) in &self.attributes.vectors {
let new_map: HashMap<String, Vec<Vec<f64>>> = HashMap::new();
node_trajectories.vectors.insert(*node_idx, new_map);
if let Some(attr) = node_trajectories.vectors.get_mut(node_idx) {
for key in node.keys() {
attr.insert(key.clone(), Vec::new());
}
let attr = node_trajectories.vectors.get_mut(node_idx).expect("New vector map not found.");
for key in node.keys() {
attr.insert(key.clone(), Vec::with_capacity(n_time));
}
}
}


// iterate over the observations
for observation in input_data {

// 1. belief propagation for one time slice
self.belief_propagation(vec![observation]);
belief_propagation(self, vec![observation], &predictions, &updates);

// 2. append the new beliefs in the trajectories structure
// iterate over the float hashmap
for (new_node_idx, new_node) in &self.attributes.floats {
for (new_key, new_value) in new_node {
// If the key exists in map1, append the vector from map2
if let Some(old_node) = node_trajectories.floats.get_mut(&new_node_idx) {
if let Some(old_value) = old_node.get_mut(new_key) {
old_value.push(*new_value);
}
let old_node = node_trajectories.floats.get_mut(&new_node_idx).expect("Old node not found.");
let old_value = old_node.get_mut(new_key).expect("Old value not found");
old_value.push(*new_value);
}
}
}

// iterate over the vector hashmap
for (new_node_idx, new_node) in &self.attributes.vectors {
for (new_key, new_value) in new_node {
// If the key exists in map1, append the vector from map2
if let Some(old_node) = node_trajectories.vectors.get_mut(&new_node_idx) {
if let Some(old_value) = old_node.get_mut(new_key) {
old_value.push(new_value.clone());
}
let old_node = node_trajectories.vectors.get_mut(&new_node_idx).expect("Old vector node not found.");
let old_value = old_node.get_mut(new_key).expect("Old vector value not found.");
old_value.push(new_value.clone());
}
}
}
}

self.node_trajectories = node_trajectories;
}

Expand Down
Empty file removed src/tests/exponential_family.rs
Empty file.
33 changes: 11 additions & 22 deletions src/updates/prediction_error/exponential.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,14 @@ use crate::math::sufficient_statistics;
/// * `network` - The network after message passing.
pub fn prediction_error_exponential_state_node(network: &mut Network, node_idx: usize) {

if let Some(floats_attributes) = network.attributes.floats.get_mut(&node_idx) {
if let Some(vectors_attributes) = network.attributes.vectors.get_mut(&node_idx) {
let mean = floats_attributes.get("mean");
let nus = floats_attributes.get("nus");
let xis = vectors_attributes.get("xis");
let new_xis = match (mean, nus, xis) {
(Some(mean), Some(nus), Some(xis)) => {
let suf_stats = sufficient_statistics(mean);
let mut new_xis = xis.clone();
for i in 0..suf_stats.len() {
new_xis[i] = new_xis[i] + (1.0 / (1.0 + nus)) * (suf_stats[i] - xis[i]);
}
new_xis
}
_ => Vec::new(),
};
if let Some(xis) = vectors_attributes.get_mut("xis") {
*xis = new_xis; // Modify the value directly
}
}
}
}
let floats_attributes = network.attributes.floats.get_mut(&node_idx).expect("No floats attributes");
let vectors_attributes = network.attributes.vectors.get_mut(&node_idx).expect("No vector attributes");
let mean = floats_attributes.get("mean").expect("Mean not found");
let nus = floats_attributes.get("nus").expect("Nus not found");
let xis = vectors_attributes.get_mut("xis").expect("Xis not found");

let suf_stats = sufficient_statistics(mean);
for i in 0..suf_stats.len() {
xis[i] = xis[i] + (1.0 / (1.0 + nus)) * (suf_stats[i] - xis[i]);
}
}
25 changes: 25 additions & 0 deletions src/utils/beliefs_propagation.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
use crate::{utils::function_pointer::FnType, model::Network, updates::observations::observation_update};

/// Single time slice belief propagation.
///
/// # Arguments
/// * `observations` - A vector of values, each value is one new observation associated
/// with one node.
pub fn belief_propagation(network: &mut Network, observations_set: Vec<f64>, predictions: & Vec<(usize, FnType)>, updates: & Vec<(usize, FnType)>) {

// 1. prediction steps
for (idx, step) in predictions.iter() {
step(network, *idx);
}

// 2. observation steps
for (i, observations) in observations_set.iter().enumerate() {
let idx = network.inputs[i];
observation_update(network, idx, *observations);
}

// 3. update steps
for (idx, step) in updates.iter() {
step(network, *idx);
}
}
3 changes: 2 additions & 1 deletion src/utils/mod.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
pub mod set_sequence;
pub mod function_pointer;
pub mod function_pointer;
pub mod beliefs_propagation;

0 comments on commit c4ab757

Please sign in to comment.