Skip to content

Commit

Permalink
add a proper get_update_sequence function
Browse files Browse the repository at this point in the history
  • Loading branch information
LegrandNico committed Oct 24, 2024
1 parent 198245f commit db5910c
Show file tree
Hide file tree
Showing 13 changed files with 308 additions and 111 deletions.
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
pub mod network;
pub mod model;
pub mod utils;
pub mod math;
pub mod updates;
125 changes: 50 additions & 75 deletions src/network.rs → src/model.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
use std::collections::HashMap;
use crate::updates::posterior;
use crate::updates::observations::observation_update;
use crate::utils::get_update_sequence;
use pyo3::{prelude::*, types::{PyList, PyDict}};

#[derive(Debug)]
#[pyclass]
pub struct AdjacencyLists{
#[pyo3(get, set)]
pub value_parents: Option<usize>,
pub value_parents: Option<Vec<usize>>,
#[pyo3(get, set)]
pub value_children: Option<usize>,
pub value_children: Option<Vec<usize>>,
#[pyo3(get, set)]
pub volatility_parents: Option<usize>,
pub volatility_parents: Option<Vec<usize>>,
#[pyo3(get, set)]
pub volatility_children: Option<usize>,
pub volatility_children: Option<Vec<usize>>,
}
#[derive(Debug, Clone)]
pub struct ContinuousStateNode{
Expand All @@ -38,12 +39,22 @@ pub enum Node {
Exponential(ExponentialFamiliyStateNode),
}

// Create a default signature for update functions
pub type FnType = fn(&mut Network, usize);

#[derive(Debug)]
pub struct UpdateSequence {
pub predictions: Vec<(usize, FnType)>,
pub updates: Vec<(usize, FnType)>,
}

#[derive(Debug)]
#[pyclass]
pub struct Network{
pub nodes: HashMap<usize, Node>,
pub edges: Vec<AdjacencyLists>,
pub inputs: Vec<usize>,
pub update_sequence: UpdateSequence,
}

#[pymethods]
Expand All @@ -56,6 +67,7 @@ impl Network {
nodes: HashMap::new(),
edges: Vec::new(),
inputs: Vec::new(),
update_sequence: UpdateSequence {predictions: Vec::new(), updates: Vec::new()}
}
}

Expand All @@ -68,12 +80,17 @@ impl Network {
/// * `volatility_children` - The indexes of the node's volatility children.
/// * `volatility_parents` - The indexes of the node's volatility parents.
#[pyo3(signature = (kind="continuous-state", value_parents=None, value_children=None, volatility_children=None, volatility_parents=None))]
pub fn add_nodes(&mut self, kind: &str, value_parents: Option<usize>,
value_children: Option<usize>, volatility_children: Option<usize>,
volatility_parents: Option<usize>) {
pub fn add_nodes(&mut self, kind: &str, value_parents: Option<Vec<usize>>,
value_children: Option<Vec<usize>>, volatility_children: Option<Vec<usize>>,
volatility_parents: Option<Vec<usize>>) {

// the node ID is equal to the number of nodes already in the network
let node_id: usize = self.nodes.len();

// if this node has no children, this is an input node
if (value_children == None) & (volatility_children == None) {
self.inputs.push(node_id);
}

let edges = AdjacencyLists{
value_parents: value_children,
Expand All @@ -92,89 +109,48 @@ impl Network {
self.nodes.insert(node_id, node);
self.edges.push(edges);

// if this node has no children, this is an input node
if (value_children == None) & (volatility_children == None) {
self.inputs.push(node_id);
}
} else if kind == "exponential-state" {
let exponential_node: ExponentialFamiliyStateNode = ExponentialFamiliyStateNode{
mean: 0.0, expected_mean: 0.0, nus: 0.0, xis: [0.0, 0.0]
};
let node = Node::Exponential(exponential_node);
self.nodes.insert(node_id, node);
self.edges.push(edges);

// if this node has no children, this is an input node
if (value_children == None) & (volatility_children == None) {
self.inputs.push(node_id);
}

} else {
println!("Invalid type of node provided ({}).", kind);
}
}

pub fn prediction_error(&mut self, node_idx: usize) {

// get the observation value
let mean;
match self.nodes[&node_idx] {
Node::Continuous(ref node) => {
mean = node.mean;
}
Node::Exponential(ref node) => {
mean = node.mean;
}
}

let value_parent_idx = &self.edges[node_idx].value_parents;
match value_parent_idx {
Some(idx) => {
match self.nodes.get_mut(idx) {
Some(Node::Continuous(ref mut parent)) => {
parent.mean = mean
}
Some(Node::Exponential(ref mut parent)) => {
parent.mean = mean
}
None => println!("No prediction error for this type of node."),
}
}
None => println!("No value parent"),
}
}

pub fn posterior_update(&mut self, node_idx: usize, observation: f64) {

match self.nodes.get_mut(&node_idx) {
Some(Node::Continuous(ref mut node)) => {
node.mean = observation
}
Some(Node::Exponential(ref mut node)) => {
posterior::exponential::posterior_update_exponential(node)
}
None => println!("No posterior update for this type of node.")
}
pub fn get_update_sequence(&mut self) {
self.update_sequence = get_update_sequence(self);
}

/// One time step belief propagation.
/// Single time slice belief propagation.
///
/// # Arguments
/// * `observations` - A vector of values, each value is one new observation associated
/// with one node.
pub fn belief_propagation(&mut self, observations: Vec<f64>) {
pub fn belief_propagation(&mut self, observations_set: Vec<f64>) {

// 1. prediction propagation
let predictions = self.update_sequence.predictions.clone();
let updates = self.update_sequence.updates.clone();

for i in 0..observations.len() {

let input_node_idx = self.inputs[i];
// 2. inject the observations into the input nodes
self.posterior_update(input_node_idx, observations[i]);
// 3. posterior update - prediction errors propagation
self.prediction_error(input_node_idx);
// 1. prediction steps
for (idx, step) in predictions.iter() {
step(self, *idx);
}

// 2. observation steps
for (i, observations) in observations_set.iter().enumerate() {
let idx = self.inputs[i];
observation_update(self, idx, *observations);
}

// 3. update steps
for (idx, step) in updates.iter() {
step(self, *idx);
}

}

/// Add a sequence of observations.
Expand Down Expand Up @@ -203,17 +179,16 @@ impl Network {
for s in &self.edges {
// Create a new Python dictionary for each MyStruct
let py_dict = PyDict::new(py);
py_dict.set_item("value_parents", s.value_parents)?;
py_dict.set_item("value_children", s.value_children)?;
py_dict.set_item("volatility_parents", s.volatility_parents)?;
py_dict.set_item("volatility_children", s.volatility_children)?;
py_dict.set_item("value_parents", &s.value_parents)?;
py_dict.set_item("value_children", &s.value_children)?;
py_dict.set_item("volatility_parents", &s.volatility_parents)?;
py_dict.set_item("volatility_children", &s.volatility_children)?;

// Add the dictionary to the list
py_list.append(py_dict)?;
}
Ok(py_list)
}

}

// Create a module to expose the class to Python
Expand Down
3 changes: 2 additions & 1 deletion src/updates/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pub mod posterior;
pub mod prediction;
pub mod prediction_error;
pub mod prediction_error;
pub mod observations;
21 changes: 21 additions & 0 deletions src/updates/observations.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
use crate::model::{Network, Node};


/// Inject new observations into an input node
///
/// # Arguments
/// * `network` - The main network containing the node.
/// * `node_idx` - The input node index.
/// * `observations` - The new observations.
///
/// # Returns
/// * `network` - The network after message passing.
pub fn observation_update(network: &mut Network, node_idx: usize, observations: f64) {

match network.nodes.get_mut(&node_idx) {
Some(Node::Exponential(ref mut node)) => {
node.mean = observations;
}
_ => (),
}
}
13 changes: 13 additions & 0 deletions src/updates/posterior/continuous.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
use crate::model::Network;

/// Posterior update from a continuous state node
///
/// # Arguments
/// * `network` - The main network containing the node.
/// * `node_idx` - The node index.
///
/// # Returns
/// * `network` - The network after message passing.
pub fn posterior_update_continuous_state_node(network: &mut Network, node_idx: usize) {
let a = 1;
}
24 changes: 19 additions & 5 deletions src/updates/posterior/exponential.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,23 @@
use crate::network::ExponentialFamiliyStateNode;
use crate::model::{Network, Node};
use crate::math::sufficient_statistics;

pub fn posterior_update_exponential(node: &mut ExponentialFamiliyStateNode) {
let suf_stats = sufficient_statistics(&node.mean);
for i in 0..suf_stats.len() {
node.xis[i] = node.xis[i] + (1.0 / (1.0 + node.nus)) * (suf_stats[i] - node.xis[i]);
/// Updating an exponential family state node
///
/// # Arguments
/// * `network` - The main network containing the node.
/// * `node_idx` - The node index.
///
/// # Returns
/// * `network` - The network after message passing.
pub fn posterior_update_exponential_state_node(network: &mut Network, node_idx: usize) {

match network.nodes.get_mut(&node_idx) {
Some(Node::Exponential(ref mut node)) => {
let suf_stats = sufficient_statistics(&node.mean);
for i in 0..suf_stats.len() {
node.xis[i] = node.xis[i] + (1.0 / (1.0 + node.nus)) * (suf_stats[i] - node.xis[i]);
}
}
_ => (),
}
}
13 changes: 13 additions & 0 deletions src/updates/prediction/continuous.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
use crate::model::Network;

/// Prediction from a continuous state node
///
/// # Arguments
/// * `network` - The main network containing the node.
/// * `node_idx` - The node index.
///
/// # Returns
/// * `network` - The network after message passing.
pub fn prediction_continuous_state_node(network: &mut Network, node_idx: usize) {
let a = 1;
}
1 change: 1 addition & 0 deletions src/updates/prediction/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pub mod continuous;
13 changes: 13 additions & 0 deletions src/updates/prediction_error/continuous.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
use crate::model::Network;

/// Prediction error from a continuous state node
///
/// # Arguments
/// * `network` - The main network containing the node.
/// * `node_idx` - The node index.
///
/// # Returns
/// * `network` - The network after message passing.
pub fn prediction_error_continuous_state_node(network: &mut Network, node_idx: usize) {
let a = 1;
}
2 changes: 1 addition & 1 deletion src/updates/prediction_error/mod.rs
Original file line number Diff line number Diff line change
@@ -1 +1 @@
pub mod nodes;
pub mod continuous;
Empty file.
1 change: 0 additions & 1 deletion src/updates/prediction_error/nodes/mod.rs

This file was deleted.

Loading

0 comments on commit db5910c

Please sign in to comment.