From c9706c43b84b43e3aba5807e12abced3827267ba Mon Sep 17 00:00:00 2001 From: LegrandNico Date: Mon, 2 Sep 2024 11:27:01 +0200 Subject: [PATCH] package working with tests --- src/{rs-hgf => hgf}/.gitignore | 0 src/hgf/Cargo.lock | 171 +++++++++++ src/{rs-hgf => hgf}/Cargo.toml | 7 +- src/{rs-hgf => hgf}/pyproject.toml | 2 +- src/hgf/src/lib.rs | 4 + src/hgf/src/math.rs | 3 + src/hgf/src/network.rs | 189 ++++++++++++ src/hgf/src/updates/mod.rs | 3 + src/hgf/src/updates/posterior/continuous.rs | 9 + src/hgf/src/updates/posterior/mod.rs | 1 + src/hgf/src/updates/prediction/mod.rs | 0 .../updates/prediction_error/inputs/mod.rs | 0 src/hgf/src/updates/prediction_error/mod.rs | 2 + .../prediction_error/nodes/continuous.rs | 0 .../src/updates/prediction_error/nodes/mod.rs | 1 + src/hgf/src/utils.rs | 34 +++ src/hgf/tests/exponential_family.rs | 0 src/rs-hgf/.github/workflows/CI.yml | 120 -------- src/rs-hgf/Cargo.lock | 273 ------------------ src/rs-hgf/src/lib.rs | 14 - src/rs-hgf/src/main.rs | 220 -------------- 21 files changed, 422 insertions(+), 631 deletions(-) rename src/{rs-hgf => hgf}/.gitignore (100%) create mode 100644 src/hgf/Cargo.lock rename src/{rs-hgf => hgf}/Cargo.toml (65%) rename src/{rs-hgf => hgf}/pyproject.toml (95%) create mode 100644 src/hgf/src/lib.rs create mode 100644 src/hgf/src/math.rs create mode 100644 src/hgf/src/network.rs create mode 100644 src/hgf/src/updates/mod.rs create mode 100644 src/hgf/src/updates/posterior/continuous.rs create mode 100644 src/hgf/src/updates/posterior/mod.rs create mode 100644 src/hgf/src/updates/prediction/mod.rs create mode 100644 src/hgf/src/updates/prediction_error/inputs/mod.rs create mode 100644 src/hgf/src/updates/prediction_error/mod.rs create mode 100644 src/hgf/src/updates/prediction_error/nodes/continuous.rs create mode 100644 src/hgf/src/updates/prediction_error/nodes/mod.rs create mode 100644 src/hgf/src/utils.rs create mode 100644 src/hgf/tests/exponential_family.rs delete mode 100644 src/rs-hgf/.github/workflows/CI.yml delete mode 100644 src/rs-hgf/Cargo.lock delete mode 100644 src/rs-hgf/src/lib.rs delete mode 100644 src/rs-hgf/src/main.rs diff --git a/src/rs-hgf/.gitignore b/src/hgf/.gitignore similarity index 100% rename from src/rs-hgf/.gitignore rename to src/hgf/.gitignore diff --git a/src/hgf/Cargo.lock b/src/hgf/Cargo.lock new file mode 100644 index 000000000..a63f242ae --- /dev/null +++ b/src/hgf/Cargo.lock @@ -0,0 +1,171 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "autocfg" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" + +[[package]] +name = "cfg-if" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" + +[[package]] +name = "heck" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" + +[[package]] +name = "hgf" +version = "0.1.0" +dependencies = [ + "pyo3", +] + +[[package]] +name = "indoc" +version = "2.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b248f5224d1d606005e02c97f5aa4e88eeb230488bcc03bc9ca4d7991399f2b5" + +[[package]] +name = "libc" +version = "0.2.153" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c198f91728a82281a64e1f4f9eeb25d82cb32a5de251c6bd1b5154d63a8e7bd" + +[[package]] +name = "memoffset" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a634b1c61a95585bd15607c6ab0c4e5b226e695ff2800ba0cdccddf208c406c" +dependencies = [ + "autocfg", +] + +[[package]] +name = "once_cell" +version = "1.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" + +[[package]] +name = "portable-atomic" +version = "1.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da544ee218f0d287a911e9c99a39a8c9bc8fcad3cb8db5959940044ecfc67265" + +[[package]] +name = "proc-macro2" +version = "1.0.86" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e719e8df665df0d1c8fbfd238015744736151d4445ec0836b8e628aae103b77" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "pyo3" +version = "0.22.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "831e8e819a138c36e212f3af3fd9eeffed6bf1510a805af35b0edee5ffa59433" +dependencies = [ + "cfg-if", + "indoc", + "libc", + "memoffset", + "once_cell", + "portable-atomic", + "pyo3-build-config", + "pyo3-ffi", + "pyo3-macros", + "unindent", +] + +[[package]] +name = "pyo3-build-config" +version = "0.22.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e8730e591b14492a8945cdff32f089250b05f5accecf74aeddf9e8272ce1fa8" +dependencies = [ + "once_cell", + "target-lexicon", +] + +[[package]] +name = "pyo3-ffi" +version = "0.22.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e97e919d2df92eb88ca80a037969f44e5e70356559654962cbb3316d00300c6" +dependencies = [ + "libc", + "pyo3-build-config", +] + +[[package]] +name = "pyo3-macros" +version = "0.22.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eb57983022ad41f9e683a599f2fd13c3664d7063a3ac5714cae4b7bee7d3f206" +dependencies = [ + "proc-macro2", + "pyo3-macros-backend", + "quote", + "syn", +] + +[[package]] +name = "pyo3-macros-backend" +version = "0.22.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec480c0c51ddec81019531705acac51bcdbeae563557c982aa8263bb96880372" +dependencies = [ + "heck", + "proc-macro2", + "pyo3-build-config", + "quote", + "syn", +] + +[[package]] +name = "quote" +version = "1.0.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "291ec9ab5efd934aaf503a6466c5d5251535d108ee747472c3977cc5acc868ef" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "syn" +version = "2.0.77" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f35bcdf61fd8e7be6caf75f429fdca8beb3ed76584befb503b1569faee373ed" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "target-lexicon" +version = "0.12.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1" + +[[package]] +name = "unicode-ident" +version = "1.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" + +[[package]] +name = "unindent" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7de7d73e1754487cb58364ee906a499937a0dfabd86bcb980fa99ec8c8fa2ce" diff --git a/src/rs-hgf/Cargo.toml b/src/hgf/Cargo.toml similarity index 65% rename from src/rs-hgf/Cargo.toml rename to src/hgf/Cargo.toml index 19da1aa16..9e430d5e9 100644 --- a/src/rs-hgf/Cargo.toml +++ b/src/hgf/Cargo.toml @@ -1,12 +1,13 @@ [package] -name = "rs-hgf" +name = "hgf" version = "0.1.0" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [lib] -name = "rs_hgf" +name = "hgf" crate-type = ["cdylib"] +path = "src/lib.rs" # The source file of the target. [dependencies] -pyo3 = "0.19.0" +pyo3 = "0.22.2" diff --git a/src/rs-hgf/pyproject.toml b/src/hgf/pyproject.toml similarity index 95% rename from src/rs-hgf/pyproject.toml rename to src/hgf/pyproject.toml index c370be96a..399c9df9e 100644 --- a/src/rs-hgf/pyproject.toml +++ b/src/hgf/pyproject.toml @@ -3,7 +3,7 @@ requires = ["maturin>=1.4,<2.0"] build-backend = "maturin" [project] -name = "rs-hgf" +name = "hgf" requires-python = ">=3.8" classifiers = [ "Programming Language :: Rust", diff --git a/src/hgf/src/lib.rs b/src/hgf/src/lib.rs new file mode 100644 index 000000000..504294410 --- /dev/null +++ b/src/hgf/src/lib.rs @@ -0,0 +1,4 @@ +pub mod network; +pub mod utils; +pub mod math; +pub mod updates; \ No newline at end of file diff --git a/src/hgf/src/math.rs b/src/hgf/src/math.rs new file mode 100644 index 000000000..20a0e568b --- /dev/null +++ b/src/hgf/src/math.rs @@ -0,0 +1,3 @@ +pub fn sufficient_statistics(x: &f64) -> [f64; 2] { + [*x, x.powf(2.0)] +} \ No newline at end of file diff --git a/src/hgf/src/network.rs b/src/hgf/src/network.rs new file mode 100644 index 000000000..1d79a952e --- /dev/null +++ b/src/hgf/src/network.rs @@ -0,0 +1,189 @@ +use std::collections::HashMap; +use crate::updates::posterior; + +#[derive(Debug)] +pub struct AdjacencyLists{ + pub value_parents: Option, + pub value_children: Option, +} +#[derive(Debug, Clone)] +pub struct GenericInputNode{ + pub observation: f64, + pub time_step: f64, +} +#[derive(Debug, Clone)] +pub struct ExponentialNode { + pub observation: f64, + pub nus: f64, + pub xis: [f64; 2], +} + +#[derive(Debug, Clone)] +pub enum Node { + Generic(GenericInputNode), + Exponential(ExponentialNode), +} + +#[derive(Debug)] +pub struct Network{ + pub nodes: HashMap, + pub edges: Vec, + pub inputs: Vec, +} + +impl Network { + // Create a new graph + pub fn new() -> Self { + Network { + nodes: HashMap::new(), + edges: Vec::new(), + inputs: Vec::new(), + } + } + + // Add a node to the graph + pub fn add_node(&mut self, kind: String, value_parents: Option, value_childrens: Option) { + + // the node ID is equal to the number of nodes already in the network + let node_id: usize = self.nodes.len(); + + let edges = AdjacencyLists{ + value_children: value_parents, + value_parents: value_childrens, + }; + + // add edges and attributes + if kind == "generic-input" { + let generic_input = GenericInputNode{observation: 0.0, time_step: 0.0}; + let node = Node::Generic(generic_input); + self.nodes.insert(node_id, node); + self.edges.push(edges); + self.inputs.push(node_id); + } else if kind == "exponential-node" { + let exponential_node: ExponentialNode = ExponentialNode{observation: 0.0, nus: 0.0, xis: [0.0, 0.0]}; + let node = Node::Exponential(exponential_node); + self.nodes.insert(node_id, node); + self.edges.push(edges); + } else { + println!("Invalid type of node provided ({}).", kind); + } + } + + pub fn prediction_error(&mut self, node_idx: usize) { + + // get the observation value + let observation; + match self.nodes[&node_idx] { + Node::Generic(ref node) => { + observation = node.observation; + } + Node::Exponential(ref node) => { + observation = node.observation; + } + } + + let value_parent_idx = &self.edges[node_idx].value_parents; + match value_parent_idx { + Some(idx) => { + match self.nodes.get_mut(idx) { + Some(Node::Generic(ref mut parent)) => { + parent.observation = observation + } + Some(Node::Exponential(ref mut parent)) => { + parent.observation = observation + } + None => println!("No prediction error for this type of node."), + } + } + None => println!("No value parent"), + } + } + + pub fn posterior_update(&mut self, node_idx: &usize, observation: f64) { + + match self.nodes.get_mut(node_idx) { + Some(Node::Generic(ref mut node)) => { + node.observation = observation + } + Some(Node::Exponential(ref mut node)) => { + posterior::continuous::posterior_update_exponential(node) + } + None => println!("No posterior update for this type of node.") + } + } + + pub fn belief_propagation(&mut self, observations: Vec) { + + // 1. prediction propagation + + // 2. inject the observations into the input nodes + for i in 0..observations.len() { + + let input_node_idx = self.inputs[i]; + self.posterior_update(&input_node_idx, observations[i]); + self.prediction_error(input_node_idx); + } + + // 3. posterior update - prediction errors propagation + } + + pub fn input_data(&mut self, input_data: Vec>) { + for observation in input_data { + self.belief_propagation(observation); + } + } + } + + +// Tests module for unit tests +#[cfg(test)] // Only compile and include this module when running tests +mod tests { + use super::*; // Import the parent module's items to test them + + + #[test] + fn test_exponential_family_gaussian() { + + // initialize network + let mut network = Network::new(); + + // create a network + network.add_node( + String::from("generic-input"), + None, + None, + ); + network.add_node( + String::from("generic-input"), + None, + None, + ); + network.add_node( + String::from("exponential-node"), + None, + Some(0), + ); + network.add_node( + String::from("exponential-node"), + None, + Some(1), + ); + + println!("Graph before belief propagation: {:?}", network); + + // belief propagation + let input_data = vec![ + vec![1.1, 2.2], + vec![1.2, 2.1], + vec![1.0, 2.0], + vec![1.3, 2.2], + vec![1.1, 2.5], + vec![1.0, 2.6], + ]; + + network.input_data(input_data); + + println!("Graph after belief propagation: {:?}", network); + + } +} diff --git a/src/hgf/src/updates/mod.rs b/src/hgf/src/updates/mod.rs new file mode 100644 index 000000000..dd8404066 --- /dev/null +++ b/src/hgf/src/updates/mod.rs @@ -0,0 +1,3 @@ +pub mod posterior; +pub mod prediction; +pub mod prediction_error; \ No newline at end of file diff --git a/src/hgf/src/updates/posterior/continuous.rs b/src/hgf/src/updates/posterior/continuous.rs new file mode 100644 index 000000000..7fb5b78e6 --- /dev/null +++ b/src/hgf/src/updates/posterior/continuous.rs @@ -0,0 +1,9 @@ +use crate::network::ExponentialNode; +use crate::math::sufficient_statistics; + +pub fn posterior_update_exponential(node: &mut ExponentialNode) { + let suf_stats = sufficient_statistics(&node.observation); + for i in 0..suf_stats.len() { + node.xis[i] = node.xis[i] + (1.0 / (1.0 + node.nus)) * (suf_stats[i] - node.xis[i]); + } +} \ No newline at end of file diff --git a/src/hgf/src/updates/posterior/mod.rs b/src/hgf/src/updates/posterior/mod.rs new file mode 100644 index 000000000..6817d49e7 --- /dev/null +++ b/src/hgf/src/updates/posterior/mod.rs @@ -0,0 +1 @@ +pub mod continuous; \ No newline at end of file diff --git a/src/hgf/src/updates/prediction/mod.rs b/src/hgf/src/updates/prediction/mod.rs new file mode 100644 index 000000000..e69de29bb diff --git a/src/hgf/src/updates/prediction_error/inputs/mod.rs b/src/hgf/src/updates/prediction_error/inputs/mod.rs new file mode 100644 index 000000000..e69de29bb diff --git a/src/hgf/src/updates/prediction_error/mod.rs b/src/hgf/src/updates/prediction_error/mod.rs new file mode 100644 index 000000000..264f047ac --- /dev/null +++ b/src/hgf/src/updates/prediction_error/mod.rs @@ -0,0 +1,2 @@ +pub mod inputs; +pub mod nodes; diff --git a/src/hgf/src/updates/prediction_error/nodes/continuous.rs b/src/hgf/src/updates/prediction_error/nodes/continuous.rs new file mode 100644 index 000000000..e69de29bb diff --git a/src/hgf/src/updates/prediction_error/nodes/mod.rs b/src/hgf/src/updates/prediction_error/nodes/mod.rs new file mode 100644 index 000000000..6817d49e7 --- /dev/null +++ b/src/hgf/src/updates/prediction_error/nodes/mod.rs @@ -0,0 +1 @@ +pub mod continuous; \ No newline at end of file diff --git a/src/hgf/src/utils.rs b/src/hgf/src/utils.rs new file mode 100644 index 000000000..354166a33 --- /dev/null +++ b/src/hgf/src/utils.rs @@ -0,0 +1,34 @@ +use crate::network::Network; + + +pub fn get_update_order(network: Network) -> Vec { + + let mut update_list = Vec::new(); + + // list all nodes availables in the network + let mut nodes_idxs: Vec = network.nodes.keys().cloned().collect(); + + // remove the input nodes + nodes_idxs.retain(|x| !network.inputs.contains(x)); + + // start with the value parents of input nodes + for input_idx in network.inputs { + let value_parent_idxs = network.edges[input_idx].value_parents; + match value_parent_idxs { + Some(idx) => { + // if this parent is still in the list, update it now + if nodes_idxs.contains(&idx) { + + // add the node in the update list + update_list.push(idx); + + // remove the parent from the availables nodes list + nodes_idxs.retain(|&x| x != idx); + + } + } + None => println!("The value is None") + } + } + nodes_idxs +} diff --git a/src/hgf/tests/exponential_family.rs b/src/hgf/tests/exponential_family.rs new file mode 100644 index 000000000..e69de29bb diff --git a/src/rs-hgf/.github/workflows/CI.yml b/src/rs-hgf/.github/workflows/CI.yml deleted file mode 100644 index 1bae4be43..000000000 --- a/src/rs-hgf/.github/workflows/CI.yml +++ /dev/null @@ -1,120 +0,0 @@ -# This file is autogenerated by maturin v1.4.0 -# To update, run -# -# maturin generate-ci github -# -name: CI - -on: - push: - branches: - - main - - master - tags: - - '*' - pull_request: - workflow_dispatch: - -permissions: - contents: read - -jobs: - linux: - runs-on: ubuntu-latest - strategy: - matrix: - target: [x86_64, x86, aarch64, armv7, s390x, ppc64le] - steps: - - uses: actions/checkout@v3 - - uses: actions/setup-python@v4 - with: - python-version: '3.10' - - name: Build wheels - uses: PyO3/maturin-action@v1 - with: - target: ${{ matrix.target }} - args: --release --out dist --find-interpreter - sccache: 'true' - manylinux: auto - - name: Upload wheels - uses: actions/upload-artifact@v3 - with: - name: wheels - path: dist - - windows: - runs-on: windows-latest - strategy: - matrix: - target: [x64, x86] - steps: - - uses: actions/checkout@v3 - - uses: actions/setup-python@v4 - with: - python-version: '3.10' - architecture: ${{ matrix.target }} - - name: Build wheels - uses: PyO3/maturin-action@v1 - with: - target: ${{ matrix.target }} - args: --release --out dist --find-interpreter - sccache: 'true' - - name: Upload wheels - uses: actions/upload-artifact@v3 - with: - name: wheels - path: dist - - macos: - runs-on: macos-latest - strategy: - matrix: - target: [x86_64, aarch64] - steps: - - uses: actions/checkout@v3 - - uses: actions/setup-python@v4 - with: - python-version: '3.10' - - name: Build wheels - uses: PyO3/maturin-action@v1 - with: - target: ${{ matrix.target }} - args: --release --out dist --find-interpreter - sccache: 'true' - - name: Upload wheels - uses: actions/upload-artifact@v3 - with: - name: wheels - path: dist - - sdist: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v3 - - name: Build sdist - uses: PyO3/maturin-action@v1 - with: - command: sdist - args: --out dist - - name: Upload sdist - uses: actions/upload-artifact@v3 - with: - name: wheels - path: dist - - release: - name: Release - runs-on: ubuntu-latest - if: "startsWith(github.ref, 'refs/tags/')" - needs: [linux, windows, macos, sdist] - steps: - - uses: actions/download-artifact@v3 - with: - name: wheels - - name: Publish to PyPI - uses: PyO3/maturin-action@v1 - env: - MATURIN_PYPI_TOKEN: ${{ secrets.PYPI_API_TOKEN }} - with: - command: upload - args: --non-interactive --skip-existing * diff --git a/src/rs-hgf/Cargo.lock b/src/rs-hgf/Cargo.lock deleted file mode 100644 index 06994a31a..000000000 --- a/src/rs-hgf/Cargo.lock +++ /dev/null @@ -1,273 +0,0 @@ -# This file is automatically @generated by Cargo. -# It is not intended for manual editing. -version = 3 - -[[package]] -name = "autocfg" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" - -[[package]] -name = "bitflags" -version = "1.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" - -[[package]] -name = "cfg-if" -version = "1.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" - -[[package]] -name = "indoc" -version = "1.0.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bfa799dd5ed20a7e349f3b4639aa80d74549c81716d9ec4f994c9b5815598306" - -[[package]] -name = "libc" -version = "0.2.153" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c198f91728a82281a64e1f4f9eeb25d82cb32a5de251c6bd1b5154d63a8e7bd" - -[[package]] -name = "lock_api" -version = "0.4.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3c168f8615b12bc01f9c17e2eb0cc07dcae1940121185446edc3744920e8ef45" -dependencies = [ - "autocfg", - "scopeguard", -] - -[[package]] -name = "memoffset" -version = "0.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a634b1c61a95585bd15607c6ab0c4e5b226e695ff2800ba0cdccddf208c406c" -dependencies = [ - "autocfg", -] - -[[package]] -name = "once_cell" -version = "1.19.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" - -[[package]] -name = "parking_lot" -version = "0.12.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3742b2c103b9f06bc9fff0a37ff4912935851bee6d36f3c02bcc755bcfec228f" -dependencies = [ - "lock_api", - "parking_lot_core", -] - -[[package]] -name = "parking_lot_core" -version = "0.9.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c42a9226546d68acdd9c0a280d17ce19bfe27a46bf68784e4066115788d008e" -dependencies = [ - "cfg-if", - "libc", - "redox_syscall", - "smallvec", - "windows-targets", -] - -[[package]] -name = "proc-macro2" -version = "1.0.78" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2422ad645d89c99f8f3e6b88a9fdeca7fabeac836b1002371c4367c8f984aae" -dependencies = [ - "unicode-ident", -] - -[[package]] -name = "pyo3" -version = "0.19.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e681a6cfdc4adcc93b4d3cf993749a4552018ee0a9b65fc0ccfad74352c72a38" -dependencies = [ - "cfg-if", - "indoc", - "libc", - "memoffset", - "parking_lot", - "pyo3-build-config", - "pyo3-ffi", - "pyo3-macros", - "unindent", -] - -[[package]] -name = "pyo3-build-config" -version = "0.19.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "076c73d0bc438f7a4ef6fdd0c3bb4732149136abd952b110ac93e4edb13a6ba5" -dependencies = [ - "once_cell", - "target-lexicon", -] - -[[package]] -name = "pyo3-ffi" -version = "0.19.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e53cee42e77ebe256066ba8aa77eff722b3bb91f3419177cf4cd0f304d3284d9" -dependencies = [ - "libc", - "pyo3-build-config", -] - -[[package]] -name = "pyo3-macros" -version = "0.19.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dfeb4c99597e136528c6dd7d5e3de5434d1ceaf487436a3f03b2d56b6fc9efd1" -dependencies = [ - "proc-macro2", - "pyo3-macros-backend", - "quote", - "syn", -] - -[[package]] -name = "pyo3-macros-backend" -version = "0.19.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "947dc12175c254889edc0c02e399476c2f652b4b9ebd123aa655c224de259536" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "quote" -version = "1.0.35" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "291ec9ab5efd934aaf503a6466c5d5251535d108ee747472c3977cc5acc868ef" -dependencies = [ - "proc-macro2", -] - -[[package]] -name = "redox_syscall" -version = "0.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4722d768eff46b75989dd134e5c353f0d6296e5aaa3132e776cbdb56be7731aa" -dependencies = [ - "bitflags", -] - -[[package]] -name = "rs-hgf" -version = "0.1.0" -dependencies = [ - "pyo3", -] - -[[package]] -name = "scopeguard" -version = "1.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" - -[[package]] -name = "smallvec" -version = "1.13.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6ecd384b10a64542d77071bd64bd7b231f4ed5940fba55e98c3de13824cf3d7" - -[[package]] -name = "syn" -version = "1.0.109" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" -dependencies = [ - "proc-macro2", - "quote", - "unicode-ident", -] - -[[package]] -name = "target-lexicon" -version = "0.12.13" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69758bda2e78f098e4ccb393021a0963bb3442eac05f135c30f61b7370bbafae" - -[[package]] -name = "unicode-ident" -version = "1.0.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" - -[[package]] -name = "unindent" -version = "0.1.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e1766d682d402817b5ac4490b3c3002d91dfa0d22812f341609f97b08757359c" - -[[package]] -name = "windows-targets" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" -dependencies = [ - "windows_aarch64_gnullvm", - "windows_aarch64_msvc", - "windows_i686_gnu", - "windows_i686_msvc", - "windows_x86_64_gnu", - "windows_x86_64_gnullvm", - "windows_x86_64_msvc", -] - -[[package]] -name = "windows_aarch64_gnullvm" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" - -[[package]] -name = "windows_aarch64_msvc" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" - -[[package]] -name = "windows_i686_gnu" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" - -[[package]] -name = "windows_i686_msvc" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" - -[[package]] -name = "windows_x86_64_gnu" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" - -[[package]] -name = "windows_x86_64_gnullvm" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" - -[[package]] -name = "windows_x86_64_msvc" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" diff --git a/src/rs-hgf/src/lib.rs b/src/rs-hgf/src/lib.rs deleted file mode 100644 index 53f27b16d..000000000 --- a/src/rs-hgf/src/lib.rs +++ /dev/null @@ -1,14 +0,0 @@ -// use pyo3::prelude::*; - -// /// Formats the sum of two numbers as string. -// #[pyfunction] -// fn sum_as_string(a: usize, b: usize) -> PyResult { -// Ok((a + b).to_string()) -// } - -// /// A Python module implemented in Rust. -// #[pymodule] -// fn rs_hgf(_py: Python, m: &PyModule) -> PyResult<()> { -// m.add_function(wrap_pyfunction!(sum_as_string, m)?)?; -// Ok(()) -// } diff --git a/src/rs-hgf/src/main.rs b/src/rs-hgf/src/main.rs deleted file mode 100644 index 7809c5dc4..000000000 --- a/src/rs-hgf/src/main.rs +++ /dev/null @@ -1,220 +0,0 @@ -use std::collections::HashMap; - -#[derive(Debug)] -struct AdjacencyLists{ - value_parents: Option, - value_children: Option, -} -#[derive(Debug, Clone)] -struct GenericInputNode{ - observation: f64, - time_step: f64, -} -#[derive(Debug, Clone)] -struct ExponentialNode { - observation: f64, - nus: f64, - xis: [f64; 2], -} - -#[derive(Debug, Clone)] -enum Node { - Generic(GenericInputNode), - Exponential(ExponentialNode), -} - -#[derive(Debug)] -struct Network{ - nodes: HashMap, - edges: Vec, - inputs: Vec, -} - -fn sufficient_statistics(x: &f64) -> [f64; 2] { - [*x, x.powf(2.0)] -} - -impl Network { - // Create a new graph - fn new() -> Self { - Network { - nodes: HashMap::new(), - edges: Vec::new(), - inputs: Vec::new(), - } - } - - // Add a node to the graph - fn add_node(&mut self, kind: String, value_parents: Option, value_childrens: Option) { - - // the node ID is equal to the number of nodes already in the network - let node_id: usize = self.nodes.len(); - - let edges = AdjacencyLists{ - value_children: value_parents, - value_parents: value_childrens, - }; - - // add edges and attributes - if kind == "generic-input" { - let generic_input = GenericInputNode{observation: 0.0, time_step: 0.0}; - let node = Node::Generic(generic_input); - self.nodes.insert(node_id, node); - self.edges.push(edges); - self.inputs.push(node_id); - } else if kind == "exponential-node" { - let exponential_node: ExponentialNode = ExponentialNode{observation: 0.0, nus: 0.0, xis: [0.0, 0.0]}; - let node = Node::Exponential(exponential_node); - self.nodes.insert(node_id, node); - self.edges.push(edges); - } else { - println!("Invalid type of node provided ({}).", kind); - } - } - - fn posterior_update(&mut self, node_idx: &usize, observation: f64) { - - match self.nodes.get_mut(node_idx) { - Some(Node::Generic(ref mut node)) => { - node.observation = observation - } - Some(Node::Exponential(ref mut node)) => { - let suf_stats = sufficient_statistics(&node.observation); - for i in 0..suf_stats.len() { - node.xis[i] = node.xis[i] + (1.0 / (1.0 + node.nus)) * (suf_stats[i] - node.xis[i]); - } - } - None => println!("The value is None") - } - } - - fn prediction_error(&mut self, node_idx: usize) { - - // get the observation value - let observation; - match self.nodes[&node_idx] { - Node::Generic(ref node) => { - observation = node.observation; - } - Node::Exponential(ref node) => { - observation = node.observation; - } - } - - let value_parent_idx = &self.edges[node_idx].value_parents; - match value_parent_idx { - Some(idx) => { - match self.nodes.get_mut(idx) { - Some(Node::Generic(ref mut parent)) => { - parent.observation = observation - } - Some(Node::Exponential(ref mut parent)) => { - parent.observation = observation - } - None => println!("The value is None"), - } - } - None => println!("The value is None"), - } - } - - fn belief_propagation(&mut self, observations: Vec) { - - // 1. prediction propagation - - // 2. inject the observations into the input nodes - for i in 0..observations.len() { - - let input_node_idx = self.inputs[i]; - self.posterior_update(&input_node_idx, observations[i]); - self.prediction_error(input_node_idx); - } - - // 3. posterior update - prediction errors propagation - } - - fn input_data(&mut self, input_data: Vec>) { - for observation in input_data { - self.belief_propagation(observation); - } - } - - fn get_update_order(self) -> Vec { - - let mut update_list = Vec::new(); - - // list all nodes availables in the network - let mut nodes_idxs: Vec = self.nodes.keys().cloned().collect(); - - // remove the input nodes - nodes_idxs.retain(|x| !self.inputs.contains(x)); - - // start with the value parents of input nodes - for input_idx in self.inputs { - let value_parent_idxs = self.edges[input_idx].value_parents; - match value_parent_idxs { - Some(idx) => { - // if this parent is still in the list, update it now - if nodes_idxs.contains(&idx) { - - // add the node in the update list - update_list.push(idx); - - // remove the parent from the availables nodes list - nodes_idxs.retain(|&x| x != idx); - - } - } - None => println!("The value is None") - } - } - nodes_idxs - } - - } - - -fn main() { - - // initialize network - let mut network = Network::new(); - - // create a network - network.add_node( - String::from("generic-input"), - None, - None, - ); - network.add_node( - String::from("generic-input"), - None, - None, - ); - network.add_node( - String::from("exponential-node"), - None, - Some(0), - ); - network.add_node( - String::from("exponential-node"), - None, - Some(1), - ); - - println!("Graph before belief propagation: {:?}", network); - - // belief propagation - let input_data = vec![ - vec![1.1, 2.2], - vec![1.2, 2.1], - vec![1.0, 2.0], - vec![1.3, 2.2], - vec![1.1, 2.5], - vec![1.0, 2.6], - ]; - - network.input_data(input_data); - - println!("Graph after belief propagation: {:?}", network); - -} \ No newline at end of file