Skip to content

Commit

Permalink
fix pre-commit paths and Rust CI (#260)
Browse files Browse the repository at this point in the history
* fix pre-commit path and Rust CI

* mypy errors

* remove musllinus
  • Loading branch information
LegrandNico authored Dec 16, 2024
1 parent 9045dc0 commit 8d1bc7c
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 64 deletions.
51 changes: 9 additions & 42 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,13 @@ jobs:
target: x86_64
- runner: ubuntu-22.04
target: x86
- runner: ubuntu-22.04
target: aarch64
- runner: ubuntu-22.04
target: armv7
- runner: ubuntu-22.04
target: s390x
- runner: ubuntu-22.04
target: ppc64le
python:
- "3.12"
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: 3.x
python-version: ${{ matrix.python }}
- name: Build wheels
uses: PyO3/maturin-action@v1
with:
Expand All @@ -54,37 +48,6 @@ jobs:
name: wheels-linux-${{ matrix.platform.target }}
path: dist

musllinux:
runs-on: ${{ matrix.platform.runner }}
strategy:
matrix:
platform:
- runner: ubuntu-22.04
target: x86_64
- runner: ubuntu-22.04
target: x86
- runner: ubuntu-22.04
target: aarch64
- runner: ubuntu-22.04
target: armv7
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: 3.x
- name: Build wheels
uses: PyO3/maturin-action@v1
with:
target: ${{ matrix.platform.target }}
args: --release --out dist --find-interpreter
sccache: 'true'
manylinux: musllinux_1_2
- name: Upload wheels
uses: actions/upload-artifact@v4
with:
name: wheels-musllinux-${{ matrix.platform.target }}
path: dist

windows:
runs-on: ${{ matrix.platform.runner }}
strategy:
Expand All @@ -94,11 +57,13 @@ jobs:
target: x64
- runner: windows-latest
target: x86
python:
- "3.12"
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: 3.x
python-version: ${{ matrix.python }}
architecture: ${{ matrix.platform.target }}
- name: Build wheels
uses: PyO3/maturin-action@v1
Expand All @@ -121,11 +86,13 @@ jobs:
target: x86_64
- runner: macos-14
target: aarch64
python:
- "3.12"
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: 3.x
python-version: ${{ matrix.python }}
- name: Build wheels
uses: PyO3/maturin-action@v1
with:
Expand Down
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ repos:
hooks:
- id: pydocstyle
args: ['--ignore', 'D213,D100,D203,D104']
files: ^src/
files: ^pyhgf/
- repo: https://github.com/pre-commit/mirrors-mypy
rev: 'v1.13.0'
hooks:
- id: mypy
files: ^src/
files: ^pyhgf/
args: [--ignore-missing-imports]
26 changes: 9 additions & 17 deletions pyhgf/model/add_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

from copy import deepcopy
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union

import jax.numpy as jnp

Expand All @@ -27,7 +27,6 @@ def add_continuous_state(
coupling_fn: Tuple[Optional[Callable], ...],
):
"""Add continuous state node(s) to a network."""

node_type = 2

default_parameters = {
Expand Down Expand Up @@ -80,7 +79,6 @@ def add_binary_state(
additional_parameters: Dict,
):
"""Add binary state node(s) to a network."""

# define the type of node that is created
node_type = 1

Expand Down Expand Up @@ -119,10 +117,9 @@ def add_ef_state(
n_nodes: int,
node_parameters: Dict,
additional_parameters: Dict,
value_children: Optional[Tuple[Optional[Tuple]]],
value_children: Tuple = (None, None),
):
"""Add exponential family state node(s) to a network."""

node_type = 3

default_parameters = {
Expand Down Expand Up @@ -174,7 +171,6 @@ def add_categorical_state(
network: Network, n_nodes: int, node_parameters: Dict, additional_parameters: Dict
) -> Network:
"""Add categorical state node(s) to a network."""

node_type = 5

if "n_categories" in node_parameters:
Expand Down Expand Up @@ -233,7 +229,6 @@ def add_dp_state(
network: Network, n_nodes: int, node_parameters: Dict, additional_parameters: Dict
):
"""Add a Dirichlet Process node to a network."""

node_type = 4

if "batch_size" in additional_parameters.keys():
Expand Down Expand Up @@ -273,13 +268,12 @@ def add_dp_state(


def get_couplings(
value_parents: Optional[Tuple],
volatility_parents: Optional[Tuple],
value_children: Optional[Tuple],
volatility_children: Optional[Tuple],
value_parents: Optional[Union[Tuple, List, int]],
volatility_parents: Optional[Union[Tuple, List, int]],
value_children: Optional[Union[Tuple, List, int]],
volatility_children: Optional[Union[Tuple, List, int]],
) -> Tuple[Tuple, ...]:
"""Transform coupling parameter into tuple of indexes and strenghts."""

couplings = []
for indexes in [
value_parents,
Expand All @@ -301,14 +295,13 @@ def get_couplings(
coupling_idxs, coupling_strengths = None, None
couplings.append((coupling_idxs, coupling_strengths))

return couplings
return tuple(couplings)


def update_parameters(
node_parameters: Dict, default_parameters: Dict, additional_parameters: Dict
) -> Dict:
"""Update the default node parameters using keywords args and dictonary"""

"""Update the default node parameters using keywords args and dictonary."""
if bool(additional_parameters):
# ensure that all passed values are valid keys
invalid_keys = [
Expand Down Expand Up @@ -345,10 +338,9 @@ def insert_nodes(
volatility_parents: Tuple = (None, None),
value_children: Tuple = (None, None),
volatility_children: Tuple = (None, None),
coupling_fn: Optional[Tuple[Optional[Callable], ...]] = (None,),
coupling_fn: Tuple[Optional[Callable], ...] = (None,),
) -> Network:
"""Insert a set of parametrised node in a network."""

# ensure that the set of coupling functions match with the number of child nodes
if value_children[0] is not None:
if value_children[0] is None:
Expand Down
3 changes: 0 additions & 3 deletions pyhgf/model/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,3 @@ def add_edges(
self.edges = edges

return self


# Functions to be added

0 comments on commit 8d1bc7c

Please sign in to comment.