File tree 4 files changed +23
-19
lines changed
4 files changed +23
-19
lines changed Original file line number Diff line number Diff line change @@ -40,25 +40,18 @@ def to_table(*args):
40
40
if _is_dpc_backend :
41
41
import numpy as np
42
42
43
- from ..common ._spmd_policy import _SPMDDataParallelInteropPolicy
44
- from ..common ._policy import _HostInteropPolicy , _DataParallelInteropPolicy
43
+ from ..common ._policy import _HostInteropPolicy
45
44
46
45
def _convert_to_supported_impl (policy , * data ):
47
46
# CPUs support FP64 by default
48
- is_host = isinstance (policy , _HostInteropPolicy )
49
- no_dpcpp = not _is_dpc_backend
50
- if is_host or no_dpcpp :
47
+ if isinstance (policy , _HostInteropPolicy ):
51
48
return data
52
49
53
- # There is only one option of data parallel policy
54
- is_dpcpp_policy = isinstance (policy , _DataParallelInteropPolicy )
55
- is_spmd_policy = isinstance (policy , _SPMDDataParallelInteropPolicy )
56
- assert is_spmd_policy or is_dpcpp_policy
57
-
50
+ # It can be either SPMD or DPCPP policy
58
51
device = policy ._queue .sycl_device
59
52
60
53
def convert_or_pass (x ):
61
- if x .dtype is not np .float32 :
54
+ if x .dtype is np .float64 :
62
55
return x .astype (np .float32 )
63
56
else :
64
57
return x
Original file line number Diff line number Diff line change 26
26
#include < limits>
27
27
#include < vector>
28
28
29
- #include < iostream>
30
- #include < utility>
31
-
32
29
#define ONEDAL_PY_TERMINAL_NODE -1
33
30
#define ONEDAL_PY_NO_FEATURE -2
34
31
@@ -45,7 +42,7 @@ inline static const double get_nan64() {
45
42
46
43
// equivalent for numpy arange
47
44
template <typename T>
48
- std::vector<T> arange (T start, T stop, T step = 1 ) {
45
+ inline std::vector<T> arange (T start, T stop, T step = 1 ) {
49
46
std::vector<T> res;
50
47
for (T i = start; i < stop; i += step)
51
48
res.push_back (i);
@@ -128,7 +125,7 @@ class node_visitor {
128
125
template <typename Task>
129
126
class to_sklearn_tree_object_visitor : public tree_state <Task> {
130
127
public:
131
- to_sklearn_tree_object_visitor (size_t _depth,
128
+ to_sklearn_tree_object_visitor (std:: size_t _depth,
132
129
std::size_t _n_nodes,
133
130
std::size_t _n_leafs,
134
131
std::size_t _max_n_classes);
@@ -143,7 +140,7 @@ class to_sklearn_tree_object_visitor : public tree_state<Task> {
143
140
};
144
141
145
142
template <typename Task>
146
- to_sklearn_tree_object_visitor<Task>::to_sklearn_tree_object_visitor(size_t _depth,
143
+ to_sklearn_tree_object_visitor<Task>::to_sklearn_tree_object_visitor(std:: size_t _depth,
147
144
std::size_t _n_nodes,
148
145
std::size_t _n_leafs,
149
146
std::size_t _max_n_classes)
Original file line number Diff line number Diff line change 82
82
except ImportError :
83
83
dpctl_available = False
84
84
85
- build_distribute = dpcpp and dpctl_available and not no_dist
85
+ build_distribute = dpcpp and dpctl_available and not no_dist and IS_LIN
86
86
87
87
88
88
daal_lib_dir = lib_dir if (IS_MAC or os .path .isdir (
Original file line number Diff line number Diff line change 17
17
18
18
# System imports
19
19
import os
20
+ import sys
20
21
import time
21
22
from setuptools import setup
22
23
from scripts .version import get_onedal_version
25
26
sklearnex_version = (os .environ ["SKLEARNEX_VERSION" ] if "SKLEARNEX_VERSION" in os .environ
26
27
else time .strftime ("%Y%m%d.%H%M%S" ))
27
28
29
+ IS_WIN = False
30
+ IS_MAC = False
31
+ IS_LIN = False
32
+
33
+ if 'linux' in sys .platform :
34
+ IS_LIN = True
35
+ elif sys .platform == 'darwin' :
36
+ IS_MAC = True
37
+ elif sys .platform in ['win32' , 'cygwin' ]:
38
+ IS_WIN = True
39
+ else :
40
+ assert False , sys .platform + ' not supported'
41
+
28
42
dal_root = os .environ .get ('DALROOT' )
29
43
30
44
if dal_root is None :
41
55
except ImportError :
42
56
dpctl_available = False
43
57
44
- build_distribute = dpcpp and dpctl_available and not no_dist
58
+ build_distribute = dpcpp and dpctl_available and not no_dist and IS_LIN
45
59
46
60
ONEDAL_VERSION = get_onedal_version (dal_root )
47
61
You can’t perform that action at this time.
0 commit comments