Skip to content

Commit

Permalink
Check that SAX and PAA are fitted (#69)
Browse files Browse the repository at this point in the history
  • Loading branch information
isaksamsten committed Dec 7, 2022
1 parent 5d4cc5e commit 91da40e
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 1 deletion.
12 changes: 12 additions & 0 deletions docs/more/versions/v1.1.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ Wildboar 1.1 requires Python 3.8+, numpy 1.17.3+, scipy 1.3.2+ and scikit-learn
Version 1.1.1
=============

Major changes
-------------
Correctly depend on the oldest supported Numpy release.

Changelog
---------

Expand All @@ -28,6 +32,14 @@ Changelog

- |Fix| Correctly return ``max_motif`` motifs from :func:`annotate.motifs`.

.. grid-item-card::

:mod:`wildboar.transform`
^^^

- |Fix| Check that :class:`transform.SAX` and :class:`transform.PAA` is fitted
before ``transform``.


Version 1.1.0
=============
Expand Down
6 changes: 5 additions & 1 deletion src/wildboar/transform/_sax.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np
from scipy.stats import norm, uniform
from sklearn.base import TransformerMixin
from sklearn.utils.validation import check_scalar
from sklearn.utils.validation import check_is_fitted, check_scalar

from wildboar.utils.validation import check_array, check_option

Expand Down Expand Up @@ -147,6 +147,7 @@ def fit(self, x, y=None):
return self

def transform(self, x):
check_is_fitted(self)
x = self._validate_data(x, reset=False, dtype=float)
x_paa = self.paa_.transform(x)
thresholds = self.binning_.get_thresholds(x, estimate=self.estimate)
Expand All @@ -156,6 +157,7 @@ def transform(self, x):
return x_out

def inverse_transform(self, x):
check_is_fitted(self)
if self.estimate:
raise ValueError("Unable to inverse_transform with estimate=True")

Expand Down Expand Up @@ -208,10 +210,12 @@ def fit(self, x, y=None):
return self

def transform(self, x):
check_is_fitted(self)
x = self._validate_data(x, dtype=float, reset=False)
return self.interval_transform_.transform(x)

def inverse_transform(self, x):
check_is_fitted(self)
x = check_array(x, dtype=float)

x_inverse = np.empty((x.shape[0], self.n_timesteps_in_), dtype=float)
Expand Down

0 comments on commit 91da40e

Please sign in to comment.