Skip to content

Commit 83b5266

Browse files
[bug] fix ensemble algo _onedal_gpu_supported logic (#1696) (#1710)
* Update _forest.py * Update deselected_tests.yaml * this is definitely going to fail CI * Update deselected_tests.yaml * Update deselected_tests.yaml * Update deselected_tests.yaml * Update test_forest.py * Update test_forest.py (cherry picked from commit e0a405c) Co-authored-by: Ian Faust <[email protected]>
1 parent 54f64c2 commit 83b5266

File tree

3 files changed

+36
-34
lines changed

3 files changed

+36
-34
lines changed

deselected_tests.yaml

+19-13
Original file line numberDiff line numberDiff line change
@@ -441,8 +441,6 @@ gpu:
441441
- ensemble/tests/test_bagging.py::test_gridsearch
442442
- ensemble/tests/test_bagging.py::test_estimators_samples
443443
- ensemble/tests/test_common.py::test_ensemble_heterogeneous_estimators_behavior
444-
- ensemble/tests/test_forest.py::test_min_samples_split[RandomForestClassifier]
445-
- ensemble/tests/test_forest.py::test_min_weight_fraction_leaf
446444
- ensemble/tests/test_voting.py::test_parallel_fit
447445
- ensemble/tests/test_voting.py::test_sample_weight
448446

@@ -640,8 +638,6 @@ gpu:
640638
- model_selection/tests/test_search.py::test_random_search_cv_results
641639

642640
# Segmentation faults on GPU
643-
- ensemble/tests/test_forest.py::test_forest_classifier_oob
644-
- ensemble/tests/test_forest.py::test_forest_regressor_oob
645641
- tests/test_common.py::test_search_cv
646642
- manifold/tests/test_t_sne.py::test_n_iter_without_progress
647643

@@ -736,15 +732,25 @@ gpu:
736732
- tests/test_common.py::test_f_contiguous_array_estimator[TSNE]
737733
- manifold/tests/test_t_sne.py::test_tsne_works_with_pandas_output
738734

739-
# GPU ensemble (Random Forest and Extra Trees) algorithms have a different
740-
# implementation compared to CPU and require further validation
741-
- ensemble/tests/test_forest.py::test_importances[ExtraTreesClassifier-gini-float64]
742-
- ensemble/tests/test_forest.py::test_importances[ExtraTreesClassifier-gini-float32]
743-
- ensemble/tests/test_forest.py::test_importances[ExtraTreesRegressor-squared_error-float64]
744-
- ensemble/tests/test_forest.py::test_importances[ExtraTreesRegressor-squared_error-float32]
745-
- ensemble/tests/test_forest.py::test_importances[RandomForestClassifier-gini-float32]
746-
- ensemble/tests/test_forest.py::test_importances[RandomForestRegressor-squared_error-float64]
747-
- ensemble/tests/test_forest.py::test_importances[RandomForestRegressor-squared_error-float32]
735+
# GPU Forest algorithm implementation does not follow certain Scikit-learn standards
736+
- ensemble/tests/test_forest.py::test_max_leaf_nodes_max_depth
737+
- ensemble/tests/test_forest.py::test_min_samples_split[ExtraTreesClassifier]
738+
- ensemble/tests/test_forest.py::test_min_samples_split[RandomForestClassifier]
739+
- ensemble/tests/test_forest.py::test_min_samples_split[ExtraTreesRegressor]
740+
- ensemble/tests/test_forest.py::test_max_samples_boundary_regressors
741+
742+
# numerical issues in GPU Forest algorithms which require further investigation
743+
- ensemble/tests/test_forest.py::test_forest_classifier_oob[X0-y0-0.9-array-ExtraTreesClassifier]
744+
- ensemble/tests/test_forest.py::test_forest_classifier_oob[X0-y0-0.9-array-RandomForestClassifier]
745+
- ensemble/tests/test_forest.py::test_forest_classifier_oob[X1-y1-0.65-array-RandomForestClassifier]
746+
- ensemble/tests/test_forest.py::test_forest_classifier_oob[X2-y2-0.65-array-ExtraTreesClassifier]
747+
- ensemble/tests/test_forest.py::test_forest_classifier_oob[X2-y2-0.65-array-RandomForestClassifier]
748+
- ensemble/tests/test_forest.py::test_forest_regressor_oob[X0-y0-0.7-array-RandomForestRegressor]
749+
- ensemble/tests/test_stacking.py::test_stacking_regressor_drop_estimator
750+
- ensemble/tests/test_voting.py::test_predict_on_toy_problem[42]
751+
- tests/test_common.py::test_estimators[ExtraTreesClassifier()-check_class_weight_classifiers]
752+
- tests/test_common.py::test_estimators[ExtraTreesRegressor()-check_sample_weights_invariance(kind=zeros)]
753+
- tests/test_common.py::test_estimators[RandomForestRegressor()-check_regressor_data_not_an_array]
748754

749755
# GPU implementation of Extra Trees doesn't support sample_weights
750756
# comparisons to GPU with sample weights will use different algorithms

sklearnex/ensemble/_forest.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -745,7 +745,7 @@ def _onedal_gpu_supported(self, method_name, *data):
745745
or self.estimator.__class__ == DecisionTreeClassifier,
746746
"ExtraTrees only supported starting from oneDAL version 2023.1",
747747
),
748-
(sample_weight is not None, "sample_weight is not supported."),
748+
(sample_weight is None, "sample_weight is not supported."),
749749
]
750750
)
751751

@@ -1052,7 +1052,7 @@ def _onedal_gpu_supported(self, method_name, *data):
10521052
or self.estimator.__class__ == DecisionTreeClassifier,
10531053
"ExtraTrees only supported starting from oneDAL version 2023.1",
10541054
),
1055-
(sample_weight is not None, "sample_weight is not supported."),
1055+
(sample_weight is None, "sample_weight is not supported."),
10561056
]
10571057
)
10581058

sklearnex/ensemble/tests/test_forest.py

+15-19
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,7 @@ def test_sklearnex_import_rf_classifier(dataframe, queue):
4545
assert_allclose([1], _as_numpy(rf.predict([[0, 0, 0, 0]])))
4646

4747

48-
# TODO:
49-
# investigate failure for `dpnp.ndarrays` and `dpctl.tensors` on `GPU`
50-
@pytest.mark.parametrize(
51-
"dataframe,queue", get_dataframes_and_queues(device_filter_="cpu")
52-
)
48+
@pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues())
5349
def test_sklearnex_import_rf_regression(dataframe, queue):
5450
from sklearnex.ensemble import RandomForestRegressor
5551

@@ -59,17 +55,17 @@ def test_sklearnex_import_rf_regression(dataframe, queue):
5955
rf = RandomForestRegressor(max_depth=2, random_state=0).fit(X, y)
6056
assert "sklearnex" in rf.__module__
6157
pred = _as_numpy(rf.predict([[0, 0, 0, 0]]))
62-
if daal_check_version((2024, "P", 0)):
63-
assert_allclose([-6.971], pred, atol=1e-2)
58+
59+
if queue is not None and queue.sycl_device.is_gpu:
60+
assert_allclose([-0.011208], pred, atol=1e-2)
6461
else:
65-
assert_allclose([-6.839], pred, atol=1e-2)
62+
if daal_check_version((2024, "P", 0)):
63+
assert_allclose([-6.971], pred, atol=1e-2)
64+
else:
65+
assert_allclose([-6.839], pred, atol=1e-2)
6666

6767

68-
# TODO:
69-
# investigate failure for `dpnp.ndarrays` and `dpctl.tensors` on `GPU`
70-
@pytest.mark.parametrize(
71-
"dataframe,queue", get_dataframes_and_queues(device_filter_="cpu")
72-
)
68+
@pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues())
7369
def test_sklearnex_import_et_classifier(dataframe, queue):
7470
from sklearnex.ensemble import ExtraTreesClassifier
7571

@@ -90,11 +86,7 @@ def test_sklearnex_import_et_classifier(dataframe, queue):
9086
assert_allclose([1], _as_numpy(rf.predict([[0, 0, 0, 0]])))
9187

9288

93-
# TODO:
94-
# investigate failure for `dpnp.ndarrays` and `dpctl.tensors` on `GPU`
95-
@pytest.mark.parametrize(
96-
"dataframe,queue", get_dataframes_and_queues(device_filter_="cpu")
97-
)
89+
@pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues())
9890
def test_sklearnex_import_et_regression(dataframe, queue):
9991
from sklearnex.ensemble import ExtraTreesRegressor
10092

@@ -114,4 +106,8 @@ def test_sklearnex_import_et_regression(dataframe, queue):
114106
]
115107
)
116108
)
117-
assert_allclose([0.445], pred, atol=1e-2)
109+
110+
if queue is not None and queue.sycl_device.is_gpu:
111+
assert_allclose([1.909769], pred, atol=1e-2)
112+
else:
113+
assert_allclose([0.445], pred, atol=1e-2)

0 commit comments

Comments
 (0)