diff --git a/pyhgf/model/add_nodes.py b/pyhgf/model/add_nodes.py index 4a40ab525..ea928cc05 100644 --- a/pyhgf/model/add_nodes.py +++ b/pyhgf/model/add_nodes.py @@ -3,7 +3,7 @@ from __future__ import annotations from copy import deepcopy -from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union import jax.numpy as jnp @@ -27,7 +27,6 @@ def add_continuous_state( coupling_fn: Tuple[Optional[Callable], ...], ): """Add continuous state node(s) to a network.""" - node_type = 2 default_parameters = { @@ -80,7 +79,6 @@ def add_binary_state( additional_parameters: Dict, ): """Add binary state node(s) to a network.""" - # define the type of node that is created node_type = 1 @@ -119,10 +117,9 @@ def add_ef_state( n_nodes: int, node_parameters: Dict, additional_parameters: Dict, - value_children: Optional[Tuple[Optional[Tuple]]], + value_children: Tuple = (None, None), ): """Add exponential family state node(s) to a network.""" - node_type = 3 default_parameters = { @@ -174,7 +171,6 @@ def add_categorical_state( network: Network, n_nodes: int, node_parameters: Dict, additional_parameters: Dict ) -> Network: """Add categorical state node(s) to a network.""" - node_type = 5 if "n_categories" in node_parameters: @@ -233,7 +229,6 @@ def add_dp_state( network: Network, n_nodes: int, node_parameters: Dict, additional_parameters: Dict ): """Add a Dirichlet Process node to a network.""" - node_type = 4 if "batch_size" in additional_parameters.keys(): @@ -273,13 +268,12 @@ def add_dp_state( def get_couplings( - value_parents: Optional[Tuple], - volatility_parents: Optional[Tuple], - value_children: Optional[Tuple], - volatility_children: Optional[Tuple], + value_parents: Optional[Union[Tuple, List, int]], + volatility_parents: Optional[Union[Tuple, List, int]], + value_children: Optional[Union[Tuple, List, int]], + volatility_children: Optional[Union[Tuple, List, int]], ) -> Tuple[Tuple, ...]: """Transform coupling parameter into tuple of indexes and strenghts.""" - couplings = [] for indexes in [ value_parents, @@ -301,14 +295,13 @@ def get_couplings( coupling_idxs, coupling_strengths = None, None couplings.append((coupling_idxs, coupling_strengths)) - return couplings + return tuple(couplings) def update_parameters( node_parameters: Dict, default_parameters: Dict, additional_parameters: Dict ) -> Dict: - """Update the default node parameters using keywords args and dictonary""" - + """Update the default node parameters using keywords args and dictonary.""" if bool(additional_parameters): # ensure that all passed values are valid keys invalid_keys = [ @@ -345,10 +338,9 @@ def insert_nodes( volatility_parents: Tuple = (None, None), value_children: Tuple = (None, None), volatility_children: Tuple = (None, None), - coupling_fn: Optional[Tuple[Optional[Callable], ...]] = (None,), + coupling_fn: Tuple[Optional[Callable], ...] = (None,), ) -> Network: """Insert a set of parametrised node in a network.""" - # ensure that the set of coupling functions match with the number of child nodes if value_children[0] is not None: if value_children[0] is None: diff --git a/pyhgf/model/network.py b/pyhgf/model/network.py index c6700cc8c..3eec09908 100644 --- a/pyhgf/model/network.py +++ b/pyhgf/model/network.py @@ -567,6 +567,3 @@ def add_edges( self.edges = edges return self - - -# Functions to be added