Skip to content

Commit 9928722

Browse files
author
KulikovNikita
authored
Fixes in LinearRegression SPMD (#1195)
1 parent e5e8a13 commit 9928722

File tree

4 files changed

+23
-19
lines changed

4 files changed

+23
-19
lines changed

onedal/datatypes/_data_conversion.py

+4-11
Original file line numberDiff line numberDiff line change
@@ -40,25 +40,18 @@ def to_table(*args):
4040
if _is_dpc_backend:
4141
import numpy as np
4242

43-
from ..common._spmd_policy import _SPMDDataParallelInteropPolicy
44-
from ..common._policy import _HostInteropPolicy, _DataParallelInteropPolicy
43+
from ..common._policy import _HostInteropPolicy
4544

4645
def _convert_to_supported_impl(policy, *data):
4746
# CPUs support FP64 by default
48-
is_host = isinstance(policy, _HostInteropPolicy)
49-
no_dpcpp = not _is_dpc_backend
50-
if is_host or no_dpcpp:
47+
if isinstance(policy, _HostInteropPolicy):
5148
return data
5249

53-
# There is only one option of data parallel policy
54-
is_dpcpp_policy = isinstance(policy, _DataParallelInteropPolicy)
55-
is_spmd_policy = isinstance(policy, _SPMDDataParallelInteropPolicy)
56-
assert is_spmd_policy or is_dpcpp_policy
57-
50+
# It can be either SPMD or DPCPP policy
5851
device = policy._queue.sycl_device
5952

6053
def convert_or_pass(x):
61-
if x.dtype is not np.float32:
54+
if x.dtype is np.float64:
6255
return x.astype(np.float32)
6356
else:
6457
return x

onedal/primitives/tree_visitor.cpp

+3-6
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,6 @@
2626
#include <limits>
2727
#include <vector>
2828

29-
#include <iostream>
30-
#include <utility>
31-
3229
#define ONEDAL_PY_TERMINAL_NODE -1
3330
#define ONEDAL_PY_NO_FEATURE -2
3431

@@ -45,7 +42,7 @@ inline static const double get_nan64() {
4542

4643
// equivalent for numpy arange
4744
template <typename T>
48-
std::vector<T> arange(T start, T stop, T step = 1) {
45+
inline std::vector<T> arange(T start, T stop, T step = 1) {
4946
std::vector<T> res;
5047
for (T i = start; i < stop; i += step)
5148
res.push_back(i);
@@ -128,7 +125,7 @@ class node_visitor {
128125
template <typename Task>
129126
class to_sklearn_tree_object_visitor : public tree_state<Task> {
130127
public:
131-
to_sklearn_tree_object_visitor(size_t _depth,
128+
to_sklearn_tree_object_visitor(std::size_t _depth,
132129
std::size_t _n_nodes,
133130
std::size_t _n_leafs,
134131
std::size_t _max_n_classes);
@@ -143,7 +140,7 @@ class to_sklearn_tree_object_visitor : public tree_state<Task> {
143140
};
144141

145142
template <typename Task>
146-
to_sklearn_tree_object_visitor<Task>::to_sklearn_tree_object_visitor(size_t _depth,
143+
to_sklearn_tree_object_visitor<Task>::to_sklearn_tree_object_visitor(std::size_t _depth,
147144
std::size_t _n_nodes,
148145
std::size_t _n_leafs,
149146
std::size_t _max_n_classes)

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@
8282
except ImportError:
8383
dpctl_available = False
8484

85-
build_distribute = dpcpp and dpctl_available and not no_dist
85+
build_distribute = dpcpp and dpctl_available and not no_dist and IS_LIN
8686

8787

8888
daal_lib_dir = lib_dir if (IS_MAC or os.path.isdir(

setup_sklearnex.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
# System imports
1919
import os
20+
import sys
2021
import time
2122
from setuptools import setup
2223
from scripts.version import get_onedal_version
@@ -25,6 +26,19 @@
2526
sklearnex_version = (os.environ["SKLEARNEX_VERSION"] if "SKLEARNEX_VERSION" in os.environ
2627
else time.strftime("%Y%m%d.%H%M%S"))
2728

29+
IS_WIN = False
30+
IS_MAC = False
31+
IS_LIN = False
32+
33+
if 'linux' in sys.platform:
34+
IS_LIN = True
35+
elif sys.platform == 'darwin':
36+
IS_MAC = True
37+
elif sys.platform in ['win32', 'cygwin']:
38+
IS_WIN = True
39+
else:
40+
assert False, sys.platform + ' not supported'
41+
2842
dal_root = os.environ.get('DALROOT')
2943

3044
if dal_root is None:
@@ -41,7 +55,7 @@
4155
except ImportError:
4256
dpctl_available = False
4357

44-
build_distribute = dpcpp and dpctl_available and not no_dist
58+
build_distribute = dpcpp and dpctl_available and not no_dist and IS_LIN
4559

4660
ONEDAL_VERSION = get_onedal_version(dal_root)
4761

0 commit comments

Comments
 (0)