Skip to content

Commit

Permalink
split attributes into floats and vectors
Browse files Browse the repository at this point in the history
  • Loading branch information
LegrandNico committed Oct 28, 2024
1 parent c76d108 commit d46529d
Show file tree
Hide file tree
Showing 11 changed files with 234 additions and 132 deletions.
68 changes: 68 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ crate-type = ["cdylib"]
path = "src/lib.rs" # The source file of the target.

[dependencies]
pyo3 = { version = "0.21.2", features = ["extension-module"] }
pyo3 = { version = "0.21.2", features = ["extension-module"] }
ndarray = "0.16.1"
23 changes: 4 additions & 19 deletions pyhgf/model/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,29 +329,14 @@ def add_nodes(
be a regular state node that can have value and/or volatility
parents/children. If `"binary-state"`, the node should be the
value parent of a binary input. State nodes filtering distribution from the
exponential family can be created using the `"ef-"` prefix (e.g.
`"ef-normal"` for a univariate normal distribution). Note that only a few
distributions are implemented at the moment.
In addition to state nodes, four types of input nodes are supported:
- `generic-input`: receive a value or an array and pass it to the parent
nodes.
- `continuous-input`: receive a continuous observation as input.
- `binary-input` receives a single boolean as observation. The parameters
provided to the binary input node contain: 1. `binary_precision`, the binary
input precision, which defaults to `jnp.inf`. 2. `eta0`, the lower bound of
the binary process, which defaults to `0.0`. 3. `eta1`, the higher bound of
the binary process, which defaults to `1.0`.
- `categorical-input` receives a boolean array as observation. The
parameters provided to the categorical input node contain: 1.
`n_categories`, the number of categories implied by the categorical state.
exponential family can be created using `"exponential-state"`.
.. note::
When using a categorical state node, the `binary_parameters` can be used to
parametrize the implied collection of binary HGFs.
.. note:
When using `categorical-input`, the implied `n` binary HGFs are
When using `categorical-state`, the implied `n` binary HGFs are
automatically created with a shared volatility parent at the third level,
resulting in a network with `3n + 2` nodes in total.
Expand Down Expand Up @@ -396,7 +381,7 @@ def add_nodes(
"""
if kind not in [
"DP-state",
"ef-normal",
"exponential-state",
"categorical-state",
"continuous-state",
"binary-state",
Expand Down Expand Up @@ -483,7 +468,7 @@ def add_nodes(
"mean": 0.0,
"observed": 1,
}
elif "ef-normal" in kind:
elif "exponential-state" in kind:
default_parameters = {
"nus": 3.0,
"xis": jnp.array([0.0, 1.0]),
Expand Down
4 changes: 2 additions & 2 deletions src/math.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
pub fn sufficient_statistics(x: &f64) -> [f64; 2] {
[*x, x.powf(2.0)]
pub fn sufficient_statistics(x: &f64) -> Vec<f64> {
vec![*x, x.powf(2.0)]
}
Loading

0 comments on commit d46529d

Please sign in to comment.