From 4a40ded92bff905200554dcbfe6fc75c754a3142 Mon Sep 17 00:00:00 2001 From: LegrandNico Date: Thu, 26 Oct 2023 13:13:26 +0200 Subject: [PATCH] use scan in belief propagation --- src/pyhgf/networks.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/src/pyhgf/networks.py b/src/pyhgf/networks.py index ee3298d78..d34d5ded2 100644 --- a/src/pyhgf/networks.py +++ b/src/pyhgf/networks.py @@ -7,7 +7,7 @@ import numpy as np import pandas as pd from jax import jit -from jax.lax import switch +from jax.lax import scan, switch from jax.typing import ArrayLike from pyhgf.math import gaussian_surprise @@ -79,8 +79,9 @@ def beliefs_propagation( input_nodes_idx = jnp.asarray(input_nodes_idx) - # Fit the model with the current time and value variables, given the model structure - for node_idx, branch_idx in zip(nodes_idxs, branches_idxs): + def switching_branches(attributes, scan_input, values=values, time_step=time_step): + node_idx, branch_idx = scan_input + # if we are updating an input node, select the value that should be passed # otherwise, just pass 0.0 and the value will be ignored value = jnp.sum(jnp.equal(input_nodes_idx, node_idx) * values) @@ -95,6 +96,14 @@ def beliefs_propagation( value, ) + return attributes, attributes # ("carryover", "accumulated") + + # wrap the inputs + scan_input = jnp.array(nodes_idxs), jnp.array(branches_idxs) + + # scan over the input data and apply the switching belief propagation functions + attributes, _ = scan(switching_branches, attributes, scan_input) + return ( attributes, attributes,