Skip to content

Commit

Permalink
[fix] Add forgotten _check_array to IncrementalBasicStatistics.partia…
Browse files Browse the repository at this point in the history
  • Loading branch information
olegkkruglov authored Sep 3, 2024
1 parent 4fd4568 commit e799ec4
Showing 1 changed file with 12 additions and 0 deletions.
12 changes: 12 additions & 0 deletions onedal/basic_statistics/incremental_basic_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from daal4py.sklearn._utils import get_dtype

from ..datatypes import _convert_to_supported, from_table, to_table
from ..utils import _check_array
from .basic_statistics import BaseBasicStatistics


Expand Down Expand Up @@ -96,6 +97,17 @@ def partial_fit(self, X, weights=None, queue=None):
policy = self._get_policy(queue, X)
X, weights = _convert_to_supported(policy, X, weights)

X = _check_array(
X, dtype=[np.float64, np.float32], ensure_2d=False, force_all_finite=False
)
if weights is not None:
weights = _check_array(
weights,
dtype=[np.float64, np.float32],
ensure_2d=False,
force_all_finite=False,
)

if not hasattr(self, "_onedal_params"):
dtype = get_dtype(X)
self._onedal_params = self._get_onedal_params(False, dtype=dtype)
Expand Down

0 comments on commit e799ec4

Please sign in to comment.