diff --git a/onedal/basic_statistics/incremental_basic_statistics.py b/onedal/basic_statistics/incremental_basic_statistics.py index b3c304a2af..4935a57a47 100644 --- a/onedal/basic_statistics/incremental_basic_statistics.py +++ b/onedal/basic_statistics/incremental_basic_statistics.py @@ -19,6 +19,7 @@ from daal4py.sklearn._utils import get_dtype from ..datatypes import _convert_to_supported, from_table, to_table +from ..utils import _check_array from .basic_statistics import BaseBasicStatistics @@ -96,6 +97,17 @@ def partial_fit(self, X, weights=None, queue=None): policy = self._get_policy(queue, X) X, weights = _convert_to_supported(policy, X, weights) + X = _check_array( + X, dtype=[np.float64, np.float32], ensure_2d=False, force_all_finite=False + ) + if weights is not None: + weights = _check_array( + weights, + dtype=[np.float64, np.float32], + ensure_2d=False, + force_all_finite=False, + ) + if not hasattr(self, "_onedal_params"): dtype = get_dtype(X) self._onedal_params = self._get_onedal_params(False, dtype=dtype)