From 8d1bc7c94ae185b78f3761503d3087622ad37bae Mon Sep 17 00:00:00 2001 From: Nicolas Legrand Date: Mon, 16 Dec 2024 15:35:49 +0100 Subject: [PATCH] fix pre-commit paths and Rust CI (#260) * fix pre-commit path and Rust CI * mypy errors * remove musllinus --- .github/workflows/ci.yml | 51 +++++++--------------------------------- .pre-commit-config.yaml | 4 ++-- pyhgf/model/add_nodes.py | 26 +++++++------------- pyhgf/model/network.py | 3 --- 4 files changed, 20 insertions(+), 64 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index cb31c169d..a2df2a068 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -28,19 +28,13 @@ jobs: target: x86_64 - runner: ubuntu-22.04 target: x86 - - runner: ubuntu-22.04 - target: aarch64 - - runner: ubuntu-22.04 - target: armv7 - - runner: ubuntu-22.04 - target: s390x - - runner: ubuntu-22.04 - target: ppc64le + python: + - "3.12" steps: - uses: actions/checkout@v4 - uses: actions/setup-python@v5 with: - python-version: 3.x + python-version: ${{ matrix.python }} - name: Build wheels uses: PyO3/maturin-action@v1 with: @@ -54,37 +48,6 @@ jobs: name: wheels-linux-${{ matrix.platform.target }} path: dist - musllinux: - runs-on: ${{ matrix.platform.runner }} - strategy: - matrix: - platform: - - runner: ubuntu-22.04 - target: x86_64 - - runner: ubuntu-22.04 - target: x86 - - runner: ubuntu-22.04 - target: aarch64 - - runner: ubuntu-22.04 - target: armv7 - steps: - - uses: actions/checkout@v4 - - uses: actions/setup-python@v5 - with: - python-version: 3.x - - name: Build wheels - uses: PyO3/maturin-action@v1 - with: - target: ${{ matrix.platform.target }} - args: --release --out dist --find-interpreter - sccache: 'true' - manylinux: musllinux_1_2 - - name: Upload wheels - uses: actions/upload-artifact@v4 - with: - name: wheels-musllinux-${{ matrix.platform.target }} - path: dist - windows: runs-on: ${{ matrix.platform.runner }} strategy: @@ -94,11 +57,13 @@ jobs: target: x64 - runner: windows-latest target: x86 + python: + - "3.12" steps: - uses: actions/checkout@v4 - uses: actions/setup-python@v5 with: - python-version: 3.x + python-version: ${{ matrix.python }} architecture: ${{ matrix.platform.target }} - name: Build wheels uses: PyO3/maturin-action@v1 @@ -121,11 +86,13 @@ jobs: target: x86_64 - runner: macos-14 target: aarch64 + python: + - "3.12" steps: - uses: actions/checkout@v4 - uses: actions/setup-python@v5 with: - python-version: 3.x + python-version: ${{ matrix.python }} - name: Build wheels uses: PyO3/maturin-action@v1 with: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 47e7125eb..807d9e2b9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -22,10 +22,10 @@ repos: hooks: - id: pydocstyle args: ['--ignore', 'D213,D100,D203,D104'] - files: ^src/ + files: ^pyhgf/ - repo: https://github.com/pre-commit/mirrors-mypy rev: 'v1.13.0' hooks: - id: mypy - files: ^src/ + files: ^pyhgf/ args: [--ignore-missing-imports] \ No newline at end of file diff --git a/pyhgf/model/add_nodes.py b/pyhgf/model/add_nodes.py index 4a40ab525..ea928cc05 100644 --- a/pyhgf/model/add_nodes.py +++ b/pyhgf/model/add_nodes.py @@ -3,7 +3,7 @@ from __future__ import annotations from copy import deepcopy -from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union import jax.numpy as jnp @@ -27,7 +27,6 @@ def add_continuous_state( coupling_fn: Tuple[Optional[Callable], ...], ): """Add continuous state node(s) to a network.""" - node_type = 2 default_parameters = { @@ -80,7 +79,6 @@ def add_binary_state( additional_parameters: Dict, ): """Add binary state node(s) to a network.""" - # define the type of node that is created node_type = 1 @@ -119,10 +117,9 @@ def add_ef_state( n_nodes: int, node_parameters: Dict, additional_parameters: Dict, - value_children: Optional[Tuple[Optional[Tuple]]], + value_children: Tuple = (None, None), ): """Add exponential family state node(s) to a network.""" - node_type = 3 default_parameters = { @@ -174,7 +171,6 @@ def add_categorical_state( network: Network, n_nodes: int, node_parameters: Dict, additional_parameters: Dict ) -> Network: """Add categorical state node(s) to a network.""" - node_type = 5 if "n_categories" in node_parameters: @@ -233,7 +229,6 @@ def add_dp_state( network: Network, n_nodes: int, node_parameters: Dict, additional_parameters: Dict ): """Add a Dirichlet Process node to a network.""" - node_type = 4 if "batch_size" in additional_parameters.keys(): @@ -273,13 +268,12 @@ def add_dp_state( def get_couplings( - value_parents: Optional[Tuple], - volatility_parents: Optional[Tuple], - value_children: Optional[Tuple], - volatility_children: Optional[Tuple], + value_parents: Optional[Union[Tuple, List, int]], + volatility_parents: Optional[Union[Tuple, List, int]], + value_children: Optional[Union[Tuple, List, int]], + volatility_children: Optional[Union[Tuple, List, int]], ) -> Tuple[Tuple, ...]: """Transform coupling parameter into tuple of indexes and strenghts.""" - couplings = [] for indexes in [ value_parents, @@ -301,14 +295,13 @@ def get_couplings( coupling_idxs, coupling_strengths = None, None couplings.append((coupling_idxs, coupling_strengths)) - return couplings + return tuple(couplings) def update_parameters( node_parameters: Dict, default_parameters: Dict, additional_parameters: Dict ) -> Dict: - """Update the default node parameters using keywords args and dictonary""" - + """Update the default node parameters using keywords args and dictonary.""" if bool(additional_parameters): # ensure that all passed values are valid keys invalid_keys = [ @@ -345,10 +338,9 @@ def insert_nodes( volatility_parents: Tuple = (None, None), value_children: Tuple = (None, None), volatility_children: Tuple = (None, None), - coupling_fn: Optional[Tuple[Optional[Callable], ...]] = (None,), + coupling_fn: Tuple[Optional[Callable], ...] = (None,), ) -> Network: """Insert a set of parametrised node in a network.""" - # ensure that the set of coupling functions match with the number of child nodes if value_children[0] is not None: if value_children[0] is None: diff --git a/pyhgf/model/network.py b/pyhgf/model/network.py index c6700cc8c..3eec09908 100644 --- a/pyhgf/model/network.py +++ b/pyhgf/model/network.py @@ -567,6 +567,3 @@ def add_edges( self.edges = edges return self - - -# Functions to be added