Skip to content

Commit

Permalink
small fix
Browse files Browse the repository at this point in the history
  • Loading branch information
LegrandNico committed Aug 12, 2024
1 parent c979475 commit b82d850
Showing 1 changed file with 14 additions and 3 deletions.
17 changes: 14 additions & 3 deletions src/pyhgf/model/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,11 +647,11 @@ def add_nodes(

node_idx = len(self.attributes) # the index of the new node

# add a new edge
# for mutiple value children, set a default tuple with corresponding length
if isinstance(value_children, tuple):
coupling_fn = len(value_children) * (None,)
if len(value_children) != len(coupling_fn):
coupling_fn = len(value_children) * coupling_fn

# add a new edge
edges_as_list.append(
AdjacencyLists(
node_type, None, None, None, None, coupling_fn=coupling_fn
Expand Down Expand Up @@ -693,6 +693,7 @@ def add_nodes(
parent_idxs=node_idx,
children_idxs=value_children[0],
coupling_strengths=value_children[1], # type: ignore
coupling_fn=coupling_fn,
)
if volatility_children[0] is not None:
self.add_edges(
Expand Down Expand Up @@ -797,6 +798,7 @@ def add_edges(
parent_idxs=Union[int, List[int]],
children_idxs=Union[int, List[int]],
coupling_strengths: Union[float, List[float], Tuple[float]] = 1.0,
coupling_fn: Tuple[Optional[Callable], ...] = (None,),
) -> "Network":
"""Add a value or volatility coupling link between a set of nodes.
Expand All @@ -810,6 +812,14 @@ def add_edges(
The index(es) of the children node(s).
coupling_strengths :
The coupling strength betwen the parents and children.
coupling_fn :
Coupling function(s) between the current node and its value children.
It has to be provided as a tuple. If multiple value children are specified,
the coupling functions must be stated in the same order of the children.
Note: if a node has multiple parents nodes with different coupling
functions, a coupling function should be indicated for all the parent nodes.
If no coupling function is stated, the relationship between nodes is assumed
linear.
"""
attributes, edges = add_edges(
Expand All @@ -819,6 +829,7 @@ def add_edges(
parent_idxs=parent_idxs,
children_idxs=children_idxs,
coupling_strengths=coupling_strengths,
coupling_fn=coupling_fn,
)

self.attributes = attributes
Expand Down

0 comments on commit b82d850

Please sign in to comment.