@@ -453,14 +453,12 @@ def __init__(
453
453
454
454
# The estimator is checked against the class attribute for conformance.
455
455
# This should only trigger if the user uses this class directly.
456
- if (
457
- self .estimator .__class__ == DecisionTreeClassifier
458
- and self ._onedal_factory != onedal_RandomForestClassifier
456
+ if self .estimator .__class__ == DecisionTreeClassifier and not issubclass (
457
+ self ._onedal_factory , onedal_RandomForestClassifier
459
458
):
460
459
self ._onedal_factory = onedal_RandomForestClassifier
461
- elif (
462
- self .estimator .__class__ == ExtraTreeClassifier
463
- and self ._onedal_factory != onedal_ExtraTreesClassifier
460
+ elif self .estimator .__class__ == ExtraTreeClassifier and not issubclass (
461
+ self ._onedal_factory , onedal_ExtraTreesClassifier
464
462
):
465
463
self ._onedal_factory = onedal_ExtraTreesClassifier
466
464
@@ -843,14 +841,12 @@ def __init__(
843
841
844
842
# The splitter is checked against the class attribute for conformance
845
843
# This should only trigger if the user uses this class directly.
846
- if (
847
- self .estimator .__class__ == DecisionTreeRegressor
848
- and self ._onedal_factory != onedal_RandomForestRegressor
844
+ if self .estimator .__class__ == DecisionTreeRegressor and not issubclass (
845
+ self ._onedal_factory , onedal_RandomForestRegressor
849
846
):
850
847
self ._onedal_factory = onedal_RandomForestRegressor
851
- elif (
852
- self .estimator .__class__ == ExtraTreeRegressor
853
- and self ._onedal_factory != onedal_ExtraTreesRegressor
848
+ elif self .estimator .__class__ == ExtraTreeRegressor and not issubclass (
849
+ self ._onedal_factory , onedal_ExtraTreesRegressor
854
850
):
855
851
self ._onedal_factory = onedal_ExtraTreesRegressor
856
852
0 commit comments