Skip to content

Commit

Permalink
Fixes in LinearRegression SPMD (#1195)
Browse files Browse the repository at this point in the history
  • Loading branch information
KulikovNikita authored Feb 27, 2023
1 parent e5e8a13 commit 9928722
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 19 deletions.
15 changes: 4 additions & 11 deletions onedal/datatypes/_data_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,25 +40,18 @@ def to_table(*args):
if _is_dpc_backend:
import numpy as np

from ..common._spmd_policy import _SPMDDataParallelInteropPolicy
from ..common._policy import _HostInteropPolicy, _DataParallelInteropPolicy
from ..common._policy import _HostInteropPolicy

def _convert_to_supported_impl(policy, *data):
# CPUs support FP64 by default
is_host = isinstance(policy, _HostInteropPolicy)
no_dpcpp = not _is_dpc_backend
if is_host or no_dpcpp:
if isinstance(policy, _HostInteropPolicy):
return data

# There is only one option of data parallel policy
is_dpcpp_policy = isinstance(policy, _DataParallelInteropPolicy)
is_spmd_policy = isinstance(policy, _SPMDDataParallelInteropPolicy)
assert is_spmd_policy or is_dpcpp_policy

# It can be either SPMD or DPCPP policy
device = policy._queue.sycl_device

def convert_or_pass(x):
if x.dtype is not np.float32:
if x.dtype is np.float64:
return x.astype(np.float32)
else:
return x
Expand Down
9 changes: 3 additions & 6 deletions onedal/primitives/tree_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,6 @@
#include <limits>
#include <vector>

#include <iostream>
#include <utility>

#define ONEDAL_PY_TERMINAL_NODE -1
#define ONEDAL_PY_NO_FEATURE -2

Expand All @@ -45,7 +42,7 @@ inline static const double get_nan64() {

// equivalent for numpy arange
template <typename T>
std::vector<T> arange(T start, T stop, T step = 1) {
inline std::vector<T> arange(T start, T stop, T step = 1) {
std::vector<T> res;
for (T i = start; i < stop; i += step)
res.push_back(i);
Expand Down Expand Up @@ -128,7 +125,7 @@ class node_visitor {
template <typename Task>
class to_sklearn_tree_object_visitor : public tree_state<Task> {
public:
to_sklearn_tree_object_visitor(size_t _depth,
to_sklearn_tree_object_visitor(std::size_t _depth,
std::size_t _n_nodes,
std::size_t _n_leafs,
std::size_t _max_n_classes);
Expand All @@ -143,7 +140,7 @@ class to_sklearn_tree_object_visitor : public tree_state<Task> {
};

template <typename Task>
to_sklearn_tree_object_visitor<Task>::to_sklearn_tree_object_visitor(size_t _depth,
to_sklearn_tree_object_visitor<Task>::to_sklearn_tree_object_visitor(std::size_t _depth,
std::size_t _n_nodes,
std::size_t _n_leafs,
std::size_t _max_n_classes)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@
except ImportError:
dpctl_available = False

build_distribute = dpcpp and dpctl_available and not no_dist
build_distribute = dpcpp and dpctl_available and not no_dist and IS_LIN


daal_lib_dir = lib_dir if (IS_MAC or os.path.isdir(
Expand Down
16 changes: 15 additions & 1 deletion setup_sklearnex.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

# System imports
import os
import sys
import time
from setuptools import setup
from scripts.version import get_onedal_version
Expand All @@ -25,6 +26,19 @@
sklearnex_version = (os.environ["SKLEARNEX_VERSION"] if "SKLEARNEX_VERSION" in os.environ
else time.strftime("%Y%m%d.%H%M%S"))

IS_WIN = False
IS_MAC = False
IS_LIN = False

if 'linux' in sys.platform:
IS_LIN = True
elif sys.platform == 'darwin':
IS_MAC = True
elif sys.platform in ['win32', 'cygwin']:
IS_WIN = True
else:
assert False, sys.platform + ' not supported'

dal_root = os.environ.get('DALROOT')

if dal_root is None:
Expand All @@ -41,7 +55,7 @@
except ImportError:
dpctl_available = False

build_distribute = dpcpp and dpctl_available and not no_dist
build_distribute = dpcpp and dpctl_available and not no_dist and IS_LIN

ONEDAL_VERSION = get_onedal_version(dal_root)

Expand Down

0 comments on commit 9928722

Please sign in to comment.