Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: dpnp interop for sklearnex #1374

Merged
merged 10 commits into from
Jul 26, 2023
1 change: 1 addition & 0 deletions .ci/pipeline/build-and-test-lnx.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ steps:
bash .ci/scripts/setup_sklearn.sh $(SKLEARN_VERSION)
pip install --upgrade -r requirements-test.txt -r requirements-test-optional.txt
pip install $(python .ci/scripts/get_compatible_scipy_version.py)
if [ $(echo $(PYTHON_VERSION) | grep '3.8\|3.9\|3.10') ]; then conda install -q -y -c intel dpnp; fi
ethanglaser marked this conversation as resolved.
Show resolved Hide resolved
pip list
displayName: 'Install testing requirements'
- script: |
Expand Down
58 changes: 58 additions & 0 deletions examples/sklearnex/knn_bf_classification_dpnp_batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# ===============================================================================
# Copyright 2023 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===============================================================================

# sklearnex kNN example for GPU offloading with DPNP ndarray:
# python ./knn_bf_classification_dpnp_batch.py.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# python ./knn_bf_classification_dpnp_batch.py.py
# python ./knn_bf_classification_dpnp_batch.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch!


import dpctl
import dpnp
import numpy as np
from sklearn.datasets import make_classification
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split

from sklearnex.neighbors import KNeighborsClassifier

X, y = make_classification(
n_samples=1000,
n_features=4,
n_informative=2,
n_redundant=0,
random_state=0,
shuffle=False,
)

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)

# Make sure that all DPNP ndarrays using the same device.
q = dpctl.SyclQueue("gpu") # GPU

dpnp_X_train = dpnp.asarray(X_train, usm_type="device", sycl_queue=q)
dpnp_y_train = dpnp.asarray(y_train, usm_type="device", sycl_queue=q)
dpnp_X_test = dpnp.asarray(X_test, usm_type="device", sycl_queue=q)

knn_mdl = KNeighborsClassifier(
algorithm="brute", n_neighbors=20, weights="uniform", p=2, metric="minkowski"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose weights="uniform", p=2, metri="minkowski" - correspond to the Euclidean metric and normal mode of classification which are default.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be patched for both spmd and batch examples. Let's do it on separate PR

)
knn_mdl.fit(dpnp_X_train, dpnp_y_train)

y_predict = knn_mdl.predict(dpnp_X_test)

print("Brute Force Distributed kNN classification results:")
print("Ground truth (first 5 observations):\n{}".format(y_test[:5]))
print("Classification results (first 5 observations):\n{}".format(y_predict[:5]))
print("Accuracy (2 classes): {}\n".format(accuracy_score(y_test, y_predict.asnumpy())))
ethanglaser marked this conversation as resolved.
Show resolved Hide resolved
print("Are predicted results on GPU: {}".format(y_predict.sycl_device.is_gpu))
53 changes: 53 additions & 0 deletions examples/sklearnex/random_forest_classifier_dpctl_batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# ===============================================================================
# Copyright 2023 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===============================================================================

# sklearnex RF example for GPU offloading with DPCtl tensor:
# python ./random_forest_classifier_dpctl_batch.py

import dpctl
import dpctl.tensor as dpt
import numpy as np
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split

from sklearnex.preview.ensemble import RandomForestClassifier

# Make sure that all DPCtl tensors using the same device.
q = dpctl.SyclQueue("gpu") # GPU

X, y = make_classification(
n_samples=1000,
n_features=4,
n_informative=2,
n_redundant=0,
random_state=0,
shuffle=False,
)

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)

dpt_X_train = dpt.asarray(X_train, usm_type="device", sycl_queue=q)
dpt_y_train = dpt.asarray(y_train, usm_type="device", sycl_queue=q)
dpt_X_test = dpt.asarray(X_test, usm_type="device", sycl_queue=q)

rf = RandomForestClassifier(max_depth=2, random_state=0).fit(dpt_X_train, dpt_y_train)

pred = rf.predict(dpt_X_test)

print("Random Forest classification results:")
print("Ground truth (first 5 observations):\n{}".format(y_test[:5]))
print("Classification results (first 5 observations):\n{}".format(pred[:5]))
print("Are predicted results on GPU: {}".format(pred.sycl_device.is_gpu))
46 changes: 46 additions & 0 deletions examples/sklearnex/random_forest_regressor_dpnp_batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# ===============================================================================
# Copyright 2023 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===============================================================================

# sklearnex RF example for GPU offloading with DPNP ndarray:
# python ./random_forest_regressor_dpnp_batch.py.py

import dpnp
import numpy as np
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split

from sklearnex.preview.ensemble import RandomForestRegressor

sycl_device = "gpu:0"

X, y = make_regression(
n_samples=1000, n_features=4, n_informative=2, random_state=0, shuffle=False
)

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)

dpnp_X_train = dpnp.asarray(X_train, device=sycl_device)
dpnp_y_train = dpnp.asarray(y_train, device=sycl_device)
dpnp_X_test = dpnp.asarray(X_test, device=sycl_device)

rf = RandomForestRegressor(max_depth=2, random_state=0).fit(dpnp_X_train, dpnp_y_train)

pred = rf.predict(dpnp_X_test)

print("Random Forest regression results:")
print("Ground truth (first 5 observations):\n{}".format(y_test[:5]))
print("Regression results (first 5 observations):\n{}".format(pred[:5]))
print("Are predicted results on GPU: {}".format(pred.sycl_device.is_gpu))
125 changes: 75 additions & 50 deletions sklearnex/_device_offload.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#===============================================================================
# ===============================================================================
# Copyright 2021 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -12,44 +12,56 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#===============================================================================
# ===============================================================================

from ._config import get_config
from ._utils import get_patch_message
import logging
import sys
from functools import wraps

import numpy as np
import sys
import logging

try:
from dpctl import SyclQueue
from dpctl.memory import MemoryUSMDevice, as_usm_memory
from dpctl.tensor import usm_ndarray

dpctl_available = True
except ImportError:
dpctl_available = False

oneapi_is_available = 'daal4py.oneapi' in sys.modules
try:
import dpnp

dpnp_available = True
except ImportError:
dpnp_available = False

from ._config import get_config
from ._utils import get_patch_message

oneapi_is_available = "daal4py.oneapi" in sys.modules
if oneapi_is_available:
from daal4py.oneapi import _get_device_name_sycl_ctxt, _get_sycl_ctxt_params


class DummySyclQueue:
'''This class is designed to act like dpctl.SyclQueue
to allow device dispatching in scenarios when dpctl is not available'''
"""This class is designed to act like dpctl.SyclQueue
to allow device dispatching in scenarios when dpctl is not available"""

class DummySyclDevice:
def __init__(self, filter_string):
self._filter_string = filter_string
self.is_cpu = 'cpu' in filter_string
self.is_gpu = 'gpu' in filter_string
self.is_cpu = "cpu" in filter_string
self.is_gpu = "gpu" in filter_string
# TODO: check for possibility of fp64 support
# on other devices in this dummy class
self.has_aspect_fp64 = self.is_cpu

if not (self.is_cpu):
logging.warning("Device support is limited. "
"Please install dpctl for full experience")
logging.warning(
"Device support is limited. "
"Please install dpctl for full experience"
)

def get_filter_string(self):
return self._filter_string
Expand All @@ -65,23 +77,26 @@ def _get_device_info_from_daal4py():


def _get_global_queue():
target = get_config()['target_offload']
target = get_config()["target_offload"]
d4p_target, _ = _get_device_info_from_daal4py()
if d4p_target == 'host':
d4p_target = 'cpu'
if d4p_target == "host":
d4p_target = "cpu"

QueueClass = DummySyclQueue if not dpctl_available else SyclQueue

if target != 'auto':
if d4p_target is not None and \
d4p_target != target:
if target != "auto":
if d4p_target is not None and d4p_target != target:
if not isinstance(target, str):
if d4p_target not in target.sycl_device.get_filter_string():
raise RuntimeError("Cannot use target offload option "
"inside daal4py.oneapi.sycl_context")
raise RuntimeError(
"Cannot use target offload option "
"inside daal4py.oneapi.sycl_context"
)
else:
raise RuntimeError("Cannot use target offload option "
"inside daal4py.oneapi.sycl_context")
raise RuntimeError(
"Cannot use target offload option "
"inside daal4py.oneapi.sycl_context"
)
if isinstance(target, QueueClass):
return target
return QueueClass(target)
Expand All @@ -95,22 +110,25 @@ def _transfer_to_host(queue, *data):

host_data = []
for item in data:
usm_iface = getattr(item, '__sycl_usm_array_interface__', None)
usm_iface = getattr(item, "__sycl_usm_array_interface__", None)
if usm_iface is not None:
if not dpctl_available:
raise RuntimeError("dpctl need to be installed to work "
"with __sycl_usm_array_interface__")
raise RuntimeError(
"dpctl need to be installed to work "
"with __sycl_usm_array_interface__"
)
if queue is not None:
if queue.sycl_device != usm_iface['syclobj'].sycl_device:
raise RuntimeError('Input data shall be located '
'on single target device')
if queue.sycl_device != usm_iface["syclobj"].sycl_device:
raise RuntimeError(
"Input data shall be located " "on single target device"
)
else:
queue = usm_iface['syclobj']
queue = usm_iface["syclobj"]

buffer = as_usm_memory(item).copy_to_host()
item = np.ndarray(shape=usm_iface['shape'],
dtype=usm_iface['typestr'],
buffer=buffer)
item = np.ndarray(
shape=usm_iface["shape"], dtype=usm_iface["typestr"], buffer=buffer
)
has_usm_data = True
else:
has_host_data = True
Expand All @@ -119,7 +137,7 @@ def _transfer_to_host(queue, *data):
mismatch_usm_item = usm_iface is not None and has_host_data

if mismatch_host_item or mismatch_usm_item:
raise RuntimeError('Input data shall be located on single target device')
raise RuntimeError("Input data shall be located on single target device")

host_data.append(item)
return queue, host_data
Expand All @@ -129,20 +147,22 @@ def _get_backend(obj, queue, method_name, *data):
cpu_device = queue is None or queue.sycl_device.is_cpu
gpu_device = queue is not None and queue.sycl_device.is_gpu

if (cpu_device and obj._onedal_cpu_supported(method_name, *data)) or \
(gpu_device and obj._onedal_gpu_supported(method_name, *data)):
return 'onedal', queue
if (cpu_device and obj._onedal_cpu_supported(method_name, *data)) or (
gpu_device and obj._onedal_gpu_supported(method_name, *data)
):
return "onedal", queue
if cpu_device:
return 'sklearn', None
return "sklearn", None

_, d4p_options = _get_device_info_from_daal4py()
allow_fallback_to_host = get_config()['allow_fallback_to_host'] or \
d4p_options.get('host_offload_on_fail', False)
allow_fallback_to_host = get_config()["allow_fallback_to_host"] or d4p_options.get(
"host_offload_on_fail", False
)

if gpu_device and allow_fallback_to_host:
if obj._onedal_cpu_supported(method_name, *data):
return 'onedal', None
return 'sklearn', None
return "onedal", None
return "sklearn", None

raise RuntimeError("Device support is not implemented")

Expand All @@ -155,18 +175,20 @@ def dispatch(obj, method_name, branches, *args, **kwargs):

backend, q = _get_backend(obj, q, method_name, *hostargs)

if backend == 'onedal':
if backend == "onedal":
return branches[backend](obj, *hostargs, **hostkwargs, queue=q)
if backend == 'sklearn':
if backend == "sklearn":
return branches[backend](obj, *hostargs, **hostkwargs)
raise RuntimeError(f'Undefined backend {backend} in '
f'{obj.__class__.__name__}.{method_name}')
raise RuntimeError(
f"Undefined backend {backend} in " f"{obj.__class__.__name__}.{method_name}"
)


def _copy_to_usm(queue, array):
if not dpctl_available:
raise RuntimeError("dpctl need to be installed to work "
"with __sycl_usm_array_interface__")
raise RuntimeError(
"dpctl need to be installed to work " "with __sycl_usm_array_interface__"
)
mem = MemoryUSMDevice(array.nbytes, queue=queue)
mem.copy_from_host(array.tobytes())
return usm_ndarray(array.shape, array.dtype, buffer=mem)
Expand All @@ -179,9 +201,12 @@ def wrapper(self, *args, **kwargs):
if len(data) == 0:
usm_iface = None
else:
usm_iface = getattr(data[0], '__sycl_usm_array_interface__', None)
usm_iface = getattr(data[0], "__sycl_usm_array_interface__", None)
result = func(self, *args, **kwargs)
if usm_iface is not None:
return _copy_to_usm(usm_iface['syclobj'], result)
result = _copy_to_usm(usm_iface["syclobj"], result)
if dpnp_available and isinstance(data[0], dpnp.ndarray):
result = dpnp.array(result, copy=False)
return result
Comment on lines +207 to +209
Copy link
Contributor Author

@samir-nasibli samir-nasibli Jul 25, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This and related dpnp import are only functional changes in this file, the rest is formatting.


return wrapper
Loading
Loading