From d46529d7b9ef92f895e446de68d914a87f9ef2b4 Mon Sep 17 00:00:00 2001 From: LegrandNico Date: Mon, 28 Oct 2024 20:55:11 +0100 Subject: [PATCH] split attributes into floats and vectors --- Cargo.lock | 68 +++++++ Cargo.toml | 3 +- pyhgf/model/network.py | 23 +-- src/math.rs | 4 +- src/model.rs | 174 +++++++++++------- src/updates/observations.rs | 11 +- src/updates/prediction_error/exponential.rs | 32 +++- src/utils/set_sequence.rs | 40 ++-- tests/test_exponential_family.py | 7 +- .../prediction_errors/test_dirichlet.py | 2 +- tests/test_utils.py | 2 +- 11 files changed, 234 insertions(+), 132 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index e4c2a367f..a247d39c5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -48,6 +48,16 @@ dependencies = [ "scopeguard", ] +[[package]] +name = "matrixmultiply" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9380b911e3e96d10c1f415da0876389aaf1b56759054eeb0de7df940c456ba1a" +dependencies = [ + "autocfg", + "rawpointer", +] + [[package]] name = "memoffset" version = "0.9.1" @@ -57,6 +67,48 @@ dependencies = [ "autocfg", ] +[[package]] +name = "ndarray" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "882ed72dce9365842bf196bdeedf5055305f11fc8c03dee7bb0194a6cad34841" +dependencies = [ + "matrixmultiply", + "num-complex", + "num-integer", + "num-traits", + "portable-atomic", + "portable-atomic-util", + "rawpointer", +] + +[[package]] +name = "num-complex" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", +] + [[package]] name = "once_cell" version = "1.20.2" @@ -92,6 +144,15 @@ version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cc9c68a3f6da06753e9335d63e27f6b9754dd1920d941135b7ea8224f141adb2" +[[package]] +name = "portable-atomic-util" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90a7d5beecc52a491b54d6dd05c7a45ba1801666a5baad9fdbfc6fef8d2d206c" +dependencies = [ + "portable-atomic", +] + [[package]] name = "proc-macro2" version = "1.0.87" @@ -173,6 +234,12 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "rawpointer" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" + [[package]] name = "redox_syscall" version = "0.5.7" @@ -186,6 +253,7 @@ dependencies = [ name = "rshgf" version = "0.1.0" dependencies = [ + "ndarray", "pyo3", ] diff --git a/Cargo.toml b/Cargo.toml index 0673e99a1..eb8e726cb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,4 +10,5 @@ crate-type = ["cdylib"] path = "src/lib.rs" # The source file of the target. [dependencies] -pyo3 = { version = "0.21.2", features = ["extension-module"] } \ No newline at end of file +pyo3 = { version = "0.21.2", features = ["extension-module"] } +ndarray = "0.16.1" \ No newline at end of file diff --git a/pyhgf/model/network.py b/pyhgf/model/network.py index f2b9b1e9c..b5096df9a 100644 --- a/pyhgf/model/network.py +++ b/pyhgf/model/network.py @@ -329,29 +329,14 @@ def add_nodes( be a regular state node that can have value and/or volatility parents/children. If `"binary-state"`, the node should be the value parent of a binary input. State nodes filtering distribution from the - exponential family can be created using the `"ef-"` prefix (e.g. - `"ef-normal"` for a univariate normal distribution). Note that only a few - distributions are implemented at the moment. - - In addition to state nodes, four types of input nodes are supported: - - `generic-input`: receive a value or an array and pass it to the parent - nodes. - - `continuous-input`: receive a continuous observation as input. - - `binary-input` receives a single boolean as observation. The parameters - provided to the binary input node contain: 1. `binary_precision`, the binary - input precision, which defaults to `jnp.inf`. 2. `eta0`, the lower bound of - the binary process, which defaults to `0.0`. 3. `eta1`, the higher bound of - the binary process, which defaults to `1.0`. - - `categorical-input` receives a boolean array as observation. The - parameters provided to the categorical input node contain: 1. - `n_categories`, the number of categories implied by the categorical state. + exponential family can be created using `"exponential-state"`. .. note:: When using a categorical state node, the `binary_parameters` can be used to parametrize the implied collection of binary HGFs. .. note: - When using `categorical-input`, the implied `n` binary HGFs are + When using `categorical-state`, the implied `n` binary HGFs are automatically created with a shared volatility parent at the third level, resulting in a network with `3n + 2` nodes in total. @@ -396,7 +381,7 @@ def add_nodes( """ if kind not in [ "DP-state", - "ef-normal", + "exponential-state", "categorical-state", "continuous-state", "binary-state", @@ -483,7 +468,7 @@ def add_nodes( "mean": 0.0, "observed": 1, } - elif "ef-normal" in kind: + elif "exponential-state" in kind: default_parameters = { "nus": 3.0, "xis": jnp.array([0.0, 1.0]), diff --git a/src/math.rs b/src/math.rs index 20a0e568b..80cd8d4dd 100644 --- a/src/math.rs +++ b/src/math.rs @@ -1,3 +1,3 @@ -pub fn sufficient_statistics(x: &f64) -> [f64; 2] { - [*x, x.powf(2.0)] +pub fn sufficient_statistics(x: &f64) -> Vec { + vec![*x, x.powf(2.0)] } \ No newline at end of file diff --git a/src/model.rs b/src/model.rs index 5f23e09e9..931064346 100644 --- a/src/model.rs +++ b/src/model.rs @@ -4,10 +4,13 @@ use crate::utils::set_sequence::set_update_sequence; use crate::utils::function_pointer::get_func_map; use pyo3::types::PyTuple; use pyo3::{prelude::*, types::{PyList, PyDict}}; +use ndarray::{Array2, Axis, stack}; #[derive(Debug)] #[pyclass] pub struct AdjacencyLists{ + #[pyo3(get, set)] + pub node_type: String, #[pyo3(get, set)] pub value_parents: Option>, #[pyo3(get, set)] @@ -17,29 +20,6 @@ pub struct AdjacencyLists{ #[pyo3(get, set)] pub volatility_children: Option>, } -#[derive(Debug, Clone)] -pub struct ContinuousStateNode{ - pub mean: f64, - pub expected_mean: f64, - pub precision: f64, - pub expected_precision: f64, - pub tonic_volatility: f64, - pub tonic_drift: f64, - pub autoconnection_strength: f64, -} -#[derive(Debug, Clone)] -pub struct ExponentialFamiliyStateNode { - pub mean: f64, - pub expected_mean: f64, - pub nus: f64, - pub xis: [f64; 2], -} - -#[derive(Debug, Clone)] -pub enum Node { - Continuous(ContinuousStateNode), - Exponential(ExponentialFamiliyStateNode), -} #[derive(Debug)] pub struct UpdateSequence { @@ -47,13 +27,26 @@ pub struct UpdateSequence { pub updates: Vec<(usize, FnType)>, } +#[derive(Debug)] +pub struct Attributes { + pub floats: HashMap>, + pub vectors: HashMap>>, +} + +#[derive(Debug)] +pub struct NodeTrajectories { + pub floats: HashMap>>, + pub vectors: HashMap>>>, +} + #[derive(Debug)] #[pyclass] pub struct Network{ - pub nodes: HashMap, - pub edges: Vec, + pub attributes: Attributes, + pub edges: HashMap, pub inputs: Vec, pub update_sequence: UpdateSequence, + pub node_trajectories: NodeTrajectories, } #[pymethods] @@ -63,12 +56,13 @@ impl Network { #[new] // Define the constructor accessible from Python pub fn new() -> Self { Network { - nodes: HashMap::new(), - edges: Vec::new(), + attributes: Attributes {floats: HashMap::new(), vectors: HashMap::new()}, + edges: HashMap::new(), inputs: Vec::new(), - update_sequence: UpdateSequence {predictions: Vec::new(), updates: Vec::new()} + update_sequence: UpdateSequence {predictions: Vec::new(), updates: Vec::new()}, + node_trajectories: NodeTrajectories {floats: HashMap::new(), vectors: HashMap::new()} + } } - } /// Add nodes to the network. /// @@ -84,7 +78,7 @@ impl Network { volatility_parents: Option>, volatility_children: Option>, ) { // the node ID is equal to the number of nodes already in the network - let node_id: usize = self.nodes.len(); + let node_id: usize = self.edges.len(); // if this node has no children, this is an input node if (value_children == None) & (volatility_children == None) { @@ -92,29 +86,39 @@ impl Network { } let edges = AdjacencyLists{ + node_type: String::from(kind), value_parents: value_parents, value_children: value_children, volatility_parents: volatility_parents, volatility_children: volatility_children, }; - + // add edges and attributes if kind == "continuous-state" { - let continuous_state = ContinuousStateNode{ - mean: 0.0, expected_mean: 0.0, precision: 1.0, expected_precision: 1.0, - tonic_drift: 0.0, tonic_volatility: -4.0, autoconnection_strength: 1.0 - }; - let node = Node::Continuous(continuous_state); - self.nodes.insert(node_id, node); - self.edges.push(edges); + + let attributes = [ + (String::from("mean"), 0.0), + (String::from("expected_mean"), 0.0), + (String::from("precision"), 1.0), + (String::from("expected_precision"), 1.0), + (String::from("tonic_volatility"), -4.0), + (String::from("tonic_drift"), 0.0), + (String::from("autoconnection_strength"), 1.0)].into_iter().collect(); + + self.attributes.floats.insert(node_id, attributes); + self.edges.insert(node_id, edges); } else if kind == "exponential-state" { - let exponential_node: ExponentialFamiliyStateNode = ExponentialFamiliyStateNode{ - mean: 0.0, expected_mean: 0.0, nus: 0.0, xis: [0.0, 0.0] - }; - let node = Node::Exponential(exponential_node); - self.nodes.insert(node_id, node); - self.edges.push(edges); + + let floats_attributes = [ + (String::from("mean"), 0.0), + (String::from("nus"), 3.0)].into_iter().collect(); + let vector_attributes = [ + (String::from("xis"), vec![0.0, 1.0])].into_iter().collect(); + + self.attributes.floats.insert(node_id, floats_attributes); + self.attributes.vectors.insert(node_id, vector_attributes); + self.edges.insert(node_id, edges); } else { println!("Invalid type of node provided ({}).", kind); @@ -157,10 +161,57 @@ impl Network { /// # Arguments /// * `input_data` - A vector of vectors. Each vector is a time series of observations /// associated with one node. - pub fn input_data(&mut self, input_data: Vec>) { + pub fn input_data(&mut self, input_data: Vec) { + + // initialize the belief trajectories result struture + let mut node_trajectories = NodeTrajectories {floats: HashMap::new(), vectors: HashMap::new()}; + for (node_idx, node) in &self.attributes.floats { + let new_map: HashMap> = HashMap::new(); + node_trajectories.floats.insert(*node_idx, new_map); + if let Some(attr) = node_trajectories.floats.get_mut(node_idx) { + for key in node.keys() { + attr.insert(key.clone(), Vec::new()); + } + } + } + // iterate over the observations for observation in input_data { - self.belief_propagation(observation); + + // 1. belief propagation for one time slice + self.belief_propagation(vec![observation]); + + // 2. append the new states in the result vector + for (new_node_idx, new_node) in &self.attributes.floats { + for (new_key, new_value) in new_node { + // If the key exists in map1, append the vector from map2 + if let Some(old_node) = node_trajectories.floats.get_mut(&new_node_idx) { + if let Some(old_value) = old_node.get_mut(new_key) { + old_value.push(*new_value); + } + } + } + } } + + self.node_trajectories = node_trajectories; + } + + #[getter] + pub fn get_node_trajectories<'py>(&self, py: Python<'py>) -> PyResult<&'py PyList> { + let py_list = PyList::empty(py); + + + // Iterate over the Rust HashMap and insert key-value pairs into the PyDict + for (node_idx, node) in &self.node_trajectories.floats { + let py_dict = PyDict::new(py); + for (key, value) in node { + // Create a new Python dictionary + py_dict.set_item(key, value).expect("Failed to set item in PyDict"); + } + py_list.append(py_dict)?; + } + // Create a PyList from Vec + Ok(py_list) } #[getter] @@ -175,13 +226,13 @@ impl Network { let py_list = PyList::empty(py); // Convert each struct in the Vec to a Python object and add to PyList - for s in &self.edges { + for i in 0..self.edges.len() { // Create a new Python dictionary for each MyStruct let py_dict = PyDict::new(py); - py_dict.set_item("value_parents", &s.value_parents)?; - py_dict.set_item("value_children", &s.value_children)?; - py_dict.set_item("volatility_parents", &s.volatility_parents)?; - py_dict.set_item("volatility_children", &s.volatility_children)?; + py_dict.set_item("value_parents", &self.edges[&i].value_parents)?; + py_dict.set_item("value_children", &self.edges[&i].value_children)?; + py_dict.set_item("volatility_parents", &self.edges[&i].volatility_parents)?; + py_dict.set_item("volatility_children", &self.edges[&i].volatility_children)?; // Add the dictionary to the list py_list.append(py_dict)?; @@ -246,29 +297,16 @@ mod tests { None, None ); - network.add_nodes( - "exponential-state", - None, - None, - None, - None - ); // println!("Graph before belief propagation: {:?}", network); // belief propagation - let input_data = vec![ - vec![1.1, 2.2], - vec![1.2, 2.1], - vec![1.0, 2.0], - vec![1.3, 2.2], - vec![1.1, 2.5], - vec![1.0, 2.6], - ]; - + let input_data = vec![1.0, 1.3, 1.5, 1.7]; + network.set_update_sequence(); network.input_data(input_data); - // println!("Graph after belief propagation: {:?}", network); + println!("Update sequence: {:?}", network.update_sequence); + println!("Node trajectories: {:?}", network.node_trajectories); } } diff --git a/src/updates/observations.rs b/src/updates/observations.rs index 1d64ec7a9..1037b83b2 100644 --- a/src/updates/observations.rs +++ b/src/updates/observations.rs @@ -1,4 +1,4 @@ -use crate::model::{Network, Node}; +use crate::model::Network; /// Inject new observations into an input node @@ -12,10 +12,9 @@ use crate::model::{Network, Node}; /// * `network` - The network after message passing. pub fn observation_update(network: &mut Network, node_idx: usize, observations: f64) { - match network.nodes.get_mut(&node_idx) { - Some(Node::Exponential(ref mut node)) => { - node.mean = observations; - } - _ => (), + if let Some(node) = network.attributes.floats.get_mut(&node_idx) { + if let Some(mean) = node.get_mut("mean") { + *mean = observations; } + } } \ No newline at end of file diff --git a/src/updates/prediction_error/exponential.rs b/src/updates/prediction_error/exponential.rs index 35695b9c3..e2cebca05 100644 --- a/src/updates/prediction_error/exponential.rs +++ b/src/updates/prediction_error/exponential.rs @@ -1,4 +1,4 @@ -use crate::model::{Network, Node}; +use crate::model::Network; use crate::math::sufficient_statistics; /// Updating an exponential family state node @@ -11,13 +11,25 @@ use crate::math::sufficient_statistics; /// * `network` - The network after message passing. pub fn prediction_error_exponential_state_node(network: &mut Network, node_idx: usize) { - match network.nodes.get_mut(&node_idx) { - Some(Node::Exponential(ref mut node)) => { - let suf_stats = sufficient_statistics(&node.mean); - for i in 0..suf_stats.len() { - node.xis[i] = node.xis[i] + (1.0 / (1.0 + node.nus)) * (suf_stats[i] - node.xis[i]); - } + if let Some(floats_attributes) = network.attributes.floats.get_mut(&node_idx) { + if let Some(vectors_attributes) = network.attributes.vectors.get_mut(&node_idx) { + let mean = floats_attributes.get("mean"); + let nus = floats_attributes.get("nus"); + let xis = vectors_attributes.get("xis"); + let new_xis = match (mean, nus, xis) { + (Some(mean), Some(nus), Some(xis)) => { + let suf_stats = sufficient_statistics(mean); + let mut new_xis = xis.clone(); + for i in 0..suf_stats.len() { + new_xis[i] = new_xis[i] + (1.0 / (1.0 + nus)) * (suf_stats[i] - xis[i]); + } + new_xis + } + _ => Vec::new(), + }; + if let Some(xis) = vectors_attributes.get_mut("xis") { + *xis = new_xis; // Modify the value directly + } + } } - _ => (), - } -} \ No newline at end of file + } \ No newline at end of file diff --git a/src/utils/set_sequence.rs b/src/utils/set_sequence.rs index 2a130980e..191d734ab 100644 --- a/src/utils/set_sequence.rs +++ b/src/utils/set_sequence.rs @@ -1,4 +1,4 @@ -use crate::{model::{Network, Node, UpdateSequence}, updates::{posterior::continuous::posterior_update_continuous_state_node, prediction::continuous::prediction_continuous_state_node, prediction_error::{continuous::prediction_error_continuous_state_node, exponential::prediction_error_exponential_state_node}}}; +use crate::{model::{AdjacencyLists, Network, UpdateSequence}, updates::{posterior::continuous::posterior_update_continuous_state_node, prediction::continuous::prediction_continuous_state_node, prediction_error::{continuous::prediction_error_continuous_state_node, exponential::prediction_error_exponential_state_node}}}; use crate::utils::function_pointer::FnType; pub fn set_update_sequence(network: &Network) -> UpdateSequence { @@ -18,7 +18,7 @@ pub fn get_predictions_sequence(network: &Network) -> Vec<(usize, FnType)> { // 1. get prediction sequence ------------------------------------------------------ // list all nodes availables in the network - let mut nodes_idxs: Vec = network.nodes.keys().cloned().collect(); + let mut nodes_idxs: Vec = network.edges.keys().cloned().collect(); // iterate over all nodes and add the prediction step if all criteria are met let mut n_remaining = nodes_idxs.len(); @@ -34,9 +34,9 @@ pub fn get_predictions_sequence(network: &Network) -> Vec<(usize, FnType)> { let idx = nodes_idxs[i]; // list the node's parents - let value_parents_idxs = &network.edges[idx].value_parents; - let volatility_parents_idxs = &network.edges[idx].volatility_parents; - + let value_parents_idxs = &network.edges[&idx].value_parents; + let volatility_parents_idxs = &network.edges[&idx].volatility_parents; + let parents_idxs = match (value_parents_idxs, volatility_parents_idxs) { // If both are Some, merge the vectors (Some(ref vec1), Some(ref vec2)) => { @@ -50,6 +50,7 @@ pub fn get_predictions_sequence(network: &Network) -> Vec<(usize, FnType)> { (None, None) => None, }; + // check if there is any parent node that is still found in the to-be-updated list let contains_common = match parents_idxs { Some(vec) => vec.iter().any(|item| nodes_idxs.contains(item)), @@ -60,12 +61,11 @@ pub fn get_predictions_sequence(network: &Network) -> Vec<(usize, FnType)> { if !(contains_common) { // add the node in the update list - match network.nodes.get(&idx) { - Some(Node::Continuous(_)) => { + match network.edges.get(&idx) { + Some(AdjacencyLists {node_type, ..}) if node_type == "continuous-state" => { predictions.push((idx, prediction_continuous_state_node)); } - Some(Node::Exponential(_)) => (), - None => () + _ => () } @@ -93,8 +93,8 @@ pub fn get_updates_sequence(network: &Network) -> Vec<(usize, FnType)> { // 1. get update sequence ---------------------------------------------------------- // list all nodes availables in the network - let mut pe_nodes_idxs: Vec = network.nodes.keys().cloned().collect(); - let mut po_nodes_idxs: Vec = network.nodes.keys().cloned().collect(); + let mut pe_nodes_idxs: Vec = network.edges.keys().cloned().collect(); + let mut po_nodes_idxs: Vec = network.edges.keys().cloned().collect(); // remove the input nodes from the to-be-visited nodes for posterior updates po_nodes_idxs.retain(|x| !network.inputs.contains(x)); @@ -117,8 +117,8 @@ pub fn get_updates_sequence(network: &Network) -> Vec<(usize, FnType)> { if !(po_nodes_idxs.contains(&idx)) { // only send a prediction error if this node has any parent - let value_parents_idxs = &network.edges[idx].value_parents; - let volatility_parents_idxs = &network.edges[idx].volatility_parents; + let value_parents_idxs = &network.edges[&idx].value_parents; + let volatility_parents_idxs = &network.edges[&idx].volatility_parents; let has_parents = match (value_parents_idxs, volatility_parents_idxs) { // If both are None, return false @@ -127,8 +127,8 @@ pub fn get_updates_sequence(network: &Network) -> Vec<(usize, FnType)> { }; // add the node in the update list - match (network.nodes.get(&idx), has_parents) { - (Some(Node::Continuous(_)), true) => { + match (network.edges.get(&idx), has_parents) { + (Some(AdjacencyLists {node_type, ..}), true) if node_type == "continuous-state" => { updates.push((idx, prediction_error_continuous_state_node)); // remove the node from the to-be-updated list pe_nodes_idxs.retain(|&x| x != idx); @@ -136,7 +136,7 @@ pub fn get_updates_sequence(network: &Network) -> Vec<(usize, FnType)> { has_update = true; break; } - (Some(Node::Exponential(_)), _) => { + (Some(AdjacencyLists {node_type, ..}), _) if node_type == "exponential-state" => { updates.push((idx, prediction_error_exponential_state_node)); // remove the node from the to-be-updated list pe_nodes_idxs.retain(|&x| x != idx); @@ -158,8 +158,8 @@ pub fn get_updates_sequence(network: &Network) -> Vec<(usize, FnType)> { // to start a posterior update, all children should have sent prediction errors // 1. get a list of all children - let value_children_idxs = &network.edges[idx].value_children; - let volatility_children_idxs = &network.edges[idx].volatility_children; + let value_children_idxs = &network.edges[&idx].value_children; + let volatility_children_idxs = &network.edges[&idx].volatility_children; let children_idxs = match (value_children_idxs, volatility_children_idxs) { // If both are Some, merge the vectors @@ -185,8 +185,8 @@ pub fn get_updates_sequence(network: &Network) -> Vec<(usize, FnType)> { if !missing_pe { // add the node in the update list - match network.nodes.get(&idx) { - Some(Node::Continuous(_)) => { + match network.edges.get(&idx) { + Some(AdjacencyLists {node_type, ..}) if node_type == "continuous-state" => { updates.push((idx, posterior_update_continuous_state_node)); } _ => () diff --git a/tests/test_exponential_family.py b/tests/test_exponential_family.py index 837b685f8..847200e27 100644 --- a/tests/test_exponential_family.py +++ b/tests/test_exponential_family.py @@ -13,13 +13,12 @@ def test_1d_gaussain(): # Rust ----------------------------------------------------------------------------- rs_network = RsNetwork() rs_network.add_nodes(kind="exponential-state") - rs_network.add_nodes(kind="exponential-state") rs_network.inputs rs_network.edges + rs_network.set_update_sequence() - rs_network.input_data([[0, 0], [1, 1]]) - rs_network.input_data([timeseries, timeseries]) + rs_network.input_data(timeseries) # Python --------------------------------------------------------------------------- - py_network = PyNetwork().add_nodes(kind="continuous-state") + py_network = PyNetwork().add_nodes(kind="exponential-state") py_network.attributes diff --git a/tests/test_updates/prediction_errors/test_dirichlet.py b/tests/test_updates/prediction_errors/test_dirichlet.py index e6e12c3bd..d044ca80d 100644 --- a/tests/test_updates/prediction_errors/test_dirichlet.py +++ b/tests/test_updates/prediction_errors/test_dirichlet.py @@ -29,7 +29,7 @@ def test_dirichlet_node_prediction_error(): .add_nodes(kind="generic-state") .add_nodes(kind="DP-state", value_children=0, batch_size=2) .add_nodes( - kind="ef-normal", + kind="exponential-state", n_nodes=2, value_children=1, xis=jnp.array([0.0, 1 / 8]), diff --git a/tests/test_utils.py b/tests/test_utils.py index f3763ef4a..93a37dad4 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -87,7 +87,7 @@ def test_set_update_sequence(): .add_nodes(kind="generic-state") .add_nodes(kind="DP-state", value_children=0, alpha=0.1, batch_size=2) .add_nodes( - kind="ef-normal", + kind="exponential-state", n_nodes=2, value_children=1, xis=jnp.array([0.0, 1 / 8]),