Skip to content

Commit

Permalink
improve generalised filtering nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
LegrandNico committed Nov 13, 2024
1 parent c4ab757 commit cca3663
Show file tree
Hide file tree
Showing 15 changed files with 157 additions and 479 deletions.
471 changes: 40 additions & 431 deletions docs/source/notebooks/0.3-Generalised_filtering.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion examples/exponential.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ fn main() {

// create a network with two exponential family state nodes
network.add_nodes(
"exponential-state",
"ef-state",
None,
None,
None,
Expand Down
40 changes: 28 additions & 12 deletions pyhgf/model/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from jax.tree_util import Partial
from jax.typing import ArrayLike

from pyhgf.math import MultivariateNormal, Normal
from pyhgf.plots import plot_correlations, plot_network, plot_nodes, plot_trajectories
from pyhgf.typing import (
AdjacencyLists,
Expand Down Expand Up @@ -63,6 +64,7 @@ def __init__(self) -> None:
self.attributes: Attributes = {-1: {"time_step": 0.0}}
self.update_sequence: Optional[UpdateSequence] = None
self.scan_fn: Optional[Callable] = None
self.additional_parameters: Dict = {}

@property
def input_idxs(self):
Expand Down Expand Up @@ -329,7 +331,7 @@ def add_nodes(
be a regular state node that can have value and/or volatility
parents/children. If `"binary-state"`, the node should be the
value parent of a binary input. State nodes filtering distribution from the
exponential family can be created using `"exponential-state"`.
exponential family can be created using `"ef-state"`.
.. note::
When using a categorical state node, the `binary_parameters` can be used to
Expand Down Expand Up @@ -381,7 +383,7 @@ def add_nodes(
"""
if kind not in [
"DP-state",
"exponential-state",
"ef-state",
"categorical-state",
"continuous-state",
"binary-state",
Expand Down Expand Up @@ -463,13 +465,11 @@ def add_nodes(
"value_prediction_error": 0.0,
},
}
elif kind == "generic-state":
default_parameters = {
"mean": 0.0,
"observed": 1,
}
elif "exponential-state" in kind:
elif "ef-state" in kind:
default_parameters = {
"dimension": 1,
"distribution": "normal",
"learning": "filtering",
"nus": 3.0,
"xis": jnp.array([0.0, 1.0]),
"mean": 0.0,
Expand Down Expand Up @@ -556,14 +556,23 @@ def add_nodes(
node_parameters = default_parameters

# define the type of node that is created
if kind == "generic-state":
node_type = 0
elif kind == "binary-state":
if kind == "binary-state":
node_type = 1
elif kind == "continuous-state":
node_type = 2
elif kind == "exponential-state":
elif kind == "ef-state":
node_type = 3

# create the update function and drop side parameters
if node_parameters["distribution"] == "normal":
sufficient_stats_fn = Normal().sufficient_statistics
elif node_parameters["distribution"] == "multivariate-normal":
sufficient_stats_fn = MultivariateNormal().sufficient_statistics

node_parameters.pop("dimension")
node_parameters.pop("distribution")
node_parameters.pop("learning")

elif kind == "DP-state":
node_type = 4
elif kind == "categorical-state":
Expand Down Expand Up @@ -597,6 +606,13 @@ def add_nodes(
# update the node structure
self.attributes[node_idx] = deepcopy(node_parameters)

# if we are creating an ef-state, add the sufficient statistics function
# in the side parameters
if kind == "ef-state":
self.additional_parameters[node_idx][
"sufficient_stats_fn"
] = sufficient_stats_fn

# Update the edges of the parents and children accordingly
# --------------------------------------------------------
if value_parents[0] is not None:
Expand Down
12 changes: 10 additions & 2 deletions pyhgf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from jax.tree_util import Partial
from jax.typing import ArrayLike

from pyhgf.math import Normal, binary_surprise, gaussian_surprise
from pyhgf.math import binary_surprise, gaussian_surprise
from pyhgf.typing import AdjacencyLists, Attributes, Edges, Sequence, UpdateSequence
from pyhgf.updates.observation import set_observation
from pyhgf.updates.posterior.categorical import categorical_state_update
Expand Down Expand Up @@ -402,11 +402,19 @@ def get_update_sequence(
# unless this is an exponential family state node
if len(all_parents) == 0:
if network.edges[idx].node_type == 3:

# retrieve the desired sufficient statistics function
# from the side parameter dictionary
sufficient_stats_fn = network.additional_parameters[idx][
"sufficient_stats_fn"
]
network.additional_parameters[idx].pop("sufficient_stats_fn")

# create the sufficient statistic function
# for the exponential family node
ef_update = Partial(
prediction_error_update_exponential_family,
sufficient_stats_fn=Normal().sufficient_statistics,
sufficient_stats_fn=sufficient_stats_fn,
)
update_fn = ef_update
no_update = False
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
pub mod model;
pub mod utils;
pub mod math;
pub mod maths;
pub mod updates;
3 changes: 0 additions & 3 deletions src/math.rs

This file was deleted.

1 change: 1 addition & 0 deletions src/maths/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pub mod sufficient_statistics;
15 changes: 15 additions & 0 deletions src/maths/sufficient_statistics.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
pub fn normal(x: &f64) -> Vec<f64> {
vec![*x, x.powf(2.0)]
}

pub fn multivariate_normal(x: &Vec<f64>) -> Vec<f64> {
vec![*x, x.powf(2.0)]
}

pub fn get_sufficient_statistics_fn(distribution: String) {
if distribution == "normal" {
normal;
} else if distribution == "multivariate_normal" {
multivariate_normal;
}
}
66 changes: 47 additions & 19 deletions src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,33 @@ impl Network {
/// * `value_children` - The indexes of the node's value children.
/// * `volatility_children` - The indexes of the node's volatility children.
/// * `volatility_parents` - The indexes of the node's volatility parents.
#[pyo3(signature = (kind="continuous-state", value_parents=None, value_children=None, volatility_parents=None, volatility_children=None,))]
pub fn add_nodes(&mut self, kind: &str, value_parents: Option<Vec<usize>>,
#[pyo3(signature = (
kind="continuous-state",
value_parents=None,
value_children=None,
volatility_parents=None,
volatility_children=None,
ef_dimension=None,
ef_distribution=None,
ef_learning=None,
)
)]
pub fn add_nodes(
&mut self,
kind: &str,
value_parents: Option<Vec<usize>>,
value_children: Option<Vec<usize>>,
volatility_parents: Option<Vec<usize>>, volatility_children: Option<Vec<usize>>, ) {
volatility_parents: Option<Vec<usize>>,
volatility_children: Option<Vec<usize>>,
ef_dimension: Option<usize>,
ef_distribution: Option<String>,
ef_learning: Option<String>,
) {

// set default values for optional parameters
let ef_dimension = ef_dimension.unwrap_or(1);
let ef_distribution = ef_distribution.unwrap_or(String::from("normal"));
let ef_learning = ef_learning.unwrap_or(String::from("filtering"));

// the node ID is equal to the number of nodes already in the network
let node_id: usize = self.edges.len();
Expand All @@ -86,6 +109,7 @@ impl Network {
self.inputs.push(node_id);
}

// Update the edges variable
let edges = AdjacencyLists{
node_type: String::from(kind),
value_parents: value_parents,
Expand All @@ -94,6 +118,11 @@ impl Network {
volatility_children: volatility_children,
};

// Add emtpy adjacency lists in the new node
self.edges.insert(node_id, edges);

// TODO: Update the edges of parents and children accordingly

// add edges and attributes
if kind == "continuous-state" {

Expand All @@ -107,24 +136,23 @@ impl Network {
(String::from("autoconnection_strength"), 1.0)].into_iter().collect();

self.attributes.floats.insert(node_id, attributes);
self.edges.insert(node_id, edges);

} else if kind == "exponential-state" {

let floats_attributes = [
(String::from("mean"), 0.0),
(String::from("nus"), 3.0)].into_iter().collect();
let vector_attributes = [
(String::from("xis"), vec![0.0, 1.0])].into_iter().collect();

self.attributes.floats.insert(node_id, floats_attributes);
self.attributes.vectors.insert(node_id, vector_attributes);
self.edges.insert(node_id, edges);
} else if kind == "ef-state" {

// Extract the dimensions of the distribution - Defaults to 0 if parsing fails
if ef_distribution == "normal" {
let floats_attributes = [
(String::from("mean"), 0.0),
(String::from("nus"), 3.0)].into_iter().collect();
self.attributes.floats.insert(node_id, floats_attributes);
let vector_attributes = [
(String::from("xis"), vec![0.0, 1.0])].into_iter().collect();
self.attributes.vectors.insert(node_id, vector_attributes);
} else if ef_distribution == "multivariate-normal" {

} else {
println!("Invalid type of node provided ({}).", kind);
}
}
}
}

pub fn set_update_sequence(&mut self) {
self.update_sequence = set_update_sequence(self);
Expand Down Expand Up @@ -299,7 +327,7 @@ mod tests {

// create a network with two exponential family state nodes
network.add_nodes(
"exponential-state",
"ef-state",
None,
None,
None,
Expand Down
2 changes: 1 addition & 1 deletion src/updates/prediction/continuous.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::model::Network;
/// Prediction from a continuous state node
///
/// # Arguments
/// * `network` - The main network containing the node.
/// * `network` - The main network structure.
/// * `node_idx` - The node index.
///
/// # Returns
Expand Down
4 changes: 2 additions & 2 deletions src/updates/prediction_error/exponential.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
use crate::model::Network;
use crate::math::sufficient_statistics;

/// Updating an exponential family state node
///
/// # Arguments
/// * `network` - The main network containing the node.
/// * `node_idx` - The node index.
/// * `sufficient_statistics` - A function computing the sufficient statistics of an exponential family distribution.
///
/// # Returns
/// * `network` - The network after message passing.
pub fn prediction_error_exponential_state_node(network: &mut Network, node_idx: usize) {
pub fn prediction_error_exponential_state_node(network: &mut Network, node_idx: usize, sufficient_statistics: fn(&f64) -> Vec<f64>) {

let floats_attributes = network.attributes.floats.get_mut(&node_idx).expect("No floats attributes");
let vectors_attributes = network.attributes.vectors.get_mut(&node_idx).expect("No vector attributes");
Expand Down
8 changes: 6 additions & 2 deletions src/utils/set_sequence.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::{model::{AdjacencyLists, Network, UpdateSequence}, updates::{posterior::continuous::posterior_update_continuous_state_node, prediction::continuous::prediction_continuous_state_node, prediction_error::{continuous::prediction_error_continuous_state_node, exponential::prediction_error_exponential_state_node}}};
use crate::utils::function_pointer::FnType;
use crate::maths::sufficient_statistics::get_sufficient_statistics_fn;

pub fn set_update_sequence(network: &Network) -> UpdateSequence {
let predictions = get_predictions_sequence(network);
Expand Down Expand Up @@ -136,7 +137,10 @@ pub fn get_updates_sequence(network: &Network) -> Vec<(usize, FnType)> {
has_update = true;
break;
}
(Some(AdjacencyLists {node_type, ..}), _) if node_type == "exponential-state" => {
(Some(AdjacencyLists {node_type, ..}), _) if node_type == "ef-state" => {

// create a closure that preallocate the sufficient statistics function
let update_fn = |network, node_idx| prediction_error_exponential_state_node(network, node_idx, get_sufficient_statistics_fn(distribution))
updates.push((idx, prediction_error_exponential_state_node));
// remove the node from the to-be-updated list
pe_nodes_idxs.retain(|&x| x != idx);
Expand Down Expand Up @@ -259,7 +263,7 @@ mod tests {
// initialize network
let mut exp_network = Network::new();
exp_network.add_nodes(
"exponential-state",
"ef-state",
None,
None,
None,
Expand Down
4 changes: 2 additions & 2 deletions tests/test_exponential_family.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@ def test_1d_gaussain():

# Rust -----------------------------------------------------------------------------
rs_network = RsNetwork()
rs_network.add_nodes(kind="exponential-state")
rs_network.add_nodes(kind="ef-state")
rs_network.set_update_sequence()
rs_network.input_data(timeseries)

# Python ---------------------------------------------------------------------------
py_network = PyNetwork().add_nodes(kind="exponential-state")
py_network = PyNetwork().add_nodes(kind="ef-state")
py_network.input_data(timeseries)

# Ensure identical results
Expand Down
2 changes: 1 addition & 1 deletion tests/test_updates/prediction_errors/test_dirichlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def test_dirichlet_node_prediction_error():
.add_nodes(kind="generic-state")
.add_nodes(kind="DP-state", value_children=0, batch_size=2)
.add_nodes(
kind="exponential-state",
kind="ef-state",
n_nodes=2,
value_children=1,
xis=jnp.array([0.0, 1 / 8]),
Expand Down
4 changes: 2 additions & 2 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def test_set_update_sequence():
network3 = (
Network()
.add_nodes(kind="generic-state")
.add_nodes(kind="exponential-state", value_children=0)
.add_nodes(kind="ef-state", value_children=0)
.create_belief_propagation_fn()
)
predictions, updates = network3.update_sequence
Expand All @@ -87,7 +87,7 @@ def test_set_update_sequence():
.add_nodes(kind="generic-state")
.add_nodes(kind="DP-state", value_children=0, alpha=0.1, batch_size=2)
.add_nodes(
kind="exponential-state",
kind="ef-state",
n_nodes=2,
value_children=1,
xis=jnp.array([0.0, 1 / 8]),
Expand Down

0 comments on commit cca3663

Please sign in to comment.