Skip to content

Commit

Permalink
ENH: SPMD interface for IncrementalBasicStatistics (uxlfoundation#1961)
Browse files Browse the repository at this point in the history
* Added SPMD interface for IncrementalBasicStatistics
* Changed policy saving workflow, now queue is saved to attributes instead of policy. It is necessary because finalize_fit requires spmd_policy, but partial_fit requires data_parallel_policy on oneDAL side
* finalize_fit now uses provided queue for computations on onedal4py side.
  • Loading branch information
olegkkruglov authored Sep 2, 2024
1 parent 18d0428 commit 7ecc9f1
Show file tree
Hide file tree
Showing 10 changed files with 451 additions and 27 deletions.
1 change: 1 addition & 0 deletions onedal/basic_statistics/basic_statistics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ ONEDAL_PY_INIT_MODULE(basic_statistics) {

#ifdef ONEDAL_DATA_PARALLEL_SPMD
ONEDAL_PY_INSTANTIATE(init_compute_ops, sub, policy_spmd, task::compute);
ONEDAL_PY_INSTANTIATE(init_finalize_compute_ops, sub, policy_spmd, task::compute);
#else // ONEDAL_DATA_PARALLEL_SPMD
ONEDAL_PY_INSTANTIATE(init_compute_ops, sub, policy_list, task::compute);
ONEDAL_PY_INSTANTIATE(init_partial_compute_ops, sub, policy_list, task::compute);
Expand Down
43 changes: 27 additions & 16 deletions onedal/basic_statistics/incremental_basic_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,12 @@ class IncrementalBasicStatistics(BaseBasicStatistics):

def __init__(self, result_options="all"):
super().__init__(result_options, algorithm="by_default")
module = self._get_backend("basic_statistics")
self._partial_result = module.partial_compute_result()
self._reset()

def _reset(self):
module = self._get_backend("basic_statistics")
self._partial_result = module.partial_train_result()
self._partial_result = self._get_backend(
"basic_statistics", None, "partial_compute_result"
)

def partial_fit(self, X, weights=None, queue=None):
"""
Expand All @@ -92,19 +92,20 @@ def partial_fit(self, X, weights=None, queue=None):
self : object
Returns the instance itself.
"""
if not hasattr(self, "_policy"):
self._policy = self._get_policy(queue, X)

X, weights = _convert_to_supported(self._policy, X, weights)
self._queue = queue
policy = self._get_policy(queue, X)
X, weights = _convert_to_supported(policy, X, weights)

if not hasattr(self, "_onedal_params"):
dtype = get_dtype(X)
self._onedal_params = self._get_onedal_params(dtype)
self._onedal_params = self._get_onedal_params(False, dtype=dtype)

X_table, weights_table = to_table(X, weights)
module = self._get_backend("basic_statistics")
self._partial_result = module.partial_compute(
self._policy,
self._partial_result = self._get_backend(
"basic_statistics",
None,
"partial_compute",
policy,
self._onedal_params,
self._partial_result,
X_table,
Expand All @@ -119,16 +120,26 @@ def finalize_fit(self, queue=None):
Parameters
----------
queue : dpctl.SyclQueue
Not used here, added for API conformance
If not None, use this queue for computations.
Returns
-------
self : object
Returns the instance itself.
"""
module = self._get_backend("basic_statistics")
result = module.finalize_compute(
self._policy, self._onedal_params, self._partial_result

if queue is not None:
policy = self._get_policy(queue)
else:
policy = self._get_policy(self._queue)

result = self._get_backend(
"basic_statistics",
None,
"finalize_compute",
policy,
self._onedal_params,
self._partial_result,
)
options = self._get_result_options(self.options).split("|")
for opt in options:
Expand Down
3 changes: 2 additions & 1 deletion onedal/spmd/basic_statistics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,6 @@
# ==============================================================================

from .basic_statistics import BasicStatistics
from .incremental_basic_statistics import IncrementalBasicStatistics

__all__ = ["BasicStatistics"]
__all__ = ["BasicStatistics", "IncrementalBasicStatistics"]
2 changes: 0 additions & 2 deletions onedal/spmd/basic_statistics/basic_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
# limitations under the License.
# ==============================================================================

import warnings

from onedal.basic_statistics import BasicStatistics as BasicStatistics_Batch

from ..._device_offload import support_usm_ndarray
Expand Down
69 changes: 69 additions & 0 deletions onedal/spmd/basic_statistics/incremental_basic_statistics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# ==============================================================================
# Copyright 2024 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from daal4py.sklearn._utils import get_dtype

from ...basic_statistics import (
IncrementalBasicStatistics as base_IncrementalBasicStatistics,
)
from ...datatypes import _convert_to_supported, to_table
from .._base import BaseEstimatorSPMD


class IncrementalBasicStatistics(BaseEstimatorSPMD, base_IncrementalBasicStatistics):
def _reset(self):
self._partial_result = super(base_IncrementalBasicStatistics, self)._get_backend(
"basic_statistics", None, "partial_compute_result"
)

def partial_fit(self, X, weights=None, queue=None):
"""
Computes partial data for basic statistics
from data batch X and saves it to `_partial_result`.
Parameters
----------
X : array-like of shape (n_samples, n_features)
Training data batch, where `n_samples` is the number of samples
in the batch, and `n_features` is the number of features.
queue : dpctl.SyclQueue
If not None, use this queue for computations.
Returns
-------
self : object
Returns the instance itself.
"""
self._queue = queue
policy = super(base_IncrementalBasicStatistics, self)._get_policy(queue, X)
X, weights = _convert_to_supported(policy, X, weights)

if not hasattr(self, "_onedal_params"):
dtype = get_dtype(X)
self._onedal_params = self._get_onedal_params(False, dtype=dtype)

X_table, weights_table = to_table(X, weights)
self._partial_result = super(base_IncrementalBasicStatistics, self)._get_backend(
"basic_statistics",
None,
"partial_compute",
policy,
self._onedal_params,
self._partial_result,
X_table,
weights_table,
)
10 changes: 5 additions & 5 deletions sklearnex/basic_statistics/incremental_basic_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def __init__(self, result_options="all", batch_size=None):

def _onedal_supported(self, method_name, *data):
patching_status = PatchingConditionsChain(
f"sklearn.covariance.{self.__class__.__name__}.{method_name}"
f"sklearn.basic_statistics.{self.__class__.__name__}.{method_name}"
)
return patching_status

Expand All @@ -135,9 +135,9 @@ def _get_onedal_result_options(self, options):
assert isinstance(onedal_options, str)
return options

def _onedal_finalize_fit(self):
def _onedal_finalize_fit(self, queue=None):
assert hasattr(self, "_onedal_estimator")
self._onedal_estimator.finalize_fit()
self._onedal_estimator.finalize_fit(queue=queue)
self._need_to_finalize = False

def _onedal_partial_fit(self, X, sample_weight=None, queue=None):
Expand Down Expand Up @@ -171,7 +171,7 @@ def _onedal_partial_fit(self, X, sample_weight=None, queue=None):
self._onedal_estimator = self._onedal_incremental_basic_statistics(
**onedal_params
)
self._onedal_estimator.partial_fit(X, sample_weight, queue)
self._onedal_estimator.partial_fit(X, weights=sample_weight, queue=queue)
self._need_to_finalize = True

def _onedal_fit(self, X, sample_weight=None, queue=None):
Expand Down Expand Up @@ -203,7 +203,7 @@ def _onedal_fit(self, X, sample_weight=None, queue=None):

self.n_features_in_ = X.shape[1]

self._onedal_finalize_fit()
self._onedal_finalize_fit(queue=queue)

return self

Expand Down
3 changes: 2 additions & 1 deletion sklearnex/spmd/basic_statistics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,6 @@
# ==============================================================================

from .basic_statistics import BasicStatistics
from .incremental_basic_statistics import IncrementalBasicStatistics

__all__ = ["BasicStatistics"]
__all__ = ["BasicStatistics", "IncrementalBasicStatistics"]
30 changes: 30 additions & 0 deletions sklearnex/spmd/basic_statistics/incremental_basic_statistics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# ==============================================================================
# Copyright 2024 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================


from onedal.spmd.basic_statistics import (
IncrementalBasicStatistics as onedalSPMD_IncrementalBasicStatistics,
)

from ...basic_statistics import (
IncrementalBasicStatistics as base_IncrementalBasicStatistics,
)


class IncrementalBasicStatistics(base_IncrementalBasicStatistics):
_onedal_incremental_basic_statistics = staticmethod(
onedalSPMD_IncrementalBasicStatistics
)
Loading

0 comments on commit 7ecc9f1

Please sign in to comment.