Skip to content

Commit

Permalink
Merge pull request #4 from gmgeorg/cleanup-1
Browse files Browse the repository at this point in the history
numpy/cython fix; adding workflow for testing flake8 and pytest
  • Loading branch information
gmgeorg authored Apr 13, 2024
2 parents b6dbd5c + d1adfba commit 7f3d779
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 7 deletions.
39 changes: 39 additions & 0 deletions .github/workflows/python-app.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# This workflow will install Python dependencies, run tests and lint with a single version of Python
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python

name: Python application

on:
push:
branches: [ "main" ]
pull_request:
branches: [ "main" ]

permissions:
contents: read

jobs:
build:

runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v3
- name: Set up Python 3.10
uses: actions/setup-python@v3
with:
python-version: "3.10"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install flake8 pytest
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics --ignore=C901
- name: Test with pytest
run: |
pytest
2 changes: 1 addition & 1 deletion pylambertw/_version.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
"""Version."""

__version__ = "0.0.2"
__version__ = "0.0.3"
4 changes: 2 additions & 2 deletions pylambertw/igmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,11 +216,11 @@ def fit(self, data: np.ndarray):
self._initialize_params(data)

tau_trace = np.zeros(shape=(self.max_iter + 1, 3))
tau_trace[0,] = (
tau_trace[0,] = np.array([
self.tau_init.loc,
self.tau_init.scale,
self.tau_init.lambertw_params.delta,
)
]).reshape(1, -1)

for kk in range(self.max_iter):
current = tau_trace[kk, :]
Expand Down
6 changes: 3 additions & 3 deletions pylambertw/mle.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@


def _dict_torch_to_series(dict_torch: Dict[str, torch.Tensor]) -> pd.Series:
return pd.Series({k: float(v.detach().numpy()) for k, v in dict_torch.items()})
"""Turns a dictionary of torch tensors (float) into a pandas Series."""
return pd.Series({k: float(v.detach().numpy().item()) for k, v in dict_torch.items()})


class MLE(sklearn.base.BaseEstimator, sklearn.base.TransformerMixin):
Expand All @@ -49,8 +50,7 @@ def __init__(
"""Initializes the class."""
self.distribution_name = distribution_name
self.distribution_constructor = (
distribution_constructor
or lwd.utils.get_distribution_constructor(self.distribution_name)
distribution_constructor or lwd.utils.get_distribution_constructor(self.distribution_name)
)
self.lambertw_type = base.LambertWType(lambertw_type)
self.max_iter = max_iter
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
cython>=3.0.10
dataclasses>=0.6
git+https://github.com/gmgeorg/torchlambertw.git
matplotlib>=3.3.0
numpy>=1.0.1
pandas>=1.0.0
pytest>=6.1.1
scikit-learn>=1.0.1
scipy~=1.6.0
scipy>=1.7.0
seaborn>=0.11.1
statsmodels>=0.12.0
torch>=2.0.1
Expand Down

0 comments on commit 7f3d779

Please sign in to comment.