Skip to content

Commit

Permalink
improve generalised filtering nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
LegrandNico committed Nov 21, 2024
1 parent eb05544 commit 1c84178
Show file tree
Hide file tree
Showing 10 changed files with 46 additions and 13 deletions.
4 changes: 2 additions & 2 deletions docs/source/notebooks/0.3-Generalised_filtering.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": null,
"id": "1798765e-3d65-4bfd-964b-7f9b6b0902be",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -358,7 +358,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": null,
"id": "2d921e51-a940-42b2-88f2-e25bd7ab5a01",
"metadata": {
"editable": true,
Expand Down
2 changes: 1 addition & 1 deletion examples/exponential.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ fn main() {

// create a network with two exponential family state nodes
network.add_nodes(
"exponential-state",
"ef-state",
None,
None,
None,
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
pub mod model;
pub mod utils;
pub mod math;
pub mod maths;
pub mod updates;
3 changes: 0 additions & 3 deletions src/math.rs

This file was deleted.

1 change: 1 addition & 0 deletions src/maths/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pub mod sufficient_statistics;
15 changes: 15 additions & 0 deletions src/maths/sufficient_statistics.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
pub fn normal(x: &f64) -> Vec<f64> {
vec![*x, x.powf(2.0)]
}

pub fn multivariate_normal(x: &Vec<f64>) -> Vec<f64> {
vec![*x, x.powf(2.0)]
}

pub fn get_sufficient_statistics_fn(distribution: String) {
if distribution == "normal" {
normal;
} else if distribution == "multivariate_normal" {
multivariate_normal;
}
}
25 changes: 22 additions & 3 deletions src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,21 @@ impl Network {
/// * `value_children` - The indexes of the node's value children.
/// * `volatility_children` - The indexes of the node's volatility children.
/// * `volatility_parents` - The indexes of the node's volatility parents.
#[pyo3(signature = (kind="continuous-state", value_parents=None, value_children=None, volatility_parents=None, volatility_children=None,))]
pub fn add_nodes(&mut self, kind: &str, value_parents: Option<Vec<usize>>,
#[pyo3(signature = (
kind="continuous-state",
value_parents=None,
value_children=None,
volatility_parents=None,
volatility_children=None,
ef_dimension=None,
ef_distribution=None,
ef_learning=None,
)
)]
pub fn add_nodes(
&mut self,
kind: &str,
value_parents: Option<Vec<usize>>,
value_children: Option<Vec<usize>>,
volatility_parents: Option<Vec<usize>>, volatility_children: Option<Vec<usize>>, ) {

Expand All @@ -86,6 +99,7 @@ impl Network {
self.inputs.push(node_id);
}

// Update the edges variable
let edges = AdjacencyLists{
node_type: String::from(kind),
value_parents: value_parents,
Expand All @@ -94,6 +108,11 @@ impl Network {
volatility_children: volatility_children,
};

// Add emtpy adjacency lists in the new node
self.edges.insert(node_id, edges);

// TODO: Update the edges of parents and children accordingly

// add edges and attributes
if kind == "continuous-state" {

Expand All @@ -107,7 +126,6 @@ impl Network {
(String::from("autoconnection_strength"), 1.0)].into_iter().collect();

self.attributes.floats.insert(node_id, attributes);
self.edges.insert(node_id, edges);

} else if kind == "ef-state" {

Expand All @@ -123,6 +141,7 @@ impl Network {

}
}
}

pub fn set_update_sequence(&mut self) {
self.update_sequence = set_update_sequence(self);
Expand Down
2 changes: 1 addition & 1 deletion src/updates/prediction/continuous.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::model::Network;
/// Prediction from a continuous state node
///
/// # Arguments
/// * `network` - The main network containing the node.
/// * `network` - The main network structure.
/// * `node_idx` - The node index.
///
/// # Returns
Expand Down
4 changes: 2 additions & 2 deletions src/updates/prediction_error/exponential.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
use crate::model::Network;
use crate::math::sufficient_statistics;

/// Updating an exponential family state node
///
/// # Arguments
/// * `network` - The main network containing the node.
/// * `node_idx` - The node index.
/// * `sufficient_statistics` - A function computing the sufficient statistics of an exponential family distribution.
///
/// # Returns
/// * `network` - The network after message passing.
pub fn prediction_error_exponential_state_node(network: &mut Network, node_idx: usize) {
pub fn prediction_error_exponential_state_node(network: &mut Network, node_idx: usize, sufficient_statistics: fn(&f64) -> Vec<f64>) {

let floats_attributes = network.attributes.floats.get_mut(&node_idx).expect("No floats attributes");
let vectors_attributes = network.attributes.vectors.get_mut(&node_idx).expect("No vector attributes");
Expand Down
1 change: 1 addition & 0 deletions src/utils/set_sequence.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::{model::{AdjacencyLists, Network, UpdateSequence}, updates::{posterior::continuous::posterior_update_continuous_state_node, prediction::continuous::prediction_continuous_state_node, prediction_error::{continuous::prediction_error_continuous_state_node, exponential::prediction_error_exponential_state_node}}};
use crate::utils::function_pointer::FnType;
use crate::maths::sufficient_statistics::get_sufficient_statistics_fn;

pub fn set_update_sequence(network: &Network) -> UpdateSequence {
let predictions = get_predictions_sequence(network);
Expand Down

0 comments on commit 1c84178

Please sign in to comment.