Skip to content

Commit

Permalink
returns attributes
Browse files Browse the repository at this point in the history
  • Loading branch information
LegrandNico committed Oct 22, 2024
1 parent d2e3a55 commit 198245f
Show file tree
Hide file tree
Showing 5 changed files with 229 additions and 30 deletions.
150 changes: 137 additions & 13 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@ crate-type = ["cdylib"]
path = "src/lib.rs" # The source file of the target.

[dependencies]
pyo3 = { version = "0.22.5", features = ["extension-module"] }
pyo3 = { version = "0.21.2", features = ["extension-module"] }
66 changes: 58 additions & 8 deletions src/network.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
use std::collections::HashMap;
use crate::updates::posterior;
use pyo3::prelude::*;
use pyo3::{prelude::*, types::{PyList, PyDict}};

#[derive(Debug)]
#[pyclass]
pub struct AdjacencyLists{
#[pyo3(get, set)]
pub value_parents: Option<usize>,
#[pyo3(get, set)]
pub value_children: Option<usize>,
#[pyo3(get, set)]
pub volatility_parents: Option<usize>,
#[pyo3(get, set)]
pub volatility_children: Option<usize>,
}
#[derive(Debug, Clone)]
Expand Down Expand Up @@ -54,9 +59,18 @@ impl Network {
}
}

// Add a node to the graph
#[pyo3(signature = (kind, value_parents=None, value_children=None, volatility_children=None, volatility_parents=None))]
pub fn add_node(&mut self, kind: String, value_parents: Option<usize>, value_children: Option<usize>, volatility_children: Option<usize>, volatility_parents: Option<usize>) {
/// Add nodes to the network.
///
/// # Arguments
/// * `kind` - The type of node that should be added.
/// * `value_parents` - The indexes of the node's value parents.
/// * `value_children` - The indexes of the node's value children.
/// * `volatility_children` - The indexes of the node's volatility children.
/// * `volatility_parents` - The indexes of the node's volatility parents.
#[pyo3(signature = (kind="continuous-state", value_parents=None, value_children=None, volatility_children=None, volatility_parents=None))]
pub fn add_nodes(&mut self, kind: &str, value_parents: Option<usize>,
value_children: Option<usize>, volatility_children: Option<usize>,
volatility_parents: Option<usize>) {

// the node ID is equal to the number of nodes already in the network
let node_id: usize = self.nodes.len();
Expand Down Expand Up @@ -143,6 +157,11 @@ impl Network {
}
}

/// One time step belief propagation.
///
/// # Arguments
/// * `observations` - A vector of values, each value is one new observation associated
/// with one node.
pub fn belief_propagation(&mut self, observations: Vec<f64>) {

// 1. prediction propagation
Expand All @@ -158,13 +177,44 @@ impl Network {

}

/// Add a sequence of observations.
///
/// # Arguments
/// * `input_data` - A vector of vectors. Each vector is a time series of observations
/// associated with one node.
pub fn input_data(&mut self, input_data: Vec<Vec<f64>>) {
for observation in input_data {
self.belief_propagation(observation);
}
}

#[getter]
pub fn get_inputs<'py>(&self, py: Python<'py>) -> PyResult<&'py PyList> {
let py_list = PyList::new(py, &self.inputs); // Create a PyList from Vec<usize>
Ok(py_list)
}

#[getter]
pub fn get_edges<'py>(&self, py: Python<'py>) -> PyResult<&'py PyList> {
// Create a new Python list
let py_list = PyList::empty(py);

// Convert each struct in the Vec to a Python object and add to PyList
for s in &self.edges {
// Create a new Python dictionary for each MyStruct
let py_dict = PyDict::new(py);
py_dict.set_item("value_parents", s.value_parents)?;
py_dict.set_item("value_children", s.value_children)?;
py_dict.set_item("volatility_parents", s.volatility_parents)?;
py_dict.set_item("volatility_children", s.volatility_children)?;

// Add the dictionary to the list
py_list.append(py_dict)?;
}
Ok(py_list)
}

}

// Create a module to expose the class to Python
#[pymodule]
Expand All @@ -186,15 +236,15 @@ mod tests {
let mut network = Network::new();

// create a network with two exponential family state nodes
network.add_node(
String::from("exponential-state"),
network.add_nodes(
"exponential-state",
None,
None,
None,
None
);
network.add_node(
String::from("exponential-state"),
network.add_nodes(
"exponential-state",
None,
None,
None,
Expand Down
Loading

0 comments on commit 198245f

Please sign in to comment.