Skip to content

Commit

Permalink
remove input nodes and add volatility couplings
Browse files Browse the repository at this point in the history
  • Loading branch information
LegrandNico committed Oct 21, 2024
1 parent 0fd0346 commit d2e3a55
Show file tree
Hide file tree
Showing 7 changed files with 103 additions and 75 deletions.
34 changes: 17 additions & 17 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
[package]
name = "hgf"
name = "rshgf"
version = "0.1.0"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[lib]
name = "hgf"
name = "rshgf"
crate-type = ["cdylib"]
path = "src/lib.rs" # The source file of the target.

[dependencies]
pyo3 = { version = "0.22.4", features = ["extension-module"] }
pyo3 = { version = "0.22.5", features = ["extension-module"] }
95 changes: 57 additions & 38 deletions src/network.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,31 @@ use pyo3::prelude::*;
pub struct AdjacencyLists{
pub value_parents: Option<usize>,
pub value_children: Option<usize>,
pub volatility_parents: Option<usize>,
pub volatility_children: Option<usize>,
}
#[derive(Debug, Clone)]
pub struct GenericInputNode{
pub observation: f64,
pub time_step: f64,
pub struct ContinuousStateNode{
pub mean: f64,
pub expected_mean: f64,
pub precision: f64,
pub expected_precision: f64,
pub tonic_volatility: f64,
pub tonic_drift: f64,
pub autoconnection_strength: f64,
}
#[derive(Debug, Clone)]
pub struct ExponentialNode {
pub observation: f64,
pub struct ExponentialFamiliyStateNode {
pub mean: f64,
pub expected_mean: f64,
pub nus: f64,
pub xis: [f64; 2],
}

#[derive(Debug, Clone)]
pub enum Node {
Generic(GenericInputNode),
Exponential(ExponentialNode),
Continuous(ContinuousStateNode),
Exponential(ExponentialFamiliyStateNode),
}

#[derive(Debug)]
Expand All @@ -47,29 +55,46 @@ impl Network {
}

// Add a node to the graph
#[pyo3(signature = (kind, value_parents=None, value_childrens=None))]
pub fn add_node(&mut self, kind: String, value_parents: Option<usize>, value_childrens: Option<usize>) {
#[pyo3(signature = (kind, value_parents=None, value_children=None, volatility_children=None, volatility_parents=None))]
pub fn add_node(&mut self, kind: String, value_parents: Option<usize>, value_children: Option<usize>, volatility_children: Option<usize>, volatility_parents: Option<usize>) {

// the node ID is equal to the number of nodes already in the network
let node_id: usize = self.nodes.len();

let edges = AdjacencyLists{
value_parents: value_children,
value_children: value_parents,
value_parents: value_childrens,
volatility_parents: volatility_parents,
volatility_children: volatility_children,
};

// add edges and attributes
if kind == "generic-input" {
let generic_input = GenericInputNode{observation: 0.0, time_step: 0.0};
let node = Node::Generic(generic_input);
if kind == "continuous-state" {
let continuous_state = ContinuousStateNode{
mean: 0.0, expected_mean: 0.0, precision: 1.0, expected_precision: 1.0,
tonic_drift: 0.0, tonic_volatility: -4.0, autoconnection_strength: 1.0
};
let node = Node::Continuous(continuous_state);
self.nodes.insert(node_id, node);
self.edges.push(edges);
self.inputs.push(node_id);
} else if kind == "exponential-node" {
let exponential_node: ExponentialNode = ExponentialNode{observation: 0.0, nus: 0.0, xis: [0.0, 0.0]};

// if this node has no children, this is an input node
if (value_children == None) & (volatility_children == None) {
self.inputs.push(node_id);
}
} else if kind == "exponential-state" {
let exponential_node: ExponentialFamiliyStateNode = ExponentialFamiliyStateNode{
mean: 0.0, expected_mean: 0.0, nus: 0.0, xis: [0.0, 0.0]
};
let node = Node::Exponential(exponential_node);
self.nodes.insert(node_id, node);
self.edges.push(edges);

// if this node has no children, this is an input node
if (value_children == None) & (volatility_children == None) {
self.inputs.push(node_id);
}

} else {
println!("Invalid type of node provided ({}).", kind);
}
Expand All @@ -78,25 +103,25 @@ impl Network {
pub fn prediction_error(&mut self, node_idx: usize) {

// get the observation value
let observation;
let mean;
match self.nodes[&node_idx] {
Node::Generic(ref node) => {
observation = node.observation;
Node::Continuous(ref node) => {
mean = node.mean;
}
Node::Exponential(ref node) => {
observation = node.observation;
mean = node.mean;
}
}

let value_parent_idx = &self.edges[node_idx].value_parents;
match value_parent_idx {
Some(idx) => {
match self.nodes.get_mut(idx) {
Some(Node::Generic(ref mut parent)) => {
parent.observation = observation
Some(Node::Continuous(ref mut parent)) => {
parent.mean = mean
}
Some(Node::Exponential(ref mut parent)) => {
parent.observation = observation
parent.mean = mean
}
None => println!("No prediction error for this type of node."),
}
Expand All @@ -108,11 +133,11 @@ impl Network {
pub fn posterior_update(&mut self, node_idx: usize, observation: f64) {

match self.nodes.get_mut(&node_idx) {
Some(Node::Generic(ref mut node)) => {
node.observation = observation
Some(Node::Continuous(ref mut node)) => {
node.mean = observation
}
Some(Node::Exponential(ref mut node)) => {
posterior::continuous::posterior_update_exponential(node)
posterior::exponential::posterior_update_exponential(node)
}
None => println!("No posterior update for this type of node.")
}
Expand Down Expand Up @@ -160,26 +185,20 @@ mod tests {
// initialize network
let mut network = Network::new();

// create a network
// create a network with two exponential family state nodes
network.add_node(
String::from("generic-input"),
String::from("exponential-state"),
None,
None,
);
network.add_node(
String::from("generic-input"),
None,
None,
None
);
network.add_node(
String::from("exponential-node"),
String::from("exponential-state"),
None,
None,
Some(0),
);
network.add_node(
String::from("exponential-node"),
None,
Some(1),
None
);

// println!("Graph before belief propagation: {:?}", network);
Expand Down
9 changes: 0 additions & 9 deletions src/updates/posterior/continuous.rs
Original file line number Diff line number Diff line change
@@ -1,9 +0,0 @@
use crate::network::ExponentialNode;
use crate::math::sufficient_statistics;

pub fn posterior_update_exponential(node: &mut ExponentialNode) {
let suf_stats = sufficient_statistics(&node.observation);
for i in 0..suf_stats.len() {
node.xis[i] = node.xis[i] + (1.0 / (1.0 + node.nus)) * (suf_stats[i] - node.xis[i]);
}
}
9 changes: 9 additions & 0 deletions src/updates/posterior/exponential.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
use crate::network::ExponentialFamiliyStateNode;
use crate::math::sufficient_statistics;

pub fn posterior_update_exponential(node: &mut ExponentialFamiliyStateNode) {
let suf_stats = sufficient_statistics(&node.mean);
for i in 0..suf_stats.len() {
node.xis[i] = node.xis[i] + (1.0 / (1.0 + node.nus)) * (suf_stats[i] - node.xis[i]);
}
}
3 changes: 2 additions & 1 deletion src/updates/posterior/mod.rs
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
pub mod continuous;
pub mod continuous;
pub mod exponential;
22 changes: 15 additions & 7 deletions src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,24 +60,32 @@ mod tests {

// create a network
network.add_node(
String::from("generic-input"),
String::from("continuous-state"),
Some(1),
None,
None,
Some(2),
);
network.add_node(
String::from("exponential-node"),
Some(2),
String::from("continuous-state"),
None,
Some(0),
None,
None,
);
network.add_node(
String::from("exponential-node"),
Some(3),
Some(1),
String::from("continuous-state"),
None,
None,
Some(0),
None,
);
network.add_node(
String::from("exponential-node"),
None,
Some(2),
None,
None,
None,
);

println!("Network: {:?}", network);
Expand Down

0 comments on commit d2e3a55

Please sign in to comment.