diff --git a/Cargo.lock b/Cargo.lock index 1d4b58a05..e4c2a367f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8,6 +8,12 @@ version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" +[[package]] +name = "bitflags" +version = "2.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" + [[package]] name = "cfg-if" version = "1.0.0" @@ -16,9 +22,9 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "heck" -version = "0.5.0" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" +checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" [[package]] name = "indoc" @@ -32,6 +38,16 @@ version = "0.2.159" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "561d97a539a36e26a9a5fad1ea11a3039a67714694aaa379433e580854bc3dc5" +[[package]] +name = "lock_api" +version = "0.4.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07af8b9cdd281b7915f413fa73f29ebd5d55d0d3f0155584dade1ff18cea1b17" +dependencies = [ + "autocfg", + "scopeguard", +] + [[package]] name = "memoffset" version = "0.9.1" @@ -47,6 +63,29 @@ version = "1.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775" +[[package]] +name = "parking_lot" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1bf18183cf54e8d6059647fc3063646a1801cf30896933ec2311622cc4b9a27" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "smallvec", + "windows-targets", +] + [[package]] name = "portable-atomic" version = "1.9.0" @@ -64,15 +103,15 @@ dependencies = [ [[package]] name = "pyo3" -version = "0.22.5" +version = "0.21.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d922163ba1f79c04bc49073ba7b32fd5a8d3b76a87c955921234b8e77333c51" +checksum = "a5e00b96a521718e08e03b1a622f01c8a8deb50719335de3f60b3b3950f069d8" dependencies = [ "cfg-if", "indoc", "libc", "memoffset", - "once_cell", + "parking_lot", "portable-atomic", "pyo3-build-config", "pyo3-ffi", @@ -82,9 +121,9 @@ dependencies = [ [[package]] name = "pyo3-build-config" -version = "0.22.5" +version = "0.21.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc38c5feeb496c8321091edf3d63e9a6829eab4b863b4a6a65f26f3e9cc6b179" +checksum = "7883df5835fafdad87c0d888b266c8ec0f4c9ca48a5bed6bbb592e8dedee1b50" dependencies = [ "once_cell", "target-lexicon", @@ -92,9 +131,9 @@ dependencies = [ [[package]] name = "pyo3-ffi" -version = "0.22.5" +version = "0.21.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94845622d88ae274d2729fcefc850e63d7a3ddff5e3ce11bd88486db9f1d357d" +checksum = "01be5843dc60b916ab4dad1dca6d20b9b4e6ddc8e15f50c47fe6d85f1fb97403" dependencies = [ "libc", "pyo3-build-config", @@ -102,9 +141,9 @@ dependencies = [ [[package]] name = "pyo3-macros" -version = "0.22.5" +version = "0.21.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e655aad15e09b94ffdb3ce3d217acf652e26bbc37697ef012f5e5e348c716e5e" +checksum = "77b34069fc0682e11b31dbd10321cbf94808394c56fd996796ce45217dfac53c" dependencies = [ "proc-macro2", "pyo3-macros-backend", @@ -114,9 +153,9 @@ dependencies = [ [[package]] name = "pyo3-macros-backend" -version = "0.22.5" +version = "0.21.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ae1e3f09eecd94618f60a455a23def79f79eba4dc561a97324bf9ac8c6df30ce" +checksum = "08260721f32db5e1a5beae69a55553f56b99bd0e1c3e6e0a5e8851a9d0f5a85c" dependencies = [ "heck", "proc-macro2", @@ -134,6 +173,15 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "redox_syscall" +version = "0.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b6dfecf2c74bce2466cabf93f6664d6998a69eb21e39f4207930065b27b771f" +dependencies = [ + "bitflags", +] + [[package]] name = "rshgf" version = "0.1.0" @@ -141,6 +189,18 @@ dependencies = [ "pyo3", ] +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + +[[package]] +name = "smallvec" +version = "1.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" + [[package]] name = "syn" version = "2.0.79" @@ -169,3 +229,67 @@ name = "unindent" version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c7de7d73e1754487cb58364ee906a499937a0dfabd86bcb980fa99ec8c8fa2ce" + +[[package]] +name = "windows-targets" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" +dependencies = [ + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_gnullvm", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" + +[[package]] +name = "windows_i686_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" + +[[package]] +name = "windows_i686_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" diff --git a/Cargo.toml b/Cargo.toml index 9a30c4362..0673e99a1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,4 +10,4 @@ crate-type = ["cdylib"] path = "src/lib.rs" # The source file of the target. [dependencies] -pyo3 = { version = "0.22.5", features = ["extension-module"] } \ No newline at end of file +pyo3 = { version = "0.21.2", features = ["extension-module"] } \ No newline at end of file diff --git a/src/network.rs b/src/network.rs index 2822d5e9d..61408539e 100644 --- a/src/network.rs +++ b/src/network.rs @@ -1,12 +1,17 @@ use std::collections::HashMap; use crate::updates::posterior; -use pyo3::prelude::*; +use pyo3::{prelude::*, types::{PyList, PyDict}}; #[derive(Debug)] +#[pyclass] pub struct AdjacencyLists{ + #[pyo3(get, set)] pub value_parents: Option, + #[pyo3(get, set)] pub value_children: Option, + #[pyo3(get, set)] pub volatility_parents: Option, + #[pyo3(get, set)] pub volatility_children: Option, } #[derive(Debug, Clone)] @@ -54,9 +59,18 @@ impl Network { } } - // Add a node to the graph - #[pyo3(signature = (kind, value_parents=None, value_children=None, volatility_children=None, volatility_parents=None))] - pub fn add_node(&mut self, kind: String, value_parents: Option, value_children: Option, volatility_children: Option, volatility_parents: Option) { + /// Add nodes to the network. + /// + /// # Arguments + /// * `kind` - The type of node that should be added. + /// * `value_parents` - The indexes of the node's value parents. + /// * `value_children` - The indexes of the node's value children. + /// * `volatility_children` - The indexes of the node's volatility children. + /// * `volatility_parents` - The indexes of the node's volatility parents. + #[pyo3(signature = (kind="continuous-state", value_parents=None, value_children=None, volatility_children=None, volatility_parents=None))] + pub fn add_nodes(&mut self, kind: &str, value_parents: Option, + value_children: Option, volatility_children: Option, + volatility_parents: Option) { // the node ID is equal to the number of nodes already in the network let node_id: usize = self.nodes.len(); @@ -143,6 +157,11 @@ impl Network { } } + /// One time step belief propagation. + /// + /// # Arguments + /// * `observations` - A vector of values, each value is one new observation associated + /// with one node. pub fn belief_propagation(&mut self, observations: Vec) { // 1. prediction propagation @@ -158,13 +177,44 @@ impl Network { } + /// Add a sequence of observations. + /// + /// # Arguments + /// * `input_data` - A vector of vectors. Each vector is a time series of observations + /// associated with one node. pub fn input_data(&mut self, input_data: Vec>) { for observation in input_data { self.belief_propagation(observation); } } + + #[getter] + pub fn get_inputs<'py>(&self, py: Python<'py>) -> PyResult<&'py PyList> { + let py_list = PyList::new(py, &self.inputs); // Create a PyList from Vec + Ok(py_list) + } + + #[getter] + pub fn get_edges<'py>(&self, py: Python<'py>) -> PyResult<&'py PyList> { + // Create a new Python list + let py_list = PyList::empty(py); + + // Convert each struct in the Vec to a Python object and add to PyList + for s in &self.edges { + // Create a new Python dictionary for each MyStruct + let py_dict = PyDict::new(py); + py_dict.set_item("value_parents", s.value_parents)?; + py_dict.set_item("value_children", s.value_children)?; + py_dict.set_item("volatility_parents", s.volatility_parents)?; + py_dict.set_item("volatility_children", s.volatility_children)?; + + // Add the dictionary to the list + py_list.append(py_dict)?; + } + Ok(py_list) } +} // Create a module to expose the class to Python #[pymodule] @@ -186,15 +236,15 @@ mod tests { let mut network = Network::new(); // create a network with two exponential family state nodes - network.add_node( - String::from("exponential-state"), + network.add_nodes( + "exponential-state", None, None, None, None ); - network.add_node( - String::from("exponential-state"), + network.add_nodes( + "exponential-state", None, None, None, diff --git a/src/utils.rs b/src/utils.rs index 87acda9e6..ab5da35a3 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -59,29 +59,29 @@ mod tests { let mut network = Network::new(); // create a network - network.add_node( - String::from("continuous-state"), + network.add_nodes( + "continuous-state", Some(1), None, None, Some(2), ); - network.add_node( - String::from("continuous-state"), + network.add_nodes( + "continuous-state", None, Some(0), None, None, ); - network.add_node( - String::from("continuous-state"), + network.add_nodes( + "continuous-state", None, None, Some(0), None, ); - network.add_node( - String::from("exponential-node"), + network.add_nodes( + "exponential-node", None, None, None, diff --git a/tests/test_exponential_family.py b/tests/test_exponential_family.py new file mode 100644 index 000000000..837b685f8 --- /dev/null +++ b/tests/test_exponential_family.py @@ -0,0 +1,25 @@ +# Author: Nicolas Legrand + +from rshgf import Network as RsNetwork + +from pyhgf import load_data +from pyhgf.model import Network as PyNetwork + + +def test_1d_gaussain(): + + timeseries = load_data("continuous") + + # Rust ----------------------------------------------------------------------------- + rs_network = RsNetwork() + rs_network.add_nodes(kind="exponential-state") + rs_network.add_nodes(kind="exponential-state") + rs_network.inputs + rs_network.edges + + rs_network.input_data([[0, 0], [1, 1]]) + rs_network.input_data([timeseries, timeseries]) + + # Python --------------------------------------------------------------------------- + py_network = PyNetwork().add_nodes(kind="continuous-state") + py_network.attributes