diff --git a/.github/workflows/python-app.yml b/.github/workflows/python-app.yml new file mode 100644 index 0000000..c043973 --- /dev/null +++ b/.github/workflows/python-app.yml @@ -0,0 +1,39 @@ +# This workflow will install Python dependencies, run tests and lint with a single version of Python +# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python + +name: Python application + +on: + push: + branches: [ "main" ] + pull_request: + branches: [ "main" ] + +permissions: + contents: read + +jobs: + build: + + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + - name: Set up Python 3.10 + uses: actions/setup-python@v3 + with: + python-version: "3.10" + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install flake8 pytest + if [ -f requirements.txt ]; then pip install -r requirements.txt; fi + - name: Lint with flake8 + run: | + # stop the build if there are Python syntax errors or undefined names + flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics + # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide + flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics --ignore=C901 + - name: Test with pytest + run: | + pytest diff --git a/pylambertw/_version.py b/pylambertw/_version.py index 771ca48..f1fb5f5 100644 --- a/pylambertw/_version.py +++ b/pylambertw/_version.py @@ -1,3 +1,3 @@ """Version.""" -__version__ = "0.0.2" +__version__ = "0.0.3" diff --git a/pylambertw/igmm.py b/pylambertw/igmm.py index 5a000f1..161a401 100644 --- a/pylambertw/igmm.py +++ b/pylambertw/igmm.py @@ -216,11 +216,11 @@ def fit(self, data: np.ndarray): self._initialize_params(data) tau_trace = np.zeros(shape=(self.max_iter + 1, 3)) - tau_trace[0,] = ( + tau_trace[0,] = np.array([ self.tau_init.loc, self.tau_init.scale, self.tau_init.lambertw_params.delta, - ) + ]).reshape(1, -1) for kk in range(self.max_iter): current = tau_trace[kk, :] diff --git a/pylambertw/mle.py b/pylambertw/mle.py index 48c8b4c..287edde 100644 --- a/pylambertw/mle.py +++ b/pylambertw/mle.py @@ -29,7 +29,8 @@ def _dict_torch_to_series(dict_torch: Dict[str, torch.Tensor]) -> pd.Series: - return pd.Series({k: float(v.detach().numpy()) for k, v in dict_torch.items()}) + """Turns a dictionary of torch tensors (float) into a pandas Series.""" + return pd.Series({k: float(v.detach().numpy().item()) for k, v in dict_torch.items()}) class MLE(sklearn.base.BaseEstimator, sklearn.base.TransformerMixin): @@ -49,8 +50,7 @@ def __init__( """Initializes the class.""" self.distribution_name = distribution_name self.distribution_constructor = ( - distribution_constructor - or lwd.utils.get_distribution_constructor(self.distribution_name) + distribution_constructor or lwd.utils.get_distribution_constructor(self.distribution_name) ) self.lambertw_type = base.LambertWType(lambertw_type) self.max_iter = max_iter diff --git a/requirements.txt b/requirements.txt index c06d61a..d98c0fe 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ +cython>=3.0.10 dataclasses>=0.6 git+https://github.com/gmgeorg/torchlambertw.git matplotlib>=3.3.0 @@ -5,7 +6,7 @@ numpy>=1.0.1 pandas>=1.0.0 pytest>=6.1.1 scikit-learn>=1.0.1 -scipy~=1.6.0 +scipy>=1.7.0 seaborn>=0.11.1 statsmodels>=0.12.0 torch>=2.0.1