Skip to content

Commit

Permalink
mypy errors
Browse files Browse the repository at this point in the history
  • Loading branch information
LegrandNico committed Dec 16, 2024
1 parent 26ecf07 commit 6c59b36
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 20 deletions.
26 changes: 9 additions & 17 deletions pyhgf/model/add_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

from copy import deepcopy
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union

import jax.numpy as jnp

Expand All @@ -27,7 +27,6 @@ def add_continuous_state(
coupling_fn: Tuple[Optional[Callable], ...],
):
"""Add continuous state node(s) to a network."""

node_type = 2

default_parameters = {
Expand Down Expand Up @@ -80,7 +79,6 @@ def add_binary_state(
additional_parameters: Dict,
):
"""Add binary state node(s) to a network."""

# define the type of node that is created
node_type = 1

Expand Down Expand Up @@ -119,10 +117,9 @@ def add_ef_state(
n_nodes: int,
node_parameters: Dict,
additional_parameters: Dict,
value_children: Optional[Tuple[Optional[Tuple]]],
value_children: Tuple = (None, None),
):
"""Add exponential family state node(s) to a network."""

node_type = 3

default_parameters = {
Expand Down Expand Up @@ -174,7 +171,6 @@ def add_categorical_state(
network: Network, n_nodes: int, node_parameters: Dict, additional_parameters: Dict
) -> Network:
"""Add categorical state node(s) to a network."""

node_type = 5

if "n_categories" in node_parameters:
Expand Down Expand Up @@ -233,7 +229,6 @@ def add_dp_state(
network: Network, n_nodes: int, node_parameters: Dict, additional_parameters: Dict
):
"""Add a Dirichlet Process node to a network."""

node_type = 4

if "batch_size" in additional_parameters.keys():
Expand Down Expand Up @@ -273,13 +268,12 @@ def add_dp_state(


def get_couplings(
value_parents: Optional[Tuple],
volatility_parents: Optional[Tuple],
value_children: Optional[Tuple],
volatility_children: Optional[Tuple],
value_parents: Optional[Union[Tuple, List, int]],
volatility_parents: Optional[Union[Tuple, List, int]],
value_children: Optional[Union[Tuple, List, int]],
volatility_children: Optional[Union[Tuple, List, int]],
) -> Tuple[Tuple, ...]:
"""Transform coupling parameter into tuple of indexes and strenghts."""

couplings = []
for indexes in [
value_parents,
Expand All @@ -301,14 +295,13 @@ def get_couplings(
coupling_idxs, coupling_strengths = None, None
couplings.append((coupling_idxs, coupling_strengths))

return couplings
return tuple(couplings)


def update_parameters(
node_parameters: Dict, default_parameters: Dict, additional_parameters: Dict
) -> Dict:
"""Update the default node parameters using keywords args and dictonary"""

"""Update the default node parameters using keywords args and dictonary."""
if bool(additional_parameters):
# ensure that all passed values are valid keys
invalid_keys = [
Expand Down Expand Up @@ -345,10 +338,9 @@ def insert_nodes(
volatility_parents: Tuple = (None, None),
value_children: Tuple = (None, None),
volatility_children: Tuple = (None, None),
coupling_fn: Optional[Tuple[Optional[Callable], ...]] = (None,),
coupling_fn: Tuple[Optional[Callable], ...] = (None,),
) -> Network:
"""Insert a set of parametrised node in a network."""

# ensure that the set of coupling functions match with the number of child nodes
if value_children[0] is not None:
if value_children[0] is None:
Expand Down
3 changes: 0 additions & 3 deletions pyhgf/model/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,3 @@ def add_edges(
self.edges = edges

return self


# Functions to be added

0 comments on commit 6c59b36

Please sign in to comment.