diff --git a/onedal/cluster/dbscan.py b/onedal/cluster/dbscan.py index f91325b65c..9919436472 100644 --- a/onedal/cluster/dbscan.py +++ b/onedal/cluster/dbscan.py @@ -15,13 +15,22 @@ # =============================================================================== import numpy as np +from sklearn.utils import check_array -from daal4py.sklearn._utils import get_dtype, make2d +from onedal.utils._array_api import get_dtype, make2d from ..common._base import BaseEstimator from ..common._mixin import ClusterMixin from ..datatypes import _convert_to_supported, from_table, to_table -from ..utils import _check_array +from ..utils._array_api import ( + _asarray, + _convert_to_numpy, + _ravel, + get_dtype, + get_namespace, + make2d, + sklearn_array_api_dispatch, +) class BaseDBSCAN(BaseEstimator, ClusterMixin): @@ -46,9 +55,9 @@ def __init__( self.p = p self.n_jobs = n_jobs - def _get_onedal_params(self, dtype=np.float32): + def _get_onedal_params(self, xp, dtype): return { - "fptype": "float" if dtype == np.float32 else "double", + "fptype": "float" if dtype == xp.float32 else "double", "method": "by_default", "min_observations": int(self.min_samples), "epsilon": float(self.eps), @@ -56,28 +65,70 @@ def _get_onedal_params(self, dtype=np.float32): "result_options": "core_observation_indices|responses", } - def _fit(self, X, y, sample_weight, module, queue): + @sklearn_array_api_dispatch() + def _fit(self, X, sua_iface, xp, is_array_api_compliant, y, sample_weight, queue): policy = self._get_policy(queue, X) - X = _check_array(X, accept_sparse="csr", dtype=[np.float64, np.float32]) + # TODO: + # check on dispatching and warn. + # using scikit-learn primitives will require array_api_dispatch=True + X = check_array(X, accept_sparse="csr", dtype=[xp.float64, xp.float32]) + sample_weight = make2d(sample_weight) if sample_weight is not None else None X = make2d(X) + # X_device = X.device if xp else None + + # TODO: + # move to _convert_to_supported to do astype conversion + # at once. + types = [xp.float32, xp.float64] - types = [np.float32, np.float64] + # TODO: + # could be impossible, if device doesn't support fp65 + # make sense update _convert_to_supported for it. if get_dtype(X) not in types: - X = X.astype(np.float64) - X = _convert_to_supported(policy, X) + X = X.astype(xp.float64) + X = _convert_to_supported(policy, X, xp=xp) + # TODO: + # remove if not required. + sample_weight = ( + _convert_to_supported(policy, sample_weight, xp=xp) + if sample_weight is not None + else None + ) dtype = get_dtype(X) - params = self._get_onedal_params(dtype) - result = module.compute(policy, params, to_table(X), to_table(sample_weight)) + params = self._get_onedal_params(xp, dtype) + X_table = to_table(X) + sample_weight_table = to_table(sample_weight) - self.labels_ = from_table(result.responses).ravel() - if result.core_observation_indices is not None: - self.core_sample_indices_ = from_table( - result.core_observation_indices - ).ravel() + result = self._get_backend("dbscan", "clustering", None).compute( + policy, params, X_table, sample_weight_table + ) + self.labels_ = _ravel( + from_table(result.responses, sua_iface=sua_iface, sycl_queue=queue, xp=xp), xp + ) + if ( + result.core_observation_indices is not None + and not result.core_observation_indices.kind == "empty" + ): + self.core_sample_indices_ = _ravel( + from_table( + result.core_observation_indices, + sycl_queue=queue, + sua_iface=sua_iface, + xp=xp, + ), + xp, + ) else: - self.core_sample_indices_ = np.array([], dtype=np.intc) - self.components_ = np.take(X, self.core_sample_indices_, axis=0) + # TODO: + # self.core_sample_indices_ = _asarray([], xp, sycl_queue=queue, dtype=xp.int32) + if sua_iface: + self.core_sample_indices_ = xp.asarray( + [], sycl_queue=queue, dtype=xp.int32 + ) + else: + self.core_sample_indices_ = xp.asarray([], dtype=xp.int32) + self.components_ = xp.take(X, self.core_sample_indices_, axis=0) self.n_features_in_ = X.shape[1] return self @@ -105,6 +156,11 @@ def __init__( self.n_jobs = n_jobs def fit(self, X, y=None, sample_weight=None, queue=None): + sua_iface, xp, is_array_api_compliant = get_namespace(X) + # TODO: + # update for queue getting. + if sua_iface: + queue = X.sycl_queue return super()._fit( - X, y, sample_weight, self._get_backend("dbscan", "clustering", None), queue + X, sua_iface, xp, is_array_api_compliant, y, sample_weight, queue ) diff --git a/onedal/cluster/tests/test_dbscan.py b/onedal/cluster/tests/test_dbscan.py index d309cc8767..2ad7e7fa2b 100644 --- a/onedal/cluster/tests/test_dbscan.py +++ b/onedal/cluster/tests/test_dbscan.py @@ -16,10 +16,16 @@ import numpy as np import pytest +from numpy.testing import assert_allclose from sklearn.cluster import DBSCAN as DBSCAN_SKLEARN from sklearn.cluster.tests.common import generate_clustered_data from onedal.cluster import DBSCAN as ONEDAL_DBSCAN +from onedal.tests.utils._dataframes_support import ( + _as_numpy, + _convert_to_dataframe, + get_dataframes_and_queues, +) from onedal.tests.utils._device_selection import get_queues @@ -123,3 +129,17 @@ def _test_across_grid_parameter_numpy_gen(queue, metric, use_weights: bool): @pytest.mark.parametrize("queue", get_queues()) def test_across_grid_parameter_numpy_gen(queue, metric, use_weights: bool): _test_across_grid_parameter_numpy_gen(queue, metric=metric, use_weights=use_weights) + + +# TODO: +# dtypes. +@pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues()) +def test_base_dbscan(dataframe, queue): + + X = np.array([[1, 2], [2, 2], [2, 3], [8, 7], [8, 8], [25, 80]]) + X = _convert_to_dataframe(X, sycl_queue=queue, target_df=dataframe) + dbscan = ONEDAL_DBSCAN(eps=3, min_samples=2).fit(X) + + result = dbscan.labels_ + expected = np.array([0, 0, 0, 1, 1, -1], dtype=np.int32) + assert_allclose(expected, _as_numpy(result)) diff --git a/onedal/datatypes/_data_conversion.py b/onedal/datatypes/_data_conversion.py index 0d91bcdfb0..4b2118425c 100644 --- a/onedal/datatypes/_data_conversion.py +++ b/onedal/datatypes/_data_conversion.py @@ -81,7 +81,7 @@ def _table_to_array(table, xp=None): from ..common._policy import _HostInteropPolicy - def _convert_to_supported(policy, *data): + def _convert_to_supported(policy, *data, xp=np): def func(x): return x @@ -93,13 +93,13 @@ def func(x): device = policy._queue.sycl_device def convert_or_pass(x): - if (x is not None) and (x.dtype == np.float64): + if (x is not None) and (x.dtype == xp.float64): warnings.warn( "Data will be converted into float32 from " "float64 because device does not support it", RuntimeWarning, ) - return x.astype(np.float32) + return xp.astype(x, dtype=xp.float32) else: return x @@ -132,7 +132,7 @@ def convert_one_from_table(table, sycl_queue=None, sua_iface=None, xp=None): else: - def _convert_to_supported(policy, *data): + def _convert_to_supported(policy, *data, xp=np): def func(x): return x diff --git a/onedal/utils/__init__.py b/onedal/utils/__init__.py index 0a1b05fbc2..22794748b2 100644 --- a/onedal/utils/__init__.py +++ b/onedal/utils/__init__.py @@ -18,6 +18,7 @@ _check_array, _check_classification_targets, _check_n_features, + _check_sample_weight, _check_X_y, _column_or_1d, _is_arraylike, diff --git a/onedal/utils/_array_api.py b/onedal/utils/_array_api.py index 47da103da9..a5049a6c11 100644 --- a/onedal/utils/_array_api.py +++ b/onedal/utils/_array_api.py @@ -17,12 +17,39 @@ """Tools to support array_api.""" from collections.abc import Iterable +from functools import wraps + +from daal4py.sklearn._utils import sklearn_check_version + +if sklearn_check_version("1.4"): + from sklearn.utils._array_api import get_namespace as sklearn_get_namespace + +import numpy as np +from sklearn import config_context, get_config + +from daal4py.sklearn._utils import get_dtype +from daal4py.sklearn._utils import make2d as d4p_make2d +from daal4py.sklearn._utils import sklearn_check_version from ._dpep_helpers import dpctl_available, dpnp_available if dpctl_available: + import dpctl.tensor as dpt from dpctl.tensor import usm_ndarray + +# TODO: +# move to Array API module. +# TODO +# def make2d(arg, xp=None, is_array_api_compliant=None): +def make2d(arg, xp=None): + if xp and not _is_numpy_namespace(xp) and arg.ndim == 1: + return xp.reshape(arg, (arg.size, 1)) if arg.ndim == 1 else arg + # TODO: + # reimpl via is_array_api_compliant usage. + return d4p_make2d(arg) + + if dpnp_available: import dpnp @@ -38,6 +65,20 @@ def _convert_to_dpnp(array): return array +def _convert_to_numpy(array, xp): + """Convert X into a NumPy ndarray on the CPU.""" + xp_name = xp.__name__ + + if dpctl_available and xp_name in { + "dpctl.tensor", + }: + return dpt.to_numpy(array) + elif dpnp_available and isinstance(array, dpnp.ndarray): + return dpnp.asnumpy(array) + else: + return _asarray(array, xp) + + def _asarray(data, xp, *args, **kwargs): """Converted input object to array format of xp namespace provided.""" if hasattr(data, "__array_namespace__"): @@ -54,6 +95,17 @@ def _asarray(data, xp, *args, **kwargs): return data +def _ravel(array, xp): + """Return a flattened array. + + Note + ---- + Input array expected to be contiguous. + """ + + return xp.reshape(array, (array.size,)) + + def _is_numpy_namespace(xp): """Return True if xp is backed by NumPy.""" return xp.__name__ in {"numpy", "array_api_compat.numpy", "numpy.array_api"} @@ -79,3 +131,140 @@ def _get_sycl_namespace(*arrays): raise ValueError(f"SYCL type not recognized: {sua_iface}") return sua_iface, None, False + + +# TODO: +# +sklearn_array_api_version = True + + +def sklearn_array_api_dispatch(freefunc=False): + """ + TBD + """ + + def decorator(func): + def wrapper_impl(obj, *args, **kwargs): + # if sklearn_array_api_version and not get_config["array_api_dispatch"]: + if sklearn_array_api_version: + with config_context(array_api_dispatch=True): + return func(obj, *args, **kwargs) + return func(obj, *args, **kwargs) + + if freefunc: + + @wraps(func) + def wrapper_free(*args, **kwargs): + return wrapper_impl(None, *args, **kwargs) + + return wrapper_free + + @wraps(func) + def wrapper_with_self(self, *args, **kwargs): + return wrapper_impl(self, *args, **kwargs) + + return wrapper_with_self + + return decorator + + +if sklearn_check_version("1.5"): + + def get_namespace(*arrays, remove_none=True, remove_types=(str,), xp=None): + """Get namespace of arrays. + + Extends stock scikit-learn's `get_namespace` primitive to support DPCTL usm_ndarrays + and DPNP ndarrays. + If no DPCTL usm_ndarray or DPNP ndarray inputs and backend scikit-learn version supports + Array API then :obj:`sklearn.utils._array_api.get_namespace` results are drawn. + Otherwise, numpy namespace will be returned. + + Designed to work for numpy.ndarray, DPCTL usm_ndarrays and DPNP ndarrays without + `array-api-compat` or backend scikit-learn Array API support. + + For full documentation refer to :obj:`sklearn.utils._array_api.get_namespace`. + + Parameters + ---------- + *arrays : array objects + Array objects. + + remove_none : bool, default=True + Whether to ignore None objects passed in arrays. + + remove_types : tuple or list, default=(str,) + Types to ignore in the arrays. + + xp : module, default=None + Precomputed array namespace module. When passed, typically from a caller + that has already performed inspection of its own inputs, skips array + namespace inspection. + + Returns + ------- + usm_iface : TBD + + namespace : module + Namespace shared by array objects. + + is_array_api : bool + True of the arrays are containers that implement the Array API spec. + """ + + usm_iface, xp_sycl_namespace, is_array_api_compliant = _get_sycl_namespace( + *arrays + ) + + if usm_iface: + return usm_iface, xp_sycl_namespace, is_array_api_compliant + elif sklearn_check_version("1.4"): + xp, is_array_api_compliant = sklearn_get_namespace( + *arrays, remove_none=remove_none, remove_types=remove_types, xp=xp + ) + return usm_iface, xp, is_array_api_compliant + else: + return usm_iface, np, False + +else: + + def get_namespace(*arrays): + """Get namespace of arrays. + + Extends stock scikit-learn's `get_namespace` primitive to support DPCTL usm_ndarrays + and DPNP ndarrays. + If no DPCTL usm_ndarray or DPNP ndarray inputs and backend scikit-learn version supports + Array API then :obj:`sklearn.utils._array_api.get_namespace(*arrays)` results are drawn. + Otherwise, numpy namespace will be returned. + + Designed to work for numpy.ndarray, DPCTL usm_ndarrays and DPNP ndarrays without + `array-api-compat` or backend scikit-learn Array API support. + + For full documentation refer to :obj:`sklearn.utils._array_api.get_namespace`. + + Parameters + ---------- + *arrays : array objects + Array objects. + + Returns + ------- + usm_iface : TBD + + namespace : module + Namespace shared by array objects. + + is_array_api : bool + True of the arrays are containers that implement the Array API spec. + """ + + usm_iface, xp_sycl_namespace, is_array_api_compliant = _get_sycl_namespace( + *arrays + ) + + if usm_iface: + return usm_iface, xp_sycl_namespace, is_array_api_compliant + elif sklearn_check_version("1.4"): + xp, is_array_api_compliant = sklearn_get_namespace(*arrays) + return usm_iface, xp, is_array_api_compliant + else: + return usm_iface, np, False diff --git a/onedal/utils/validation.py b/onedal/utils/validation.py index c97b77a577..068e15a696 100644 --- a/onedal/utils/validation.py +++ b/onedal/utils/validation.py @@ -16,7 +16,7 @@ import warnings from collections.abc import Sequence -from numbers import Integral +from numbers import Integral, Number import numpy as np from scipy import sparse as sp @@ -29,10 +29,12 @@ from numpy import VisibleDeprecationWarning from sklearn.preprocessing import LabelEncoder -from sklearn.utils.validation import check_array +from sklearn.utils.validation import check_array, check_non_negative from daal4py.sklearn.utils.validation import _assert_all_finite +from ..utils._array_api import get_namespace + class DataConversionWarning(UserWarning): """Warning used to notify implicit data conversions happening in the code.""" @@ -410,10 +412,12 @@ def _num_samples(x): if hasattr(x, "fit") and callable(x.fit): # Don't get num_samples from an ensembles length! raise TypeError(message) - + xp, _ = get_namespace(x) if not hasattr(x, "__len__") and not hasattr(x, "shape"): if hasattr(x, "__array__"): - x = np.asarray(x) + # TODO: + # use sycl_queue if required. + x = xp.asarray(x) else: raise TypeError(message) @@ -438,3 +442,47 @@ def _is_csr(x): return isinstance(x, sp.csr_matrix) or ( hasattr(sp, "csr_array") and isinstance(x, sp.csr_array) ) + + +def _check_sample_weight( + sample_weight, X, dtype=None, copy=False, ensure_non_negative=False +): + """Validate sample weights. + TBD + """ + xp, _ = get_namespace(X) + n_samples = _num_samples(X) + + if dtype is not None and dtype not in [xp.float32, xp.float64]: + dtype = xp.float64 + + if sample_weight is None: + sample_weight = xp.ones(n_samples, dtype=dtype) + elif isinstance(sample_weight, Number): + sample_weight = xp.full(n_samples, sample_weight, dtype=dtype) + else: + if dtype is None: + dtype = [xp.float64, xp.float32] + sample_weight = check_array( + sample_weight, + accept_sparse=False, + ensure_2d=False, + dtype=dtype, + order="C", + copy=copy, + input_name="sample_weight", + ) + if sample_weight.ndim != 1: + raise ValueError("Sample weights must be 1D array or scalar") + + if sample_weight.shape != (n_samples,): + raise ValueError( + "sample_weight.shape == {}, expected {}!".format( + sample_weight.shape, (n_samples,) + ) + ) + + if ensure_non_negative: + check_non_negative(sample_weight, "`sample_weight`") + + return sample_weight diff --git a/sklearnex/_device_offload.py b/sklearnex/_device_offload.py index 7e299f07e0..3b3204e787 100644 --- a/sklearnex/_device_offload.py +++ b/sklearnex/_device_offload.py @@ -16,6 +16,7 @@ from functools import wraps +from daal4py.sklearn._utils import sklearn_check_version from onedal._device_offload import _copy_to_usm, _get_global_queue, _transfer_to_host from onedal.utils._array_api import _asarray from onedal.utils._dpep_helpers import dpnp_available @@ -27,7 +28,7 @@ from ._config import get_config -def _get_backend(obj, queue, method_name, *data): +def _get_backend(obj, queue, method_name, sua_iface, xp, is_array_api_compliant, *data): cpu_device = queue is None or queue.sycl_device.is_cpu gpu_device = queue is not None and queue.sycl_device.is_gpu @@ -57,43 +58,58 @@ def _get_backend(obj, queue, method_name, *data): raise RuntimeError("Device support is not implemented") -def dispatch(obj, method_name, branches, *args, **kwargs): +# TODO: +# update. +def dispatch( + obj, + method_name, + branches, + sua_iface=None, + xp=None, + is_array_api_compliant=None, + *args, + **kwargs, +): + is_array_api_dispatch = get_config()["array_api_dispatch"] q = _get_global_queue() - has_usm_data_for_args, q, hostargs = _transfer_to_host(q, *args) - has_usm_data_for_kwargs, q, hostvalues = _transfer_to_host(q, *kwargs.values()) - hostkwargs = dict(zip(kwargs.keys(), hostvalues)) - - backend, q, patching_status = _get_backend(obj, q, method_name, *hostargs) - has_usm_data = has_usm_data_for_args or has_usm_data_for_kwargs - if backend == "onedal": - # Host args only used before onedal backend call. - # Device will be offloaded when onedal backend will be called. - patching_status.write_log(queue=q, transferred_to_host=False) - return branches[backend](obj, *hostargs, **hostkwargs, queue=q) - if backend == "sklearn": - if ( - "array_api_dispatch" in get_config() - and get_config()["array_api_dispatch"] - and "array_api_support" in obj._get_tags() - and obj._get_tags()["array_api_support"] - and not has_usm_data - ): - # USM ndarrays are also excluded for the fallback Array API. Currently, DPNP.ndarray is - # not compliant with the Array API standard, and DPCTL usm_ndarray Array API is compliant, - # except for the linalg module. There is no guarantee that stock scikit-learn will - # work with such input data. The condition will be updated after DPNP.ndarray and - # DPCTL usm_ndarray enabling for conformance testing and these arrays supportance - # of the fallback cases. - # If `array_api_dispatch` enabled and array api is supported for the stock scikit-learn, - # then raw inputs are used for the fallback. - patching_status.write_log(transferred_to_host=False) - return branches[backend](obj, *args, **kwargs) - else: - patching_status.write_log() - return branches[backend](obj, *hostargs, **hostkwargs) - raise RuntimeError( - f"Undefined backend {backend} in " f"{obj.__class__.__name__}.{method_name}" - ) + if is_array_api_dispatch: + backend, q, patching_status = _get_backend(obj, q, method_name, *args) + else: + has_usm_data_for_args, q, hostargs = _transfer_to_host(q, *args) + has_usm_data_for_kwargs, q, hostvalues = _transfer_to_host(q, *kwargs.values()) + hostkwargs = dict(zip(kwargs.keys(), hostvalues)) + + backend, q, patching_status = _get_backend(obj, q, method_name, *hostargs) + has_usm_data = has_usm_data_for_args or has_usm_data_for_kwargs + if backend == "onedal": + # Host args only used before onedal backend call. + # Device will be offloaded when onedal backend will be called. + patching_status.write_log(queue=q, transferred_to_host=False) + return branches[backend](obj, *hostargs, **hostkwargs, queue=q) + if backend == "sklearn": + if ( + sklearn_check_version("1.4") + and get_config()["array_api_dispatch"] + and "array_api_support" in obj._get_tags() + and obj._get_tags()["array_api_support"] + and not has_usm_data + ): + # USM ndarrays are also excluded for the fallback Array API. Currently, DPNP.ndarray is + # not compliant with the Array API standard, and DPCTL usm_ndarray Array API is compliant, + # except for the linalg module. There is no guarantee that stock scikit-learn will + # work with such input data. The condition will be updated after DPNP.ndarray and + # DPCTL usm_ndarray enabling for conformance testing and these arrays supportance + # of the fallback cases. + # If `array_api_dispatch` enabled and array api is supported for the stock scikit-learn, + # then raw inputs are used for the fallback. + patching_status.write_log(transferred_to_host=False) + return branches[backend](obj, *args, **kwargs) + else: + patching_status.write_log() + return branches[backend](obj, *hostargs, **hostkwargs) + raise RuntimeError( + f"Undefined backend {backend} in " f"{obj.__class__.__name__}.{method_name}" + ) def wrap_output_data(func): diff --git a/sklearnex/cluster/dbscan.py b/sklearnex/cluster/dbscan.py index ef5f6b78d9..165de06175 100755 --- a/sklearnex/cluster/dbscan.py +++ b/sklearnex/cluster/dbscan.py @@ -19,14 +19,15 @@ from scipy import sparse as sp from sklearn.cluster import DBSCAN as _sklearn_DBSCAN -from sklearn.utils.validation import _check_sample_weight from daal4py.sklearn._n_jobs_support import control_n_jobs from daal4py.sklearn._utils import sklearn_check_version from onedal.cluster import DBSCAN as onedal_DBSCAN +from onedal.utils.validation import _check_sample_weight from .._device_offload import dispatch from .._utils import PatchingConditionsChain +from ..utils._array_api import get_namespace if sklearn_check_version("1.1") and not sklearn_check_version("1.2"): from sklearn.utils import check_scalar @@ -89,6 +90,7 @@ def __init__( self.n_jobs = n_jobs def _onedal_fit(self, X, y, sample_weight=None, queue=None): + xp, is_array_api_compliant = get_namespace(X) if sklearn_check_version("1.0"): X = validate_data(self, X, force_all_finite=False) @@ -104,7 +106,9 @@ def _onedal_fit(self, X, y, sample_weight=None, queue=None): } self._onedal_estimator = self._onedal_dbscan(**onedal_params) - self._onedal_estimator.fit(X, y=y, sample_weight=sample_weight, queue=queue) + self._onedal_estimator._fit( + X, xp, is_array_api_compliant, y, sample_weight, queue=queue + ) self._save_attributes() def _onedal_supported(self, method_name, *data): @@ -140,6 +144,7 @@ def _onedal_gpu_supported(self, method_name, *data): return self._onedal_supported(method_name, *data) def fit(self, X, y=None, sample_weight=None): + sua_iface, xp, is_array_api_compliant = get_namespace(X) if sklearn_check_version("1.2"): self._validate_params() elif sklearn_check_version("1.1"): @@ -180,6 +185,8 @@ def fit(self, X, y=None, sample_weight=None): if sample_weight is not None: sample_weight = _check_sample_weight(sample_weight, X) + # TODO: + # add new dispatching with array api context. dispatch( self, "fit", @@ -187,6 +194,9 @@ def fit(self, X, y=None, sample_weight=None): "onedal": self.__class__._onedal_fit, "sklearn": _sklearn_DBSCAN.fit, }, + sua_iface, + xp, + is_array_api_compliant, X, y, sample_weight, @@ -194,4 +204,10 @@ def fit(self, X, y=None, sample_weight=None): return self + # TODO: + # check it in case of the fallback + # to stock scikit-learn. + def _more_tags(self): + return {"array_api_support": True} + fit.__doc__ = _sklearn_DBSCAN.fit.__doc__ diff --git a/sklearnex/dispatcher.py b/sklearnex/dispatcher.py index a4a62556f6..9e3601ff14 100644 --- a/sklearnex/dispatcher.py +++ b/sklearnex/dispatcher.py @@ -128,6 +128,9 @@ def get_patch_map_core(preview=False): from ._config import get_config as get_config_sklearnex from ._config import set_config as set_config_sklearnex + if sklearn_check_version("1.4"): + import sklearn.utils._array_api as _array_api_module + if sklearn_check_version("1.2.1"): from .utils.parallel import _FuncWrapper as _FuncWrapper_sklearnex else: @@ -165,6 +168,10 @@ def get_patch_map_core(preview=False): from .svm import NuSVC as NuSVC_sklearnex from .svm import NuSVR as NuSVR_sklearnex + if sklearn_check_version("1.4"): + from .utils._array_api import _convert_to_numpy as _convert_to_numpy_sklearnex + from .utils._array_api import get_namespace as get_namespace_sklearnex + # DBSCAN mapping.pop("dbscan") mapping["dbscan"] = [[(cluster_module, "DBSCAN", DBSCAN_sklearnex), None]] @@ -440,6 +447,24 @@ def get_patch_map_core(preview=False): mapping["_funcwrapper"] = [ [(parallel_module, "_FuncWrapper", _FuncWrapper_sklearnex), None] ] + if sklearn_check_version("1.4"): + # Necessary for array_api support + mapping["get_namespace"] = [ + [ + ( + _array_api_module, + "get_namespace", + get_namespace_sklearnex, + ), + None, + ] + ] + mapping["_convert_to_numpy"] = [ + [ + (_array_api_module, "_convert_to_numpy", _convert_to_numpy_sklearnex), + None, + ] + ] return mapping diff --git a/sklearnex/tests/test_memory_usage.py b/sklearnex/tests/test_memory_usage.py index 6e7fdb72b5..5a9f57f640 100644 --- a/sklearnex/tests/test_memory_usage.py +++ b/sklearnex/tests/test_memory_usage.py @@ -52,6 +52,7 @@ CPU_SKIP_LIST = ( + "_convert_to_numpy", # additional memory allocation is expected proportional to the input data "TSNE", # too slow for using in testing on common data size "config_context", # does not malloc "get_config", # does not malloc @@ -66,6 +67,7 @@ ) GPU_SKIP_LIST = ( + "_convert_to_numpy", # additional memory allocation is expected proportional to the input data "TSNE", # too slow for using in testing on common data size "RandomForestRegressor", # too slow for using in testing on common data size "KMeans", # does not support GPU offloading diff --git a/sklearnex/tests/test_patching.py b/sklearnex/tests/test_patching.py index 897f19172d..c7ec3b1475 100755 --- a/sklearnex/tests/test_patching.py +++ b/sklearnex/tests/test_patching.py @@ -307,10 +307,13 @@ def list_all_attr(string): module_map = {i: i for i in sklearnex__all__.intersection(sklearn__all__)} - # _assert_all_finite patches an internal sklearn function which isn't - # exposed via __all__ in sklearn. It is a special case where this rule - # is not applied (e.g. it is grandfathered in). + # _assert_all_finite, _convert_to_numpy, get_namespace patch an internal + # sklearn functions which aren't exposed via __all__ in sklearn. It is a special + # case where this rule is not applied (e.g. it is grandfathered in). del patched["_assert_all_finite"] + if sklearn_check_version("1.4"): + del patched["_convert_to_numpy"] + del patched["get_namespace"] # remove all scikit-learn-intelex-only estimators for i in patched.copy(): diff --git a/sklearnex/utils/_array_api.py b/sklearnex/utils/_array_api.py index bc30be5021..daa4bc94e0 100644 --- a/sklearnex/utils/_array_api.py +++ b/sklearnex/utils/_array_api.py @@ -19,64 +19,113 @@ import numpy as np from daal4py.sklearn._utils import sklearn_check_version -from onedal.utils._array_api import _get_sycl_namespace +from onedal.utils._array_api import _asarray +from onedal.utils._array_api import get_namespace as onedal_get_namespace -if sklearn_check_version("1.2"): +if sklearn_check_version("1.4"): from sklearn.utils._array_api import get_namespace as sklearn_get_namespace + from sklearn.utils._array_api import _convert_to_numpy as _sklearn_convert_to_numpy +from onedal._device_offload import dpctl_available, dpnp_available -def get_namespace(*arrays): - """Get namespace of arrays. +if dpctl_available: + import dpctl.tensor as dpt - Introspect `arrays` arguments and return their common Array API - compatible namespace object, if any. NumPy 1.22 and later can - construct such containers using the `numpy.array_api` namespace - for instance. +if dpnp_available: + import dpnp - This function will return the namespace of SYCL-related arrays - which define the __sycl_usm_array_interface__ attribute - regardless of array_api support, the configuration of - array_api_dispatch, or scikit-learn version. - See: https://numpy.org/neps/nep-0047-array-api-standard.html +def _convert_to_numpy(array, xp): + """Convert X into a NumPy ndarray on the CPU.""" + xp_name = xp.__name__ - If `arrays` are regular numpy arrays, an instance of the - `_NumPyApiWrapper` compatibility wrapper is returned instead. + if dpctl_available and isinstance(array, dpt.usm_ndarray): + return dpt.to_numpy(array) + elif dpnp_available and isinstance(array, dpnp.ndarray): + return dpnp.asnumpy(array) + elif sklearn_check_version("1.4"): + return _sklearn_convert_to_numpy(array, xp) + else: + return _asarray(array, xp) + + +# TODO: +# refactor +if sklearn_check_version("1.5"): + + def get_namespace(*arrays, remove_none=True, remove_types=(str,), xp=None): + """Get namespace of arrays. - Namespace support is not enabled by default. To enabled it - call: + Extends stock scikit-learn's `get_namespace` primitive to support DPCTL usm_ndarrays + and DPNP ndarrays. + If no DPCTL usm_ndarray or DPNP ndarray inputs and backend scikit-learn version supports + Array API then :obj:`sklearn.utils._array_api.get_namespace` results are drawn. + Otherwise, numpy namespace will be returned. - sklearn.set_config(array_api_dispatch=True) + Designed to work for numpy.ndarray, DPCTL usm_ndarrays and DPNP ndarrays without + `array-api-compat` or backend scikit-learn Array API support. - or: + For full documentation refer to :obj:`sklearn.utils._array_api.get_namespace`. - with sklearn.config_context(array_api_dispatch=True): - # your code here + Parameters + ---------- + *arrays : array objects + Array objects. - Otherwise an instance of the `_NumPyApiWrapper` - compatibility wrapper is always returned irrespective of - the fact that arrays implement the `__array_namespace__` - protocol or not. + remove_none : bool, default=True + Whether to ignore None objects passed in arrays. - Parameters - ---------- - *arrays : array objects - Array objects. + remove_types : tuple or list, default=(str,) + Types to ignore in the arrays. - Returns - ------- - namespace : module - Namespace shared by array objects. + xp : module, default=None + Precomputed array namespace module. When passed, typically from a caller + that has already performed inspection of its own inputs, skips array + namespace inspection. - is_array_api : bool - True of the arrays are containers that implement the Array API spec. - """ + Returns + ------- + namespace : module + Namespace shared by array objects. - sycl_type, xp, is_array_api_compliant = _get_sycl_namespace(*arrays) + is_array_api : bool + True of the arrays are containers that implement the Array API spec. + """ - if sycl_type: + _, xp, is_array_api_compliant = onedal_get_namespace( + *arrays, remove_none=remove_none, remove_types=remove_types, xp=xp + ) + return xp, is_array_api_compliant + +else: + + def get_namespace(*arrays): + """Get namespace of arrays. + + Extends stock scikit-learn's `get_namespace` primitive to support DPCTL usm_ndarrays + and DPNP ndarrays. + If no DPCTL usm_ndarray or DPNP ndarray inputs and backend scikit-learn version supports + Array API then :obj:`sklearn.utils._array_api.get_namespace(*arrays)` results are drawn. + Otherwise, numpy namespace will be returned. + + Designed to work for numpy.ndarray, DPCTL usm_ndarrays and DPNP ndarrays without + `array-api-compat` or backend scikit-learn Array API support. + + For full documentation refer to :obj:`sklearn.utils._array_api.get_namespace`. + + Parameters + ---------- + *arrays : array objects + Array objects. + + Returns + ------- + namespace : module + Namespace shared by array objects. + + is_array_api : bool + True of the arrays are containers that implement the Array API spec. + """ + + _, xp, is_array_api_compliant = onedal_get_namespace(*arrays) return xp, is_array_api_compliant - elif sklearn_check_version("1.2"): - return sklearn_get_namespace(*arrays) - else: - return np, False diff --git a/sklearnex/utils/tests/test_array_api.py b/sklearnex/utils/tests/test_array_api.py new file mode 100644 index 0000000000..bc4756ba84 --- /dev/null +++ b/sklearnex/utils/tests/test_array_api.py @@ -0,0 +1,182 @@ +# ============================================================================== +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import numpy as np +import pytest +from numpy.testing import assert_allclose + +from daal4py.sklearn._utils import sklearn_check_version +from onedal.tests.utils._dataframes_support import ( + _convert_to_dataframe, + get_dataframes_and_queues, +) + +# TODO: +# add test suit for dpctl.tensor, dpnp.ndarray, numpy.ndarray without config_context(array_api_dispatch=True)). +# TODO: +# extend for DPNP inputs. + + +@pytest.mark.skipif( + not sklearn_check_version("1.4"), + reason="Array API dispatch requires sklearn 1.4 version", +) +@pytest.mark.parametrize( + "dataframe,queue", + get_dataframes_and_queues( + dataframe_filter_="numpy,dpctl,array_api", device_filter_="cpu,gpu" + ), +) +def test_get_namespace_with_config_context(dataframe, queue): + """Test `get_namespace` with `array_api_dispatch` enabled.""" + from sklearnex import config_context + from sklearnex.utils._array_api import get_namespace + + array_api_compat = pytest.importorskip("array_api_compat") + + X_np = np.asarray([[1, 2, 3]]) + X = _convert_to_dataframe(X_np, sycl_queue=queue, target_df=dataframe) + + with config_context(array_api_dispatch=True): + xp_out, is_array_api_compliant = get_namespace(X) + assert is_array_api_compliant + if not dataframe in "numpy,array_api": + # Rather than array_api_compat.get_namespace raw output + # `get_namespace` has specific wrapper classes for `numpy.ndarray` + # or `array-api-strict`. + assert xp_out == array_api_compat.get_namespace(X) + + +@pytest.mark.skipif( + not sklearn_check_version("1.4"), + reason="Array API dispatch requires sklearn 1.4 version", +) +@pytest.mark.parametrize( + "dataframe,queue", + get_dataframes_and_queues( + dataframe_filter_="numpy,dpctl,array_api", device_filter_="cpu,gpu" + ), +) +def test_get_namespace_with_patching(dataframe, queue): + """Test `get_namespace` with `array_api_dispatch` and + `patch_sklearn` enabled. + """ + array_api_compat = pytest.importorskip("array_api_compat") + + from sklearnex import patch_sklearn + + patch_sklearn() + + from sklearn import config_context + from sklearn.utils._array_api import get_namespace + + X_np = np.asarray([[1, 2, 3]]) + X = _convert_to_dataframe(X_np, sycl_queue=queue, target_df=dataframe) + + with config_context(array_api_dispatch=True): + xp_out, is_array_api_compliant = get_namespace(X) + assert is_array_api_compliant + if not dataframe in "numpy,array_api": + # Rather than array_api_compat.get_namespace raw output + # `get_namespace` has specific wrapper classes for `numpy.ndarray` + # or `array-api-strict`. + assert xp_out == array_api_compat.get_namespace(X) + + +@pytest.mark.skipif( + not sklearn_check_version("1.4"), + reason="Array API dispatch requires sklearn 1.4 version", +) +@pytest.mark.parametrize( + "dataframe,queue", + get_dataframes_and_queues( + dataframe_filter_="dpctl,array_api", device_filter_="cpu,gpu" + ), +) +def test_convert_to_numpy_with_patching(dataframe, queue): + """Test `_convert_to_numpy` with `array_api_dispatch` and + `patch_sklearn` enabled. + """ + pytest.importorskip("array_api_compat") + + from sklearnex import patch_sklearn + + patch_sklearn() + + from sklearn import config_context + from sklearn.utils._array_api import _convert_to_numpy, get_namespace + + X_np = np.asarray([[1, 2, 3]]) + X = _convert_to_dataframe(X_np, sycl_queue=queue, target_df=dataframe) + + with config_context(array_api_dispatch=True): + xp, _ = get_namespace(X) + x_np = _convert_to_numpy(X, xp) + assert type(X_np) == type(x_np) + assert_allclose(X_np, x_np) + + +@pytest.mark.skipif( + not sklearn_check_version("1.4"), + reason="Array API dispatch requires sklearn 1.4 version", +) +@pytest.mark.parametrize( + "dataframe,queue", + get_dataframes_and_queues( + dataframe_filter_="numpy,dpctl,array_api", device_filter_="cpu,gpu" + ), +) +@pytest.mark.parametrize( + "dtype", + [ + pytest.param(np.float32, id=np.dtype(np.float32).name), + pytest.param(np.float64, id=np.dtype(np.float64).name), + ], +) +def test_validate_data_with_patching(dataframe, queue, dtype): + """Test validate_data with `array_api_dispatch` and + `patch_sklearn` enabled. + """ + pytest.importorskip("array_api_compat") + + from sklearnex import patch_sklearn + + patch_sklearn() + + from sklearn import config_context + from sklearn.base import BaseEstimator + + if sklearn_check_version("1.6"): + from sklearn.utils.validation import validate_data + else: + validate_data = BaseEstimator._validate_data + + from sklearn.utils._array_api import _convert_to_numpy, get_namespace + + X_np = np.asarray([[1, 2, 3], [4, 5, 6]], dtype=dtype) + X_df = _convert_to_dataframe(X_np, sycl_queue=queue, target_df=dataframe) + with config_context(array_api_dispatch=True): + est = BaseEstimator() + xp, _ = get_namespace(X_df) + X_df_res = validate_data( + est, X_df, accept_sparse="csr", dtype=[xp.float64, xp.float32] + ) + assert type(X_df) == type(X_df_res) + if dataframe != "numpy": + # _convert_to_numpy not designed for numpy.ndarray inputs. + assert_allclose(_convert_to_numpy(X_df, xp), _convert_to_numpy(X_df_res, xp)) + else: + assert_allclose(X_df, X_df_res)