Skip to content

Commit 54f64c2

Browse files
enh: re-enabling spmd rf interfaces (#1700) (#1711)
* enh: re-enabling spmd rf interfaces * restoring onedal_factor usage but adding __class__ * lint * another attempt at it * isinstance to issubclass and re-adding self (cherry picked from commit bfa470b) Co-authored-by: ethanglaser <[email protected]>
1 parent 6c72039 commit 54f64c2

File tree

2 files changed

+12
-22
lines changed

2 files changed

+12
-22
lines changed

sklearnex/ensemble/_forest.py

+8-12
Original file line numberDiff line numberDiff line change
@@ -453,14 +453,12 @@ def __init__(
453453

454454
# The estimator is checked against the class attribute for conformance.
455455
# This should only trigger if the user uses this class directly.
456-
if (
457-
self.estimator.__class__ == DecisionTreeClassifier
458-
and self._onedal_factory != onedal_RandomForestClassifier
456+
if self.estimator.__class__ == DecisionTreeClassifier and not issubclass(
457+
self._onedal_factory, onedal_RandomForestClassifier
459458
):
460459
self._onedal_factory = onedal_RandomForestClassifier
461-
elif (
462-
self.estimator.__class__ == ExtraTreeClassifier
463-
and self._onedal_factory != onedal_ExtraTreesClassifier
460+
elif self.estimator.__class__ == ExtraTreeClassifier and not issubclass(
461+
self._onedal_factory, onedal_ExtraTreesClassifier
464462
):
465463
self._onedal_factory = onedal_ExtraTreesClassifier
466464

@@ -843,14 +841,12 @@ def __init__(
843841

844842
# The splitter is checked against the class attribute for conformance
845843
# This should only trigger if the user uses this class directly.
846-
if (
847-
self.estimator.__class__ == DecisionTreeRegressor
848-
and self._onedal_factory != onedal_RandomForestRegressor
844+
if self.estimator.__class__ == DecisionTreeRegressor and not issubclass(
845+
self._onedal_factory, onedal_RandomForestRegressor
849846
):
850847
self._onedal_factory = onedal_RandomForestRegressor
851-
elif (
852-
self.estimator.__class__ == ExtraTreeRegressor
853-
and self._onedal_factory != onedal_ExtraTreesRegressor
848+
elif self.estimator.__class__ == ExtraTreeRegressor and not issubclass(
849+
self._onedal_factory, onedal_ExtraTreesRegressor
854850
):
855851
self._onedal_factory = onedal_ExtraTreesRegressor
856852

sklearnex/spmd/ensemble/forest.py

+4-10
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,9 @@
2323
from ...ensemble import RandomForestRegressor as RandomForestRegressor_Batch
2424

2525

26-
class BaseForestSPMD(ABC):
27-
def _onedal_classifier(self, **onedal_params):
28-
return onedal_RandomForestClassifier(**onedal_params)
29-
30-
def _onedal_regressor(self, **onedal_params):
31-
return onedal_RandomForestRegressor(**onedal_params)
32-
33-
34-
class RandomForestClassifier(BaseForestSPMD, RandomForestClassifier_Batch):
26+
class RandomForestClassifier(RandomForestClassifier_Batch):
3527
__doc__ = RandomForestClassifier_Batch.__doc__
28+
_onedal_factory = onedal_RandomForestClassifier
3629

3730
def _onedal_cpu_supported(self, method_name, *data):
3831
# TODO:
@@ -55,8 +48,9 @@ def _onedal_gpu_supported(self, method_name, *data):
5548
return ready
5649

5750

58-
class RandomForestRegressor(BaseForestSPMD, RandomForestRegressor_Batch):
51+
class RandomForestRegressor(RandomForestRegressor_Batch):
5952
__doc__ = RandomForestRegressor_Batch.__doc__
53+
_onedal_factory = onedal_RandomForestRegressor
6054

6155
def _onedal_cpu_supported(self, method_name, *data):
6256
# TODO:

0 commit comments

Comments
 (0)