diff --git a/asv_benchmarks/benchmarks/ensemble.py b/asv_benchmarks/benchmarks/ensemble.py index c336d1e5f8805..877fcdb09fe68 100644 --- a/asv_benchmarks/benchmarks/ensemble.py +++ b/asv_benchmarks/benchmarks/ensemble.py @@ -2,6 +2,7 @@ GradientBoostingClassifier, HistGradientBoostingClassifier, RandomForestClassifier, + RandomForestRegressor, ) from .common import Benchmark, Estimator, Predictor @@ -9,8 +10,50 @@ _20newsgroups_highdim_dataset, _20newsgroups_lowdim_dataset, _synth_classification_dataset, + _synth_regression_dataset, + _synth_regression_sparse_dataset, ) -from .utils import make_gen_classif_scorers +from .utils import make_gen_classif_scorers, make_gen_reg_scorers + + +class RandomForestRegressorBenchmark(Predictor, Estimator, Benchmark): + """ + Benchmarks for RandomForestRegressor. + """ + + param_names = ["representation", "n_jobs"] + params = (["dense", "sparse"], Benchmark.n_jobs_vals) + + def setup_cache(self): + super().setup_cache() + + def make_data(self, params): + representation, n_jobs = params + + if representation == "sparse": + data = _synth_regression_sparse_dataset() + else: + data = _synth_regression_dataset() + + return data + + def make_estimator(self, params): + representation, n_jobs = params + + n_estimators = 500 if Benchmark.data_size == "large" else 100 + + estimator = RandomForestRegressor( + n_estimators=n_estimators, + min_samples_split=10, + max_features="log2", + n_jobs=n_jobs, + random_state=0, + ) + + return estimator + + def make_scorers(self): + make_gen_reg_scorers(self) class RandomForestClassifierBenchmark(Predictor, Estimator, Benchmark): diff --git a/sklearn/ensemble/__init__.py b/sklearn/ensemble/__init__.py index 2a8cf413be9da..2e304ddc61b6a 100644 --- a/sklearn/ensemble/__init__.py +++ b/sklearn/ensemble/__init__.py @@ -8,6 +8,7 @@ from ._forest import ( ExtraTreesClassifier, ExtraTreesRegressor, + HonestRandomForestClassifier, RandomForestClassifier, RandomForestRegressor, RandomTreesEmbedding, @@ -24,6 +25,7 @@ __all__ = [ "BaseEnsemble", + "HonestRandomForestClassifier", "RandomForestClassifier", "RandomForestRegressor", "RandomTreesEmbedding", diff --git a/sklearn/ensemble/_forest.py b/sklearn/ensemble/_forest.py index 57a4750c612bd..ed1eef3041683 100644 --- a/sklearn/ensemble/_forest.py +++ b/sklearn/ensemble/_forest.py @@ -84,9 +84,11 @@ class calls the ``fit`` method of each sub-estimator on random samples ExtraTreeClassifier, ExtraTreeRegressor, ) +from ..tree._honest_tree import HonestDecisionTree from ..tree._tree import DOUBLE, DTYPE __all__ = [ + "HonestRandomForestClassifier", "RandomForestClassifier", "RandomForestRegressor", "ExtraTreesClassifier", @@ -2089,7 +2091,7 @@ class labels (multi-output problem). dict, list, None, - ], + ] } _parameter_constraints.pop("splitter") @@ -2116,7 +2118,7 @@ def __init__( max_samples=None, max_bins=None, store_leaf_values=False, - monotonic_cst=None, + monotonic_cst=None ): super().__init__( estimator=DecisionTreeClassifier(), @@ -2169,6 +2171,492 @@ def __sklearn_tags__(self): return tags +class HonestRandomForestClassifier(ForestClassifier): + """ + A random forest classifier. + + A random forest is a meta estimator that fits a number of decision tree + classifiers on various sub-samples of the dataset and uses averaging to + improve the predictive accuracy and control over-fitting. + Trees in the forest use the best split strategy, i.e. equivalent to passing + `splitter="best"` to the underlying :class:`~sklearn.tree.DecisionTreeRegressor`. + The sub-sample size is controlled with the `max_samples` parameter if + `bootstrap=True` (default), otherwise the whole dataset is used to build + each tree. + + For a comparison between tree-based ensemble models see the example + :ref:`sphx_glr_auto_examples_ensemble_plot_forest_hist_grad_boosting_comparison.py`. + + Read more in the :ref:`User Guide `. + + Parameters + ---------- + n_estimators : int, default=100 + The number of trees in the forest. + + .. versionchanged:: 0.22 + The default value of ``n_estimators`` changed from 10 to 100 + in 0.22. + + criterion : {"gini", "entropy", "log_loss"}, default="gini" + The function to measure the quality of a split. Supported criteria are + "gini" for the Gini impurity and "log_loss" and "entropy" both for the + Shannon information gain, see :ref:`tree_mathematical_formulation`. + Note: This parameter is tree-specific. + + max_depth : int, default=None + The maximum depth of the tree. If None, then nodes are expanded until + all leaves are pure or until all leaves contain less than + min_samples_split samples. + + min_samples_split : int or float, default=2 + The minimum number of samples required to split an internal node: + + - If int, then consider `min_samples_split` as the minimum number. + - If float, then `min_samples_split` is a fraction and + `ceil(min_samples_split * n_samples)` are the minimum + number of samples for each split. + + .. versionchanged:: 0.18 + Added float values for fractions. + + min_samples_leaf : int or float, default=1 + The minimum number of samples required to be at a leaf node. + A split point at any depth will only be considered if it leaves at + least ``min_samples_leaf`` training samples in each of the left and + right branches. This may have the effect of smoothing the model, + especially in regression. + + - If int, then consider `min_samples_leaf` as the minimum number. + - If float, then `min_samples_leaf` is a fraction and + `ceil(min_samples_leaf * n_samples)` are the minimum + number of samples for each node. + + .. versionchanged:: 0.18 + Added float values for fractions. + + min_weight_fraction_leaf : float, default=0.0 + The minimum weighted fraction of the sum total of weights (of all + the input samples) required to be at a leaf node. Samples have + equal weight when sample_weight is not provided. + + max_features : {"sqrt", "log2", None}, int or float, default="sqrt" + The number of features to consider when looking for the best split: + + - If int, then consider `max_features` features at each split. + - If float, then `max_features` is a fraction and + `max(1, int(max_features * n_features_in_))` features are considered at each + split. + - If "sqrt", then `max_features=sqrt(n_features)`. + - If "log2", then `max_features=log2(n_features)`. + - If None, then `max_features=n_features`. + + .. versionchanged:: 1.1 + The default of `max_features` changed from `"auto"` to `"sqrt"`. + + Note: the search for a split does not stop until at least one + valid partition of the node samples is found, even if it requires to + effectively inspect more than ``max_features`` features. + + max_leaf_nodes : int, default=None + Grow trees with ``max_leaf_nodes`` in best-first fashion. + Best nodes are defined as relative reduction in impurity. + If None then unlimited number of leaf nodes. + + min_impurity_decrease : float, default=0.0 + A node will be split if this split induces a decrease of the impurity + greater than or equal to this value. + + The weighted impurity decrease equation is the following:: + + N_t / N * (impurity - N_t_R / N_t * right_impurity + - N_t_L / N_t * left_impurity) + + where ``N`` is the total number of samples, ``N_t`` is the number of + samples at the current node, ``N_t_L`` is the number of samples in the + left child, and ``N_t_R`` is the number of samples in the right child. + + ``N``, ``N_t``, ``N_t_R`` and ``N_t_L`` all refer to the weighted sum, + if ``sample_weight`` is passed. + + .. versionadded:: 0.19 + + bootstrap : bool, default=True + Whether bootstrap samples are used when building trees. If False, the + whole dataset is used to build each tree. + + oob_score : bool or callable, default=False + Whether to use out-of-bag samples to estimate the generalization score. + By default, :func:`~sklearn.metrics.accuracy_score` is used. + Provide a callable with signature `metric(y_true, y_pred)` to use a + custom metric. Only available if `bootstrap=True`. + + n_jobs : int, default=None + The number of jobs to run in parallel. :meth:`fit`, :meth:`predict`, + :meth:`decision_path` and :meth:`apply` are all parallelized over the + trees. ``None`` means 1 unless in a :obj:`joblib.parallel_backend` + context. ``-1`` means using all processors. See :term:`Glossary + ` for more details. + + random_state : int, RandomState instance or None, default=None + Controls both the randomness of the bootstrapping of the samples used + when building trees (if ``bootstrap=True``) and the sampling of the + features to consider when looking for the best split at each node + (if ``max_features < n_features``). + See :term:`Glossary ` for details. + + verbose : int, default=0 + Controls the verbosity when fitting and predicting. + + warm_start : bool, default=False + When set to ``True``, reuse the solution of the previous call to fit + and add more estimators to the ensemble, otherwise, just fit a whole + new forest. See :term:`Glossary ` and + :ref:`tree_ensemble_warm_start` for details. + + class_weight : {"balanced", "balanced_subsample"}, dict or list of dicts, \ + default=None + Weights associated with classes in the form ``{class_label: weight}``. + If not given, all classes are supposed to have weight one. For + multi-output problems, a list of dicts can be provided in the same + order as the columns of y. + + Note that for multioutput (including multilabel) weights should be + defined for each class of every column in its own dict. For example, + for four-class multilabel classification weights should be + [{0: 1, 1: 1}, {0: 1, 1: 5}, {0: 1, 1: 1}, {0: 1, 1: 1}] instead of + [{1:1}, {2:5}, {3:1}, {4:1}]. + + The "balanced" mode uses the values of y to automatically adjust + weights inversely proportional to class frequencies in the input data + as ``n_samples / (n_classes * np.bincount(y))`` + + The "balanced_subsample" mode is the same as "balanced" except that + weights are computed based on the bootstrap sample for every tree + grown. + + For multi-output, the weights of each column of y will be multiplied. + + Note that these weights will be multiplied with sample_weight (passed + through the fit method) if sample_weight is specified. + + ccp_alpha : non-negative float, default=0.0 + Complexity parameter used for Minimal Cost-Complexity Pruning. The + subtree with the largest cost complexity that is smaller than + ``ccp_alpha`` will be chosen. By default, no pruning is performed. See + :ref:`minimal_cost_complexity_pruning` for details. + + .. versionadded:: 0.22 + + max_samples : int or float, default=None + If bootstrap is True, the number of samples to draw from X + to train each base estimator. + + - If None (default), then draw `X.shape[0]` samples. + - If int, then draw `max_samples` samples. + - If float, then draw `max(round(n_samples * max_samples), 1)` samples. Thus, + `max_samples` should be in the interval `(0.0, 1.0]`. + + .. versionadded:: 0.22 + + max_bins : int, default=255 + The maximum number of bins to use for non-missing values. + + **This is an experimental feature**. + + store_leaf_values : bool, default=False + Whether to store the leaf values in the ``get_leaf_node_samples`` function. + + **This is an experimental feature**. + + monotonic_cst : array-like of int of shape (n_features), default=None + Indicates the monotonicity constraint to enforce on each feature. + - 1: monotonic increase + - 0: no constraint + - -1: monotonic decrease + + If monotonic_cst is None, no constraints are applied. + + Monotonicity constraints are not supported for: + - multiclass classifications (i.e. when `n_classes > 2`), + - multioutput classifications (i.e. when `n_outputs_ > 1`), + - classifications trained on data with missing values. + + The constraints hold over the probability of the positive class. + + Read more in the :ref:`User Guide `. + + .. versionadded:: 1.4 + + Attributes + ---------- + estimator_ : :class:`~sklearn.tree.DecisionTreeClassifier` + The child estimator template used to create the collection of fitted + sub-estimators. + + .. versionadded:: 1.2 + `base_estimator_` was renamed to `estimator_`. + + estimators_ : list of DecisionTreeClassifier + The collection of fitted sub-estimators. + + classes_ : ndarray of shape (n_classes,) or a list of such arrays + The classes labels (single output problem), or a list of arrays of + class labels (multi-output problem). + + n_classes_ : int or list + The number of classes (single output problem), or a list containing the + number of classes for each output (multi-output problem). + + n_features_in_ : int + Number of features seen during :term:`fit`. + + .. versionadded:: 0.24 + + feature_names_in_ : ndarray of shape (`n_features_in_`,) + Names of features seen during :term:`fit`. Defined only when `X` + has feature names that are all strings. + + .. versionadded:: 1.0 + + n_outputs_ : int + The number of outputs when ``fit`` is performed. + + feature_importances_ : ndarray of shape (n_features,) + The impurity-based feature importances. + The higher, the more important the feature. + The importance of a feature is computed as the (normalized) + total reduction of the criterion brought by that feature. It is also + known as the Gini importance. + + Warning: impurity-based feature importances can be misleading for + high cardinality features (many unique values). See + :func:`sklearn.inspection.permutation_importance` as an alternative. + + oob_score_ : float + Score of the training dataset obtained using an out-of-bag estimate. + This attribute exists only when ``oob_score`` is True. + + oob_decision_function_ : ndarray of shape (n_samples, n_classes) or \ + (n_samples, n_classes, n_outputs) + Decision function computed with out-of-bag estimate on the training + set. If n_estimators is small it might be possible that a data point + was never left out during the bootstrap. In this case, + `oob_decision_function_` might contain NaN. This attribute exists + only when ``oob_score`` is True. + + estimators_samples_ : list of arrays + The subset of drawn samples (i.e., the in-bag samples) for each base + estimator. Each subset is defined by an array of the indices selected. + + .. versionadded:: 1.4 + + See Also + -------- + sklearn.tree.DecisionTreeClassifier : A decision tree classifier. + sklearn.ensemble.ExtraTreesClassifier : Ensemble of extremely randomized + tree classifiers. + sklearn.ensemble.HistGradientBoostingClassifier : A Histogram-based Gradient + Boosting Classification Tree, very fast for big datasets (n_samples >= + 10_000). + + Notes + ----- + The default values for the parameters controlling the size of the trees + (e.g. ``max_depth``, ``min_samples_leaf``, etc.) lead to fully grown and + unpruned trees which can potentially be very large on some data sets. To + reduce memory consumption, the complexity and size of the trees should be + controlled by setting those parameter values. + + The features are always randomly permuted at each split. Therefore, + the best found split may vary, even with the same training data, + ``max_features=n_features`` and ``bootstrap=False``, if the improvement + of the criterion is identical for several splits enumerated during the + search of the best split. To obtain a deterministic behaviour during + fitting, ``random_state`` has to be fixed. + + References + ---------- + .. [1] L. Breiman, "Random Forests", Machine Learning, 45(1), 5-32, 2001. + + Examples + -------- + >>> from sklearn.ensemble import RandomForestClassifier + >>> from sklearn.datasets import make_classification + >>> X, y = make_classification(n_samples=1000, n_features=4, + ... n_informative=2, n_redundant=0, + ... random_state=0, shuffle=False) + >>> clf = RandomForestClassifier(max_depth=2, random_state=0) + >>> clf.fit(X, y) + RandomForestClassifier(...) + >>> print(clf.predict([[0, 0, 0, 0]])) + [1] + """ + + _parameter_constraints: dict = { + **ForestClassifier._parameter_constraints, + **DecisionTreeClassifier._parameter_constraints, + **HonestDecisionTree._parameter_constraints, + "class_weight": [ + StrOptions({"balanced_subsample", "balanced"}), + dict, + list, + None, + ] + } + _parameter_constraints.pop("splitter") + _parameter_constraints.pop("max_samples") + _parameter_constraints["max_samples"] = [ + None, + Interval(RealNotInt, 0.0, None, closed="right"), + Interval(Integral, 1, None, closed="left"), + ] + + @staticmethod + def _generate_sample_indices(tree, random_state, n_samples): + return _generate_sample_indices(tree, random_state, n_samples) + + def __init__( + self, + n_estimators=100, + *, + target_tree_class=DecisionTreeClassifier, + criterion="gini", + max_depth=None, + min_samples_split=2, + min_samples_leaf=1, + min_weight_fraction_leaf=0.0, + max_features="sqrt", + max_leaf_nodes=None, + min_impurity_decrease=0.0, + bootstrap=True, + oob_score=False, + n_jobs=None, + random_state=None, + verbose=0, + warm_start=False, + class_weight=None, + ccp_alpha=0.0, + max_samples=None, + max_bins=None, + store_leaf_values=False, + monotonic_cst=None, + stratify=False, + honest_prior="ignore", + honest_fraction=0.5 + ): + self.target_tree_kwargs = { + "criterion": criterion, + "max_depth": max_depth, + "min_samples_split": min_samples_split, + "min_samples_leaf": min_samples_leaf, + "min_weight_fraction_leaf": min_weight_fraction_leaf, + "max_features": max_features, + "max_leaf_nodes": max_leaf_nodes, + "min_impurity_decrease": min_impurity_decrease, + "random_state": random_state, + "ccp_alpha": ccp_alpha, + "store_leaf_values": store_leaf_values, + "monotonic_cst": monotonic_cst + } + super().__init__( + estimator=HonestDecisionTree( + target_tree_class=target_tree_class, + target_tree_kwargs=self.target_tree_kwargs, + stratify=stratify, + honest_prior=honest_prior, + honest_fraction=honest_fraction, + random_state=random_state + ), + n_estimators=n_estimators, + estimator_params=( + "target_tree_class", + "target_tree_kwargs", + "stratify", + "honest_prior", + "honest_fraction", + "random_state" + ), + # estimator_params=( + # "criterion", + # "max_depth", + # "min_samples_split", + # "min_samples_leaf", + # "min_weight_fraction_leaf", + # "max_features", + # "max_leaf_nodes", + # "min_impurity_decrease", + # "random_state", + # "ccp_alpha", + # "store_leaf_values", + # "monotonic_cst", + # ), + bootstrap=bootstrap, + oob_score=oob_score, + n_jobs=n_jobs, + random_state=random_state, + verbose=verbose, + warm_start=warm_start, + class_weight=class_weight, + max_samples=max_samples, + max_bins=max_bins, + store_leaf_values=store_leaf_values, + ) + + self.criterion = criterion + self.max_depth = max_depth + self.min_samples_split = min_samples_split + self.min_samples_leaf = min_samples_leaf + self.min_weight_fraction_leaf = min_weight_fraction_leaf + self.max_features = max_features + self.max_leaf_nodes = max_leaf_nodes + self.min_impurity_decrease = min_impurity_decrease + self.monotonic_cst = monotonic_cst + self.ccp_alpha = ccp_alpha + self.target_tree_class = target_tree_class + self.stratify = stratify + self.honest_prior = honest_prior + self.honest_fraction = honest_fraction + + + @property + def structure_indices_(self): + """The indices used to learn the structure of the trees.""" + check_is_fitted(self) + return [tree.structure_indices_ for tree in self.estimators_] + + @property + def honest_indices_(self): + """The indices used to fit the leaf nodes.""" + check_is_fitted(self) + return [tree.honest_indices_ for tree in self.estimators_] + + @property + def oob_samples_(self): + """The sample indices that are out-of-bag. + + Only utilized if ``bootstrap=True``, otherwise, all samples are "in-bag". + """ + if self.bootstrap is False and ( + self._n_samples_bootstrap is None or self._n_samples_bootstrap == self._n_samples + ): + raise RuntimeError( + "Cannot extract out-of-bag samples when bootstrap is False and " + "n_samples == n_samples_bootstrap" + ) + check_is_fitted(self) + + oob_samples = [] + + possible_indices = np.arange(self._n_samples) + for structure_idx, honest_idx in zip(self.structure_indices_, self.honest_indices_): + _oob_samples = np.setdiff1d( + possible_indices, np.concatenate((structure_idx, honest_idx)) + ) + oob_samples.append(_oob_samples) + return oob_samples + + class RandomForestRegressor(ForestRegressor): """ A random forest regressor. diff --git a/sklearn/ensemble/tests/test_forest.py b/sklearn/ensemble/tests/test_forest.py index 51fbb3e823726..ea7f899dc5851 100644 --- a/sklearn/ensemble/tests/test_forest.py +++ b/sklearn/ensemble/tests/test_forest.py @@ -34,6 +34,7 @@ from sklearn.ensemble._forest import ( _generate_unsampled_indices, _get_n_samples_bootstrap, + HonestRandomForestClassifier, ) from sklearn.exceptions import NotFittedError from sklearn.metrics import ( @@ -44,6 +45,7 @@ ) from sklearn.model_selection import GridSearchCV, cross_val_score, train_test_split from sklearn.svm import LinearSVC +from sklearn.tree.tests.test_tree import make_trunk_classification from sklearn.tree._classes import SPARSE_SPLITTERS from sklearn.utils._testing import ( _convert_container, @@ -270,6 +272,137 @@ def test_iris_criterion(name, criterion): score = clf.score(iris.data, iris.target) assert score > 0.5, "Failed with criterion %s and score = %f" % (criterion, score) +@pytest.mark.parametrize("criterion", ("gini", "log_loss")) +def test_honest_forest_iris_criterion(criterion): + # Check consistency on dataset iris. + clf = HonestRandomForestClassifier( + n_estimators=10, criterion=criterion, random_state=1 + ) + clf.fit(iris.data, iris.target) + score = clf.score(iris.data, iris.target) + assert score > 0.9, "Failed with criterion %s and score = %f" % (criterion, score) + + clf = HonestRandomForestClassifier( + n_estimators=10, criterion=criterion, max_features=2, random_state=1 + ) + clf.fit(iris.data, iris.target) + score = clf.score(iris.data, iris.target) + assert score > 0.5, "Failed with criterion %s and score = %f" % (criterion, score) + + +def test_honest_forest_separation(): + # verify that splits by trees in an honest forest are made independent of honest + # Y labels. this can't be done using the shuffle test method used in the tree + # tests because in a forest using stratified sampling, the honest Y labels are + # used to determine the stratification, making it impossible to both shuffle the + # Y labels and keep the honest index selection fixed between trials. thus we must + # use a different method to test forests, which is simply to run two trials, + # shifting the honest X values in the second trial such that any split which + # considered the honest Y labels must move. we also do a third trial moving some + # of the structure X values to verify that moving X's under consideration would + # in fact alter splits, obvious as it may seem. + # + # in order for this test to work, one must ensure that the honest split rejection + # criteria never veto a desired split by the shadow structure tree. + # the lazy way to do this is to make sure there are enough honest observations + # so that there will be enough on either side of any potential structure split. + # thus more dims => more samples + N_TREES = 1 + N_DIM = 10 + SAMPLE_SIZE = 2098 + RANDOM_STATE = 1 + HONEST_FRACTION = 0.95 + STRATIFY = True + + X, y = make_trunk_classification( + n_samples=SAMPLE_SIZE, + n_dim=N_DIM, + n_informative=1, + seed=0, + mu_0=-5, + mu_1=5 + ) + X_t = np.concatenate(( + X[: SAMPLE_SIZE // 2], + X[SAMPLE_SIZE // 2 :] + )) + y_t = np.concatenate(( + y[: SAMPLE_SIZE // 2], + y[SAMPLE_SIZE // 2 :] + )) + + + def perturb(X, y, indices): + for d in range(N_DIM): + for i in indices: + if y[i] == 0 and np.random.randint(0, 2, 1) > 0: + X[i, d] -= 5 + elif np.random.randint(0, 2, 1) > 0: + X[i, d] -= 2 + + return X, y + + + class Trial: + def __init__(self, X, y): + self.est = HonestRandomForestClassifier( + n_estimators=N_TREES, + max_samples=1.0, + max_features=0.3, + bootstrap=True, + stratify=STRATIFY, + n_jobs=-2, + random_state=RANDOM_STATE, + honest_prior="ignore", + honest_fraction=HONEST_FRACTION, + ) + self.est.fit(X, y) + + self.tree = self.est.estimators_[0] + self.honest_tree = self.tree.tree_ + self.structure_tree = self.honest_tree.target_tree + self.honest_indices = np.sort(self.tree.honest_indices_) + self.structure_indices = np.sort(self.tree.structure_indices_) + self.threshold = self.honest_tree.target_tree.threshold.copy() + + + trial_results = [] + trial_results.append(Trial(X_t, y_t)) + + # perturb honest X values; threshold should not change + X_t, y_t = perturb(X_t, y_t, trial_results[0].honest_indices) + + trial_results.append(Trial(X_t, y_t)) + assert np.array_equal( + trial_results[0].honest_indices, + trial_results[1].honest_indices + ) + assert np.array_equal( + trial_results[0].structure_indices, + trial_results[1].structure_indices + ) + assert np.array_equal( + trial_results[0].threshold, + trial_results[1].threshold + ), f"threshold1 = {trial_results[0].threshold}\nthreshold2 = {trial_results[1].threshold}" + + + # perturb structure X's; threshold should change + X_t, y_t = perturb(X_t, y_t, trial_results[0].structure_indices) + trial_results.append(Trial(X_t, y_t)) + assert np.array_equal( + trial_results[0].honest_indices, + trial_results[2].honest_indices + ) + assert np.array_equal( + trial_results[0].structure_indices, + trial_results[2].structure_indices + ) + assert not np.array_equal( + trial_results[0].threshold, + trial_results[2].threshold + ) + @pytest.mark.parametrize("name", FOREST_REGRESSORS) @pytest.mark.parametrize( diff --git a/sklearn/tree/__init__.py b/sklearn/tree/__init__.py index c961a811fe05c..8eec4a25dc7c4 100644 --- a/sklearn/tree/__init__.py +++ b/sklearn/tree/__init__.py @@ -10,10 +10,12 @@ ExtraTreeClassifier, ExtraTreeRegressor, ) +from ._honest_tree import HonestDecisionTree from ._export import export_graphviz, export_text, plot_tree __all__ = [ "BaseDecisionTree", + "HonestDecisionTree", "DecisionTreeClassifier", "DecisionTreeRegressor", "ExtraTreeClassifier", diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index 7a49c6dc93485..e86262ece6af6 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -91,6 +91,36 @@ # ============================================================================= +class BuildTreeArgs: + def __init__( + self, + X, + y, + sample_weight, + missing_values_in_feature_mask, + min_samples_leaf, + min_weight_leaf, + max_leaf_nodes, + min_samples_split, + max_depth, + random_state, + classes, + n_classes + ): + self.X = X + self.y = y + self.sample_weight = sample_weight + self.missing_values_in_feature_mask = missing_values_in_feature_mask + self.min_samples_leaf = min_samples_leaf + self.min_weight_leaf = min_weight_leaf + self.max_leaf_nodes = max_leaf_nodes + self.min_samples_split = min_samples_split + self.max_depth = max_depth + self.random_state = random_state + self.classes = classes + self.n_classes = n_classes + + class BaseDecisionTree(MultiOutputMixin, BaseEstimator, metaclass=ABCMeta): """Base class for decision trees. @@ -162,6 +192,10 @@ def __init__( self.ccp_alpha = ccp_alpha self.store_leaf_values = store_leaf_values self.monotonic_cst = monotonic_cst + self.presplit_conditions = None + self.postsplit_conditions = None + self.splitter_listeners = None + self.tree_build_listeners = None def get_depth(self): """Return the depth of the decision tree. @@ -235,7 +269,7 @@ def _compute_missing_values_in_feature_mask(self, X, estimator_name=None): missing_values_in_feature_mask = _any_isnan_axis0(X) return missing_values_in_feature_mask - def _fit( + def _prep_data( self, X, y, @@ -412,8 +446,7 @@ def _fit( min_weight_leaf = self.min_weight_fraction_leaf * np.sum(sample_weight) self.min_weight_leaf_ = min_weight_leaf - # build the actual tree now with the parameters - self = self._build_tree( + return BuildTreeArgs( X=X, y=y, sample_weight=sample_weight, @@ -424,12 +457,88 @@ def _fit( min_samples_split=min_samples_split, max_depth=max_depth, random_state=random_state, + classes=classes, + n_classes=getattr(self, 'n_classes_', None) ) - return self + + # The existing implementation of _fit was almost nothing but data prep and + # state initialization, followed by a call to _build_tree. This made it + # impossible to tweak _fit ever so slightly without duplicating a lot of + # code. So we've modularized it a bit. + def _fit( + self, + X, + y, + sample_weight=None, + check_input=True, + missing_values_in_feature_mask=None, + classes=None, + ): + bta = self._prep_data( + X=X, + y=y, + sample_weight=sample_weight, + check_input=check_input, + missing_values_in_feature_mask=missing_values_in_feature_mask, + classes=classes + ) + + # Criterion can't be created until we do the class distribution analysis + # in _prep_data, so we have to create it here, and best to do it as a + # factory which can be overridden if necessary. This used to be in + # _build_tree, but that is the wrong place to commit to a particular + # implementation; it should be passed in as a parameter. + criterion = BaseDecisionTree._create_criterion( + self, + n_outputs=bta.y.shape[1], + n_samples=bta.X.shape[0], + n_classes=bta.n_classes + ) + + # build the actual tree now with the parameters + return self._build_tree( + criterion=criterion, + X=bta.X, + y=bta.y, + sample_weight=bta.sample_weight, + missing_values_in_feature_mask=bta.missing_values_in_feature_mask, + min_samples_leaf=bta.min_samples_leaf, + min_weight_leaf=bta.min_weight_leaf, + max_leaf_nodes=bta.max_leaf_nodes, + min_samples_split=bta.min_samples_split, + max_depth=bta.max_depth, + random_state=bta.random_state, + ) + + @staticmethod + # n_classes is an array of length n_outputs + # containing the number of classes in each output dimension + def _create_criterion( + tree: "BaseDecisionTree", + n_outputs, + n_samples, + n_classes=None + ) -> BaseCriterion: + criterion = tree.criterion + if not isinstance(tree.criterion, BaseCriterion): + if is_classifier(tree): + criterion = CRITERIA_CLF[tree.criterion]( + n_outputs, n_classes + ) + else: + criterion = CRITERIA_REG[tree.criterion](n_outputs, n_samples) + else: + # Make a deepcopy in case the criterion has mutable attributes that + # might be shared and modified concurrently during parallel fitting + criterion = copy.deepcopy(tree.criterion) + + return criterion + def _build_tree( self, + criterion, X, y, sample_weight, @@ -466,20 +575,6 @@ def _build_tree( """ n_samples = X.shape[0] - # Build tree - criterion = self.criterion - if not isinstance(criterion, BaseCriterion): - if is_classifier(self): - criterion = CRITERIA_CLF[self.criterion]( - self.n_outputs_, self.n_classes_ - ) - else: - criterion = CRITERIA_REG[self.criterion](self.n_outputs_, n_samples) - else: - # Make a deepcopy in case the criterion has mutable attributes that - # might be shared and modified concurrently during parallel fitting - criterion = copy.deepcopy(criterion) - SPLITTERS = SPARSE_SPLITTERS if issparse(X) else DENSE_SPLITTERS if self.monotonic_cst is None: @@ -530,6 +625,9 @@ def _build_tree( min_weight_leaf, random_state, monotonic_cst, + presplit_conditions=self.presplit_conditions, + postsplit_conditions=self.postsplit_conditions, + listeners=self.splitter_listeners ) if is_classifier(self): @@ -552,6 +650,7 @@ def _build_tree( max_depth, self.min_impurity_decrease, self.store_leaf_values, + listeners = self.tree_build_listeners ) else: builder = BestFirstTreeBuilder( @@ -563,6 +662,7 @@ def _build_tree( max_leaf_nodes, self.min_impurity_decrease, self.store_leaf_values, + listeners = self.tree_build_listeners ) builder.build(self.tree_, X, y, sample_weight, missing_values_in_feature_mask) diff --git a/sklearn/tree/_events.pxd b/sklearn/tree/_events.pxd new file mode 100644 index 0000000000000..1dc9b0a87f116 --- /dev/null +++ b/sklearn/tree/_events.pxd @@ -0,0 +1,63 @@ +# Authors: Samuel Carliles +# +# License: BSD 3 clause + +# See _events.pyx for details. + +from libcpp.vector cimport vector +from ..utils._typedefs cimport float32_t, float64_t, intp_t, int32_t, uint32_t + + +# a simple, general purpose event broker. +# +# it utilizes a somewhat clunky interface built around an event handler closure +# struct, as we are trying to balance generality with execution speed, and in +# practice nothing's faster than simply applying a function pointer. +# +# the idea is we would like something like a closure for event handlers, so that +# we may bind instances to instance-specific parameter values, like say you have +# a "threshold" parameter and you would like threshold-dependent handler behavior, +# but you want this threshold configurable at runtime. so we keep this threshold +# parameter in an environment bound to a "closure" instance, which is just a struct +# with a pointer to the environment instance and handler function. now vectors of +# these closures are compact, fast to iterate through, and low overhead to execute. +# +# the idea with EventType is that you have an event broker handling a class of +# conceptually related events, like suppose "server" events, and EventType would +# typically be values from an enum like say: +# +# cdef enum ServerEvent: +# SERVER_UP = 1 +# SERVER_DOWN = 2 +# SERVER_ON_FIRE = 3 +# +# an assumption of the current implementation is that these enum values are small +# integers, and we use them to allocate and index into a listener vector. +# +# EventData is simply a pointer to whatever event payload information is relevant +# to your handler, and it is expected that event_type maps to an associated handler +# which knows what specific "concrete" type to cast its event_data to. + +ctypedef int EventType +ctypedef void* EventHandlerEnv +ctypedef void* EventData +ctypedef bint (*EventHandlerFunction)( + EventType event_type, + EventHandlerEnv handler_env, + EventData event_data +) noexcept nogil + +cdef struct EventHandlerClosure: + EventHandlerFunction f + EventHandlerEnv e + +cdef class EventHandler: + cdef public int[:] event_types + cdef EventHandlerClosure c + +cdef class NullHandler(EventHandler): + pass + +cdef class EventBroker: + cdef vector[vector[EventHandlerClosure]] listeners # listeners acts as a map from EventType to corresponding event handlers + cdef bint fire_event(self, EventType event_type, EventData event_data) noexcept nogil diff --git a/sklearn/tree/_events.pyx b/sklearn/tree/_events.pyx new file mode 100644 index 0000000000000..7a143be44d487 --- /dev/null +++ b/sklearn/tree/_events.pyx @@ -0,0 +1,57 @@ + +# Authors: Samuel Carliles +# +# License: BSD 3 clause + + +cdef class EventBroker: + def __cinit__(self, listeners: [EventHandler], event_types: [EventType]): + """ + Parameters: + - listeners ([EventHandler]) + - event_types ([EventType]): a list of EventTypes that may be fired by this EventBroker + + Notes: + - Don't mix event types in a single EventBroker instance, + i.e. don't use the same EventBroker for brokering NodeSplitEvent that you use + for brokering TreeBuildEvent, etc + """ + self.listeners.resize(max(event_types) + 1) + + if(listeners is None): + for e in range(max(event_types) + 1): + self.listeners[e].resize(0) + else: + self.add_listeners(listeners, event_types) + + def add_listeners(self, listeners: [EventHandler], event_types: [EventType]): + cdef int e, i, j, offset, mx, ct + cdef list l + + # listeners is a vector of vectors which we index using EventType, + # so if event_types contains any EventType for which we don't already have a vector, + # its integer value will be larger than our current size + 1 + mx = max(event_types) + offset = self.listeners.size() + if mx > offset + 1: + self.listeners.resize(mx + 1) + + if(listeners is not None): + for e in event_types: + # find indices for all listeners to event type e + l = [j for j, _l in enumerate(listeners) if e in (_l).event_types] + offset = self.listeners[e].size() + ct = len(l) + self.listeners[e].resize(offset + ct) + for i in range(ct): + j = l[i] + self.listeners[e][offset + i] = (listeners[j]).c + + cdef bint fire_event(self, EventType event_type, EventData event_data) noexcept nogil: + cdef bint result = True + + if event_type < self.listeners.size(): + for l in self.listeners[event_type]: + result = result and l.f(event_type, l.e, event_data) + + return result diff --git a/sklearn/tree/_honest_tree.py b/sklearn/tree/_honest_tree.py new file mode 100644 index 0000000000000..96e27ed1eaf9a --- /dev/null +++ b/sklearn/tree/_honest_tree.py @@ -0,0 +1,377 @@ +# Authors: Haoyin Xu +# Samuel Carliles +# +# Adopted from: https://github.com/neurodata/honest-forests + +# An honest classification tree implemented by inheriting BaseDecisionTree and +# including the honesty module. The general idea is that: +# +# 1. The interface looks mostly like a regular DecisionTree, and we inherit as +# much of the implementation as we can. +# 2. Rather than actually being our own tree however, we have a target tree for +# learning the structure which is just a regular DecisionTree trained on the +# structure sample, and an honesty instance which grows the shadow tree described +# in the honesty module. + +import numpy as np +from numpy import float32 as DTYPE + +from ..base import _fit_context, is_classifier +from ..model_selection import StratifiedShuffleSplit +from ..utils import compute_sample_weight +from ..utils._param_validation import Interval, RealNotInt, StrOptions +from ..utils.multiclass import check_classification_targets + +from ._classes import ( + BaseDecisionTree, + CRITERIA_CLF, CRITERIA_REG, DENSE_SPLITTERS, SPARSE_SPLITTERS +) +from ._honesty import HonestTree, Honesty +from ._tree import DOUBLE, Tree + +import inspect + + +# note: max_n_classes is the maximum number of classes observed +# in any response variable dimension +class HonestDecisionTree(BaseDecisionTree): + _parameter_constraints: dict = { + **BaseDecisionTree._parameter_constraints, + "target_tree_class": "no_validation", + "target_tree_kwargs": [dict], + "honest_fraction": [Interval(RealNotInt, 0.0, 1.0, closed="both")], + "honest_prior": [StrOptions({"empirical", "uniform", "ignore"})], + "stratify": ["boolean"], + } + + def __init__( + self, + *, + criterion=None, + target_tree_class=None, + target_tree_kwargs=None, + random_state=None, + honest_fraction=0.5, + honest_prior="empirical", + stratify=False + ): + self.criterion = criterion + self.target_tree_class = target_tree_class + self.target_tree_kwargs = target_tree_kwargs if target_tree_kwargs is not None else {} + + self.random_state = random_state + self.honest_fraction = honest_fraction + self.honest_prior = honest_prior + self.stratify = stratify + + # TODO: unwind this whole gross antipattern + if target_tree_class is not None: + HonestDecisionTree._target_tree_hack(self, target_tree_class, **target_tree_kwargs) + + # In order to inherit behavior from BaseDecisionTree, we must satisfy a lot of + # pythonic introspective attribute assumptions. This was the lowest effort way + # that came to mind. + @staticmethod + def _target_tree_hack(honest_tree, target_tree_class, **kwargs): + honest_tree.target_tree_class = target_tree_class + honest_tree.target_tree = target_tree_class(**kwargs) + + # copy over the attributes of the target tree + for attr_name in vars(honest_tree.target_tree): + setattr( + honest_tree, + attr_name, + getattr(honest_tree.target_tree, attr_name, None) + ) + + if is_classifier(honest_tree.target_tree): + honest_tree._estimator_type = honest_tree.target_tree._estimator_type + honest_tree.predict_proba = honest_tree.target_tree.predict_proba + honest_tree.predict_log_proba = honest_tree.target_tree.predict_log_proba + + def _fit( + self, + X, + y, + sample_weight=None, + check_input=True, + missing_values_in_feature_mask=None, + classes=None + ): + return self.fit( + X, y, sample_weight, check_input, missing_values_in_feature_mask, classes + ) + + @_fit_context(prefer_skip_nested_validation=True) + def fit( + self, + X, + y, + sample_weight=None, + check_input=True, + missing_values_in_feature_mask=None, + classes=None, + ): + """Build an honest tree from the training set (X, y). + + Parameters + ---------- + X : {array-like, sparse matrix} of shape (n_samples, n_features) + The training input samples. Internally, it will be converted to + ``dtype=np.float32`` and if a sparse matrix is provided + to a sparse ``csc_matrix``. + + y : array-like of shape (n_samples,) or (n_samples, n_outputs) + + sample_weight : array-like of shape (n_samples,), default=None + Sample weights. If None, then samples are equally weighted. Splits + that would create child nodes with net zero or negative weight are + ignored while searching for a split in each node. Splits are also + ignored if they would result in any single class carrying a + negative weight in either child node. + + check_input : bool, default=True + Allow to bypass several input checking. + Don't use this parameter unless you know what you do. + + classes : array-like of shape (n_classes,), default=None + List of all the classes that can possibly appear in the y vector. + + Returns + ------- + self : HonestTree + Fitted tree estimator. + """ + + # run this again because of the way ensemble creates estimators + HonestDecisionTree._target_tree_hack(self, self.target_tree_class, **self.target_tree_kwargs) + target_bta = self.target_tree._prep_data( + X=X, + y=y, + sample_weight=sample_weight, + check_input=check_input, + missing_values_in_feature_mask=missing_values_in_feature_mask, + classes=classes + ) + + # TODO: go fix TODO in classes.py line 636 + if target_bta.n_classes is None: + target_bta.n_classes = np.array( + [1] * self.target_tree.n_outputs_, + dtype=np.intp + ) + + # Determine output settings + self._init_output_shape(target_bta.X, target_bta.y, target_bta.classes) + + # obtain the structure sample weights + sample_weights_structure, sample_weights_honest = self._partition_honest_indices( + target_bta.y, + target_bta.sample_weight + ) + + # create honesty, set up listeners in target tree + self.honesty = Honesty( + target_bta.X, + self.honest_indices_, + target_bta.min_samples_leaf, + missing_values_in_feature_mask = target_bta.missing_values_in_feature_mask + ) + + self.target_tree.presplit_conditions = self.honesty.presplit_conditions + self.target_tree.postsplit_conditions = self.honesty.postsplit_conditions + self.target_tree.splitter_listeners = self.honesty.splitter_event_handlers + self.target_tree.tree_build_listeners = self.honesty.tree_event_handlers + + # Learn structure on subsample + # XXX: this allows us to use BaseDecisionTree without partial_fit API + try: + self.target_tree.fit( + target_bta.X, + target_bta.y, + sample_weight=sample_weights_structure, + check_input=check_input, + classes=target_bta.classes + ) + except Exception: + self.target_tree.fit( + target_bta.X, + target_bta.y, + sample_weight=sample_weights_structure, + check_input=check_input + ) + + # more pythonic introspection minutiae + setattr( + self, + "classes_", + getattr(self.target_tree, "classes_", None) + ) + + n_samples = target_bta.X.shape[0] + samples = np.empty(n_samples, dtype=np.intp) + weighted_n_samples = 0.0 + j = 0 + + for i in range(n_samples): + # Only work with positively weighted samples + if sample_weights_honest[i] != 0.0: + samples[j] = i + j += 1 + + weighted_n_samples += sample_weights_honest[i] + + # more pythonic introspection minutiae + # fingers crossed sklearn.utils.validation.check_is_fitted doesn't + # change its behavior + self.tree_ = HonestTree( + self.target_tree.n_features_in_, + target_bta.n_classes, + self.target_tree.n_outputs_, + self.target_tree.tree_ + ) + self.honesty.resize_tree(self.tree_, self.honesty.get_node_count()) + self.tree_.node_count = self.honesty.get_node_count() + + # Criterion is very stateful, so do all the instantiation and initialization + criterion = BaseDecisionTree._create_criterion( + self.target_tree, + n_outputs=target_bta.y.shape[1], + n_samples=target_bta.X.shape[0], + n_classes=target_bta.n_classes + ) + self.honesty.init_criterion( + criterion, + target_bta.y, + sample_weights_honest, + weighted_n_samples, + self.honest_indices_ + ) + + for i in range(self.honesty.get_node_count()): + start, end = self.honesty.get_node_range(i) + self.honesty.set_sample_pointers(criterion, start, end) + + if missing_values_in_feature_mask is not None: + self.honesty.init_sum_missing(criterion) + + self.honesty.node_value(self.tree_, criterion, i) + + if self.honesty.is_leaf(i): + self.honesty.node_samples(self.tree_, criterion, i) + + # more pythonic introspection minutiae + setattr( + self, + "__sklearn_is_fitted__", + lambda: True + ) + + return self + + + def _init_output_shape(self, X, y, classes=None): + # Determine output settings + self.n_samples_, self.n_features_in_ = X.shape + + # Do preprocessing if 'y' is passed + is_classification = False + if y is not None: + is_classification = is_classifier(self) + y = np.atleast_1d(y) + expanded_class_weight = None + + if y.ndim == 1: + # reshape is necessary to preserve the data contiguity against vs + # [:, np.newaxis] that does not. + y = np.reshape(y, (-1, 1)) + + self.n_outputs_ = y.shape[1] + + if is_classification: + check_classification_targets(y) + y = np.copy(y) + + self.classes_ = [] + self.n_classes_ = [] + + if self.class_weight is not None: + y_original = np.copy(y) + + y_encoded = np.zeros(y.shape, dtype=int) + if classes is not None: + classes = np.atleast_1d(classes) + if classes.ndim == 1: + classes = np.array([classes]) + + for k in classes: + self.classes_.append(np.array(k)) + self.n_classes_.append(np.array(k).shape[0]) + + for i in range(self.n_samples_): + for j in range(self.n_outputs_): + y_encoded[i, j] = np.where(self.classes_[j] == y[i, j])[0][ + 0 + ] + else: + for k in range(self.n_outputs_): + classes_k, y_encoded[:, k] = np.unique( + y[:, k], return_inverse=True + ) + self.classes_.append(classes_k) + self.n_classes_.append(classes_k.shape[0]) + + y = y_encoded + + if self.class_weight is not None: + expanded_class_weight = compute_sample_weight( + self.class_weight, y_original + ) + + self.n_classes_ = np.array(self.n_classes_, dtype=np.intp) + self._n_classes_ = self.n_classes_ + if getattr(y, "dtype", None) != DOUBLE or not y.flags.contiguous: + y = np.ascontiguousarray(y, dtype=DOUBLE) + + if len(y) != self.n_samples_: + raise ValueError( + "Number of labels=%d does not match number of samples=%d" + % (len(y), self.n_samples_) + ) + + + def _partition_honest_indices(self, y, sample_weight): + # Account for bootstrapping too + if sample_weight is None: + structure_weight = np.ones((len(y),), dtype=np.float64) + honest_weight = np.ones((len(y),), dtype=np.float64) + else: + structure_weight = np.array(sample_weight) + honest_weight = np.array(sample_weight) + + nonzero_indices = np.where(structure_weight > 0)[0] + + # sample the structure indices + if self.stratify: + ss = StratifiedShuffleSplit( + n_splits=1, test_size=self.honest_fraction, random_state=self.random_state + ) + for structure_idx, _ in ss.split( + np.zeros((len(nonzero_indices), 1)), y[nonzero_indices] + ): + self.structure_indices_ = nonzero_indices[structure_idx] + + else: + rng = np.random.default_rng(self.random_state) + self.structure_indices_ = rng.choice( + nonzero_indices, + int((1 - self.honest_fraction) * len(nonzero_indices)), + replace=False, + ) + + honest_weight[self.structure_indices_] = 0 + + self.honest_indices_ = np.setdiff1d(nonzero_indices, self.structure_indices_) + structure_weight[self.honest_indices_] = 0 + + return structure_weight, honest_weight diff --git a/sklearn/tree/_honesty.pxd b/sklearn/tree/_honesty.pxd new file mode 100644 index 0000000000000..781a7738800c3 --- /dev/null +++ b/sklearn/tree/_honesty.pxd @@ -0,0 +1,110 @@ +# Authors: Samuel Carliles +# +# License: BSD 3 clause + +# See _honesty.pyx for details. + +# Here we cash in the architectural changes/additions we made to Splitter and +# TreeBuilder. We implement this as an honest module not dependent on any particular +# type of Tree so that it can be composed into any type of Tree. +# +# The general ideas are that we: +# 1. inject honest split rejection criteria into Splitter +# 2. listen to tree build events fired by TreeBuilder to build a shadow tree +# which contains the honest sample +# +# So we implement honest split rejection criteria for injection into Splitter, +# and event handlers which construct the shadow tree in response to events fired +# by TreeBuilder. + +from ._events cimport EventData, EventHandler, EventHandlerEnv, EventType +from ._partitioner cimport Partitioner +from ._splitter cimport ( + NodeSplitEvent, + NodeSortFeatureEventData, + NodeSplitEventData, + Splitter, + SplitConditionEnv, + SplitConditionFunction, + SplitConditionClosure, + SplitCondition +) +from ._tree cimport ( + Tree, + TreeBuildEvent, + TreeBuildSetActiveParentEventData, + TreeBuildAddNodeEventData +) + +from ..utils._typedefs cimport float32_t, float64_t, intp_t, int32_t, uint32_t + +from libcpp.vector cimport vector + + +# We do a much simplified tree model, barely more than enough to define the +# partition extents in the honest-masked data array corresponding to the node's +# elements. We store it in a vector indexed by the corresponding node IDs in the +# "structure" tree. +cdef struct Interval: + intp_t start_idx # index into samples + intp_t n + intp_t feature + intp_t split_idx # start of right child + float64_t split_value + +cdef class Views: + cdef: + const float32_t[:, :] X + const float32_t[:, ::1] y + intp_t[::1] samples + float32_t[::1] feature_values # temp. array holding feature values + Partitioner partitioner + +cdef struct HonestEnv: + void* data_views + vector[Interval] tree + intp_t node_count + Interval* active_parent + Interval active_node + intp_t active_is_left + +cdef class Honesty: + cdef: + public list splitter_event_handlers # python list of EventHandler + public list presplit_conditions # python list of SplitCondition + public list postsplit_conditions # python list of SplitCondition + public list tree_event_handlers # python list of EventHandler + + public Views views + HonestEnv env + + +cdef class HonestTree(Tree): + cdef public Tree target_tree + + +cdef struct TrivialEnv: + vector[int32_t] event_types + +cdef class TrivialHandler(EventHandler): + cdef TrivialEnv _env + +cdef class NodeSortFeatureHandler(EventHandler): + pass + +cdef class AddNodeHandler(EventHandler): + pass + +cdef class SetActiveParentHandler(EventHandler): + pass + +cdef class TrivialCondition(SplitCondition): + pass + + +cdef struct MinSamplesLeafConditionEnv: + intp_t min_samples + HonestEnv* honest_env + +cdef class HonestMinSamplesLeafCondition(SplitCondition): + cdef MinSamplesLeafConditionEnv _env diff --git a/sklearn/tree/_honesty.pyx b/sklearn/tree/_honesty.pyx new file mode 100644 index 0000000000000..11b9719c78670 --- /dev/null +++ b/sklearn/tree/_honesty.pyx @@ -0,0 +1,375 @@ +from cython cimport cast +from libc.stdint cimport uintptr_t +from libc.math cimport floor, fmax, log2, pow, isnan, NAN + +from ._criterion cimport BaseCriterion, Criterion +from ._partitioner cimport DensePartitioner, SparsePartitioner + +cimport numpy as cnp +import numpy as np +from scipy.sparse import issparse + + +cdef class HonestTree(Tree): + """args[0] must be target_tree of type Tree""" + def __init__(self, intp_t n_features, cnp.ndarray n_classes, intp_t n_outputs, Tree target_tree, *args): + self.target_tree = target_tree + + cpdef cnp.ndarray apply(self, object X): + """Finds the terminal region (=leaf node) for each sample in X.""" + + return self.target_tree.apply(X) + + +cdef class Honesty: + def __cinit__( + self, + object X, + object samples, + intp_t min_samples_leaf, + const unsigned char[::1] missing_values_in_feature_mask = None, + Partitioner honest_partitioner = None, + splitter_event_handlers : [EventHandler] = None, + presplit_conditions : [SplitCondition] = None, + postsplit_conditions : [SplitCondition] = None, + tree_event_handlers : [EventHandler] = None + ): + if splitter_event_handlers is None: + splitter_event_handlers = [] + if presplit_conditions is None: + presplit_conditions = [] + if postsplit_conditions is None: + postsplit_conditions = [] + if tree_event_handlers is None: + tree_event_handlers = [] + + self.env.node_count = 0 + self.views = Views() + self.views.X = X + self.views.samples = samples + self.views.feature_values = np.empty(len(samples), dtype=np.float32) + self.views.partitioner = ( + honest_partitioner if honest_partitioner is not None + else Honesty.create_partitioner( + X, + samples, + self.views.feature_values, + missing_values_in_feature_mask + ) + ) + self.env.data_views = self.views + + self.splitter_event_handlers = [NodeSortFeatureHandler(self)] + ( + splitter_event_handlers if splitter_event_handlers is not None else [] + ) + self.presplit_conditions = [HonestMinSamplesLeafCondition(self, min_samples_leaf)] + ( + presplit_conditions if presplit_conditions is not None else [] + ) + self.postsplit_conditions = [] + ( + postsplit_conditions if postsplit_conditions is not None else [] + ) + self.tree_event_handlers = [ + SetActiveParentHandler(self), + AddNodeHandler(self) + ] + (tree_event_handlers if tree_event_handlers is not None else []) + + + @staticmethod + def create_partitioner(X, samples, feature_values, missing_values_in_feature_mask): + return SparsePartitioner( + X, samples, feature_values, missing_values_in_feature_mask + ) if issparse(X) else DensePartitioner( + X, samples, feature_values, missing_values_in_feature_mask + ) + + # The Criterion classes are quite stateful, and since we wish to reuse them + # to maintain behavior consistent with them, we have to do some implementational + # shenanigans like this. + def init_criterion( + self, + Criterion criterion, + y, + sample_weights, + weighted_n_samples, + sample_indices + ): + criterion.init(y, sample_weights, weighted_n_samples, sample_indices) + + def set_sample_pointers(self, Criterion criterion, intp_t start, intp_t end): + criterion.set_sample_pointers(start, end) + + def init_sum_missing(self, Criterion criterion): + criterion.init_sum_missing() + + def node_value(self, Tree tree, Criterion criterion, intp_t i): + criterion.node_value((tree.value + i * tree.value_stride)) + + def node_samples(self, Tree tree, Criterion criterion, intp_t i): + criterion.node_samples(tree.value_samples[i]) + + def get_node_count(self): + return self.env.node_count + + def resize_tree(self, Tree tree, intp_t capacity): + tree._resize(capacity) + + def get_node_range(self, i): + return ( + self.env.tree[i].start_idx, + self.env.tree[i].start_idx + self.env.tree[i].n + ) + + def is_leaf(self, i): + return self.env.tree[i].feature == -1 + + @staticmethod + def get_value_samples_ndarray(Tree tree, intp_t node_id): + return tree._get_value_samples_ndarray(node_id) + + +cdef bint _handle_trivial( + EventType event_type, + EventHandlerEnv handler_env, + EventData event_data +) noexcept nogil: + cdef bint result = False + cdef TrivialEnv* env = handler_env + + with gil: + print("in _handle_trivial") + + for i in range(env.event_types.size()): + result = result | env.event_types[i] + + return result + + +cdef class TrivialHandler(EventHandler): + def __cinit__(self, event_types : [EventType]): + self.event_types = np.array(event_types, dtype=np.int32) + + self._env.event_types.resize(len(event_types)) + for i in range(len(event_types)): + self._env.event_types[i] = event_types[i] + + self.c.f = _handle_trivial + self.c.e = &self._env + + +cdef bint _handle_set_active_parent( + EventType event_type, + EventHandlerEnv handler_env, + EventData event_data +) noexcept nogil: + if event_type != TreeBuildEvent.SET_ACTIVE_PARENT: + return True + + cdef HonestEnv* env = handler_env + cdef TreeBuildSetActiveParentEventData* data = event_data + cdef Interval* node = &env.active_node + + if (data.parent_node_id) >= (env.tree.size()): + return False + + env.active_is_left = data.child_is_left + + node.feature = -1 + node.split_idx = 0 + node.split_value = NAN + + if data.parent_node_id < 0: + env.active_parent = NULL + node.start_idx = 0 + node.n = (env.data_views).samples.shape[0] + else: + env.active_parent = &(env.tree[data.parent_node_id]) + if env.active_is_left: + node.start_idx = env.active_parent.start_idx + node.n = env.active_parent.split_idx - env.active_parent.start_idx + else: + node.start_idx = env.active_parent.split_idx + node.n = env.active_parent.n - env.active_parent.split_idx + + (env.data_views).partitioner.init_node_split(node.start_idx, node.start_idx + node.n) + + return True + +cdef class SetActiveParentHandler(EventHandler): + def __cinit__(self, Honesty h): + self.event_types = np.array([TreeBuildEvent.SET_ACTIVE_PARENT], dtype=np.int32) + + self.c.f = _handle_set_active_parent + self.c.e = &h.env + + +cdef bint _handle_sort_feature( + EventType event_type, + EventHandlerEnv handler_env, + EventData event_data +) noexcept nogil: + if event_type != NodeSplitEvent.SORT_FEATURE: + return True + + cdef HonestEnv* env = handler_env + cdef NodeSortFeatureEventData* data = event_data + cdef Interval* node = &env.active_node + + node.feature = data.feature + node.split_idx = 0 + node.split_value = NAN + + (env.data_views).partitioner.sort_samples_and_feature_values(node.feature) + + return True + +# When the structure tree sorts by a feature, we must do the same +cdef class NodeSortFeatureHandler(EventHandler): + def __cinit__(self, Honesty h): + self.event_types = np.array([NodeSplitEvent.SORT_FEATURE], dtype=np.int32) + + self.c.f = _handle_sort_feature + self.c.e = &h.env + + +cdef bint _handle_add_node( + EventType event_type, + EventHandlerEnv handler_env, + EventData event_data +) noexcept nogil: + if event_type != TreeBuildEvent.ADD_NODE: + return True + + cdef HonestEnv* env = handler_env + cdef const float32_t[:, :] X = (env.data_views).X + cdef intp_t[::1] samples = (env.data_views).samples + cdef float64_t h, feature_value + cdef intp_t i, n_left, n_missing, size = env.tree.size() + cdef TreeBuildAddNodeEventData* data = event_data + cdef Interval *interval = NULL + cdef Interval *parent = NULL + + if data.node_id >= size: + # as a heuristic, assume a complete tree and add a level + h = floor(fmax(0, log2(size))) + env.tree.resize(size + pow(2, h + 1)) + + interval = &(env.tree[data.node_id]) + interval.feature = data.feature + interval.split_value = data.split_point + + if data.parent_node_id < 0: + # the node being added is the tree root + interval.start_idx = 0 + interval.n = samples.shape[0] + else: + parent = &(env.tree[data.parent_node_id]) + + if data.is_left: + interval.start_idx = parent.start_idx + interval.n = parent.split_idx - parent.start_idx + else: + interval.start_idx = parent.split_idx + interval.n = parent.n - (parent.split_idx - parent.start_idx) + + # We also reuse Partitioner. *We* don't need to sort to find the split pos we'll + # need for partitioning, but the partitioner internals are so stateful we had + # better just do it to ensure that it's in the expected state + (env.data_views).partitioner.init_node_split(interval.start_idx, interval.start_idx + interval.n) + (env.data_views).partitioner.sort_samples_and_feature_values(interval.feature) + + # count n_left to find split pos + n_left = 0 + i = interval.start_idx + feature_value = X[samples[i], interval.feature] + + while (not isnan(feature_value)) and feature_value < interval.split_value and i < interval.start_idx + interval.n: + n_left += 1 + i += 1 + feature_value = X[samples[i], interval.feature] + + interval.split_idx = interval.start_idx + n_left + + (env.data_views).partitioner.partition_samples_final( + interval.split_idx, interval.split_value, interval.feature, (env.data_views).partitioner.n_missing + ) + + env.node_count += 1 + + +cdef class AddNodeHandler(EventHandler): + def __cinit__(self, Honesty h): + self.event_types = np.array([TreeBuildEvent.ADD_NODE], dtype=np.int32) + + self.c.f = _handle_add_node + self.c.e = &h.env + + +cdef bint _trivial_condition( + Splitter splitter, + intp_t split_feature, + intp_t split_pos, + float64_t split_value, + intp_t n_missing, + bint missing_go_to_left, + float64_t lower_bound, + float64_t upper_bound, + SplitConditionEnv split_condition_env +) noexcept nogil: + return True + +cdef class TrivialCondition(SplitCondition): + def __cinit__(self): + self.c.f = _trivial_condition + self.c.e = NULL + + +cdef bint _honest_min_sample_leaf_condition( + Splitter splitter, + intp_t split_feature, + intp_t split_pos, + float64_t split_value, + intp_t n_missing, + bint missing_go_to_left, + float64_t lower_bound, + float64_t upper_bound, + SplitConditionEnv split_condition_env +) noexcept nogil: + cdef MinSamplesLeafConditionEnv* env = split_condition_env + cdef Interval* node = &env.honest_env.active_node + + cdef intp_t min_samples_leaf = env.min_samples + cdef intp_t end_non_missing, n_left, n_right + + # we don't care about n_missing in the structure set + n_missing = (env.honest_env.data_views).partitioner.n_missing + end_non_missing = node.start_idx + node.n - n_missing + + # we don't care about split_pos in the structure set, + # need to scan forward in the honest set based on split_value to find it + while node.split_idx < node.start_idx + node.n and (env.honest_env.data_views).X[(env.honest_env.data_views).samples[node.split_idx], node.feature] <= split_value: + node.split_idx += 1 + + if missing_go_to_left: + n_left = node.split_idx - node.start_idx + n_missing + n_right = end_non_missing - node.split_idx + else: + n_left = node.split_idx - node.start_idx + n_right = end_non_missing - node.split_idx + n_missing + + # Reject if min_samples_leaf is not guaranteed + if n_left < min_samples_leaf or n_right < min_samples_leaf: + #with gil: + # print("returning False") + return False + + return True + +# Check that the honest set will have sufficient samples on each side of this +# candidate split. +cdef class HonestMinSamplesLeafCondition(SplitCondition): + def __cinit__(self, Honesty h, intp_t min_samples): + self._env.min_samples = min_samples + self._env.honest_env = &h.env + + self.c.f = _honest_min_sample_leaf_condition + self.c.e = &self._env diff --git a/sklearn/tree/_partitioner.pxd b/sklearn/tree/_partitioner.pxd index 4ddd2a9cf9eb6..019b9c040162c 100644 --- a/sklearn/tree/_partitioner.pxd +++ b/sklearn/tree/_partitioner.pxd @@ -1,178 +1,143 @@ -# Authors: The scikit-learn developers +# Authors: Gilles Louppe +# Peter Prettenhofer +# Brian Holt +# Joel Nothman +# Arnaud Joly +# Jacob Schreiber +# Adam Li +# Jong Shin +# Samuel Carliles +# +# License: BSD 3 clause # SPDX-License-Identifier: BSD-3-Clause -# See _partitioner.pyx for details. +from ..utils._typedefs cimport float32_t, float64_t, intp_t, int8_t, uint8_t, int32_t, uint32_t -from ..utils._typedefs cimport ( - float32_t, float64_t, int8_t, int32_t, intp_t, uint8_t, uint32_t -) -from ._splitter cimport SplitRecord +# Constant to switch between algorithm non zero value extract algorithm +# in SparsePartitioner +cdef float32_t EXTRACT_NNZ_SWITCH = 0.1 +# We introduce a different approach to the fused type for {Dense, Sparse}Partitioner. +# The main drawback of the fused type approach is that it seemed to require a +# proliferation of concrete Splitter types in order to accommodate holding ownership +# of each concrete type of Partitioner, hence the +# {Best, BestSparse, Random, RandomSparse}Splitter classes. This pattern generalizes +# to any class wishing to hold a concrete instance of Partitioner, which makes +# reusing the Partitioner code (as we wish to do for honesty and obliqueness) a +# fractal class-generating process. +# +# The alternative we introduce is the same pattern we use all over the place: +# function pointers. Assigning method implementations as function pointer values +# in init allows DensePartitioner and SparsePartitioner to be plain old subclasses +# of Partitioner, and there is no performance hit from virtual method lookup. +# +# Since we also seek to reuse Partitioner as its own module, we break it out into +# its own files. -# Mitigate precision differences between 32 bit and 64 bit -cdef float32_t FEATURE_THRESHOLD = 1e-7 +# Introduce a fused-class to make it possible to share the split implementation +# between the dense and sparse cases in the node_split_best and node_split_random +# functions. The alternative would have been to use inheritance-based polymorphism +# but it would have resulted in a ~10% overall tree fitting performance +# degradation caused by the overhead frequent virtual method lookups. +#ctypedef fused Partitioner: +# DensePartitioner +# SparsePartitioner -# We provide here the abstract interfact for a Partitioner that would be -# theoretically shared between the Dense and Sparse partitioners. However, -# we leave it commented out for now as it is not used in the current -# implementation due to the performance hit from vtable lookups when using -# inheritance based polymorphism. It is left here for future reference. -# -# Note: Instead, in `_splitter.pyx`, we define a fused type that can be used -# to represent both the dense and sparse partitioners. -# -# cdef class BasePartitioner: -# cdef intp_t[::1] samples -# cdef float32_t[::1] feature_values -# cdef intp_t start -# cdef intp_t end -# cdef intp_t n_missing -# cdef const uint8_t[::1] missing_values_in_feature_mask - -# cdef void sort_samples_and_feature_values( -# self, intp_t current_feature -# ) noexcept nogil -# cdef void init_node_split( -# self, -# intp_t start, -# intp_t end -# ) noexcept nogil -# cdef void find_min_max( -# self, -# intp_t current_feature, -# float32_t* min_feature_value_out, -# float32_t* max_feature_value_out, -# ) noexcept nogil -# cdef void next_p( -# self, -# intp_t* p_prev, -# intp_t* p -# ) noexcept nogil -# cdef intp_t partition_samples( -# self, -# float64_t current_threshold -# ) noexcept nogil -# cdef void partition_samples_final( -# self, -# intp_t best_pos, -# float64_t best_threshold, -# intp_t best_feature, -# intp_t n_missing, -# ) noexcept nogil - - -cdef class DensePartitioner: +ctypedef void (*InitNodeSplitFunction)( + Partitioner partitioner, intp_t start, intp_t end +) noexcept nogil + +ctypedef void (*SortSamplesAndFeatureValuesFunction)( + Partitioner partitioner, intp_t current_feature +) noexcept nogil + +ctypedef void (*FindMinMaxFunction)( + Partitioner partitioner, + intp_t current_feature, + float32_t* min_feature_value_out, + float32_t* max_feature_value_out, +) noexcept nogil + +ctypedef void (*NextPFunction)( + Partitioner partitioner, intp_t* p_prev, intp_t* p +) noexcept nogil + +ctypedef intp_t (*PartitionSamplesFunction)( + Partitioner partitioner, float64_t current_threshold +) noexcept nogil + +ctypedef void (*PartitionSamplesFinalFunction)( + Partitioner partitioner, + intp_t best_pos, + float64_t best_threshold, + intp_t best_feature, + intp_t best_n_missing, +) noexcept nogil + + +cdef class Partitioner: + cdef: + intp_t[::1] samples + float32_t[::1] feature_values + intp_t start + intp_t end + intp_t n_missing + const uint8_t[::1] missing_values_in_feature_mask + + inline void init_node_split(self, intp_t start, intp_t end) noexcept nogil + inline void sort_samples_and_feature_values( + self, + intp_t current_feature + ) noexcept nogil + inline void find_min_max( + self, + intp_t current_feature, + float32_t* min_feature_value_out, + float32_t* max_feature_value_out, + ) noexcept nogil + inline void next_p(self, intp_t* p_prev, intp_t* p) noexcept nogil + inline intp_t partition_samples(self, float64_t current_threshold) noexcept nogil + inline void partition_samples_final( + self, + intp_t best_pos, + float64_t best_threshold, + intp_t best_feature, + intp_t best_n_missing, + ) noexcept nogil + + InitNodeSplitFunction _init_node_split + SortSamplesAndFeatureValuesFunction _sort_samples_and_feature_values + FindMinMaxFunction _find_min_max + NextPFunction _next_p + PartitionSamplesFunction _partition_samples + PartitionSamplesFinalFunction _partition_samples_final + + +cdef class DensePartitioner(Partitioner): """Partitioner specialized for dense data. Note that this partitioner is agnostic to the splitting strategy (best vs. random). """ - cdef const float32_t[:, :] X - cdef intp_t[::1] samples - cdef float32_t[::1] feature_values - cdef intp_t start - cdef intp_t end - cdef intp_t n_missing - cdef const uint8_t[::1] missing_values_in_feature_mask - - cdef void sort_samples_and_feature_values( - self, intp_t current_feature - ) noexcept nogil - cdef void init_node_split( - self, - intp_t start, - intp_t end - ) noexcept nogil - cdef void find_min_max( - self, - intp_t current_feature, - float32_t* min_feature_value_out, - float32_t* max_feature_value_out, - ) noexcept nogil - cdef void next_p( - self, - intp_t* p_prev, - intp_t* p - ) noexcept nogil - cdef intp_t partition_samples( - self, - float64_t current_threshold - ) noexcept nogil - cdef void partition_samples_final( - self, - intp_t best_pos, - float64_t best_threshold, - intp_t best_feature, - intp_t n_missing, - ) noexcept nogil - - -cdef class SparsePartitioner: + cdef: + const float32_t[:, :] X + + +cdef class SparsePartitioner(Partitioner): """Partitioner specialized for sparse CSC data. Note that this partitioner is agnostic to the splitting strategy (best vs. random). """ - cdef const float32_t[::1] X_data - cdef const int32_t[::1] X_indices - cdef const int32_t[::1] X_indptr - cdef intp_t n_total_samples - cdef intp_t[::1] index_to_samples - cdef intp_t[::1] sorted_samples - cdef intp_t start_positive - cdef intp_t end_negative - cdef bint is_samples_sorted - - cdef intp_t[::1] samples - cdef float32_t[::1] feature_values - cdef intp_t start - cdef intp_t end - cdef intp_t n_missing - cdef const uint8_t[::1] missing_values_in_feature_mask - - cdef void sort_samples_and_feature_values( - self, intp_t current_feature - ) noexcept nogil - cdef void init_node_split( - self, - intp_t start, - intp_t end - ) noexcept nogil - cdef void find_min_max( - self, - intp_t current_feature, - float32_t* min_feature_value_out, - float32_t* max_feature_value_out, - ) noexcept nogil - cdef void next_p( - self, - intp_t* p_prev, - intp_t* p - ) noexcept nogil - cdef intp_t partition_samples( - self, - float64_t current_threshold - ) noexcept nogil - cdef void partition_samples_final( - self, - intp_t best_pos, - float64_t best_threshold, - intp_t best_feature, - intp_t n_missing, - ) noexcept nogil - - cdef void extract_nnz( - self, - intp_t feature - ) noexcept nogil - cdef intp_t _partition( - self, - float64_t threshold, - intp_t zero_pos - ) noexcept nogil - - -cdef void shift_missing_values_to_left_if_required( - SplitRecord* best, - intp_t[::1] samples, - intp_t end, -) noexcept nogil + cdef: + const float32_t[::1] X_data + const int32_t[::1] X_indices + const int32_t[::1] X_indptr + + intp_t n_total_samples + + intp_t[::1] index_to_samples + intp_t[::1] sorted_samples + + intp_t start_positive + intp_t end_negative + bint is_samples_sorted diff --git a/sklearn/tree/_partitioner.pyx b/sklearn/tree/_partitioner.pyx index 57801c3f279ed..6e6ffba42af63 100644 --- a/sklearn/tree/_partitioner.pyx +++ b/sklearn/tree/_partitioner.pyx @@ -1,34 +1,51 @@ -"""Partition samples in the construction of a tree. - -This module contains the algorithms for moving sample indices to -the left and right child node given a split determined by the -splitting algorithm in `_splitter.pyx`. - -Partitioning is done in a way that is efficient for both dense data, -and sparse data stored in a Compressed Sparse Column (CSC) format. -""" -# Authors: The scikit-learn developers -# SPDX-License-Identifier: BSD-3-Clause - from cython cimport final from libc.math cimport isnan, log from libc.stdlib cimport qsort from libc.string cimport memcpy - -import numpy as np from scipy.sparse import issparse +import numpy as np -# Constant to switch between algorithm non zero value extract algorithm -# in SparsePartitioner -cdef float32_t EXTRACT_NNZ_SWITCH = 0.1 - -# Allow for 32 bit float comparisons -cdef float32_t INFINITY_32t = np.inf +from ._sort cimport sort, sparse_swap, swap, FEATURE_THRESHOLD + + +cdef class Partitioner: + cdef: + inline void init_node_split(self, intp_t start, intp_t end) noexcept nogil: + self._init_node_split(self, start, end) + + inline void sort_samples_and_feature_values( + self, + intp_t current_feature + ) noexcept nogil: + self._sort_samples_and_feature_values(self, current_feature) + + inline void find_min_max( + self, + intp_t current_feature, + float32_t* min_feature_value_out, + float32_t* max_feature_value_out, + ) noexcept nogil: + self._find_min_max(self, current_feature, min_feature_value_out, max_feature_value_out) + + inline void next_p(self, intp_t* p_prev, intp_t* p) noexcept nogil: + self._next_p(self, p_prev, p) + + inline intp_t partition_samples(self, float64_t current_threshold) noexcept nogil: + return self._partition_samples(self, current_threshold) + + inline void partition_samples_final( + self, + intp_t best_pos, + float64_t best_threshold, + intp_t best_feature, + intp_t best_n_missing, + ) noexcept nogil: + self._partition_samples_final(self, best_pos, best_threshold, best_feature, best_n_missing) @final -cdef class DensePartitioner: +cdef class DensePartitioner(Partitioner): """Partitioner specialized for dense data. Note that this partitioner is agnostic to the splitting strategy (best vs. random). @@ -45,228 +62,203 @@ cdef class DensePartitioner: self.feature_values = feature_values self.missing_values_in_feature_mask = missing_values_in_feature_mask - cdef inline void init_node_split(self, intp_t start, intp_t end) noexcept nogil: - """Initialize splitter at the beginning of node_split.""" - self.start = start - self.end = end - self.n_missing = 0 - - cdef inline void sort_samples_and_feature_values( - self, intp_t current_feature - ) noexcept nogil: - """Simultaneously sort based on the feature_values. - - Missing values are stored at the end of feature_values. - The number of missing values observed in feature_values is stored - in self.n_missing. - """ - cdef: - intp_t i, current_end - float32_t[::1] feature_values = self.feature_values - const float32_t[:, :] X = self.X - intp_t[::1] samples = self.samples - intp_t n_missing = 0 - const uint8_t[::1] missing_values_in_feature_mask = self.missing_values_in_feature_mask - - # Sort samples along that feature; by copying the values into an array and - # sorting the array in a manner which utilizes the cache more effectively. - if missing_values_in_feature_mask is not None and missing_values_in_feature_mask[current_feature]: - i, current_end = self.start, self.end - 1 - # Missing values are placed at the end and do not participate in the sorting. - while i <= current_end: - # Finds the right-most value that is not missing so that - # it can be swapped with missing values at its left. - if isnan(X[samples[current_end], current_feature]): - n_missing += 1 - current_end -= 1 - continue - - # X[samples[current_end], current_feature] is a non-missing value - if isnan(X[samples[i], current_feature]): - samples[i], samples[current_end] = samples[current_end], samples[i] - n_missing += 1 - current_end -= 1 - - feature_values[i] = X[samples[i], current_feature] - i += 1 - else: - # When there are no missing values, we only need to copy the data into - # feature_values - for i in range(self.start, self.end): - feature_values[i] = X[samples[i], current_feature] + self._init_node_split = dense_init_node_split + self._sort_samples_and_feature_values = dense_sort_samples_and_feature_values + self._find_min_max = dense_find_min_max + self._next_p = dense_next_p + self._partition_samples = dense_partition_samples + self._partition_samples_final = dense_partition_samples_final - sort(&feature_values[self.start], &samples[self.start], self.end - self.start - n_missing) - self.n_missing = n_missing - cdef inline void find_min_max( - self, - intp_t current_feature, - float32_t* min_feature_value_out, - float32_t* max_feature_value_out, - ) noexcept nogil: - """Find the minimum and maximum value for current_feature. - - Missing values are stored at the end of feature_values. The number of missing - values observed in feature_values is stored in self.n_missing. - """ - cdef: - intp_t p, current_end - float32_t current_feature_value - const float32_t[:, :] X = self.X - intp_t[::1] samples = self.samples - float32_t min_feature_value = INFINITY_32t - float32_t max_feature_value = -INFINITY_32t - float32_t[::1] feature_values = self.feature_values - intp_t n_missing = 0 - const uint8_t[::1] missing_values_in_feature_mask = self.missing_values_in_feature_mask - - # We are copying the values into an array and finding min/max of the array in - # a manner which utilizes the cache more effectively. We need to also count - # the number of missing-values there are. - if missing_values_in_feature_mask is not None and missing_values_in_feature_mask[current_feature]: - p, current_end = self.start, self.end - 1 - # Missing values are placed at the end and do not participate in the - # min/max calculation. - while p <= current_end: - # Finds the right-most value that is not missing so that - # it can be swapped with missing values towards its left. - if isnan(X[samples[current_end], current_feature]): - n_missing += 1 - current_end -= 1 - continue - - # X[samples[current_end], current_feature] is a non-missing value - if isnan(X[samples[p], current_feature]): - samples[p], samples[current_end] = samples[current_end], samples[p] - n_missing += 1 - current_end -= 1 - - current_feature_value = X[samples[p], current_feature] - feature_values[p] = current_feature_value - if current_feature_value < min_feature_value: - min_feature_value = current_feature_value - elif current_feature_value > max_feature_value: - max_feature_value = current_feature_value - p += 1 - else: - min_feature_value = X[samples[self.start], current_feature] - max_feature_value = min_feature_value - - feature_values[self.start] = min_feature_value - for p in range(self.start + 1, self.end): - current_feature_value = X[samples[p], current_feature] - feature_values[p] = current_feature_value - - if current_feature_value < min_feature_value: - min_feature_value = current_feature_value - elif current_feature_value > max_feature_value: - max_feature_value = current_feature_value - - min_feature_value_out[0] = min_feature_value - max_feature_value_out[0] = max_feature_value - self.n_missing = n_missing - - cdef inline void next_p(self, intp_t* p_prev, intp_t* p) noexcept nogil: - """Compute the next p_prev and p for iteratiing over feature values. - - The missing values are not included when iterating through the feature values. - """ - cdef: - float32_t[::1] feature_values = self.feature_values - intp_t end_non_missing = self.end - self.n_missing - - while ( - p[0] + 1 < end_non_missing and - feature_values[p[0] + 1] <= feature_values[p[0]] + FEATURE_THRESHOLD - ): - p[0] += 1 - - p_prev[0] = p[0] - - # By adding 1, we have - # (feature_values[p] >= end) or (feature_values[p] > feature_values[p - 1]) +cdef inline void dense_init_node_split( + Partitioner self, intp_t start, intp_t end +) noexcept nogil: + """Initialize splitter at the beginning of node_split.""" + self.start = start + self.end = end + self.n_missing = 0 + +cdef inline void dense_sort_samples_and_feature_values( + Partitioner self, intp_t current_feature +) noexcept nogil: + """Simultaneously sort based on the feature_values. + + Missing values are stored at the end of feature_values. + The number of missing values observed in feature_values is stored + in self.n_missing. + """ + cdef: + intp_t i, current_end + float32_t[::1] feature_values = self.feature_values + const float32_t[:, :] X = (self).X + intp_t[::1] samples = self.samples + intp_t n_missing = 0 + const uint8_t[::1] missing_values_in_feature_mask = self.missing_values_in_feature_mask + + # Sort samples along that feature; by + # copying the values into an array and + # sorting the array in a manner which utilizes the cache more + # effectively. + if missing_values_in_feature_mask is not None and missing_values_in_feature_mask[current_feature]: + i, current_end = self.start, self.end - 1 + # Missing values are placed at the end and do not participate in the sorting. + while i <= current_end: + # Finds the right-most value that is not missing so that + # it can be swapped with missing values at its left. + if isnan(X[samples[current_end], current_feature]): + n_missing += 1 + current_end -= 1 + continue + + # X[samples[current_end], current_feature] is a non-missing value + if isnan(X[samples[i], current_feature]): + samples[i], samples[current_end] = samples[current_end], samples[i] + n_missing += 1 + current_end -= 1 + + feature_values[i] = X[samples[i], current_feature] + i += 1 + else: + # When there are no missing values, we only need to copy the data into + # feature_values + for i in range(self.start, self.end): + feature_values[i] = X[samples[i], current_feature] + + sort(&feature_values[self.start], &samples[self.start], self.end - self.start - n_missing) + self.n_missing = n_missing + +cdef inline void dense_find_min_max( + Partitioner self, + intp_t current_feature, + float32_t* min_feature_value_out, + float32_t* max_feature_value_out, +) noexcept nogil: + """Find the minimum and maximum value for current_feature.""" + cdef: + intp_t p + float32_t current_feature_value + const float32_t[:, :] X = (self).X + intp_t[::1] samples = self.samples + float32_t min_feature_value = X[samples[self.start], current_feature] + float32_t max_feature_value = min_feature_value + float32_t[::1] feature_values = self.feature_values + + feature_values[self.start] = min_feature_value + + for p in range(self.start + 1, self.end): + current_feature_value = X[samples[p], current_feature] + feature_values[p] = current_feature_value + + if current_feature_value < min_feature_value: + min_feature_value = current_feature_value + elif current_feature_value > max_feature_value: + max_feature_value = current_feature_value + + min_feature_value_out[0] = min_feature_value + max_feature_value_out[0] = max_feature_value + +cdef inline void dense_next_p( + Partitioner self, intp_t* p_prev, intp_t* p +) noexcept nogil: + """Compute the next p_prev and p for iteratiing over feature values. + + The missing values are not included when iterating through the feature values. + """ + cdef: + float32_t[::1] feature_values = self.feature_values + intp_t end_non_missing = self.end - self.n_missing + + while ( + p[0] + 1 < end_non_missing and + feature_values[p[0] + 1] <= feature_values[p[0]] + FEATURE_THRESHOLD + ): p[0] += 1 - cdef inline intp_t partition_samples( - self, - float64_t current_threshold - ) noexcept nogil: - """Partition samples for feature_values at the current_threshold.""" - cdef: - intp_t p = self.start - intp_t partition_end = self.end - intp_t[::1] samples = self.samples - float32_t[::1] feature_values = self.feature_values + p_prev[0] = p[0] + + # By adding 1, we have + # (feature_values[p] >= end) or (feature_values[p] > feature_values[p - 1]) + p[0] += 1 + +cdef inline intp_t dense_partition_samples( + Partitioner self, float64_t current_threshold +) noexcept nogil: + """Partition samples for feature_values at the current_threshold.""" + cdef: + intp_t p = self.start + intp_t partition_end = self.end + intp_t[::1] samples = self.samples + float32_t[::1] feature_values = self.feature_values + + while p < partition_end: + if feature_values[p] <= current_threshold: + p += 1 + else: + partition_end -= 1 + feature_values[p], feature_values[partition_end] = ( + feature_values[partition_end], feature_values[p] + ) + samples[p], samples[partition_end] = samples[partition_end], samples[p] + + return partition_end + +cdef inline void dense_partition_samples_final( + Partitioner self, + intp_t best_pos, + float64_t best_threshold, + intp_t best_feature, + intp_t best_n_missing, +) noexcept nogil: + """Partition samples for X at the best_threshold and best_feature. + + If missing values are present, this method partitions `samples` + so that the `best_n_missing` missing values' indices are in the + right-most end of `samples`, that is `samples[end_non_missing:end]`. + """ + cdef: + # Local invariance: start <= p <= partition_end <= end + intp_t start = self.start + intp_t p = start + intp_t end = self.end - 1 + intp_t partition_end = end - best_n_missing + intp_t[::1] samples = self.samples + const float32_t[:, :] X = (self).X + float32_t current_value + + if best_n_missing != 0: + # Move samples with missing values to the end while partitioning the + # non-missing samples while p < partition_end: - if feature_values[p] <= current_threshold: + # Keep samples with missing values at the end + if isnan(X[samples[end], best_feature]): + end -= 1 + continue + + # Swap sample with missing values with the sample at the end + current_value = X[samples[p], best_feature] + if isnan(current_value): + samples[p], samples[end] = samples[end], samples[p] + end -= 1 + + # The swapped sample at the end is always a non-missing value, so + # we can continue the algorithm without checking for missingness. + current_value = X[samples[p], best_feature] + + # Partition the non-missing samples + if current_value <= best_threshold: p += 1 else: + samples[p], samples[partition_end] = samples[partition_end], samples[p] partition_end -= 1 - - feature_values[p], feature_values[partition_end] = ( - feature_values[partition_end], feature_values[p] - ) + else: + # Partitioning routine when there are no missing values + while p < partition_end: + if X[samples[p], best_feature] <= best_threshold: + p += 1 + else: samples[p], samples[partition_end] = samples[partition_end], samples[p] - - return partition_end - - cdef inline void partition_samples_final( - self, - intp_t best_pos, - float64_t best_threshold, - intp_t best_feature, - intp_t best_n_missing, - ) noexcept nogil: - """Partition samples for X at the best_threshold and best_feature. - - If missing values are present, this method partitions `samples` - so that the `best_n_missing` missing values' indices are in the - right-most end of `samples`, that is `samples[end_non_missing:end]`. - """ - cdef: - # Local invariance: start <= p <= partition_end <= end - intp_t start = self.start - intp_t p = start - intp_t end = self.end - 1 - intp_t partition_end = end - best_n_missing - intp_t[::1] samples = self.samples - const float32_t[:, :] X = self.X - float32_t current_value - - if best_n_missing != 0: - # Move samples with missing values to the end while partitioning the - # non-missing samples - while p < partition_end: - # Keep samples with missing values at the end - if isnan(X[samples[end], best_feature]): - end -= 1 - continue - - # Swap sample with missing values with the sample at the end - current_value = X[samples[p], best_feature] - if isnan(current_value): - samples[p], samples[end] = samples[end], samples[p] - end -= 1 - - # The swapped sample at the end is always a non-missing value, so - # we can continue the algorithm without checking for missingness. - current_value = X[samples[p], best_feature] - - # Partition the non-missing samples - if current_value <= best_threshold: - p += 1 - else: - samples[p], samples[partition_end] = samples[partition_end], samples[p] - partition_end -= 1 - else: - # Partitioning routine when there are no missing values - while p < partition_end: - if X[samples[p], best_feature] <= best_threshold: - p += 1 - else: - samples[p], samples[partition_end] = samples[partition_end], samples[p] - partition_end -= 1 + partition_end -= 1 @final @@ -307,221 +299,259 @@ cdef class SparsePartitioner: self.missing_values_in_feature_mask = missing_values_in_feature_mask - cdef inline void init_node_split(self, intp_t start, intp_t end) noexcept nogil: - """Initialize splitter at the beginning of node_split.""" - self.start = start - self.end = end - self.is_samples_sorted = 0 - self.n_missing = 0 + self._init_node_split = sparse_init_node_split + self._sort_samples_and_feature_values = sparse_sort_samples_and_feature_values + self._find_min_max = sparse_find_min_max + self._next_p = sparse_next_p + self._partition_samples = sparse_partition_samples + self._partition_samples_final = sparse_partition_samples_final - cdef inline void sort_samples_and_feature_values( - self, - intp_t current_feature - ) noexcept nogil: - """Simultaneously sort based on the feature_values.""" - cdef: - float32_t[::1] feature_values = self.feature_values - intp_t[::1] index_to_samples = self.index_to_samples - intp_t[::1] samples = self.samples - - self.extract_nnz(current_feature) - # Sort the positive and negative parts of `feature_values` - sort(&feature_values[self.start], &samples[self.start], self.end_negative - self.start) - if self.start_positive < self.end: - sort( - &feature_values[self.start_positive], - &samples[self.start_positive], - self.end - self.start_positive - ) - # Update index_to_samples to take into account the sort - for p in range(self.start, self.end_negative): - index_to_samples[samples[p]] = p - for p in range(self.start_positive, self.end): - index_to_samples[samples[p]] = p +cdef inline void sparse_init_node_split(Partitioner self, intp_t start, intp_t end) noexcept nogil: + """Initialize splitter at the beginning of node_split.""" + self.start = start + self.end = end + (self).is_samples_sorted = 0 + self.n_missing = 0 - # Add one or two zeros in feature_values, if there is any - if self.end_negative < self.start_positive: - self.start_positive -= 1 - feature_values[self.start_positive] = 0. - if self.end_negative != self.start_positive: - feature_values[self.end_negative] = 0. - self.end_negative += 1 +cdef inline void sparse_sort_samples_and_feature_values( + Partitioner self, intp_t current_feature +) noexcept nogil: + _sparse_sort_samples_and_feature_values(self, current_feature) - # XXX: When sparse supports missing values, this should be set to the - # number of missing values for current_feature - self.n_missing = 0 - cdef inline void find_min_max( - self, - intp_t current_feature, - float32_t* min_feature_value_out, - float32_t* max_feature_value_out, - ) noexcept nogil: - """Find the minimum and maximum value for current_feature.""" - cdef: - intp_t p - float32_t current_feature_value, min_feature_value, max_feature_value - float32_t[::1] feature_values = self.feature_values - - self.extract_nnz(current_feature) +cdef inline void _sparse_sort_samples_and_feature_values( + SparsePartitioner self, intp_t current_feature +) noexcept nogil: + """Simultaneously sort based on the feature_values.""" + cdef: + float32_t[::1] feature_values = self.feature_values + intp_t[::1] index_to_samples = self.index_to_samples + intp_t[::1] samples = self.samples + + sparse_extract_nnz(self, current_feature) + # Sort the positive and negative parts of `feature_values` + sort(&feature_values[self.start], &samples[self.start], self.end_negative - self.start) + if self.start_positive < self.end: + sort( + &feature_values[self.start_positive], + &samples[self.start_positive], + self.end - self.start_positive + ) + + # Update index_to_samples to take into account the sort + for p in range(self.start, self.end_negative): + index_to_samples[samples[p]] = p + for p in range(self.start_positive, self.end): + index_to_samples[samples[p]] = p + + # Add one or two zeros in feature_values, if there is any + if self.end_negative < self.start_positive: + self.start_positive -= 1 + feature_values[self.start_positive] = 0. if self.end_negative != self.start_positive: - # There is a zero - min_feature_value = 0 - max_feature_value = 0 - else: - min_feature_value = feature_values[self.start] - max_feature_value = min_feature_value + feature_values[self.end_negative] = 0. + self.end_negative += 1 - # Find min, max in feature_values[start:end_negative] - for p in range(self.start, self.end_negative): - current_feature_value = feature_values[p] + # XXX: When sparse supports missing values, this should be set to the + # number of missing values for current_feature + self.n_missing = 0 - if current_feature_value < min_feature_value: - min_feature_value = current_feature_value - elif current_feature_value > max_feature_value: - max_feature_value = current_feature_value - # Update min, max given feature_values[start_positive:end] - for p in range(self.start_positive, self.end): - current_feature_value = feature_values[p] +cdef inline void sparse_find_min_max( + Partitioner self, + intp_t current_feature, + float32_t* min_feature_value_out, + float32_t* max_feature_value_out, +) noexcept nogil: + _sparse_find_min_max( + self, + current_feature, + min_feature_value_out, + max_feature_value_out + ) + +cdef inline void _sparse_find_min_max( + SparsePartitioner self, + intp_t current_feature, + float32_t* min_feature_value_out, + float32_t* max_feature_value_out, +) noexcept nogil: + """Find the minimum and maximum value for current_feature.""" + cdef: + intp_t p + float32_t current_feature_value, min_feature_value, max_feature_value + float32_t[::1] feature_values = self.feature_values + + sparse_extract_nnz(self, current_feature) + + if self.end_negative != self.start_positive: + # There is a zero + min_feature_value = 0 + max_feature_value = 0 + else: + min_feature_value = feature_values[self.start] + max_feature_value = min_feature_value + + # Find min, max in feature_values[start:end_negative] + for p in range(self.start, self.end_negative): + current_feature_value = feature_values[p] + + if current_feature_value < min_feature_value: + min_feature_value = current_feature_value + elif current_feature_value > max_feature_value: + max_feature_value = current_feature_value + + # Update min, max given feature_values[start_positive:end] + for p in range(self.start_positive, self.end): + current_feature_value = feature_values[p] + + if current_feature_value < min_feature_value: + min_feature_value = current_feature_value + elif current_feature_value > max_feature_value: + max_feature_value = current_feature_value + + min_feature_value_out[0] = min_feature_value + max_feature_value_out[0] = max_feature_value - if current_feature_value < min_feature_value: - min_feature_value = current_feature_value - elif current_feature_value > max_feature_value: - max_feature_value = current_feature_value - min_feature_value_out[0] = min_feature_value - max_feature_value_out[0] = max_feature_value +cdef inline void sparse_next_p(Partitioner self, intp_t* p_prev, intp_t* p) noexcept nogil: + _sparse_next_p(self, p_prev, p) - cdef inline void next_p(self, intp_t* p_prev, intp_t* p) noexcept nogil: - """Compute the next p_prev and p for iteratiing over feature values.""" - cdef: - intp_t p_next - float32_t[::1] feature_values = self.feature_values +cdef inline void _sparse_next_p(SparsePartitioner self, intp_t* p_prev, intp_t* p) noexcept nogil: + """Compute the next p_prev and p for iteratiing over feature values.""" + cdef: + intp_t p_next + float32_t[::1] feature_values = self.feature_values + + if p[0] + 1 != self.end_negative: + p_next = p[0] + 1 + else: + p_next = self.start_positive + + while (p_next < self.end and + feature_values[p_next] <= feature_values[p[0]] + FEATURE_THRESHOLD): + p[0] = p_next if p[0] + 1 != self.end_negative: p_next = p[0] + 1 else: p_next = self.start_positive - while (p_next < self.end and - feature_values[p_next] <= feature_values[p[0]] + FEATURE_THRESHOLD): - p[0] = p_next - if p[0] + 1 != self.end_negative: - p_next = p[0] + 1 - else: - p_next = self.start_positive + p_prev[0] = p[0] + p[0] = p_next - p_prev[0] = p[0] - p[0] = p_next - cdef inline intp_t partition_samples( - self, - float64_t current_threshold - ) noexcept nogil: - """Partition samples for feature_values at the current_threshold.""" - return self._partition(current_threshold, self.start_positive) +cdef inline intp_t sparse_partition_samples( + Partitioner self, float64_t current_threshold +) noexcept nogil: + """Partition samples for feature_values at the current_threshold.""" + return sparse_partition( + self, current_threshold, (self).start_positive + ) + + +cdef inline void sparse_partition_samples_final( + Partitioner self, + intp_t best_pos, + float64_t best_threshold, + intp_t best_feature, + intp_t n_missing, +) noexcept nogil: + """Partition samples for X at the best_threshold and best_feature.""" + sparse_extract_nnz(self, best_feature) + sparse_partition(self, best_threshold, best_pos) + + +cdef inline intp_t sparse_partition(SparsePartitioner self, float64_t threshold, intp_t zero_pos) noexcept nogil: + """Partition samples[start:end] based on threshold.""" + cdef: + intp_t p, partition_end + intp_t[::1] index_to_samples = self.index_to_samples + float32_t[::1] feature_values = self.feature_values + intp_t[::1] samples = self.samples + + if threshold < 0.: + p = self.start + partition_end = self.end_negative + elif threshold > 0.: + p = self.start_positive + partition_end = self.end + else: + # Data are already split + return zero_pos + + while p < partition_end: + if feature_values[p] <= threshold: + p += 1 - cdef inline void partition_samples_final( - self, - intp_t best_pos, - float64_t best_threshold, - intp_t best_feature, - intp_t n_missing, - ) noexcept nogil: - """Partition samples for X at the best_threshold and best_feature.""" - self.extract_nnz(best_feature) - self._partition(best_threshold, best_pos) - - cdef inline intp_t _partition(self, float64_t threshold, intp_t zero_pos) noexcept nogil: - """Partition samples[start:end] based on threshold.""" - cdef: - intp_t p, partition_end - intp_t[::1] index_to_samples = self.index_to_samples - float32_t[::1] feature_values = self.feature_values - intp_t[::1] samples = self.samples - - if threshold < 0.: - p = self.start - partition_end = self.end_negative - elif threshold > 0.: - p = self.start_positive - partition_end = self.end else: - # Data are already split - return zero_pos + partition_end -= 1 - while p < partition_end: - if feature_values[p] <= threshold: - p += 1 + feature_values[p], feature_values[partition_end] = ( + feature_values[partition_end], feature_values[p] + ) + sparse_swap(index_to_samples, samples, p, partition_end) - else: - partition_end -= 1 + return partition_end - feature_values[p], feature_values[partition_end] = ( - feature_values[partition_end], feature_values[p] - ) - sparse_swap(index_to_samples, samples, p, partition_end) - - return partition_end - - cdef inline void extract_nnz(self, intp_t feature) noexcept nogil: - """Extract and partition values for a given feature. - - The extracted values are partitioned between negative values - feature_values[start:end_negative[0]] and positive values - feature_values[start_positive[0]:end]. - The samples and index_to_samples are modified according to this - partition. - - The extraction corresponds to the intersection between the arrays - X_indices[indptr_start:indptr_end] and samples[start:end]. - This is done efficiently using either an index_to_samples based approach - or binary search based approach. - - Parameters - ---------- - feature : intp_t, - Index of the feature we want to extract non zero value. - """ - cdef intp_t[::1] samples = self.samples - cdef float32_t[::1] feature_values = self.feature_values - cdef intp_t indptr_start = self.X_indptr[feature], - cdef intp_t indptr_end = self.X_indptr[feature + 1] - cdef intp_t n_indices = (indptr_end - indptr_start) - cdef intp_t n_samples = self.end - self.start - cdef intp_t[::1] index_to_samples = self.index_to_samples - cdef intp_t[::1] sorted_samples = self.sorted_samples - cdef const int32_t[::1] X_indices = self.X_indices - cdef const float32_t[::1] X_data = self.X_data - - # Use binary search if n_samples * log(n_indices) < - # n_indices and index_to_samples approach otherwise. - # O(n_samples * log(n_indices)) is the running time of binary - # search and O(n_indices) is the running time of index_to_samples - # approach. - if ((1 - self.is_samples_sorted) * n_samples * log(n_samples) + - n_samples * log(n_indices) < EXTRACT_NNZ_SWITCH * n_indices): - extract_nnz_binary_search(X_indices, X_data, - indptr_start, indptr_end, - samples, self.start, self.end, - index_to_samples, - feature_values, - &self.end_negative, &self.start_positive, - sorted_samples, &self.is_samples_sorted) - - # Using an index to samples technique to extract non zero values - # index_to_samples is a mapping from X_indices to samples - else: - extract_nnz_index_to_samples(X_indices, X_data, - indptr_start, indptr_end, - samples, self.start, self.end, - index_to_samples, - feature_values, - &self.end_negative, &self.start_positive) + +cdef inline void sparse_extract_nnz(SparsePartitioner self, intp_t feature) noexcept nogil: + """Extract and partition values for a given feature. + + The extracted values are partitioned between negative values + feature_values[start:end_negative[0]] and positive values + feature_values[start_positive[0]:end]. + The samples and index_to_samples are modified according to this + partition. + + The extraction corresponds to the intersection between the arrays + X_indices[indptr_start:indptr_end] and samples[start:end]. + This is done efficiently using either an index_to_samples based approach + or binary search based approach. + + Parameters + ---------- + feature : intp_t, + Index of the feature we want to extract non zero value. + """ + cdef intp_t[::1] samples = self.samples + cdef float32_t[::1] feature_values = self.feature_values + cdef intp_t indptr_start = self.X_indptr[feature], + cdef intp_t indptr_end = self.X_indptr[feature + 1] + cdef intp_t n_indices = (indptr_end - indptr_start) + cdef intp_t n_samples = self.end - self.start + cdef intp_t[::1] index_to_samples = self.index_to_samples + cdef intp_t[::1] sorted_samples = self.sorted_samples + cdef const int32_t[::1] X_indices = self.X_indices + cdef const float32_t[::1] X_data = self.X_data + + # Use binary search if n_samples * log(n_indices) < + # n_indices and index_to_samples approach otherwise. + # O(n_samples * log(n_indices)) is the running time of binary + # search and O(n_indices) is the running time of index_to_samples + # approach. + if ((1 - self.is_samples_sorted) * n_samples * log(n_samples) + + n_samples * log(n_indices) < EXTRACT_NNZ_SWITCH * n_indices): + extract_nnz_binary_search(X_indices, X_data, + indptr_start, indptr_end, + samples, self.start, self.end, + index_to_samples, + feature_values, + &self.end_negative, &self.start_positive, + sorted_samples, &self.is_samples_sorted) + + # Using an index to samples technique to extract non zero values + # index_to_samples is a mapping from X_indices to samples + else: + extract_nnz_index_to_samples(X_indices, X_data, + indptr_start, indptr_end, + samples, self.start, self.end, + index_to_samples, + feature_values, + &self.end_negative, &self.start_positive) cdef int compare_SIZE_t(const void* a, const void* b) noexcept nogil: @@ -666,151 +696,3 @@ cdef inline void extract_nnz_binary_search(const int32_t[::1] X_indices, # Returned values end_negative[0] = end_negative_ start_positive[0] = start_positive_ - - -cdef inline void sparse_swap(intp_t[::1] index_to_samples, intp_t[::1] samples, - intp_t pos_1, intp_t pos_2) noexcept nogil: - """Swap sample pos_1 and pos_2 preserving sparse invariant.""" - samples[pos_1], samples[pos_2] = samples[pos_2], samples[pos_1] - index_to_samples[samples[pos_1]] = pos_1 - index_to_samples[samples[pos_2]] = pos_2 - - -cdef inline void shift_missing_values_to_left_if_required( - SplitRecord* best, - intp_t[::1] samples, - intp_t end, -) noexcept nogil: - """Shift missing value sample indices to the left of the split if required. - - Note: this should always be called at the very end because it will - move samples around, thereby affecting the criterion. - This affects the computation of the children impurity, which affects - the computation of the next node. - """ - cdef intp_t i, p, current_end - # The partitioner partitions the data such that the missing values are in - # samples[-n_missing:] for the criterion to consume. If the missing values - # are going to the right node, then the missing values are already in the - # correct position. If the missing values go left, then we move the missing - # values to samples[best.pos:best.pos+n_missing] and update `best.pos`. - if best.n_missing > 0 and best.missing_go_to_left: - for p in range(best.n_missing): - i = best.pos + p - current_end = end - 1 - p - samples[i], samples[current_end] = samples[current_end], samples[i] - best.pos += best.n_missing - - -# Sort n-element arrays pointed to by feature_values and samples, simultaneously, -# by the values in feature_values. Algorithm: Introsort (Musser, SP&E, 1997). -cdef inline void sort(float32_t* feature_values, intp_t* samples, intp_t n) noexcept nogil: - if n == 0: - return - cdef intp_t maxd = 2 * log(n) - introsort(feature_values, samples, n, maxd) - - -cdef inline void swap(float32_t* feature_values, intp_t* samples, - intp_t i, intp_t j) noexcept nogil: - # Helper for sort - feature_values[i], feature_values[j] = feature_values[j], feature_values[i] - samples[i], samples[j] = samples[j], samples[i] - - -cdef inline float32_t median3(float32_t* feature_values, intp_t n) noexcept nogil: - # Median of three pivot selection, after Bentley and McIlroy (1993). - # Engineering a sort function. SP&E. Requires 8/3 comparisons on average. - cdef float32_t a = feature_values[0], b = feature_values[n / 2], c = feature_values[n - 1] - if a < b: - if b < c: - return b - elif a < c: - return c - else: - return a - elif b < c: - if a < c: - return a - else: - return c - else: - return b - - -# Introsort with median of 3 pivot selection and 3-way partition function -# (robust to repeated elements, e.g. lots of zero features). -cdef void introsort(float32_t* feature_values, intp_t *samples, - intp_t n, intp_t maxd) noexcept nogil: - cdef float32_t pivot - cdef intp_t i, l, r - - while n > 1: - if maxd <= 0: # max depth limit exceeded ("gone quadratic") - heapsort(feature_values, samples, n) - return - maxd -= 1 - - pivot = median3(feature_values, n) - - # Three-way partition. - i = l = 0 - r = n - while i < r: - if feature_values[i] < pivot: - swap(feature_values, samples, i, l) - i += 1 - l += 1 - elif feature_values[i] > pivot: - r -= 1 - swap(feature_values, samples, i, r) - else: - i += 1 - - introsort(feature_values, samples, l, maxd) - feature_values += r - samples += r - n -= r - - -cdef inline void sift_down(float32_t* feature_values, intp_t* samples, - intp_t start, intp_t end) noexcept nogil: - # Restore heap order in feature_values[start:end] by moving the max element to start. - cdef intp_t child, maxind, root - - root = start - while True: - child = root * 2 + 1 - - # find max of root, left child, right child - maxind = root - if child < end and feature_values[maxind] < feature_values[child]: - maxind = child - if child + 1 < end and feature_values[maxind] < feature_values[child + 1]: - maxind = child + 1 - - if maxind == root: - break - else: - swap(feature_values, samples, root, maxind) - root = maxind - - -cdef void heapsort(float32_t* feature_values, intp_t* samples, intp_t n) noexcept nogil: - cdef intp_t start, end - - # heapify - start = (n - 2) / 2 - end = n - while True: - sift_down(feature_values, samples, start, end) - if start == 0: - break - start -= 1 - - # sort by shrinking the heap, putting the max element immediately after it - end = n - 1 - while end > 0: - swap(feature_values, samples, 0, end) - sift_down(feature_values, samples, 0, end) - end = end - 1 diff --git a/sklearn/tree/_sort.pxd b/sklearn/tree/_sort.pxd new file mode 100644 index 0000000000000..99db858c52a96 --- /dev/null +++ b/sklearn/tree/_sort.pxd @@ -0,0 +1,29 @@ +# Authors: Gilles Louppe +# Peter Prettenhofer +# Brian Holt +# Joel Nothman +# Arnaud Joly +# Jacob Schreiber +# Adam Li +# Jong Shin +# Samuel Carliles +# +# License: BSD 3 clause +# SPDX-License-Identifier: BSD-3-Clause + +from ..utils._typedefs cimport float32_t, float64_t, intp_t, int8_t, int32_t, uint32_t + +# Since we broke Partitioner out into its own module in order to reuse it, and since +# both Splitter and Partitioner use these sort functions, we break them out into +# their own files in order to avoid cyclic file dependency. + +# Mitigate precision differences between 32 bit and 64 bit +cdef float32_t FEATURE_THRESHOLD = 1e-7 + +# Sort n-element arrays pointed to by feature_values and samples, simultaneously, +# by the values in feature_values. Algorithm: Introsort (Musser, SP&E, 1997). +cdef void sort(float32_t* feature_values, intp_t* samples, intp_t n) noexcept nogil + +cdef void swap(float32_t* feature_values, intp_t* samples, intp_t i, intp_t j) noexcept nogil +cdef void sparse_swap(intp_t[::1] index_to_samples, intp_t[::1] samples, + intp_t pos_1, intp_t pos_2) noexcept nogil diff --git a/sklearn/tree/_sort.pyx b/sklearn/tree/_sort.pyx new file mode 100644 index 0000000000000..9a9db6edf6e00 --- /dev/null +++ b/sklearn/tree/_sort.pyx @@ -0,0 +1,123 @@ +from ._utils cimport log + + +cdef inline void sparse_swap(intp_t[::1] index_to_samples, intp_t[::1] samples, + intp_t pos_1, intp_t pos_2) noexcept nogil: + """Swap sample pos_1 and pos_2 preserving sparse invariant.""" + samples[pos_1], samples[pos_2] = samples[pos_2], samples[pos_1] + index_to_samples[samples[pos_1]] = pos_1 + index_to_samples[samples[pos_2]] = pos_2 + + +# Sort n-element arrays pointed to by feature_values and samples, simultaneously, +# by the values in feature_values. Algorithm: Introsort (Musser, SP&E, 1997). +cdef inline void sort(float32_t* feature_values, intp_t* samples, intp_t n) noexcept nogil: + if n == 0: + return + cdef intp_t maxd = 2 * log(n) + introsort(feature_values, samples, n, maxd) + + +# Introsort with median of 3 pivot selection and 3-way partition function +# (robust to repeated elements, e.g. lots of zero features). +cdef void introsort(float32_t* feature_values, intp_t *samples, + intp_t n, intp_t maxd) noexcept nogil: + cdef float32_t pivot + cdef intp_t i, l, r + + while n > 1: + if maxd <= 0: # max depth limit exceeded ("gone quadratic") + heapsort(feature_values, samples, n) + return + maxd -= 1 + + pivot = median3(feature_values, n) + + # Three-way partition. + i = l = 0 + r = n + while i < r: + if feature_values[i] < pivot: + swap(feature_values, samples, i, l) + i += 1 + l += 1 + elif feature_values[i] > pivot: + r -= 1 + swap(feature_values, samples, i, r) + else: + i += 1 + + introsort(feature_values, samples, l, maxd) + feature_values += r + samples += r + n -= r + + +cdef void heapsort(float32_t* feature_values, intp_t* samples, intp_t n) noexcept nogil: + cdef intp_t start, end + + # heapify + start = (n - 2) / 2 + end = n + while True: + sift_down(feature_values, samples, start, end) + if start == 0: + break + start -= 1 + + # sort by shrinking the heap, putting the max element immediately after it + end = n - 1 + while end > 0: + swap(feature_values, samples, 0, end) + sift_down(feature_values, samples, 0, end) + end = end - 1 + + +cdef inline float32_t median3(float32_t* feature_values, intp_t n) noexcept nogil: + # Median of three pivot selection, after Bentley and McIlroy (1993). + # Engineering a sort function. SP&E. Requires 8/3 comparisons on average. + cdef float32_t a = feature_values[0], b = feature_values[n / 2], c = feature_values[n - 1] + if a < b: + if b < c: + return b + elif a < c: + return c + else: + return a + elif b < c: + if a < c: + return a + else: + return c + else: + return b + + +cdef inline void swap(float32_t* feature_values, intp_t* samples, + intp_t i, intp_t j) noexcept nogil: + # Helper for sort + feature_values[i], feature_values[j] = feature_values[j], feature_values[i] + samples[i], samples[j] = samples[j], samples[i] + + +cdef inline void sift_down(float32_t* feature_values, intp_t* samples, + intp_t start, intp_t end) noexcept nogil: + # Restore heap order in feature_values[start:end] by moving the max element to start. + cdef intp_t child, maxind, root + + root = start + while True: + child = root * 2 + 1 + + # find max of root, left child, right child + maxind = root + if child < end and feature_values[maxind] < feature_values[child]: + maxind = child + if child + 1 < end and feature_values[maxind] < feature_values[child + 1]: + maxind = child + 1 + + if maxind == root: + break + else: + swap(feature_values, samples, root, maxind) + root = maxind diff --git a/sklearn/tree/_splitter.pxd b/sklearn/tree/_splitter.pxd index 393c81fd8b9a9..f5813cb06b8b5 100644 --- a/sklearn/tree/_splitter.pxd +++ b/sklearn/tree/_splitter.pxd @@ -1,12 +1,66 @@ -# Authors: The scikit-learn developers +# Authors: Gilles Louppe +# Peter Prettenhofer +# Brian Holt +# Joel Nothman +# Arnaud Joly +# Jacob Schreiber +# Adam Li +# Jong Shin +# Samuel Carliles +# +# License: BSD 3 clause # SPDX-License-Identifier: BSD-3-Clause # See _splitter.pyx for details. from libcpp.vector cimport vector +from ._partitioner cimport Partitioner, DensePartitioner, SparsePartitioner from ._criterion cimport BaseCriterion, Criterion from ._tree cimport ParentInfo -from ..utils._typedefs cimport float32_t, float64_t, intp_t, int8_t, int32_t, uint8_t, uint32_t + +from ..utils._typedefs cimport float32_t, float64_t, intp_t, int8_t, uint8_t, int32_t, uint32_t + +from ._events cimport EventBroker, EventHandler, NullHandler + + +cdef enum NodeSplitEvent: + SORT_FEATURE = 1 + +cdef struct NodeSortFeatureEventData: + intp_t feature + intp_t is_left + +cdef struct NodeSplitEventData: + intp_t feature + float64_t threshold + +# We wish to generalize Splitter so that arbitrary split rejection criteria can be +# passed in dynamically at construction. The natural way to want to do this is to +# pass in a list of lambdas, but as we are in cython, this is not so straightforward. +# We want the convience of being able to pass them in as a python list, and while it +# would be nice to receive them as a memoryview, this is quite a nuisance with +# cython extension types, so we do cpp vector instead. We do the same closure struct +# pattern for execution speed, but they need to be wrapped in cython extension types +# both for convenience and to go in python list. +ctypedef void* SplitConditionEnv +ctypedef bint (*SplitConditionFunction)( + Splitter splitter, + intp_t split_feature, + intp_t split_pos, + float64_t split_value, + intp_t n_missing, + bint missing_go_to_left, + float64_t lower_bound, + float64_t upper_bound, + SplitConditionEnv split_condition_env +) noexcept nogil + +cdef struct SplitConditionClosure: + SplitConditionFunction f + SplitConditionEnv e + +cdef class SplitCondition: + cdef SplitConditionClosure c cdef struct SplitRecord: @@ -24,6 +78,19 @@ cdef struct SplitRecord: uint8_t missing_go_to_left # Controls if missing values go to the left node. intp_t n_missing # Number of missing values for the feature being split on + +# In the neurodata fork of sklearn there was a hack added where SplitRecords are +# created which queries splitter for pointer size and does an inline malloc. This +# is to accommodate the ability to create extended SplitRecord types in Splitter +# subclasses. We refactor that into a factory method again implemented as a closure +# struct. +ctypedef void* SplitRecordFactoryEnv +ctypedef SplitRecord* (*SplitRecordFactory)(SplitRecordFactoryEnv env) except NULL nogil + +cdef struct SplitRecordFactoryClosure: + SplitRecordFactory f + SplitRecordFactoryEnv e + cdef class BaseSplitter: """Abstract interface for splitter.""" @@ -53,6 +120,8 @@ cdef class BaseSplitter: cdef const float64_t[:] sample_weight + cdef SplitRecordFactoryClosure split_record_factory + # The samples vector `samples` is maintained by the Splitter object such # that the samples contained in a node are contiguous. With this setting, # `node_split` reorganizes the node samples `samples[start:end]` in two @@ -84,6 +153,7 @@ cdef class BaseSplitter: cdef void node_value(self, float64_t* dest) noexcept nogil cdef float64_t node_impurity(self) noexcept nogil cdef intp_t pointer_size(self) noexcept nogil + cdef SplitRecord* create_split_record(self) except NULL nogil cdef class Splitter(BaseSplitter): """Base class for supervised splitters.""" @@ -99,6 +169,19 @@ cdef class Splitter(BaseSplitter): cdef const int8_t[:] monotonic_cst cdef bint with_monotonic_cst + cdef SplitCondition min_samples_leaf_condition + cdef SplitCondition min_weight_leaf_condition + cdef SplitCondition monotonic_constraint_condition + + # split rejection criteria checked before split selection + cdef vector[SplitConditionClosure] presplit_conditions + + # split rejection criteria checked after split selection + cdef vector[SplitConditionClosure] postsplit_conditions + + # event broker for handling splitter events + cdef EventBroker event_broker + cdef int init( self, object X, @@ -127,3 +210,16 @@ cdef class Splitter(BaseSplitter): float64_t lower_bound, float64_t upper_bound ) noexcept nogil + + cdef void _add_conditions( + self, + vector[SplitConditionClosure]* v, + split_conditions : [SplitCondition] + ) + + +cdef void shift_missing_values_to_left_if_required( + SplitRecord* best, + intp_t[::1] samples, + intp_t end, +) noexcept nogil diff --git a/sklearn/tree/_splitter.pyx b/sklearn/tree/_splitter.pyx index d69bfa067c92c..f41532e523033 100644 --- a/sklearn/tree/_splitter.pyx +++ b/sklearn/tree/_splitter.pyx @@ -1,50 +1,138 @@ -"""Splitting algorithms in the construction of a tree. - -This module contains the main splitting algorithms for constructing a tree. -Splitting is concerned with finding the optimal partition of the data into -two groups. The impurity of the groups is minimized, and the impurity is measured -by some criterion, which is typically the Gini impurity or the entropy. Criterion -are implemented in the ``_criterion`` module. - -Splitting evaluates a subset of features (defined by `max_features` also -known as mtry in the literature). The module supports two primary types -of splitting strategies: - -- Best Split: A greedy approach to find the optimal split. This method - ensures that the best possible split is chosen by examining various - thresholds for each candidate feature. -- Random Split: A stochastic approach that selects a split randomly - from a subset of the best splits. This method is faster but does - not guarantee the optimal split. -""" -# Authors: The scikit-learn developers +# Authors: Gilles Louppe +# Peter Prettenhofer +# Brian Holt +# Noel Dawe +# Satrajit Gosh +# Lars Buitinck +# Arnaud Joly +# Joel Nothman +# Fares Hedayati +# Jacob Schreiber +# Adam Li +# Jong Shin +# Samuel Carliles +# + +# License: BSD 3 clause # SPDX-License-Identifier: BSD-3-Clause + +from libc.stdlib cimport malloc from libc.string cimport memcpy +from ._criterion cimport Criterion +from ._sort cimport FEATURE_THRESHOLD +from ._utils cimport rand_int +from ._utils cimport rand_uniform +from ._utils cimport RAND_R_MAX from ..utils._typedefs cimport int8_t from ._criterion cimport Criterion -from ._partitioner cimport ( - FEATURE_THRESHOLD, DensePartitioner, SparsePartitioner, - shift_missing_values_to_left_if_required -) +from ._partitioner cimport DensePartitioner, SparsePartitioner + from ._utils cimport RAND_R_MAX, rand_int, rand_uniform import numpy as np -# Introduce a fused-class to make it possible to share the split implementation -# between the dense and sparse cases in the node_split_best and node_split_random -# functions. The alternative would have been to use inheritance-based polymorphism -# but it would have resulted in a ~10% overall tree fitting performance -# degradation caused by the overhead frequent virtual method lookups. -ctypedef fused Partitioner: - DensePartitioner - SparsePartitioner - cdef float64_t INFINITY = np.inf +# we refactor the inline min sample leaf split rejection criterion +# into our injectable SplitCondition pattern +cdef bint min_sample_leaf_condition( + Splitter splitter, + intp_t split_feature, + intp_t split_pos, + float64_t split_value, + intp_t n_missing, + bint missing_go_to_left, + float64_t lower_bound, + float64_t upper_bound, + SplitConditionEnv split_condition_env +) noexcept nogil: + cdef intp_t min_samples_leaf = splitter.min_samples_leaf + cdef intp_t end_non_missing = splitter.end - n_missing + cdef intp_t n_left, n_right + + if missing_go_to_left: + n_left = split_pos - splitter.start + n_missing + n_right = end_non_missing - split_pos + else: + n_left = split_pos - splitter.start + n_right = end_non_missing - split_pos + n_missing + + # Reject if min_samples_leaf is not guaranteed + if n_left < min_samples_leaf or n_right < min_samples_leaf: + return False + + return True + +cdef class MinSamplesLeafCondition(SplitCondition): + def __cinit__(self): + self.c.f = min_sample_leaf_condition + self.c.e = NULL # min_samples is stored in splitter, which is already passed to f + + +# we refactor the inline min weight leaf split rejection criterion +# into our injectable SplitCondition pattern +cdef bint min_weight_leaf_condition( + Splitter splitter, + intp_t split_feature, + intp_t split_pos, + float64_t split_value, + intp_t n_missing, + bint missing_go_to_left, + float64_t lower_bound, + float64_t upper_bound, + SplitConditionEnv split_condition_env +) noexcept nogil: + cdef float64_t min_weight_leaf = splitter.min_weight_leaf + + # Reject if min_weight_leaf is not satisfied + if ((splitter.criterion.weighted_n_left < min_weight_leaf) or + (splitter.criterion.weighted_n_right < min_weight_leaf)): + return False + + return True + +cdef class MinWeightLeafCondition(SplitCondition): + def __cinit__(self): + self.c.f = min_weight_leaf_condition + self.c.e = NULL # min_weight_leaf is stored in splitter, which is already passed to f + + +# we refactor the inline monotonic constraint split rejection criterion +# into our injectable SplitCondition pattern +cdef bint monotonic_constraint_condition( + Splitter splitter, + intp_t split_feature, + intp_t split_pos, + float64_t split_value, + intp_t n_missing, + bint missing_go_to_left, + float64_t lower_bound, + float64_t upper_bound, + SplitConditionEnv split_condition_env +) noexcept nogil: + if ( + splitter.with_monotonic_cst and + splitter.monotonic_cst[split_feature] != 0 and + not splitter.criterion.check_monotonicity( + splitter.monotonic_cst[split_feature], + lower_bound, + upper_bound, + ) + ): + return False + + return True + +cdef class MonotonicConstraintCondition(SplitCondition): + def __cinit__(self): + self.c.f = monotonic_constraint_condition + self.c.e = NULL + + cdef inline void _init_split(SplitRecord* self, intp_t start_pos) noexcept nogil: self.impurity_left = INFINITY self.impurity_right = INFINITY @@ -55,6 +143,10 @@ cdef inline void _init_split(SplitRecord* self, intp_t start_pos) noexcept nogil self.missing_go_to_left = False self.n_missing = 0 +# the default SplitRecord factory method simply mallocs a SplitRecord +cdef SplitRecord* _base_split_record_factory(SplitRecordFactoryEnv env) except NULL nogil: + return malloc(sizeof(SplitRecord)); + cdef class BaseSplitter: """This is an abstract interface for splitters. @@ -139,6 +231,9 @@ cdef class BaseSplitter: `SplitRecord`. """ return sizeof(SplitRecord) + + cdef SplitRecord* create_split_record(self) except NULL nogil: + return self.split_record_factory.f(self.split_record_factory.e) cdef class Splitter(BaseSplitter): """Abstract interface for supervised splitters.""" @@ -151,6 +246,9 @@ cdef class Splitter(BaseSplitter): float64_t min_weight_leaf, object random_state, const int8_t[:] monotonic_cst, + presplit_conditions : [SplitCondition] = None, + postsplit_conditions : [SplitCondition] = None, + listeners : [EventHandler] = None, *argv ): """ @@ -191,6 +289,61 @@ cdef class Splitter(BaseSplitter): self.monotonic_cst = monotonic_cst self.with_monotonic_cst = monotonic_cst is not None + self.event_broker = EventBroker(listeners, [NodeSplitEvent.SORT_FEATURE]) + + self.min_samples_leaf_condition = MinSamplesLeafCondition() + self.min_weight_leaf_condition = MinWeightLeafCondition() + + l_pre = [self.min_samples_leaf_condition] + l_post = [self.min_weight_leaf_condition] + + if(self.with_monotonic_cst): + self.monotonic_constraint_condition = MonotonicConstraintCondition() + l_pre.append(self.monotonic_constraint_condition) + l_post.append(self.monotonic_constraint_condition) + #self.presplit_conditions[offset] = self.monotonic_constraint_condition.c + #self.postsplit_conditions[offset] = self.monotonic_constraint_condition.c + #offset += 1 + + if presplit_conditions is not None: + l_pre += presplit_conditions + + if postsplit_conditions is not None: + l_post += postsplit_conditions + + self.presplit_conditions.resize(0) + self.add_presplit_conditions(l_pre) + + self.postsplit_conditions.resize(0) + self.add_postsplit_conditions(l_post) + + self.split_record_factory.f = _base_split_record_factory + self.split_record_factory.e = NULL + + def add_listeners(self, listeners: [EventHandler], event_types: [EventType]): + self.broker.add_listeners(listeners, event_types) + + def add_presplit_conditions(self, presplit_conditions): + self._add_conditions(&self.presplit_conditions, presplit_conditions) + + def add_postsplit_conditions(self, postsplit_conditions): + self._add_conditions(&self.postsplit_conditions, postsplit_conditions) + + cdef void _add_conditions( + self, + vector[SplitConditionClosure]* v, + split_conditions : [SplitCondition] + ): + cdef int offset, ct, i + + offset = v.size() + if split_conditions is not None: + ct = len(split_conditions) + v.resize(offset + ct) + for i in range(ct): + v[0][i + offset] = (split_conditions[i]).c + + def __reduce__(self): return (type(self), (self.criterion, self.max_features, @@ -400,6 +553,32 @@ cdef class Splitter(BaseSplitter): return 0 +cdef inline void shift_missing_values_to_left_if_required( + SplitRecord* best, + intp_t[::1] samples, + intp_t end, +) noexcept nogil: + """Shift missing value sample indices to the left of the split if required. + + Note: this should always be called at the very end because it will + move samples around, thereby affecting the criterion. + This affects the computation of the children impurity, which affects + the computation of the next node. + """ + cdef intp_t i, p, current_end + # The partitioner partitions the data such that the missing values are in + # samples[-n_missing:] for the criterion to consume. If the missing values + # are going to the right node, then the missing values are already in the + # correct position. If the missing values go left, then we move the missing + # values to samples[best.pos:best.pos+n_missing] and update `best.pos`. + if best.n_missing > 0 and best.missing_go_to_left: + for p in range(best.n_missing): + i = best.pos + p + current_end = end - 1 - p + samples[i], samples[current_end] = samples[current_end], samples[i] + best.pos += best.n_missing + + cdef inline intp_t node_split_best( Splitter splitter, Partitioner partitioner, @@ -437,6 +616,7 @@ cdef inline intp_t node_split_best( cdef uint32_t* random_state = &splitter.rand_r_state cdef SplitRecord best_split, current_split + cdef float64_t current_threshold cdef float64_t current_proxy_improvement = -INFINITY cdef float64_t best_proxy_improvement = -INFINITY @@ -458,6 +638,12 @@ cdef inline intp_t node_split_best( # n_total_constants = n_known_constants + n_found_constants cdef intp_t n_total_constants = n_known_constants + cdef bint conditions_hold = True + + # payloads for different node events + cdef NodeSortFeatureEventData sort_event_data + cdef NodeSplitEventData split_event_data + _init_split(&best_split, end) partitioner.init_node_split(start, end) @@ -506,6 +692,11 @@ cdef inline intp_t node_split_best( # f_j in the interval [n_total_constants, f_i[ current_split.feature = features[f_j] partitioner.sort_samples_and_feature_values(current_split.feature) + + # notify any interested parties which feature we're investingating splits for now + sort_event_data.feature = current_split.feature + splitter.event_broker.fire_event(NodeSplitEvent.SORT_FEATURE, &sort_event_data) + n_missing = partitioner.n_missing end_non_missing = end - n_missing @@ -552,28 +743,49 @@ cdef inline intp_t node_split_best( current_split.pos = p + # probably want to assign this to current_split.threshold later, + # but the code is so stateful that Write Everything Twice is the + # safer move here for now + current_threshold = ( + feature_values[p_prev] / 2.0 + feature_values[p] / 2.0 + ) + + # check pre split rejection criteria + conditions_hold = True + for condition in splitter.presplit_conditions: + if not condition.f( + splitter, current_split.feature, current_split.pos, + current_threshold, n_missing, missing_go_to_left, + lower_bound, upper_bound, condition.e + ): + conditions_hold = False + break + + if not conditions_hold: + continue + # Reject if min_samples_leaf is not guaranteed + # this can probably (and should) be removed as it is generalized + # by injectable split rejection criteria if splitter.check_presplit_conditions(¤t_split, n_missing, missing_go_to_left) == 1: continue criterion.update(current_split.pos) - # Reject if monotonicity constraints are not satisfied - if ( - with_monotonic_cst and - monotonic_cst[current_split.feature] != 0 and - not criterion.check_monotonicity( - monotonic_cst[current_split.feature], - lower_bound, - upper_bound, - ) - ): - continue - - # Reject if min_weight_leaf is not satisfied - if splitter.check_postsplit_conditions() == 1: + # check post split rejection criteria + conditions_hold = True + for condition in splitter.postsplit_conditions: + if not condition.f( + splitter, current_split.feature, current_split.pos, + current_threshold, n_missing, missing_go_to_left, + lower_bound, upper_bound, condition.e + ): + conditions_hold = False + break + + if not conditions_hold: continue - + current_proxy_improvement = criterion.proxy_impurity_improvement() if current_proxy_improvement > best_proxy_improvement: @@ -632,6 +844,7 @@ cdef inline intp_t node_split_best( current_split.pos = p best_split = current_split + # Reorganize into samples[start:best_split.pos] + samples[best_split.pos:end] if best_split.pos < end: partitioner.partition_samples_final( @@ -640,6 +853,7 @@ cdef inline intp_t node_split_best( best_split.feature, best_split.n_missing ) + criterion.init_missing(best_split.n_missing) criterion.missing_go_to_left = best_split.missing_go_to_left @@ -656,6 +870,7 @@ cdef inline intp_t node_split_best( shift_missing_values_to_left_if_required(&best_split, samples, end) + # Respect invariant for constant features: the original order of # element in features[:n_known_constants] must be preserved for sibling # and child nodes @@ -669,6 +884,7 @@ cdef inline intp_t node_split_best( # Return values parent_record.n_constant_features = n_total_constants split[0] = best_split + return 0 @@ -1017,6 +1233,7 @@ cdef class RandomSparseSplitter(Splitter): self.partitioner = SparsePartitioner( X, self.samples, self.n_samples, self.feature_values, missing_values_in_feature_mask ) + cdef int node_split( self, ParentInfo* parent_record, diff --git a/sklearn/tree/_test.pxd b/sklearn/tree/_test.pxd new file mode 100644 index 0000000000000..b8ae6cbe715c8 --- /dev/null +++ b/sklearn/tree/_test.pxd @@ -0,0 +1,21 @@ +from libcpp.vector cimport vector + +from ..utils._typedefs cimport float32_t, float64_t, intp_t, int32_t, uint32_t + +from ._tree cimport Node +from ._honesty cimport Interval as Cinterval + + +cdef class TestNode(): + cdef: + public list bounds + public int start_idx + public int n + + +cdef class HonestyTester(): + cdef: + Node* nodes + vector[Cinterval] intervals + const float32_t[:, :] X + const intp_t[::1] samples diff --git a/sklearn/tree/_test.pyx b/sklearn/tree/_test.pyx new file mode 100644 index 0000000000000..e36405d161395 --- /dev/null +++ b/sklearn/tree/_test.pyx @@ -0,0 +1,119 @@ +from collections import namedtuple +from libc.math cimport INFINITY + +from ._honest_tree import HonestTree + +from ._honesty cimport Honesty, HonestEnv, Views +from ._tree cimport BaseTree, Tree + + +Interval = namedtuple('Interval', ['lower', 'upper']) + + +cdef class TestNode(): + def __init__(self, bounds : [Interval], start_idx, n): + self.bounds = bounds + self.start_idx = start_idx + self.n = n + + def valid(self, float32_t[:, :] X, intp_t[:] samples): + for i in range(self.start_idx, self.start_idx + self.n): + for j in range(len(self.bounds)): + if X[samples[i]][j] < self.bounds[j].lower: + print("") + print(f"start_idx = {self.start_idx}") + print(f"n = {self.n}") + print(f"dimension = {j}") + print(f"X.shape = {X.shape}") + print(f"bounds = {self.bounds[j]}") + print(f"range = {[i for i in range(self.start_idx, self.start_idx + self.n)]}") + print(f"failed on {X[samples[i]][j]} < {self.bounds[j].lower}") + print(f"leaf feature values = {[ X[samples[ii]][j] for ii in range(self.start_idx, self.start_idx + self.n) ]}") + return False + + if X[samples[i]][j] > self.bounds[j].upper: + print("") + print(f"start_idx = {self.start_idx}") + print(f"n = {self.n}") + print(f"dimension = {j}") + print(f"X.shape = {X.shape}") + print(f"bounds = {self.bounds[j]}") + print(f"range = {[i for i in range(self.start_idx, self.start_idx + self.n)]}") + print(f"failed on {X[samples[i]][j]} > {self.bounds[j].upper}") + print(f"leaf feature values = {[ X[samples[ii]][j] for ii in range(self.start_idx, self.start_idx + self.n) ]}") + return False + + return True + + def to_dict(self): + return { + "bounds": self.bounds, + "start_idx": self.start_idx, + "n": self.n + } + + +cdef class HonestyTester(): + def __init__(self, honest_tree: HonestTree): + cdef Honesty honesty = honest_tree.honesty + cdef Tree t = honest_tree.target_tree.tree_ + + self.nodes = t.nodes + self.intervals = honesty.env.tree + self.X = honesty.views.X + self.samples = honesty.views.samples + + + #cdef struct Node: + # # Base storage structure for the nodes in a Tree object + # + # intp_t left_child # id of the left child of the node + # intp_t right_child # id of the right child of the node + # intp_t feature # Feature used for splitting the node + # float64_t threshold # Threshold value at the node + # float64_t impurity # Impurity of the node (i.e., the value of the criterion) + # intp_t n_node_samples # Number of samples at the node + # float64_t weighted_n_node_samples # Weighted number of samples at the node + # unsigned char missing_go_to_left # Whether features have missing values + + def get_invalid_nodes(self): + return [ + n for n in self.to_cells() + if not n.valid(self.X, self.samples) + ] + + + def to_cells(self, intp_t node_id = 0, bounds : [Interval] = None): + cdef Node* node = &self.nodes[node_id] + if bounds is None: + bounds = [ + Interval(-INFINITY, INFINITY) + for _ in range(self.X.shape[0]) + ] + + if node.feature < 0: + return [ + TestNode( + bounds, + self.intervals[node_id].start_idx, + self.intervals[node_id].n + ) + ] + else: + return self.to_cells( + node.left_child, + [ + Interval(bounds[j].lower, node.threshold) + if j == node.feature + else bounds[j] + for j in range(len(bounds)) + ] + ) + self.to_cells( + node.right_child, + [ + Interval(node.threshold, bounds[j].upper) + if j == node.feature + else bounds[j] + for j in range(len(bounds)) + ] + ) diff --git a/sklearn/tree/_tree.pxd b/sklearn/tree/_tree.pxd index 47a3cd2bd5c9d..6e549bc7adc1b 100644 --- a/sklearn/tree/_tree.pxd +++ b/sklearn/tree/_tree.pxd @@ -1,6 +1,17 @@ -# Authors: The scikit-learn developers +# Authors: Gilles Louppe +# Peter Prettenhofer +# Brian Holt +# Joel Nothman +# Arnaud Joly +# Jacob Schreiber +# Nelson Liu +# Haoyin Xu +# Samuel Carliles +# +# License: BSD 3 clause # SPDX-License-Identifier: BSD-3-Clause + # See _tree.pyx for details. import numpy as np @@ -9,7 +20,9 @@ cimport numpy as cnp from libcpp.unordered_map cimport unordered_map from libcpp.vector cimport vector -from ..utils._typedefs cimport float32_t, float64_t, intp_t, int32_t, uint8_t, uint32_t +from ..utils._typedefs cimport float32_t, float64_t, intp_t, uint8_t, int32_t, uint32_t + +from ._events cimport EventBroker, EventHandler from ._splitter cimport SplitRecord, Splitter @@ -35,6 +48,95 @@ cdef struct ParentInfo: float64_t impurity # the impurity of the parent intp_t n_constant_features # the number of constant features found in parent +# A record on the stack for depth-first tree growing +cdef struct StackRecord: + intp_t start + intp_t end + intp_t depth + intp_t parent + bint is_left + float64_t impurity + intp_t n_constant_features + float64_t lower_bound + float64_t upper_bound + +cdef extern from "" namespace "std" nogil: + cdef cppclass stack[T]: + ctypedef T value_type + stack() except + + bint empty() + void pop() + void push(T&) except + # Raise c++ exception for bad_alloc -> MemoryError + T& top() + +# A large portion of the tree build function was duplicated almost verbatim in the +# neurodata fork of sklearn. We refactor that out into its own function, and it's +# most convenient to encapsulate all the tree build state into its own env struct. +cdef enum TreeBuildStatus: + OK = 0 + EXCEPTION_OR_MEMORY_ERROR = -1 + EVENT_ERROR = -2 + +cdef struct BuildEnv: + # Parameters + intp_t max_depth + intp_t min_samples_leaf + float64_t min_weight_leaf + intp_t min_samples_split + float64_t min_impurity_decrease + + uint8_t store_leaf_values + + # Initial capacity + intp_t init_capacity + bint first + + intp_t start + intp_t end + intp_t depth + intp_t parent + bint is_left + intp_t n_node_samples + float64_t weighted_n_node_samples + intp_t node_id + float64_t right_child_min, left_child_min, right_child_max, left_child_max + + SplitRecord* split + + float64_t middle_value + bint is_leaf + intp_t max_depth_seen + + TreeBuildStatus rc + + stack[StackRecord] builder_stack + stack[StackRecord] update_stack + stack[StackRecord]* target_stack + StackRecord stack_record + + ParentInfo parent_record + + +# We add tree build events to notify interested parties of tree build state. +# Only current relevant events are implemented. +cdef enum TreeBuildEvent: + ADD_NODE = 1 + UPDATE_NODE = 2 + SET_ACTIVE_PARENT = 3 + +cdef struct TreeBuildSetActiveParentEventData: + intp_t parent_node_id + bint child_is_left + +cdef struct TreeBuildAddNodeEventData: + intp_t parent_node_id + intp_t node_id + bint is_leaf + bint is_left + intp_t feature + float64_t split_point + + cdef class BaseTree: # Inner structures: values are stored separately from node structure, @@ -165,7 +267,11 @@ cdef class TreeBuilder: cdef float64_t min_impurity_decrease # Impurity threshold for early stopping cdef cnp.ndarray initial_roots # Leaf nodes for streaming updates - cdef uint8_t store_leaf_values # Whether to store leaf values + cdef uint8_t store_leaf_values # Whether to store leaf values + + # event broker for distributing tree build events + cdef EventBroker event_broker + cpdef initialize_node_queue( self, @@ -182,7 +288,7 @@ cdef class TreeBuilder: object X, const float64_t[:, ::1] y, const float64_t[:] sample_weight=*, - const uint8_t[::1] missing_values_in_feature_mask=*, + const uint8_t[::1] missing_values_in_feature_mask=* ) cdef _check_input( diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index 943d5e6148538..ca2541354d9f1 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -1,9 +1,26 @@ -# Authors: The scikit-learn developers +# cython: language_level=3 +# cython: boundscheck=False, wraparound=False, initializedcheck=False, cdivision=True + +# Authors: Gilles Louppe +# Peter Prettenhofer +# Brian Holt +# Noel Dawe +# Satrajit Gosh +# Lars Buitinck +# Arnaud Joly +# Joel Nothman +# Fares Hedayati +# Jacob Schreiber +# Nelson Liu +# Haoyin Xu +# Samuel Carliles +# +# License: BSD 3 clause # SPDX-License-Identifier: BSD-3-Clause from cpython cimport Py_INCREF, PyObject, PyTypeObject from cython.operator cimport dereference as deref -from libc.math cimport isnan +from libc.math cimport isnan, NAN from libc.stdint cimport INTPTR_MAX from libc.stdlib cimport free, malloc from libc.string cimport memcpy, memset @@ -13,6 +30,7 @@ from libcpp cimport bool from libcpp.algorithm cimport pop_heap, push_heap from libcpp.vector cimport vector + import struct import numpy as np @@ -139,19 +157,6 @@ cdef class TreeBuilder: # Depth first builder --------------------------------------------------------- -# A record on the stack for depth-first tree growing -cdef struct StackRecord: - intp_t start - intp_t end - intp_t depth - intp_t parent - bint is_left - float64_t impurity - intp_t n_constant_features - float64_t lower_bound - float64_t upper_bound - - cdef class DepthFirstTreeBuilder(TreeBuilder): """Build a decision tree in depth-first fashion.""" @@ -165,6 +170,7 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): float64_t min_impurity_decrease, uint8_t store_leaf_values=False, cnp.ndarray initial_roots=None, + listeners : [EventHandler] =None ): self.splitter = splitter self.min_samples_split = min_samples_split @@ -175,6 +181,16 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): self.store_leaf_values = store_leaf_values self.initial_roots = initial_roots + self.event_broker = EventBroker( + listeners, + [ + TreeBuildEvent.ADD_NODE, + TreeBuildEvent.UPDATE_NODE, + TreeBuildEvent.SET_ACTIVE_PARENT + ] + ) + + def __reduce__(self): """Reduce re-implementation, for pickling.""" return(DepthFirstTreeBuilder, (self.splitter, @@ -250,6 +266,179 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): # convert dict to numpy array and store value self.initial_roots = np.array(list(false_roots.items())) + + cdef void _build_body(self, EventBroker broker, Tree tree, Splitter splitter, BuildEnv* e, bint update) noexcept nogil: + cdef TreeBuildEvent evt + + # payloads for different tree build events + cdef TreeBuildSetActiveParentEventData parent_event_data + cdef TreeBuildAddNodeEventData add_update_node_data + + while not e.target_stack.empty(): + e.stack_record = e.target_stack.top() + e.target_stack.pop() + + e.start = e.stack_record.start + e.end = e.stack_record.end + e.depth = e.stack_record.depth + e.parent = e.stack_record.parent + e.is_left = e.stack_record.is_left + e.parent_record.impurity = e.stack_record.impurity + e.parent_record.n_constant_features = e.stack_record.n_constant_features + e.parent_record.lower_bound = e.stack_record.lower_bound + e.parent_record.upper_bound = e.stack_record.upper_bound + + e.n_node_samples = e.end - e.start + + parent_event_data.parent_node_id = e.stack_record.parent + parent_event_data.child_is_left = e.stack_record.is_left + + # tree build state is kind of weird as implemented because + # the child node id is assigned after child node creation, and all + # situational awareness during creation is referenced to the parent node. + # so we fire an event indicating the current active parent. + if not broker.fire_event(TreeBuildEvent.SET_ACTIVE_PARENT, &parent_event_data): + e.rc = TreeBuildStatus.EVENT_ERROR + break + + splitter.node_reset(e.start, e.end, &e.weighted_n_node_samples) + + e.is_leaf = (e.depth >= e.max_depth or + e.n_node_samples < e.min_samples_split or + e.n_node_samples < 2 * e.min_samples_leaf or + e.weighted_n_node_samples < 2 * e.min_weight_leaf) + + if e.first: + e.parent_record.impurity = splitter.node_impurity() + e.first = 0 + + # impurity == 0 with tolerance due to rounding errors + e.is_leaf = e.is_leaf or e.parent_record.impurity <= EPSILON + + add_update_node_data.parent_node_id = e.parent + add_update_node_data.is_left = e.is_left + add_update_node_data.feature = -1 + add_update_node_data.split_point = NAN + + if not e.is_leaf: + splitter.node_split( + &e.parent_record, + e.split, + ) + + # If EPSILON=0 in the below comparison, float precision + # issues stop splitting, producing trees that are + # dissimilar to v0.18 + e.is_leaf = (e.is_leaf or e.split.pos >= e.end or + (e.split.improvement + EPSILON < + e.min_impurity_decrease)) + + if not e.is_leaf: + add_update_node_data.feature = e.split.feature + add_update_node_data.split_point = e.split.threshold + + + if update == 1: + e.node_id = tree._update_node( + e.parent, e.is_left, e.is_leaf, e.split, + e.parent_record.impurity, e.n_node_samples, e.weighted_n_node_samples, + e.split.missing_go_to_left + ) + evt = TreeBuildEvent.UPDATE_NODE + else: + e.node_id = tree._add_node( + e.parent, e.is_left, e.is_leaf, e.split, + e.parent_record.impurity, e.n_node_samples, e.weighted_n_node_samples, + e.split.missing_go_to_left + ) + evt = TreeBuildEvent.ADD_NODE + + if e.node_id == INTPTR_MAX: + e.rc = TreeBuildStatus.EXCEPTION_OR_MEMORY_ERROR + break + + add_update_node_data.node_id = e.node_id + add_update_node_data.is_leaf = e.is_leaf + + # now that all relevant information has been accumulated, + # notify interested parties that a node has been added/updated + broker.fire_event(evt, &add_update_node_data) + + # Store value for all nodes, to facilitate tree/model + # inspection and interpretation + splitter.node_value(tree.value + e.node_id * tree.value_stride) + if splitter.with_monotonic_cst: + splitter.clip_node_value( + tree.value + e.node_id * tree.value_stride, + e.parent_record.lower_bound, + e.parent_record.upper_bound + ) + + if not e.is_leaf: + if ( + not splitter.with_monotonic_cst or + splitter.monotonic_cst[e.split.feature] == 0 + ): + # Split on a feature with no monotonicity constraint + + # Current bounds must always be propagated to both children. + # If a monotonic constraint is active, bounds are used in + # node value clipping. + e.left_child_min = e.right_child_min = e.parent_record.lower_bound + e.left_child_max = e.right_child_max = e.parent_record.upper_bound + elif splitter.monotonic_cst[e.split.feature] == 1: + # Split on a feature with monotonic increase constraint + e.left_child_min = e.parent_record.lower_bound + e.right_child_max = e.parent_record.upper_bound + + # Lower bound for right child and upper bound for left child + # are set to the same value. + e.middle_value = splitter.criterion.middle_value() + e.right_child_min = e.middle_value + e.left_child_max = e.middle_value + else: # i.e. splitter.monotonic_cst[e.split.feature] == -1 + # Split on a feature with monotonic decrease constraint + e.right_child_min = e.parent_record.lower_bound + e.left_child_max = e.parent_record.upper_bound + + # Lower bound for left child and upper bound for right child + # are set to the same value. + e.middle_value = splitter.criterion.middle_value() + e.left_child_min = e.middle_value + e.right_child_max = e.middle_value + + # Push right child on stack + e.builder_stack.push({ + "start": e.split.pos, + "end": e.end, + "depth": e.depth + 1, + "parent": e.node_id, + "is_left": 0, + "impurity": e.split.impurity_right, + "n_constant_features": e.parent_record.n_constant_features, + "lower_bound": e.right_child_min, + "upper_bound": e.right_child_max, + }) + + # Push left child on stack + e.builder_stack.push({ + "start": e.start, + "end": e.split.pos, + "depth": e.depth + 1, + "parent": e.node_id, + "is_left": 1, + "impurity": e.split.impurity_left, + "n_constant_features": e.parent_record.n_constant_features, + "lower_bound": e.left_child_min, + "upper_bound": e.left_child_max, + }) + elif e.store_leaf_values and e.is_leaf: + # copy leaf values to leaf_values array + splitter.node_samples(tree.value_samples[e.node_id]) + + if e.depth > e.max_depth_seen: + e.max_depth_seen = e.depth + cpdef build( self, Tree tree, @@ -263,31 +452,31 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): # check input X, y, sample_weight = self._check_input(X, y, sample_weight) - # Parameters cdef Splitter splitter = self.splitter - cdef intp_t max_depth = self.max_depth - cdef intp_t min_samples_leaf = self.min_samples_leaf - cdef float64_t min_weight_leaf = self.min_weight_leaf - cdef intp_t min_samples_split = self.min_samples_split - cdef float64_t min_impurity_decrease = self.min_impurity_decrease - - cdef uint8_t store_leaf_values = self.store_leaf_values cdef cnp.ndarray initial_roots = self.initial_roots + cdef BuildEnv e + e.max_depth = self.max_depth + e.min_samples_leaf = self.min_samples_leaf + e.min_weight_leaf = self.min_weight_leaf + e.min_samples_split = self.min_samples_split + e.min_impurity_decrease = self.min_impurity_decrease + + e.store_leaf_values = self.store_leaf_values + # Initial capacity - cdef intp_t init_capacity - cdef bint first = 0 + e.first = 0 if initial_roots is None: # Recursive partition (without actual recursion) splitter.init(X, y, sample_weight, missing_values_in_feature_mask) if tree.max_depth <= 10: - init_capacity = (2 ** (tree.max_depth + 1)) - 1 + e.init_capacity = (2 ** (tree.max_depth + 1)) - 1 else: - init_capacity = 2047 + e.init_capacity = 2047 - tree._resize(init_capacity) - first = 1 + tree._resize(e.init_capacity) + e.first = 1 else: # convert numpy array back to dict false_roots = {} @@ -297,39 +486,24 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): # reset the root array self.initial_roots = None - cdef intp_t start = 0 - cdef intp_t end = 0 - cdef intp_t depth - cdef intp_t parent - cdef bint is_left - cdef intp_t n_node_samples = splitter.n_samples - cdef float64_t weighted_n_node_samples - cdef intp_t node_id - cdef float64_t right_child_min, left_child_min, right_child_max, left_child_max - - cdef SplitRecord split - cdef SplitRecord* split_ptr = malloc(splitter.pointer_size()) - - cdef float64_t middle_value - cdef bint is_leaf - cdef intp_t max_depth_seen = -1 if first else tree.max_depth + e.start = 0 + e.end = 0 + e.n_node_samples = splitter.n_samples + e.split = self.splitter.create_split_record() - cdef intp_t rc = 0 + e.max_depth_seen = -1 if e.first else tree.max_depth - cdef stack[StackRecord] builder_stack - cdef stack[StackRecord] update_stack - cdef StackRecord stack_record + e.rc = TreeBuildStatus.OK - cdef ParentInfo parent_record - _init_parent_record(&parent_record) + _init_parent_record(&e.parent_record) - if not first: + if not e.first: # push reached leaf nodes onto stack for key, value in reversed(sorted(false_roots.items())): - end += value[0] - update_stack.push({ - "start": start, - "end": end, + e.end += value[0] + e.update_stack.push({ + "start": e.start, + "end": e.end, "depth": value[1], "parent": key[0], "is_left": key[1], @@ -338,12 +512,12 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): "lower_bound": -INFINITY, "upper_bound": INFINITY, }) - start += value[0] + e.start += value[0] else: # push root node onto stack - builder_stack.push({ + e.builder_stack.push({ "start": 0, - "end": n_node_samples, + "end": e.n_node_samples, "depth": 0, "parent": _TREE_UNDEFINED, "is_left": 0, @@ -354,276 +528,26 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): }) with nogil: - while not update_stack.empty(): - stack_record = update_stack.top() - update_stack.pop() - - start = stack_record.start - end = stack_record.end - depth = stack_record.depth - parent = stack_record.parent - is_left = stack_record.is_left - parent_record.impurity = stack_record.impurity - parent_record.n_constant_features = stack_record.n_constant_features - parent_record.lower_bound = stack_record.lower_bound - parent_record.upper_bound = stack_record.upper_bound - - n_node_samples = end - start - splitter.node_reset(start, end, &weighted_n_node_samples) - - is_leaf = (depth >= max_depth or - n_node_samples < min_samples_split or - n_node_samples < 2 * min_samples_leaf or - weighted_n_node_samples < 2 * min_weight_leaf) - - if first: - parent_record.impurity = splitter.node_impurity() - first = 0 - - # impurity == 0 with tolerance due to rounding errors - is_leaf = is_leaf or parent_record.impurity <= EPSILON - - if not is_leaf: - splitter.node_split( - &parent_record, - split_ptr, - ) - - # assign local copy of SplitRecord to assign - # pos, improvement, and impurity scores - split = deref(split_ptr) - - # If EPSILON=0 in the below comparison, float precision - # issues stop splitting, producing trees that are - # dissimilar to v0.18 - is_leaf = (is_leaf or split.pos >= end or - (split.improvement + EPSILON < - min_impurity_decrease)) - - node_id = tree._update_node(parent, is_left, is_leaf, split_ptr, - parent_record.impurity, - n_node_samples, weighted_n_node_samples, - split.missing_go_to_left) - - if node_id == INTPTR_MAX: - rc = -1 - break - - # Store value for all nodes, to facilitate tree/model - # inspection and interpretation - splitter.node_value(tree.value + node_id * tree.value_stride) - if splitter.with_monotonic_cst: - splitter.clip_node_value( - tree.value + node_id * tree.value_stride, - parent_record.lower_bound, - parent_record.upper_bound - ) - - if not is_leaf: - if ( - not splitter.with_monotonic_cst or - splitter.monotonic_cst[split.feature] == 0 - ): - # Split on a feature with no monotonicity constraint - - # Current bounds must always be propagated to both children. - # If a monotonic constraint is active, bounds are used in - # node value clipping. - left_child_min = right_child_min = parent_record.lower_bound - left_child_max = right_child_max = parent_record.upper_bound - elif splitter.monotonic_cst[split.feature] == 1: - # Split on a feature with monotonic increase constraint - left_child_min = parent_record.lower_bound - right_child_max = parent_record.upper_bound - - # Lower bound for right child and upper bound for left child - # are set to the same value. - middle_value = splitter.criterion.middle_value() - right_child_min = middle_value - left_child_max = middle_value - else: # i.e. splitter.monotonic_cst[split.feature] == -1 - # Split on a feature with monotonic decrease constraint - right_child_min = parent_record.lower_bound - left_child_max = parent_record.upper_bound - - # Lower bound for left child and upper bound for right child - # are set to the same value. - middle_value = splitter.criterion.middle_value() - left_child_min = middle_value - right_child_max = middle_value - - # Push right child on stack - builder_stack.push({ - "start": split.pos, - "end": end, - "depth": depth + 1, - "parent": node_id, - "is_left": 0, - "impurity": split.impurity_right, - "n_constant_features": parent_record.n_constant_features, - "lower_bound": right_child_min, - "upper_bound": right_child_max, - }) - - # Push left child on stack - builder_stack.push({ - "start": start, - "end": split.pos, - "depth": depth + 1, - "parent": node_id, - "is_left": 1, - "impurity": split.impurity_left, - "n_constant_features": parent_record.n_constant_features, - "lower_bound": left_child_min, - "upper_bound": left_child_max, - }) - elif store_leaf_values and is_leaf: - # copy leaf values to leaf_values array - splitter.node_samples(tree.value_samples[node_id]) - - if depth > max_depth_seen: - max_depth_seen = depth - - while not builder_stack.empty(): - stack_record = builder_stack.top() - builder_stack.pop() - - start = stack_record.start - end = stack_record.end - depth = stack_record.depth - parent = stack_record.parent - is_left = stack_record.is_left - parent_record.impurity = stack_record.impurity - parent_record.n_constant_features = stack_record.n_constant_features - parent_record.lower_bound = stack_record.lower_bound - parent_record.upper_bound = stack_record.upper_bound - - n_node_samples = end - start - splitter.node_reset(start, end, &weighted_n_node_samples) - - is_leaf = (depth >= max_depth or - n_node_samples < min_samples_split or - n_node_samples < 2 * min_samples_leaf or - weighted_n_node_samples < 2 * min_weight_leaf) - - if first: - parent_record.impurity = splitter.node_impurity() - first=0 - - # impurity == 0 with tolerance due to rounding errors - is_leaf = is_leaf or parent_record.impurity <= EPSILON - - if not is_leaf: - splitter.node_split( - &parent_record, - split_ptr, - ) - - # assign local copy of SplitRecord to assign - # pos, improvement, and impurity scores - split = deref(split_ptr) - - # If EPSILON=0 in the below comparison, float precision - # issues stop splitting, producing trees that are - # dissimilar to v0.18 - is_leaf = (is_leaf or split.pos >= end or - (split.improvement + EPSILON < - min_impurity_decrease)) - - node_id = tree._add_node(parent, is_left, is_leaf, split_ptr, - parent_record.impurity, n_node_samples, - weighted_n_node_samples, split.missing_go_to_left) - - if node_id == INTPTR_MAX: - rc = -1 - break - - # Store value for all nodes, to facilitate tree/model - # inspection and interpretation - splitter.node_value(tree.value + node_id * tree.value_stride) - if splitter.with_monotonic_cst: - splitter.clip_node_value( - tree.value + node_id * tree.value_stride, - parent_record.lower_bound, - parent_record.upper_bound - ) + e.target_stack = &e.update_stack + self._build_body(self.event_broker, tree, splitter, &e, 1) - if not is_leaf: - if ( - not splitter.with_monotonic_cst or - splitter.monotonic_cst[split.feature] == 0 - ): - # Split on a feature with no monotonicity constraint - - # Current bounds must always be propagated to both children. - # If a monotonic constraint is active, bounds are used in - # node value clipping. - left_child_min = right_child_min = parent_record.lower_bound - left_child_max = right_child_max = parent_record.upper_bound - elif splitter.monotonic_cst[split.feature] == 1: - # Split on a feature with monotonic increase constraint - left_child_min = parent_record.lower_bound - right_child_max = parent_record.upper_bound - - # Lower bound for right child and upper bound for left child - # are set to the same value. - middle_value = splitter.criterion.middle_value() - right_child_min = middle_value - left_child_max = middle_value - else: # i.e. splitter.monotonic_cst[split.feature] == -1 - # Split on a feature with monotonic decrease constraint - right_child_min = parent_record.lower_bound - left_child_max = parent_record.upper_bound - - # Lower bound for left child and upper bound for right child - # are set to the same value. - middle_value = splitter.criterion.middle_value() - left_child_min = middle_value - right_child_max = middle_value - - # Push right child on stack - builder_stack.push({ - "start": split.pos, - "end": end, - "depth": depth + 1, - "parent": node_id, - "is_left": 0, - "impurity": split.impurity_right, - "n_constant_features": parent_record.n_constant_features, - "lower_bound": right_child_min, - "upper_bound": right_child_max, - }) - - # Push left child on stack - builder_stack.push({ - "start": start, - "end": split.pos, - "depth": depth + 1, - "parent": node_id, - "is_left": 1, - "impurity": split.impurity_left, - "n_constant_features": parent_record.n_constant_features, - "lower_bound": left_child_min, - "upper_bound": left_child_max, - }) - elif store_leaf_values and is_leaf: - # copy leaf values to leaf_values array - splitter.node_samples(tree.value_samples[node_id]) - - if depth > max_depth_seen: - max_depth_seen = depth + e.target_stack = &e.builder_stack + self._build_body(self.event_broker, tree, splitter, &e, 0) - if rc >= 0: - rc = tree._resize_c(tree.node_count) + if e.rc >= 0: + e.rc = tree._resize_c(tree.node_count) - if rc >= 0: - tree.max_depth = max_depth_seen + if e.rc >= 0: + tree.max_depth = e.max_depth_seen # free the memory created for the SplitRecord pointer - free(split_ptr) + free(e.split) - if rc == -1: + if e.rc == TreeBuildStatus.EXCEPTION_OR_MEMORY_ERROR: raise MemoryError() + + if e.rc == TreeBuildStatus.EVENT_ERROR: + raise RuntimeError("Event handler failure") # Best first builder ---------------------------------------------------------- cdef struct FrontierRecord: @@ -678,6 +602,7 @@ cdef class BestFirstTreeBuilder(TreeBuilder): float64_t min_impurity_decrease, uint8_t store_leaf_values=False, cnp.ndarray initial_roots=None, + listeners : [EventHandler] =None ): self.splitter = splitter self.min_samples_split = min_samples_split @@ -689,6 +614,8 @@ cdef class BestFirstTreeBuilder(TreeBuilder): self.store_leaf_values = store_leaf_values self.initial_roots = initial_roots + self.event_broker = EventBroker(listeners, [TreeBuildEvent.ADD_NODE, TreeBuildEvent.UPDATE_NODE]) + def __reduce__(self): """Reduce re-implementation, for pickling.""" return(BestFirstTreeBuilder, (self.splitter, diff --git a/sklearn/tree/meson.build b/sklearn/tree/meson.build index 04d1d5f353d02..12ffc2be9e8d7 100644 --- a/sklearn/tree/meson.build +++ b/sklearn/tree/meson.build @@ -2,6 +2,9 @@ tree_extension_metadata = { '_tree': {'sources': ['_tree.pyx'], 'override_options': ['cython_language=cpp', 'optimization=3']}, + '_sort': + {'sources': ['_sort.pyx'], + 'override_options': ['cython_language=cpp', 'optimization=3']}, '_splitter': {'sources': ['_splitter.pyx'], 'override_options': ['cython_language=cpp', 'optimization=3']}, @@ -14,6 +17,15 @@ tree_extension_metadata = { '_utils': {'sources': ['_utils.pyx'], 'override_options': ['cython_language=cpp', 'optimization=3']}, + '_events': + {'sources': ['_events.pyx'], + 'override_options': ['cython_language=cpp', 'optimization=3']}, + '_honesty': + {'sources': ['_honesty.pyx'], + 'override_options': ['cython_language=cpp', 'optimization=3']}, + '_test': + {'sources': ['_test.pyx'], + 'override_options': ['cython_language=cpp', 'optimization=3']} } foreach ext_name, ext_dict : tree_extension_metadata diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py index 656a7257521ce..d533041430f80 100644 --- a/sklearn/tree/tests/test_tree.py +++ b/sklearn/tree/tests/test_tree.py @@ -35,6 +35,9 @@ DENSE_SPLITTERS, SPARSE_SPLITTERS, ) +from sklearn.tree._honesty import Honesty +from sklearn.tree._honest_tree import HonestDecisionTree +from sklearn.tree._test import HonestyTester from sklearn.tree._tree import ( NODE_DTYPE, TREE_LEAF, @@ -198,6 +201,137 @@ } +def _moving_avg_cov(n_dim, rho): + # Create a meshgrid of indices + i, j = np.meshgrid(np.arange(1, n_dim + 1), np.arange(1, n_dim + 1), indexing="ij") + + # Calculate the covariance matrix using the corrected formula + cov_matrix = rho ** np.abs(i - j) + + # Apply the banding condition + cov_matrix[abs(i - j) > 1] = 0 + return cov_matrix + + +def _autoregressive_cov(n_dim, rho): + # Create a meshgrid of indices + i, j = np.meshgrid(np.arange(1, n_dim + 1), np.arange(1, n_dim + 1), indexing="ij") + + # Calculate the covariance matrix using the corrected formula + cov_matrix = rho ** np.abs(i - j) + + return cov_matrix + + +def make_trunk_classification( + n_samples, + n_dim, + n_informative=1, + simulation: str = "trunk", + mu_0: float = 0, + mu_1: float = 1, + rho: int = 0, + band_type: str = "ma", + return_params: bool = False, + mix: float = 0.5, + seed=None, +): + if n_dim < n_informative: + raise ValueError( + f"Number of informative dimensions {n_informative} must be less than number " + f"of dimensions, {n_dim}" + ) + rng = np.random.default_rng(seed=seed) + rng1 = np.random.default_rng(seed=seed) + mu_0 = np.array([mu_0 / np.sqrt(i) for i in range(1, n_informative + 1)]) + mu_1 = np.array([mu_1 / np.sqrt(i) for i in range(1, n_informative + 1)]) + if rho != 0: + if band_type == "ma": + cov = _moving_avg_cov(n_informative, rho) + elif band_type == "ar": + cov = _autoregressive_cov(n_informative, rho) + else: + raise ValueError(f'Band type {band_type} must be one of "ma", or "ar".') + else: + cov = np.identity(n_informative) + if mix < 0 or mix > 1: + raise ValueError("Mix must be between 0 and 1.") + # speed up computations for large multivariate normal matrix with SVD approximation + if n_informative > 1000: + method = "cholesky" + else: + method = "svd" + if simulation == "trunk": + X = np.vstack( + ( + rng.multivariate_normal(mu_0, cov, n_samples // 2, method=method), + rng1.multivariate_normal(mu_1, cov, n_samples // 2, method=method), + ) + ) + elif simulation == "trunk_overlap": + mixture_idx = rng.choice( + 2, n_samples // 2, replace=True, shuffle=True, p=[mix, 1 - mix] + ) + norm_params = [[mu_0, cov], [mu_1, cov]] + X_mixture = np.fromiter( + ( + rng.multivariate_normal(*(norm_params[i]), size=1, method=method) + for i in mixture_idx + ), + dtype=np.dtype((float, n_informative)), + ) + X_mixture_2 = np.fromiter( + ( + rng1.multivariate_normal(*(norm_params[i]), size=1, method=method) + for i in mixture_idx + ), + dtype=np.dtype((float, n_informative)), + ) + X = np.vstack( + ( + X_mixture.reshape(n_samples // 2, n_informative), + X_mixture_2.reshape(n_samples // 2, n_informative), + ) + ) + elif simulation == "trunk_mix": + mixture_idx = rng.choice( + 2, n_samples // 2, replace=True, shuffle=True, p=[mix, 1 - mix] + ) + norm_params = [[mu_0, cov], [mu_1, cov]] + X_mixture = np.fromiter( + ( + rng1.multivariate_normal(*(norm_params[i]), size=1, method=method) + for i in mixture_idx + ), + dtype=np.dtype((float, n_informative)), + ) + X = np.vstack( + ( + rng.multivariate_normal( + np.zeros(n_informative), cov, n_samples // 2, method=method + ), + X_mixture.reshape(n_samples // 2, n_informative), + ) + ) + else: + raise ValueError(f"Simulation must be: trunk, trunk_overlap, trunk_mix") + if n_dim > n_informative: + X = np.hstack( + (X, rng.normal(loc=0, scale=1, size=(X.shape[0], n_dim - n_informative))) + ) + y = np.concatenate((np.zeros(n_samples // 2), np.ones(n_samples // 2))) + if return_params: + returns = [X, y] + if simulation == "trunk": + returns += [[mu_0, mu_1], [cov, cov]] + elif simulation == "trunk-overlap": + returns += [[np.zeros(n_informative), np.zeros(n_informative)], [cov, cov]] + elif simulation == "trunk-mix": + returns += [*list(zip(*norm_params)), X_mixture] + return returns + return X, y + + def assert_tree_equal(d, s, message): assert ( s.node_count == d.node_count @@ -334,6 +468,162 @@ def test_iris(): name, criterion, score ) +def test_honest_iris(): + import json + + for criterion in CLF_CRITERIONS: + hf = HonestDecisionTree( + target_tree_class=DecisionTreeClassifier, + target_tree_kwargs={ + 'criterion': criterion, + 'random_state': 0, + 'store_leaf_values': True + } + ) + hf.fit(iris.data, iris.target) + + # verify their apply results are identical + dishonest = hf.target_tree.apply(iris.data) + honest = hf.apply(iris.data) + assert np.sum((honest - dishonest)**2) == 0, ( + "Failed with apply delta. dishonest: {0}, honest: {1}".format( + dishonest, honest + ) + ) + + # # verify their predict results are identical + # # technically they may correctly differ, + # # but at least in this test case they tend not to, + # # so it's a reasonable smoke test + # dishonest = hf.target_tree.predict(iris.data) + # honest = hf.predict(iris.data) + # assert np.sum((honest - dishonest)**2) == 0, ( + # "Failed with predict delta. dishonest: {0}, honest: {1}".format( + # dishonest, honest + # ) + # ) + + # verify that at least some leaf sample sets + # are in fact different for corresponding leaves. + # again, possible to fail by chance, + # but usually a reasonable smoke test + leaf_eq = [] + leaf_ct = 0 + for i in range(hf.tree_.node_count): + if hf.honesty.is_leaf(i): + leaf_ct += 1 + dishonest = Honesty.get_value_samples_ndarray(hf.target_tree.tree_, i) + honest = Honesty.get_value_samples_ndarray(hf.tree_, i) + uniques = np.unique(np.concatenate((dishonest, honest))) + dishonest_hist, _ = np.histogram(dishonest, bins=len(uniques)) + honest_hist, _ = np.histogram(honest, bins=len(uniques)) + if np.array_equal(dishonest_hist, honest_hist): + leaf_eq.append(i) + + assert len(leaf_eq) != leaf_ct, ( + "Failed with all leaves equal: {0}".format(leaf_eq) + ) + + # check accuracy + score = accuracy_score(hf.target_tree.predict(iris.data), iris.target) + assert score > 0.9, "Failed with {0}, criterion = {1} and dishonest score = {2}".format( + "DecisionTreeClassifier", criterion, score + ) + # score = accuracy_score(hf.predict(iris.data), iris.target) + # assert score > 0.9, "Failed with {0}, criterion = {1} and honest score = {2}".format( + # "DecisionTreeClassifier", criterion, score + # ) + + # check predict_proba + dishonest_proba = hf.target_tree.predict_log_proba(iris.data) + honest_proba = hf.predict_log_proba(iris.data) + assert len(dishonest_proba) == len(honest_proba), (( + "Mismatched predict_log_proba: len(dishonest_proba) = {0}, " + "len(honest_proba) = {1}" + ).format(len(dishonest_proba), len(honest_proba))) + + for i in range(len(dishonest_proba)): + assert np.all(dishonest_proba[i] == honest_proba[i]), (( + "Failed with predict_log_proba delta row {0}. " + "dishonest: {1}, honest: {2}" + ).format(i, dishonest_proba[i], honest_proba[i])) + + # verify no invalid nodes in honest tree + ht = HonestyTester(hf) + invalid_nodes = ht.get_invalid_nodes() + invalid_nodes_dict = [node.to_dict() if hasattr(node, 'to_dict') else node for node in invalid_nodes] + invalid_nodes_json = json.dumps(invalid_nodes_dict, indent=4) + assert len(invalid_nodes) == 0, "Failed with invalid nodes: {0}".format(invalid_nodes_json) + + +def test_honest_separation(): + # verify that splits are made independently of the honest data set. + # we do this by eliminating randomness from the training process, + # running repeated trials with honest Y labels shuffled, and verifying + # that the splits do not change. + N_ITER = 100 + SAMPLE_SIZE = 1024 + RANDOM_STATE = 1 + HONEST_PRIOR = "ignore" + HONEST_FRACTION = 0.9 + + X, y = make_trunk_classification( + n_samples=SAMPLE_SIZE, + n_dim=1, + n_informative=1, + seed=0, + ) + X_t = np.concatenate(( + X[: SAMPLE_SIZE // 2], + X[SAMPLE_SIZE // 2 :] + )) + y_t = np.concatenate((np.zeros(SAMPLE_SIZE // 2), np.ones(SAMPLE_SIZE // 2))) + + + tree=HonestDecisionTree( + target_tree_class=DecisionTreeClassifier, + target_tree_kwargs={ + "criterion": "gini", + "random_state": RANDOM_STATE + }, + honest_prior=HONEST_PRIOR, + honest_fraction=HONEST_FRACTION + ) + tree.fit(X_t, y_t.ravel()) + honest_tree = tree.tree_ + structure_tree = honest_tree.target_tree + old_threshold = structure_tree.threshold.copy() + old_y = y_t.copy() + + honest_indices = tree.honest_indices_ + + for _ in range(N_ITER): + y_perm = y_t.copy() + honest_shuffled = honest_indices.copy() + np.random.shuffle(honest_shuffled) + for i in range(len(honest_indices)): + y_perm[honest_indices[i]] = y_t[honest_shuffled[i]] + + assert(not np.array_equal(y_t, y_perm)) + assert(not np.array_equal(old_y, y_perm)) + + tree=HonestDecisionTree( + target_tree_class=DecisionTreeClassifier, + target_tree_kwargs={ + "criterion": "gini", + "random_state": RANDOM_STATE + }, + honest_prior=HONEST_PRIOR, + honest_fraction=HONEST_FRACTION + ) + tree.fit(X_t, y_perm.ravel()) + honest_tree = tree.tree_ + structure_tree = honest_tree.target_tree + + assert(np.array_equal(old_threshold, structure_tree.threshold)) + old_threshold = structure_tree.threshold.copy() + old_y = y_perm.copy() + @pytest.mark.parametrize("name, Tree", REG_TREES.items()) @pytest.mark.parametrize("criterion", REG_CRITERIONS)