Skip to content

Commit 4abff0d

Browse files
ENH: added splitter_mode for Random Forest (#1223)
* ENH: added splitter_mode for Random Forest * condition for splitter mode * Update onedal/ensemble/forest.cpp * added warnings * updated tests * Add splitter_mode to RandomForest _parameter_constraints --------- Co-authored-by: Alexander Andreev <[email protected]>
1 parent ef57673 commit 4abff0d

File tree

4 files changed

+92
-7
lines changed

4 files changed

+92
-7
lines changed

onedal/ensemble/forest.cpp

+18
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,21 @@ auto get_infer_mode(const py::dict& params) {
109109
return result_mode;
110110
}
111111

112+
#if defined(ONEDAL_VERSION) && ONEDAL_VERSION >= 20230101
113+
auto get_splitter_mode(const py::dict& params) {
114+
using namespace decision_forest;
115+
auto mode = params["splitter_mode"].cast<std::string>();
116+
if (mode == "best") {
117+
return splitter_mode::best;
118+
}
119+
else if (mode == "random") {
120+
return splitter_mode::random;
121+
}
122+
else
123+
ONEDAL_PARAM_DISPATCH_THROW_INVALID_VALUE(mode);
124+
}
125+
#endif // defined(ONEDAL_VERSION) && ONEDAL_VERSION>=20230101
126+
112127
auto get_variable_importance_mode(const py::dict& params) {
113128
using namespace decision_forest;
114129

@@ -171,6 +186,9 @@ struct params2desc {
171186
.set_min_bin_size(params["min_bin_size"].cast<std::int64_t>())
172187
.set_memory_saving_mode(params["memory_saving_mode"].cast<bool>())
173188
.set_bootstrap(params["bootstrap"].cast<bool>())
189+
#if defined(ONEDAL_VERSION) && ONEDAL_VERSION >= 20230101
190+
.set_splitter_mode(get_splitter_mode(params))
191+
#endif // defined(ONEDAL_VERSION) && ONEDAL_VERSION>=20230101
174192
.set_error_metric_mode(get_error_metric_mode(params))
175193
.set_variable_importance_mode(get_variable_importance_mode(params));
176194

onedal/ensemble/forest.py

+8
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def __init__(
7777
max_bins,
7878
min_bin_size,
7979
infer_mode,
80+
splitter_mode,
8081
voting_mode,
8182
error_metric_mode,
8283
variable_importance_mode,
@@ -102,6 +103,7 @@ def __init__(
102103
self.max_bins = max_bins
103104
self.min_bin_size = min_bin_size
104105
self.infer_mode = infer_mode
106+
self.splitter_mode = splitter_mode
105107
self.voting_mode = voting_mode
106108
self.error_metric_mode = error_metric_mode
107109
self.variable_importance_mode = variable_importance_mode
@@ -230,6 +232,8 @@ def _get_onedal_params(self, data):
230232
if self.is_classification:
231233
onedal_params['class_count'] = 0 if self.classes_ is None else len(
232234
self.classes_)
235+
if daal_check_version((2023, 'P', 101)):
236+
onedal_params['splitter_mode'] = self.splitter_mode
233237
return onedal_params
234238

235239
def _check_parameters(self):
@@ -434,6 +438,7 @@ def __init__(self,
434438
max_bins=256,
435439
min_bin_size=1,
436440
infer_mode='class_responses',
441+
splitter_mode='best',
437442
voting_mode='weighted',
438443
error_metric_mode='none',
439444
variable_importance_mode='none',
@@ -460,6 +465,7 @@ def __init__(self,
460465
max_bins=max_bins,
461466
min_bin_size=min_bin_size,
462467
infer_mode=infer_mode,
468+
splitter_mode=splitter_mode,
463469
voting_mode=voting_mode,
464470
error_metric_mode=error_metric_mode,
465471
variable_importance_mode=variable_importance_mode,
@@ -516,6 +522,7 @@ def __init__(self,
516522
max_bins=256,
517523
min_bin_size=1,
518524
infer_mode='class_responses',
525+
splitter_mode='best',
519526
voting_mode='weighted',
520527
error_metric_mode='none',
521528
variable_importance_mode='none',
@@ -542,6 +549,7 @@ def __init__(self,
542549
max_bins=max_bins,
543550
min_bin_size=min_bin_size,
544551
infer_mode=infer_mode,
552+
splitter_mode=splitter_mode,
545553
voting_mode=voting_mode,
546554
error_metric_mode=error_metric_mode,
547555
variable_importance_mode=variable_importance_mode,

onedal/ensemble/tests/test_random_forest.py

+25
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import numpy as np
1919
from numpy.testing import assert_allclose
2020

21+
from daal4py.sklearn._utils import daal_check_version
2122
from onedal.ensemble import RandomForestClassifier, RandomForestRegressor
2223
from onedal.tests.utils._device_selection import get_queues
2324

@@ -42,3 +43,27 @@ def test_rf_regression(queue):
4243
max_depth=2, random_state=0).fit(X, y, queue=queue)
4344
assert_allclose(
4445
[-6.83], rf.predict([[0, 0, 0, 0]], queue=queue), atol=1e-2)
46+
47+
48+
@pytest.mark.skipif(not daal_check_version((2023, 'P', 101)),
49+
reason='requires OneDAL 2023.1.1')
50+
@pytest.mark.parametrize('queue', get_queues('gpu'))
51+
def test_rf_classifier_random_splitter(queue):
52+
X, y = make_classification(n_samples=100, n_features=4,
53+
n_informative=2, n_redundant=0,
54+
random_state=0, shuffle=False)
55+
rf = RandomForestClassifier(
56+
max_depth=2, random_state=0,
57+
splitter_mode='random').fit(X, y, queue=queue)
58+
assert_allclose([1], rf.predict([[0, 0, 0, 0]], queue=queue))
59+
60+
61+
@pytest.mark.parametrize('queue', get_queues('gpu'))
62+
def test_rf_regression_random_splitter(queue):
63+
X, y = make_regression(n_samples=100, n_features=4, n_informative=2,
64+
random_state=0, shuffle=False)
65+
rf = RandomForestRegressor(
66+
max_depth=2, random_state=0,
67+
splitter_mode='random').fit(X, y, queue=queue)
68+
assert_allclose(
69+
[-6.83], rf.predict([[0, 0, 0, 0]], queue=queue), atol=1e-2)

sklearnex/preview/ensemble/forest.py

+41-7
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@
5858
from scipy import sparse as sp
5959

6060
if sklearn_check_version('1.2'):
61-
from sklearn.utils._param_validation import Interval
61+
from sklearn.utils._param_validation import Interval, StrOptions
6262

6363

6464
class BaseRandomForest(ABC):
@@ -193,7 +193,8 @@ class RandomForestClassifier(sklearn_RandomForestClassifier, BaseRandomForest):
193193
_parameter_constraints: dict = {
194194
**sklearn_RandomForestClassifier._parameter_constraints,
195195
"max_bins": [Interval(numbers.Integral, 2, None, closed="left")],
196-
"min_bin_size": [Interval(numbers.Integral, 1, None, closed="left")]
196+
"min_bin_size": [Interval(numbers.Integral, 1, None, closed="left")],
197+
"splitter_mode": [StrOptions({"best", "random"})]
197198
}
198199

199200
if sklearn_check_version('1.0'):
@@ -218,7 +219,8 @@ def __init__(
218219
ccp_alpha=0.0,
219220
max_samples=None,
220221
max_bins=256,
221-
min_bin_size=1):
222+
min_bin_size=1,
223+
splitter_mode='best'):
222224
super(RandomForestClassifier, self).__init__(
223225
n_estimators=n_estimators,
224226
criterion=criterion,
@@ -243,6 +245,7 @@ def __init__(
243245
self.max_bins = max_bins
244246
self.min_bin_size = min_bin_size
245247
self.min_impurity_split = None
248+
self.splitter_mode = splitter_mode
246249
# self._estimator = DecisionTreeClassifier()
247250
else:
248251
def __init__(self,
@@ -266,7 +269,8 @@ def __init__(self,
266269
ccp_alpha=0.0,
267270
max_samples=None,
268271
max_bins=256,
269-
min_bin_size=1):
272+
min_bin_size=1,
273+
splitter_mode='best'):
270274
super(RandomForestClassifier, self).__init__(
271275
n_estimators=n_estimators,
272276
criterion=criterion,
@@ -294,6 +298,7 @@ def __init__(self,
294298
self.max_bins = max_bins
295299
self.min_bin_size = min_bin_size
296300
self.min_impurity_split = None
301+
self.splitter_mode = splitter_mode
297302
# self._estimator = DecisionTreeClassifier()
298303

299304
def fit(self, X, y, sample_weight=None):
@@ -529,6 +534,11 @@ def _estimators_(self):
529534
def _onedal_cpu_supported(self, method_name, *data):
530535
if method_name == 'ensemble.RandomForestClassifier.fit':
531536
ready, X, y, sample_weight = self._onedal_ready(*data)
537+
if self.splitter_mode == 'random':
538+
warnings.warn("'random' splitter mode supports GPU devices only "
539+
"and requires oneDAL version >= 2023.1.1. "
540+
"Using 'best' mode instead.", RuntimeWarning)
541+
self.splitter_mode = 'best'
532542
if not ready:
533543
return False
534544
elif sp.issparse(X):
@@ -570,6 +580,11 @@ def _onedal_cpu_supported(self, method_name, *data):
570580
def _onedal_gpu_supported(self, method_name, *data):
571581
if method_name == 'ensemble.RandomForestClassifier.fit':
572582
ready, X, y, sample_weight = self._onedal_ready(*data)
583+
if self.splitter_mode == 'random' and \
584+
not daal_check_version((2023, 'P', 101)):
585+
warnings.warn("'random' splitter mode requires OneDAL >= 2023.1.1. "
586+
"Using 'best' mode instead.", RuntimeWarning)
587+
self.splitter_mode = 'best'
573588
if not ready:
574589
return False
575590
elif sp.issparse(X):
@@ -687,6 +702,8 @@ def _onedal_fit(self, X, y, sample_weight=None, queue=None):
687702
'min_bin_size': self.min_bin_size,
688703
'max_samples': self.max_samples
689704
}
705+
if daal_check_version((2023, 'P', 101)):
706+
onedal_params['splitter_mode'] = self.splitter_mode
690707
self._cached_estimators_ = None
691708

692709
# Compute
@@ -729,7 +746,8 @@ class RandomForestRegressor(sklearn_RandomForestRegressor, BaseRandomForest):
729746
_parameter_constraints: dict = {
730747
**sklearn_RandomForestRegressor._parameter_constraints,
731748
"max_bins": [Interval(numbers.Integral, 2, None, closed="left")],
732-
"min_bin_size": [Interval(numbers.Integral, 1, None, closed="left")]
749+
"min_bin_size": [Interval(numbers.Integral, 1, None, closed="left")],
750+
"splitter_mode": [StrOptions({"best", "random"})]
733751
}
734752

735753
if sklearn_check_version('1.0'):
@@ -754,7 +772,8 @@ def __init__(
754772
ccp_alpha=0.0,
755773
max_samples=None,
756774
max_bins=256,
757-
min_bin_size=1):
775+
min_bin_size=1,
776+
splitter_mode='best'):
758777
super(RandomForestRegressor, self).__init__(
759778
n_estimators=n_estimators,
760779
criterion=criterion,
@@ -778,6 +797,7 @@ def __init__(
778797
self.max_bins = max_bins
779798
self.min_bin_size = min_bin_size
780799
self.min_impurity_split = None
800+
self.splitter_mode = splitter_mode
781801
else:
782802
def __init__(self,
783803
n_estimators=100, *,
@@ -799,7 +819,8 @@ def __init__(self,
799819
ccp_alpha=0.0,
800820
max_samples=None,
801821
max_bins=256,
802-
min_bin_size=1):
822+
min_bin_size=1,
823+
splitter_mode='best'):
803824
super(RandomForestRegressor, self).__init__(
804825
n_estimators=n_estimators,
805826
criterion=criterion,
@@ -826,6 +847,7 @@ def __init__(self,
826847
self.max_bins = max_bins
827848
self.min_bin_size = min_bin_size
828849
self.min_impurity_split = None
850+
self.splitter_mode = splitter_mode
829851

830852
@property
831853
def _estimators_(self):
@@ -902,6 +924,11 @@ def _onedal_ready(self, X, y, sample_weight):
902924
def _onedal_cpu_supported(self, method_name, *data):
903925
if method_name == 'ensemble.RandomForestRegressor.fit':
904926
ready, X, y, sample_weight = self._onedal_ready(*data)
927+
if self.splitter_mode == 'random':
928+
warnings.warn("'random' splitter mode supports GPU devices only "
929+
"and requires oneDAL version >= 2023.1.1. "
930+
"Using 'best' mode instead.", RuntimeWarning)
931+
self.splitter_mode = 'best'
905932
if not ready:
906933
return False
907934
elif not (self.oob_score and daal_check_version(
@@ -947,6 +974,11 @@ def _onedal_cpu_supported(self, method_name, *data):
947974
def _onedal_gpu_supported(self, method_name, *data):
948975
if method_name == 'ensemble.RandomForestRegressor.fit':
949976
ready, X, y, sample_weight = self._onedal_ready(*data)
977+
if self.splitter_mode == 'random' and \
978+
not daal_check_version((2023, 'P', 101)):
979+
warnings.warn("'random' splitter mode requires OneDAL >= 2023.1.1. "
980+
"Using 'best' mode instead.", RuntimeWarning)
981+
self.splitter_mode = 'best'
950982
if not ready:
951983
return False
952984
elif not (self.oob_score and daal_check_version(
@@ -1035,6 +1067,8 @@ def _onedal_fit(self, X, y, sample_weight=None, queue=None):
10351067
'variable_importance_mode': 'mdi',
10361068
'max_samples': self.max_samples
10371069
}
1070+
if daal_check_version((2023, 'P', 101)):
1071+
onedal_params['splitter_mode'] = self.splitter_mode
10381072
self._cached_estimators_ = None
10391073
self._onedal_estimator = self._onedal_regressor(**onedal_params)
10401074
self._onedal_estimator.fit(X, y, sample_weight, queue=queue)

0 commit comments

Comments
 (0)