Skip to content

Commit

Permalink
[bug] fix sample_weight check for IncrementalBasicStatistics (#1799)
Browse files Browse the repository at this point in the history
* Update incremental_basic_statistics.py

* formatting'

* Update incremental_basic_statistics.py

* Update test_incremental_basic_statistics.py

* Update incremental_basic_statistics.py

* Update incremental_basic_statistics.py

* Update incremental_basic_statistics.py

* Update incremental_basic_statistics.py

* Update incremental_basic_statistics.py

* Update test_incremental_basic_statistics.py
  • Loading branch information
icfaust authored Apr 17, 2024
1 parent 3925eef commit 68ee7ab
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 21 deletions.
34 changes: 23 additions & 11 deletions sklearnex/basic_statistics/incremental_basic_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import numpy as np
from sklearn.base import BaseEstimator
from sklearn.utils import check_array, gen_batches
from sklearn.utils.validation import _check_sample_weight

from daal4py.sklearn._n_jobs_support import control_n_jobs
from daal4py.sklearn._utils import sklearn_check_version
Expand Down Expand Up @@ -139,7 +140,7 @@ def _onedal_finalize_fit(self):
self._onedal_estimator.finalize_fit()
self._need_to_finalize = False

def _onedal_partial_fit(self, X, weights, queue):
def _onedal_partial_fit(self, X, sample_weight=None, queue=None):
first_pass = not hasattr(self, "n_samples_seen_") or self.n_samples_seen_ == 0

if sklearn_check_version("1.0"):
Expand All @@ -152,9 +153,11 @@ def _onedal_partial_fit(self, X, weights, queue):
X = check_array(
X,
dtype=[np.float64, np.float32],
copy=self.copy_X,
)

if sample_weight is not None:
sample_weight = _check_sample_weight(sample_weight, X)

if first_pass:
self.n_samples_seen_ = X.shape[0]
self.n_features_in_ = X.shape[1]
Expand All @@ -168,15 +171,18 @@ def _onedal_partial_fit(self, X, weights, queue):
self._onedal_estimator = self._onedal_incremental_basic_statistics(
**onedal_params
)
self._onedal_estimator.partial_fit(X, weights, queue)
self._onedal_estimator.partial_fit(X, sample_weight, queue)
self._need_to_finalize = True

def _onedal_fit(self, X, weights, queue=None):
def _onedal_fit(self, X, sample_weight=None, queue=None):
if sklearn_check_version("1.0"):
X = self._validate_data(X, dtype=[np.float64, np.float32])
else:
X = check_array(X, dtype=[np.float64, np.float32])

if sample_weight is not None:
sample_weight = _check_sample_weight(sample_weight, X)

n_samples, n_features = X.shape
if self.batch_size is None:
self.batch_size_ = 5 * n_features
Expand All @@ -189,7 +195,7 @@ def _onedal_fit(self, X, weights, queue=None):

for batch in gen_batches(X.shape[0], self.batch_size_):
X_batch = X[batch]
weights_batch = weights[batch] if weights is not None else None
weights_batch = sample_weight[batch] if sample_weight is not None else None
self._onedal_partial_fit(X_batch, weights_batch, queue=queue)

if sklearn_check_version("1.2"):
Expand Down Expand Up @@ -217,7 +223,7 @@ def __getattr__(self, attr):
f"'{self.__class__.__name__}' object has no attribute '{attr}'"
)

def partial_fit(self, X, weights=None):
def partial_fit(self, X, sample_weight=None):
"""Incremental fit with X. All of X is processed as a single batch.
Parameters
Expand All @@ -226,7 +232,10 @@ def partial_fit(self, X, weights=None):
Data for compute, where `n_samples` is the number of samples and
`n_features` is the number of features.
weights : array-like of shape (n_samples,)
y : Ignored
Not used, present for API consistency by convention.
sample_weight : array-like of shape (n_samples,), default=None
Weights for compute weighted statistics, where `n_samples` is the number of samples.
Returns
Expand All @@ -242,11 +251,11 @@ def partial_fit(self, X, weights=None):
"sklearn": None,
},
X,
weights,
sample_weight,
)
return self

def fit(self, X, weights=None):
def fit(self, X, y=None, sample_weight=None):
"""Compute statistics with X, using minibatches of size batch_size.
Parameters
Expand All @@ -255,7 +264,10 @@ def fit(self, X, weights=None):
Data for compute, where `n_samples` is the number of samples and
`n_features` is the number of features.
weights : array-like of shape (n_samples,)
y : Ignored
Not used, present for API consistency by convention.
sample_weight : array-like of shape (n_samples,), default=None
Weights for compute weighted statistics, where `n_samples` is the number of samples.
Returns
Expand All @@ -271,6 +283,6 @@ def fit(self, X, weights=None):
"sklearn": None,
},
X,
weights,
sample_weight,
)
return self
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def test_partial_fit_multiple_options_on_gold_data(dataframe, queue, weighted, d
weights_split_df = _convert_to_dataframe(
weights_split[i], sycl_queue=queue, target_df=dataframe
)
result = incbs.partial_fit(X_split_df, weights_split_df)
result = incbs.partial_fit(X_split_df, sample_weight=weights_split_df)
else:
result = incbs.partial_fit(X_split_df)

Expand Down Expand Up @@ -103,7 +103,7 @@ def test_partial_fit_single_option_on_random_data(
weights_split_df = _convert_to_dataframe(
weights_split[i], sycl_queue=queue, target_df=dataframe
)
result = incbs.partial_fit(X_split_df, weights_split_df)
result = incbs.partial_fit(X_split_df, sample_weight=weights_split_df)
else:
result = incbs.partial_fit(X_split_df)

Expand Down Expand Up @@ -146,7 +146,7 @@ def test_partial_fit_multiple_options_on_random_data(
weights_split_df = _convert_to_dataframe(
weights_split[i], sycl_queue=queue, target_df=dataframe
)
result = incbs.partial_fit(X_split_df, weights_split_df)
result = incbs.partial_fit(X_split_df, sample_weight=weights_split_df)
else:
result = incbs.partial_fit(X_split_df)

Expand Down Expand Up @@ -199,7 +199,7 @@ def test_partial_fit_all_option_on_random_data(
weights_split_df = _convert_to_dataframe(
weights_split[i], sycl_queue=queue, target_df=dataframe
)
result = incbs.partial_fit(X_split_df, weights_split_df)
result = incbs.partial_fit(X_split_df, sample_weight=weights_split_df)
else:
result = incbs.partial_fit(X_split_df)

Expand Down Expand Up @@ -233,7 +233,7 @@ def test_fit_multiple_options_on_gold_data(dataframe, queue, weighted, dtype):
incbs = IncrementalBasicStatistics(batch_size=1)

if weighted:
result = incbs.fit(X_df, weights_df)
result = incbs.fit(X_df, sample_weight=weights_df)
else:
result = incbs.fit(X_df)

Expand Down Expand Up @@ -272,15 +272,15 @@ def test_fit_single_option_on_random_data(
X = X.astype(dtype=dtype)
X_df = _convert_to_dataframe(X, sycl_queue=queue, target_df=dataframe)
if weighted:
weights = gen.uniform(low=-0.5, high=+1.0, size=row_count)
weights = gen.uniform(low=-0.5, high=1.0, size=row_count)
weights = weights.astype(dtype=dtype)
weights_df = _convert_to_dataframe(weights, sycl_queue=queue, target_df=dataframe)
incbs = IncrementalBasicStatistics(
result_options=result_option, batch_size=batch_size
)

if weighted:
result = incbs.fit(X_df, weights_df)
result = incbs.fit(X_df, sample_weight=weights_df)
else:
result = incbs.fit(X_df)

Expand Down Expand Up @@ -311,15 +311,15 @@ def test_partial_fit_multiple_options_on_random_data(
X = X.astype(dtype=dtype)
X_df = _convert_to_dataframe(X, sycl_queue=queue, target_df=dataframe)
if weighted:
weights = gen.uniform(low=-0.5, high=+1.0, size=row_count)
weights = gen.uniform(low=-0.5, high=1.0, size=row_count)
weights = weights.astype(dtype=dtype)
weights_df = _convert_to_dataframe(weights, sycl_queue=queue, target_df=dataframe)
incbs = IncrementalBasicStatistics(
result_options=["mean", "max", "sum"], batch_size=batch_size
)

if weighted:
result = incbs.fit(X_df, weights_df)
result = incbs.fit(X_df, sample_weight=weights_df)
else:
result = incbs.fit(X_df)

Expand Down Expand Up @@ -366,7 +366,7 @@ def test_fit_all_option_on_random_data(
incbs = IncrementalBasicStatistics(result_options="all", batch_size=batch_size)

if weighted:
result = incbs.fit(X_df, weights_df)
result = incbs.fit(X_df, sample_weight=weights_df)
else:
result = incbs.fit(X_df)

Expand Down

0 comments on commit 68ee7ab

Please sign in to comment.