From d3bcba5155d36a6cb15d929850fb70142fb0f4d8 Mon Sep 17 00:00:00 2001 From: Thibault Cordier Date: Mon, 1 Jul 2024 16:47:25 +0200 Subject: [PATCH 01/46] UPD: factorize MapieClassifier methods into several non-conformity scores + generic adaptation --- doc/api.rst | 5 + mapie/_machine_precision.py | 2 +- mapie/classification.py | 792 ++---------------- mapie/conformity_scores/__init__.py | 19 +- mapie/conformity_scores/bounds/__init__.py | 10 + mapie/conformity_scores/bounds/absolute.py | 52 ++ mapie/conformity_scores/bounds/gamma.py | 86 ++ .../residuals.py} | 152 +--- mapie/conformity_scores/checks.py | 38 - mapie/conformity_scores/classification.py | 103 +++ mapie/conformity_scores/interface.py | 256 ++++++ .../{conformity_scores.py => regression.py} | 298 ++----- mapie/conformity_scores/sets/__init__.py | 10 + mapie/conformity_scores/sets/aps.py | 497 +++++++++++ mapie/conformity_scores/sets/lac.py | 207 +++++ mapie/conformity_scores/sets/topk.py | 212 +++++ mapie/conformity_scores/sets/utils.py | 401 +++++++++ mapie/conformity_scores/utils.py | 102 ++- mapie/regression/regression.py | 34 +- mapie/regression/time_series_regression.py | 8 +- mapie/tests/test_classification.py | 60 +- mapie/tests/test_conformity_scores.py | 75 +- mapie/tests/test_conformity_scores_sets.py | 37 + mapie/tests/test_regression.py | 19 +- ..._utils_classification_conformity_scores.py | 4 +- 25 files changed, 2216 insertions(+), 1263 deletions(-) create mode 100644 mapie/conformity_scores/bounds/__init__.py create mode 100644 mapie/conformity_scores/bounds/absolute.py create mode 100644 mapie/conformity_scores/bounds/gamma.py rename mapie/conformity_scores/{residual_conformity_scores.py => bounds/residuals.py} (69%) delete mode 100644 mapie/conformity_scores/checks.py create mode 100644 mapie/conformity_scores/classification.py create mode 100644 mapie/conformity_scores/interface.py rename mapie/conformity_scores/{conformity_scores.py => regression.py} (50%) create mode 100644 mapie/conformity_scores/sets/__init__.py create mode 100644 mapie/conformity_scores/sets/aps.py create mode 100644 mapie/conformity_scores/sets/lac.py create mode 100644 mapie/conformity_scores/sets/topk.py create mode 100644 mapie/conformity_scores/sets/utils.py create mode 100644 mapie/tests/test_conformity_scores_sets.py diff --git a/doc/api.rst b/doc/api.rst index 417bddd26..a36957f36 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -80,9 +80,14 @@ Conformity scores :toctree: generated/ :template: class.rst + conformity_scores.BaseRegressionScore conformity_scores.AbsoluteConformityScore conformity_scores.GammaConformityScore conformity_scores.ResidualNormalisedScore + conformity_scores.BaseClassificationScore + conformity_scores.LAC + conformity_scores.APS + conformity_scores.TopK Resampling ========== diff --git a/mapie/_machine_precision.py b/mapie/_machine_precision.py index a23c44f5b..b4a153cae 100644 --- a/mapie/_machine_precision.py +++ b/mapie/_machine_precision.py @@ -1,5 +1,5 @@ import numpy as np -EPSILON = np.finfo(np.float64).eps +EPSILON = np.float64(1e-8) __all__ = ["EPSILON"] diff --git a/mapie/classification.py b/mapie/classification.py index 7aff7d024..232d76251 100644 --- a/mapie/classification.py +++ b/mapie/classification.py @@ -1,30 +1,27 @@ from __future__ import annotations import warnings -from typing import Any, Iterable, Optional, Tuple, Union, cast +from typing import Iterable, Optional, Tuple, Union, cast import numpy as np from sklearn.base import BaseEstimator, ClassifierMixin from sklearn.model_selection import (BaseCrossValidator, BaseShuffleSplit, StratifiedShuffleSplit) -from sklearn.preprocessing import LabelEncoder, label_binarize +from sklearn.preprocessing import LabelEncoder from sklearn.utils import _safe_indexing, check_random_state from sklearn.utils.multiclass import (check_classification_targets, type_of_target) from sklearn.utils.validation import (_check_y, _num_samples, check_is_fitted, indexable) -from mapie._machine_precision import EPSILON from mapie._typing import ArrayLike, NDArray +from mapie.conformity_scores import BaseClassificationScore +from mapie.conformity_scores.utils import check_classification_conformity_score +from mapie.conformity_scores.sets.utils import get_true_label_position from mapie.estimator.classifier import EnsembleClassifier -from mapie.metrics import classification_mean_width_score from mapie.utils import (check_alpha, check_alpha_and_n_samples, check_cv, check_estimator_classification, check_n_features_in, - check_n_jobs, check_null_weight, check_verbose, - compute_quantiles) -from mapie.conformity_scores.utils import ( - get_true_label_position -) + check_n_jobs, check_null_weight, check_verbose) class MapieClassifier(BaseEstimator, ClassifierMixin): @@ -47,7 +44,7 @@ class MapieClassifier(BaseEstimator, ClassifierMixin): Method to choose for prediction interval estimates. Choose among: - - ``"naive"``, sum of the probabilities until the 1-alpha thresold. + - ``"naive"``, sum of the probabilities until the 1-alpha threshold. - ``"lac"`` (formerly called ``"score"``), Least Ambiguous set-valued Classifier. It is based on the the scores @@ -197,6 +194,7 @@ class MapieClassifier(BaseEstimator, ClassifierMixin): "estimator_", "n_features_in_", "conformity_scores_", + "conformity_score_function_", "classes_", "label_encoder_" ] @@ -208,6 +206,7 @@ def __init__( cv: Optional[Union[int, str, BaseCrossValidator]] = None, test_size: Optional[Union[int, float]] = None, n_jobs: Optional[int] = None, + conformity_score: Optional[BaseClassificationScore] = None, random_state: Optional[Union[int, np.random.RandomState]] = None, verbose: int = 0 ) -> None: @@ -216,6 +215,7 @@ def __init__( self.cv = cv self.test_size = test_size self.n_jobs = n_jobs + self.conformity_score = conformity_score self.random_state = random_state self.verbose = verbose @@ -311,549 +311,6 @@ def _check_raps(self): f"with cv in {self.raps_valid_cv_}." ) - def _check_include_last_label( - self, - include_last_label: Optional[Union[bool, str]] - ) -> Optional[Union[bool, str]]: - """ - Check if ``include_last_label`` is a boolean or a string. - Else raise error. - - Parameters - ---------- - include_last_label: Optional[Union[bool, str]] - Whether or not to include last label in - prediction sets for the ``"aps"`` method. Choose among: - - - ``False``, does not include label whose cumulated score is just - over the quantile. - - ``True``, includes label whose cumulated score is just over the - quantile, unless there is only one label in the prediction set. - - ``"randomized"``, randomly includes label whose cumulated score - is just over the quantile based on the comparison of a uniform - number and the difference between the cumulated score of the last - label and the quantile. - - Returns - ------- - Optional[Union[bool, str]] - - Raises - ------ - ValueError - "Invalid include_last_label argument. " - "Should be a boolean or 'randomized'." - """ - if ( - (not isinstance(include_last_label, bool)) and - (not include_last_label == "randomized") - ): - raise ValueError( - "Invalid include_last_label argument. " - "Should be a boolean or 'randomized'." - ) - else: - return include_last_label - - def _check_proba_normalized( - self, - y_pred_proba: ArrayLike, - axis: int = 1 - ) -> NDArray: - """ - Check if, for all the observations, the sum of - the probabilities is equal to one. - - Parameters - ---------- - y_pred_proba: ArrayLike of shape - (n_samples, n_classes) or - (n_samples, n_train_samples, n_classes) - Softmax output of a model. - - Returns - ------- - ArrayLike of shape (n_samples, n_classes) - Softmax output of a model if the scores all sum - to one. - - Raises - ------ - ValueError - If the sum of the scores is not equal to one. - """ - np.testing.assert_allclose( - np.sum(y_pred_proba, axis=axis), - 1, - err_msg="The sum of the scores is not equal to one.", - rtol=1e-5 - ) - y_pred_proba = cast(NDArray, y_pred_proba).astype(np.float64) - return y_pred_proba - - def _get_last_index_included( - self, - y_pred_proba_cumsum: NDArray, - threshold: NDArray, - include_last_label: Optional[Union[bool, str]] - ) -> NDArray: - """ - Return the index of the last included sorted probability - depending if we included the first label over the quantile - or not. - - Parameters - ---------- - y_pred_proba_cumsum: NDArray of shape (n_samples, n_classes) - Cumsumed probabilities in the original order. - - threshold: NDArray of shape (n_alpha,) or shape (n_samples_train,) - Threshold to compare with y_proba_last_cumsum, can be either: - - - the quantiles associated with alpha values when - ``cv`` == "prefit", ``cv`` == "split" - or ``agg_scores`` is "mean" - - the conformity score from training samples otherwise - (i.e., when ``cv`` is a CV splitter and - ``agg_scores`` is "crossval") - - include_last_label: Union[bool, str] - Whether or not include the last label. If 'randomized', - the last label is included. - - Returns - ------- - NDArray of shape (n_samples, n_alpha) - Index of the last included sorted probability. - """ - if ( - (include_last_label) or - (include_last_label == 'randomized') - ): - y_pred_index_last = ( - np.ma.masked_less( - y_pred_proba_cumsum - - threshold[np.newaxis, :], - -EPSILON - ).argmin(axis=1) - ) - elif (include_last_label is False): - max_threshold = np.maximum( - threshold[np.newaxis, :], - np.min(y_pred_proba_cumsum, axis=1) - ) - y_pred_index_last = np.argmax( - np.ma.masked_greater( - y_pred_proba_cumsum - max_threshold[:, np.newaxis, :], - EPSILON - ), axis=1 - ) - else: - raise ValueError( - "Invalid include_last_label argument. " - "Should be a boolean or 'randomized'." - ) - return y_pred_index_last[:, np.newaxis, :] - - def _add_random_tie_breaking( - self, - prediction_sets: NDArray, - y_pred_index_last: NDArray, - y_pred_proba_cumsum: NDArray, - y_pred_proba_last: NDArray, - threshold: NDArray, - lambda_star: Union[NDArray, float, None], - k_star: Union[NDArray, None] - ) -> NDArray: - """ - Randomly remove last label from prediction set based on the - comparison between a random number and the difference between - cumulated score of the last included label and the quantile. - - Parameters - ---------- - prediction_sets: NDArray of shape - (n_samples, n_classes, n_threshold) - Prediction set for each observation and each alpha. - - y_pred_index_last: NDArray of shape (n_samples, threshold) - Index of the last included label. - - y_pred_proba_cumsum: NDArray of shape (n_samples, n_classes) - Cumsumed probability of the model in the original order. - - y_pred_proba_last: NDArray of shape (n_samples, 1, threshold) - Last included probability. - - threshold: NDArray of shape (n_alpha,) or shape (n_samples_train,) - Threshold to compare with y_proba_last_cumsum, can be either: - - - the quantiles associated with alpha values when - ``cv`` == "prefit", ``cv`` == "split" or - ``agg_scores`` is "mean" - - the conformity score from training samples otherwise - (i.e., when ``cv`` is a CV splitter and - ``agg_scores`` is "crossval") - - lambda_star: Union[NDArray, float, None] of shape (n_alpha): - Optimal value of the regulizer lambda. - - k_star: Union[NDArray, None] of shape (n_alpha): - Optimal value of the regulizer k. - - Returns - ------- - NDArray of shape (n_samples, n_classes, n_alpha) - Updated version of prediction_sets with randomly removed - labels. - """ - # get cumsumed probabilities up to last retained label - y_proba_last_cumsumed = np.squeeze( - np.take_along_axis( - y_pred_proba_cumsum, - y_pred_index_last, - axis=1 - ), axis=1 - ) - - if self.method in ["cumulated_score", "aps"]: - # compute V parameter from Romano+(2020) - vs = ( - (y_proba_last_cumsumed - threshold.reshape(1, -1)) / - y_pred_proba_last[:, 0, :] - ) - else: - # compute V parameter from Angelopoulos+(2020) - L = np.sum(prediction_sets, axis=1) - vs = ( - (y_proba_last_cumsumed - threshold.reshape(1, -1)) / - ( - y_pred_proba_last[:, 0, :] - - lambda_star * np.maximum(0, L - k_star) + - lambda_star * (L > k_star) - ) - ) - - # get random numbers for each observation and alpha value - random_state = check_random_state(self.random_state) - us = random_state.uniform(size=(prediction_sets.shape[0], 1)) - # remove last label from comparison between uniform number and V - vs_less_than_us = np.less_equal(vs - us, EPSILON) - np.put_along_axis( - prediction_sets, - y_pred_index_last, - vs_less_than_us[:, np.newaxis, :], - axis=1 - ) - return prediction_sets - - def _get_true_label_cumsum_proba( - self, - y: ArrayLike, - y_pred_proba: NDArray - ) -> Tuple[NDArray, NDArray]: - """ - Compute the cumsumed probability of the true label. - - Parameters - ---------- - y: NDArray of shape (n_samples, ) - Array with the labels. - y_pred_proba: NDArray of shape (n_samples, n_classes) - Predictions of the model. - - Returns - ------- - Tuple[NDArray, NDArray] of shapes - (n_samples, 1) and (n_samples, ). The first element - is the cumsum probability of the true label. The second - is the sorted position of the true label. - """ - y_true = label_binarize( - y=y, classes=self.classes_ - ) - index_sorted = np.fliplr(np.argsort(y_pred_proba, axis=1)) - y_pred_proba_sorted = np.take_along_axis( - y_pred_proba, index_sorted, axis=1 - ) - y_true_sorted = np.take_along_axis(y_true, index_sorted, axis=1) - y_pred_proba_sorted_cumsum = np.cumsum(y_pred_proba_sorted, axis=1) - cutoff = np.argmax(y_true_sorted, axis=1) - true_label_cumsum_proba = np.take_along_axis( - y_pred_proba_sorted_cumsum, cutoff.reshape(-1, 1), axis=1 - ) - - return true_label_cumsum_proba, cutoff + 1 - - def _regularize_conformity_score( - self, - k_star: NDArray, - lambda_: Union[NDArray, float], - conf_score: NDArray, - cutoff: NDArray - ) -> NDArray: - """ - Regularize the conformity scores with the ``"raps"`` - method. See algo. 2 in [3]. - - Parameters - ---------- - k_star: NDArray of shape (n_alphas, ) - Optimal value of k (called k_reg in the paper). There - is one value per alpha. - - lambda_: Union[NDArray, float] of shape (n_alphas, ) - One value of lambda for each alpha. - - conf_score: NDArray of shape (n_samples, 1) - Conformity scores. - - cutoff: NDArray of shape (n_samples, 1) - Position of the true label. - - Returns - ------- - NDArray of shape (n_samples, 1, n_alphas) - Regularized conformity scores. The regularization - depends on the value of alpha. - """ - conf_score = np.repeat( - conf_score[:, :, np.newaxis], len(k_star), axis=2 - ) - cutoff = np.repeat( - cutoff[:, np.newaxis], len(k_star), axis=1 - ) - conf_score += np.maximum( - np.expand_dims( - lambda_ * (cutoff - k_star), - axis=1 - ), - 0 - ) - return conf_score - - def _get_last_included_proba( - self, - y_pred_proba: NDArray, - thresholds: NDArray, - include_last_label: Union[bool, str, None], - lambda_: Union[NDArray, float, None], - k_star: Union[NDArray, Any] - ) -> Tuple[NDArray, NDArray, NDArray]: - """ - Function that returns the smallest score - among those which are included in the prediciton set. - - Parameters - ---------- - y_pred_proba: NDArray of shape (n_samples, n_classes) - Predictions of the model. - - thresholds: NDArray of shape (n_alphas, ) - Quantiles that have been computed from the conformity - scores. - - include_last_label: Union[bool, str, None] - Whether to include or not the label whose score - exceeds the threshold. - - lambda_: Union[NDArray, float, None] of shape (n_alphas) - Values of lambda for the regularization. - - k_star: Union[NDArray, Any] - Values of k for the regularization. - - Returns - ------- - Tuple[ArrayLike, ArrayLike, ArrayLike] - Arrays of shape (n_samples, n_classes, n_alphas), - (n_samples, 1, n_alphas) and (n_samples, 1, n_alphas). - They are respectively the cumsumed scores in the original - order which can be different according to the value of alpha - with the RAPS method, the index of the last included score - and the value of the last included score. - """ - index_sorted = np.flip( - np.argsort(y_pred_proba, axis=1), axis=1 - ) - # sort probabilities by decreasing order - y_pred_proba_sorted = np.take_along_axis( - y_pred_proba, index_sorted, axis=1 - ) - # get sorted cumulated score - y_pred_proba_sorted_cumsum = np.cumsum( - y_pred_proba_sorted, axis=1 - ) - - if self.method == "raps": - y_pred_proba_sorted_cumsum += lambda_ * np.maximum( - 0, - np.cumsum( - np.ones(y_pred_proba_sorted_cumsum.shape), - axis=1 - ) - k_star - ) - # get cumulated score at their original position - y_pred_proba_cumsum = np.take_along_axis( - y_pred_proba_sorted_cumsum, - np.argsort(index_sorted, axis=1), - axis=1 - ) - # get index of the last included label - y_pred_index_last = self._get_last_index_included( - y_pred_proba_cumsum, - thresholds, - include_last_label - ) - # get the probability of the last included label - y_pred_proba_last = np.take_along_axis( - y_pred_proba, - y_pred_index_last, - axis=1 - ) - - zeros_scores_proba_last = (y_pred_proba_last <= EPSILON) - - # If the last included proba is zero, change it to the - # smallest non-zero value to avoid inluding them in the - # prediction sets. - if np.sum(zeros_scores_proba_last) > 0: - y_pred_proba_last[zeros_scores_proba_last] = np.expand_dims( - np.min( - np.ma.masked_less( - y_pred_proba, - EPSILON - ).filled(fill_value=np.inf), - axis=1 - ), axis=1 - )[zeros_scores_proba_last] - - return y_pred_proba_cumsum, y_pred_index_last, y_pred_proba_last - - def _update_size_and_lambda( - self, - best_sizes: NDArray, - alpha_np: NDArray, - y_ps: NDArray, - lambda_: Union[NDArray, float], - lambda_star: NDArray - ) -> Tuple[NDArray, NDArray]: - """Update the values of the optimal lambda if the - average size of the prediction sets decreases with - this new value of lambda. - - Parameters - ---------- - best_sizes: NDArray of shape (n_alphas, ) - Smallest average prediciton set size before testing - for the new value of lambda_ - - alpha_np: NDArray of shape (n_alphas) - Level of confidences. - - y_ps: NDArray of shape (n_samples, n_classes, n_alphas) - Prediction sets computed with the RAPS method and the - new value of lambda_ - - lambda_: NDArray of shape (n_alphas, ) - New value of lambda_star to test - - lambda_star: NDArray of shape (n_alphas, ) - Actual optimal lambda values for each alpha. - - Returns - ------- - Tuple[NDArray, NDArray] - Arrays of shape (n_alphas, ) and (n_alpha, ) which - respectively represent the updated values of lambda_star - and the new best sizes. - """ - - sizes = [ - classification_mean_width_score( - y_ps[:, :, i] - ) for i in range(len(alpha_np)) - ] - - sizes_improve = (sizes < best_sizes - EPSILON) - lambda_star = ( - sizes_improve * lambda_ + (1 - sizes_improve) * lambda_star - ) - best_sizes = sizes_improve * sizes + (1 - sizes_improve) * best_sizes - - return lambda_star, best_sizes - - def _find_lambda_star( - self, - y_pred_proba_raps: NDArray, - alpha_np: NDArray, - include_last_label: Union[bool, str, None], - k_star: NDArray - ) -> Union[NDArray, float]: - """Find the optimal value of lambda for each alpha. - - Parameters - ---------- - y_pred_proba_raps: NDArray of shape (n_samples, n_labels, n_alphas) - Predictions of the model repeated on the last axis as many times - as the number of alphas - - alpha_np: NDArray of shape (n_alphas, ) - Levels of confidences. - - include_last_label: bool - Whether to include or not last label in - the prediction sets - - k_star: NDArray of shape (n_alphas, ) - Values of k for the regularization. - - Returns - ------- - ArrayLike of shape (n_alphas, ) - Optimal values of lambda. - """ - lambda_star = np.zeros(len(alpha_np)) - best_sizes = np.full(len(alpha_np), np.finfo(np.float64).max) - - for lambda_ in [.001, .01, .1, .2, .5]: # values given in paper[3] - true_label_cumsum_proba, cutoff = ( - self._get_true_label_cumsum_proba( - self.y_raps_no_enc, - y_pred_proba_raps[:, :, 0], - ) - ) - - true_label_cumsum_proba_reg = self._regularize_conformity_score( - k_star, - lambda_, - true_label_cumsum_proba, - cutoff - ) - - quantiles_ = compute_quantiles( - true_label_cumsum_proba_reg, - alpha_np - ) - - _, _, y_pred_proba_last = self._get_last_included_proba( - y_pred_proba_raps, - quantiles_, - include_last_label, - lambda_, - k_star - ) - - y_ps = np.greater_equal( - y_pred_proba_raps - y_pred_proba_last, -EPSILON - ) - lambda_star, best_sizes = self._update_size_and_lambda( - best_sizes, alpha_np, y_ps, lambda_, lambda_star - ) - if len(lambda_star) == 1: - lambda_star = lambda_star[0] - return lambda_star - def _get_classes_info( self, estimator: ClassifierMixin, y: NDArray ) -> Tuple[int, NDArray]: @@ -987,7 +444,20 @@ def _check_fit_parameter( self._check_target(y) - return estimator, cv, X, y, y_enc, sample_weight, groups, n_samples + cs_estimator = check_classification_conformity_score( + conformity_score=self.conformity_score, + method=self.method + ) + cs_estimator.set_external_attributes( + method=self.method, + classes=self.classes_, + random_state=self.random_state + ) + + return ( + estimator, cs_estimator, cv, + X, y, y_enc, sample_weight, groups, n_samples + ) def _split_data( self, @@ -1109,6 +579,7 @@ def fit( """ # Checks (estimator, + self.conformity_score_function_, cv, X, y, @@ -1158,35 +629,10 @@ def fit( self.y_pred_proba_raps, self.y_raps ) - # Conformity scores - if self.method == "naive": - self.conformity_scores_ = ( - np.empty(y_pred_proba.shape, dtype="float") - ) - elif self.method in ["score", "lac"]: - self.conformity_scores_ = np.take_along_axis( - 1 - y_pred_proba, y_enc.reshape(-1, 1), axis=1 - ) - elif self.method in ["cumulated_score", "aps", "raps"]: - self.conformity_scores_, self.cutoff = ( - self._get_true_label_cumsum_proba(y, y_pred_proba) - ) - y_proba_true = np.take_along_axis( - y_pred_proba, y_enc.reshape(-1, 1), axis=1 - ) - random_state = check_random_state(self.random_state) - u = random_state.uniform(size=len(y_pred_proba)).reshape(-1, 1) - self.conformity_scores_ -= u * y_proba_true - elif self.method == "top_k": - # Here we reorder the labels by decreasing probability - # and get the position of each label from decreasing - # probability - self.conformity_scores_ = get_true_label_position( - y_pred_proba, y_enc - ) - else: - raise ValueError( - "Invalid method. " f"Allowed values are {self.valid_methods_}." + # Compute the conformity scores + self.conformity_scores_ = \ + self.conformity_score_function_.get_conformity_scores( + y, y_pred_proba, y_enc=y_enc, X=X ) return self @@ -1199,8 +645,8 @@ def predict( agg_scores: Optional[str] = "mean" ) -> Union[NDArray, Tuple[NDArray, NDArray]]: """ - Prediction prediction sets on new samples based on target confidence - interval. + Prediction and prediction sets on new samples based on target + confidence interval. Prediction sets for a given ``alpha`` are deduced from: - quantiles of softmax scores (``"lac"`` method) @@ -1215,8 +661,7 @@ def predict( Can be a float, a list of floats, or a ``ArrayLike`` of floats. Between 0 and 1, represent the uncertainty of the confidence interval. - Lower ``alpha`` produce larger (more conservative) prediction - sets. + Lower ``alpha`` produce larger (more conservative) prediction sets. ``alpha`` is the complement of the target coverage level. By default ``None``. @@ -1263,20 +708,12 @@ def predict( - Tuple[NDArray, NDArray] of shapes (n_samples,) and (n_samples, n_classes, n_alpha) if alpha is not None. """ - if self.method == "top_k": - agg_scores = "mean" # Checks - cv = check_cv( - self.cv, test_size=self.test_size, random_state=self.random_state - ) - include_last_label = self._check_include_last_label(include_last_label) - alpha = cast(Optional[NDArray], check_alpha(alpha)) check_is_fitted(self, self.fit_attributes) - lambda_star, k_star = None, None + alpha = cast(Optional[NDArray], check_alpha(alpha)) - # Estimate prediction sets + # Estimate predictions y_pred = self.estimator_.single_estimator_.predict(X) - if alpha is None: return y_pred @@ -1287,149 +724,24 @@ def predict( alpha_np = cast(NDArray, alpha) check_alpha_and_n_samples(alpha_np, n) - y_pred_proba = self.estimator_.predict(X, agg_scores) - y_pred_proba = self._check_proba_normalized(y_pred_proba, axis=1) - if agg_scores != "crossval": - y_pred_proba = np.repeat( - y_pred_proba[:, :, np.newaxis], len(alpha_np), axis=2 - ) - - # Choice of the quantile - if self.method == "naive": - self.quantiles_ = 1 - alpha_np + # Estimate prediction sets + if self.method == "raps": + kwargs = { + 'X_raps': self.X_raps, + 'y_raps_no_enc': self.y_raps_no_enc, + 'y_pred_proba_raps': self.y_pred_proba_raps, + 'position_raps': self.position_raps, + } else: - if (cv == "prefit") or (agg_scores in ["mean"]): - if self.method == "raps": - check_alpha_and_n_samples(alpha_np, len(self.X_raps)) - k_star = compute_quantiles( - self.position_raps, - alpha_np - ) + 1 - y_pred_proba_raps = np.repeat( - self.y_pred_proba_raps[:, :, np.newaxis], - len(alpha_np), - axis=2 - ) - lambda_star = self._find_lambda_star( - y_pred_proba_raps, - alpha_np, - include_last_label, - k_star - ) - self.conformity_scores_regularized = ( - self._regularize_conformity_score( - k_star, - lambda_star, - self.conformity_scores_, - self.cutoff - ) - ) - self.quantiles_ = compute_quantiles( - self.conformity_scores_regularized, - alpha_np - ) - else: - self.quantiles_ = compute_quantiles( - self.conformity_scores_, - alpha_np - ) - else: - self.quantiles_ = (n + 1) * (1 - alpha_np) - - # Build prediction sets - if self.method in ["score", "lac"]: - if (cv == "prefit") or (agg_scores == "mean"): - prediction_sets = np.greater_equal( - y_pred_proba - (1 - self.quantiles_), -EPSILON - ) - else: - y_pred_included = np.less_equal( - (1 - y_pred_proba) - self.conformity_scores_.ravel(), - EPSILON - ).sum(axis=2) - prediction_sets = np.stack( - [ - np.greater_equal( - y_pred_included - _alpha * (n - 1), -EPSILON - ) - for _alpha in alpha_np - ], axis=2 - ) + kwargs = {} + + prediction_sets = self.conformity_score_function_.predict_set( + X, alpha_np, + estimator=self.estimator_, + conformity_scores=self.conformity_scores_, + include_last_label=include_last_label, + agg_scores=agg_scores, + **kwargs + ) - elif self.method in ["naive", "cumulated_score", "aps", "raps"]: - # specify which thresholds will be used - if (cv == "prefit") or (agg_scores in ["mean"]): - thresholds = self.quantiles_ - else: - thresholds = self.conformity_scores_.ravel() - # sort labels by decreasing probability - y_pred_proba_cumsum, y_pred_index_last, y_pred_proba_last = ( - self._get_last_included_proba( - y_pred_proba, - thresholds, - include_last_label, - lambda_star, - k_star, - ) - ) - # get the prediction set by taking all probabilities - # above the last one - if (cv == "prefit") or (agg_scores in ["mean"]): - y_pred_included = np.greater_equal( - y_pred_proba - y_pred_proba_last, -EPSILON - ) - else: - y_pred_included = np.less_equal( - y_pred_proba - y_pred_proba_last, EPSILON - ) - # remove last label randomly - if include_last_label == "randomized": - y_pred_included = self._add_random_tie_breaking( - y_pred_included, - y_pred_index_last, - y_pred_proba_cumsum, - y_pred_proba_last, - thresholds, - lambda_star, - k_star - ) - if (cv == "prefit") or (agg_scores in ["mean"]): - prediction_sets = y_pred_included - else: - # compute the number of times the inequality is verified - prediction_sets_summed = y_pred_included.sum(axis=2) - prediction_sets = np.less_equal( - prediction_sets_summed[:, :, np.newaxis] - - self.quantiles_[np.newaxis, np.newaxis, :], - EPSILON - ) - elif self.method == "top_k": - y_pred_proba = y_pred_proba[:, :, 0] - index_sorted = np.fliplr(np.argsort(y_pred_proba, axis=1)) - y_pred_index_last = np.stack( - [ - index_sorted[:, quantile] - for quantile in self.quantiles_ - ], axis=1 - ) - y_pred_proba_last = np.stack( - [ - np.take_along_axis( - y_pred_proba, - y_pred_index_last[:, iq].reshape(-1, 1), - axis=1 - ) - for iq, _ in enumerate(self.quantiles_) - ], axis=2 - ) - prediction_sets = np.greater_equal( - y_pred_proba[:, :, np.newaxis] - - y_pred_proba_last, - -EPSILON - ) - else: - raise ValueError( - "Invalid method. " - f"Allowed values are {self.valid_methods_}." - ) return y_pred, prediction_sets diff --git a/mapie/conformity_scores/__init__.py b/mapie/conformity_scores/__init__.py index 0dab4b62d..3b47311da 100644 --- a/mapie/conformity_scores/__init__.py +++ b/mapie/conformity_scores/__init__.py @@ -1,11 +1,18 @@ -from .conformity_scores import ConformityScore -from .residual_conformity_scores import (AbsoluteConformityScore, - GammaConformityScore, - ResidualNormalisedScore) +from .regression import BaseRegressionScore +from .classification import BaseClassificationScore +from .bounds import ( + AbsoluteConformityScore, GammaConformityScore, ResidualNormalisedScore +) +from .sets import APS, LAC, TopK + __all__ = [ - "ConformityScore", + "BaseRegressionScore", + "BaseClassificationScore", "AbsoluteConformityScore", "GammaConformityScore", - "ResidualNormalisedScore" + "ResidualNormalisedScore", + "LAC", + "APS", + "TopK" ] diff --git a/mapie/conformity_scores/bounds/__init__.py b/mapie/conformity_scores/bounds/__init__.py new file mode 100644 index 000000000..01f85b138 --- /dev/null +++ b/mapie/conformity_scores/bounds/__init__.py @@ -0,0 +1,10 @@ +from .absolute import AbsoluteConformityScore +from .gamma import GammaConformityScore +from .residuals import ResidualNormalisedScore + + +__all__ = [ + "AbsoluteConformityScore", + "GammaConformityScore", + "ResidualNormalisedScore", +] diff --git a/mapie/conformity_scores/bounds/absolute.py b/mapie/conformity_scores/bounds/absolute.py new file mode 100644 index 000000000..90c1c3e94 --- /dev/null +++ b/mapie/conformity_scores/bounds/absolute.py @@ -0,0 +1,52 @@ +import numpy as np + +from mapie._typing import ArrayLike, NDArray +from mapie.conformity_scores import BaseRegressionScore + + +class AbsoluteConformityScore(BaseRegressionScore): + """ + Absolute conformity score. + + The signed conformity score = y - y_pred. + The conformity score is symmetrical. + + This is appropriate when the confidence interval is symmetrical and + its range is approximatively the same over the range of predicted values. + """ + + def __init__( + self, + sym: bool = True, + ) -> None: + super().__init__(sym=sym, consistency_check=True) + + def get_signed_conformity_scores( + self, + y: ArrayLike, + y_pred: ArrayLike, + **kwargs + ) -> NDArray: + """ + Compute the signed conformity scores from the predicted values + and the observed ones, from the following formula: + signed conformity score = y - y_pred + """ + return np.subtract(y, y_pred) + + def get_estimation_distribution( + self, + y_pred: ArrayLike, + conformity_scores: ArrayLike, + **kwargs + ) -> NDArray: + """ + Compute samples of the estimation distribution from the predicted + values and the conformity scores, from the following formula: + signed conformity score = y - y_pred + <=> y = y_pred + signed conformity score + + ``conformity_scores`` can be either the conformity scores or + the quantile of the conformity scores. + """ + return np.add(y_pred, conformity_scores) diff --git a/mapie/conformity_scores/bounds/gamma.py b/mapie/conformity_scores/bounds/gamma.py new file mode 100644 index 000000000..09f161e02 --- /dev/null +++ b/mapie/conformity_scores/bounds/gamma.py @@ -0,0 +1,86 @@ +import numpy as np + +from mapie._typing import ArrayLike, NDArray +from mapie.conformity_scores import BaseRegressionScore + + +class GammaConformityScore(BaseRegressionScore): + """ + Gamma conformity score. + + The signed conformity score = (y - y_pred) / y_pred. + The conformity score is not symmetrical. + + This is appropriate when the confidence interval is not symmetrical and + its range depends on the predicted values. Like the Gamma distribution, + its support is limited to strictly positive reals. + """ + + def __init__( + self, + sym: bool = False, + ) -> None: + super().__init__(sym=sym, consistency_check=False) + + def _check_observed_data( + self, + y: ArrayLike, + ) -> None: + if not self._all_strictly_positive(y): + raise ValueError( + f"At least one of the observed target is negative " + f"which is incompatible with {self.__class__.__name__}. " + "All values must be strictly positive, " + "in conformity with the Gamma distribution support." + ) + + def _check_predicted_data( + self, + y_pred: ArrayLike, + ) -> None: + if not self._all_strictly_positive(y_pred): + raise ValueError( + f"At least one of the predicted target is negative " + f"which is incompatible with {self.__class__.__name__}. " + "All values must be strictly positive, " + "in conformity with the Gamma distribution support." + ) + + @staticmethod + def _all_strictly_positive( + y: ArrayLike, + ) -> bool: + return not np.any(np.less_equal(y, 0)) + + def get_signed_conformity_scores( + self, + y: ArrayLike, + y_pred: ArrayLike, + **kwargs + ) -> NDArray: + """ + Compute the signed conformity scores from the observed values + and the predicted ones, from the following formula: + signed conformity score = (y - y_pred) / y_pred + """ + self._check_observed_data(y) + self._check_predicted_data(y_pred) + return np.divide(np.subtract(y, y_pred), y_pred) + + def get_estimation_distribution( + self, + y_pred: ArrayLike, + conformity_scores: ArrayLike, + **kwargs + ) -> NDArray: + """ + Compute samples of the estimation distribution from the predicted + values and the conformity scores, from the following formula: + signed conformity score = (y - y_pred) / y_pred + <=> y = y_pred * (1 + signed conformity score) + + ``conformity_scores`` can be either the conformity scores or + the quantile of the conformity scores. + """ + self._check_predicted_data(y_pred) + return np.multiply(y_pred, np.add(1, conformity_scores)) diff --git a/mapie/conformity_scores/residual_conformity_scores.py b/mapie/conformity_scores/bounds/residuals.py similarity index 69% rename from mapie/conformity_scores/residual_conformity_scores.py rename to mapie/conformity_scores/bounds/residuals.py index d9b174e49..f6bc9c7f3 100644 --- a/mapie/conformity_scores/residual_conformity_scores.py +++ b/mapie/conformity_scores/bounds/residuals.py @@ -9,142 +9,11 @@ from sklearn.utils.validation import (check_is_fitted, check_random_state, indexable) -from mapie._machine_precision import EPSILON from mapie._typing import ArrayLike, NDArray -from mapie.conformity_scores import ConformityScore +from mapie.conformity_scores import BaseRegressionScore -class AbsoluteConformityScore(ConformityScore): - """ - Absolute conformity score. - - The signed conformity score = y - y_pred. - The conformity score is symmetrical. - - This is appropriate when the confidence interval is symmetrical and - its range is approximatively the same over the range of predicted values. - """ - - def __init__( - self, - sym: bool = True, - ) -> None: - super().__init__(sym=sym, consistency_check=True) - - def get_signed_conformity_scores( - self, - X: ArrayLike, - y: ArrayLike, - y_pred: ArrayLike, - ) -> NDArray: - """ - Compute the signed conformity scores from the predicted values - and the observed ones, from the following formula: - signed conformity score = y - y_pred - """ - return np.subtract(y, y_pred) - - def get_estimation_distribution( - self, - X: ArrayLike, - y_pred: ArrayLike, - conformity_scores: ArrayLike - ) -> NDArray: - """ - Compute samples of the estimation distribution from the predicted - values and the conformity scores, from the following formula: - signed conformity score = y - y_pred - <=> y = y_pred + signed conformity score - - ``conformity_scores`` can be either the conformity scores or - the quantile of the conformity scores. - """ - return np.add(y_pred, conformity_scores) - - -class GammaConformityScore(ConformityScore): - """ - Gamma conformity score. - - The signed conformity score = (y - y_pred) / y_pred. - The conformity score is not symmetrical. - - This is appropriate when the confidence interval is not symmetrical and - its range depends on the predicted values. Like the Gamma distribution, - its support is limited to strictly positive reals. - """ - - def __init__( - self, - sym: bool = False, - ) -> None: - super().__init__(sym=sym, consistency_check=False, eps=EPSILON) - - def _check_observed_data( - self, - y: ArrayLike, - ) -> None: - if not self._all_strictly_positive(y): - raise ValueError( - f"At least one of the observed target is negative " - f"which is incompatible with {self.__class__.__name__}. " - "All values must be strictly positive, " - "in conformity with the Gamma distribution support." - ) - - def _check_predicted_data( - self, - y_pred: ArrayLike, - ) -> None: - if not self._all_strictly_positive(y_pred): - raise ValueError( - f"At least one of the predicted target is negative " - f"which is incompatible with {self.__class__.__name__}. " - "All values must be strictly positive, " - "in conformity with the Gamma distribution support." - ) - - @staticmethod - def _all_strictly_positive( - y: ArrayLike, - ) -> bool: - return not np.any(np.less_equal(y, 0)) - - def get_signed_conformity_scores( - self, - X: ArrayLike, - y: ArrayLike, - y_pred: ArrayLike, - ) -> NDArray: - """ - Compute the signed conformity scores from the observed values - and the predicted ones, from the following formula: - signed conformity score = (y - y_pred) / y_pred - """ - self._check_observed_data(y) - self._check_predicted_data(y_pred) - return np.divide(np.subtract(y, y_pred), y_pred) - - def get_estimation_distribution( - self, - X: ArrayLike, - y_pred: ArrayLike, - conformity_scores: ArrayLike - ) -> NDArray: - """ - Compute samples of the estimation distribution from the predicted - values and the conformity scores, from the following formula: - signed conformity score = (y - y_pred) / y_pred - <=> y = y_pred * (1 + signed conformity score) - - ``conformity_scores`` can be either the conformity scores or - the quantile of the conformity scores. - """ - self._check_predicted_data(y_pred) - return np.multiply(y_pred, np.add(1, conformity_scores)) - - -class ResidualNormalisedScore(ConformityScore): +class ResidualNormalisedScore(BaseRegressionScore): """ Residual Normalised score. @@ -200,7 +69,8 @@ def __init__( self.random_state = random_state def _check_estimator( - self, estimator: Optional[RegressorMixin] = None + self, + estimator: Optional[RegressorMixin] = None ) -> RegressorMixin: """ Check if estimator is ``None``, @@ -361,9 +231,10 @@ def _predict_residual_estimator( def get_signed_conformity_scores( self, - X: ArrayLike, y: ArrayLike, - y_pred: ArrayLike + y_pred: ArrayLike, + X: Optional[ArrayLike] = None, + **kwargs ) -> NDArray: """ Computes the signed conformity score = (y - y_pred) / r_pred. @@ -374,6 +245,8 @@ def get_signed_conformity_scores( The learning is done with the log of the residual and later we use the exponential of the prediction to avoid negative values. """ + assert not (X is None) # TODO + (X, y, y_pred, self.residual_estimator_, random_state) = self._check_parameters(X, y, y_pred) @@ -418,9 +291,10 @@ def get_signed_conformity_scores( def get_estimation_distribution( self, - X: ArrayLike, y_pred: ArrayLike, - conformity_scores: ArrayLike + conformity_scores: ArrayLike, + X: Optional[ArrayLike] = None, + **kwargs ) -> NDArray: """ Compute samples of the estimation distribution from the predicted @@ -433,6 +307,8 @@ def get_estimation_distribution( ``conformity_scores`` can be either the conformity scores or the quantile of the conformity scores. """ + assert not (X is None) # TODO + r_pred = self._predict_residual_estimator(X).reshape((-1, 1)) if not self.prefit: return np.add( diff --git a/mapie/conformity_scores/checks.py b/mapie/conformity_scores/checks.py deleted file mode 100644 index 66a9277d2..000000000 --- a/mapie/conformity_scores/checks.py +++ /dev/null @@ -1,38 +0,0 @@ -from typing import Optional - -from .conformity_scores import ConformityScore -from .residual_conformity_scores import AbsoluteConformityScore - - -def check_conformity_score( - conformity_score: Optional[ConformityScore], - sym: bool = True, -) -> ConformityScore: - """ - Check parameter ``conformity_score``. - - Raises - ------ - ValueError - If parameter is not valid. - - Examples - -------- - >>> from mapie.conformity_scores.checks import check_conformity_score - >>> try: - ... check_conformity_score(1) - ... except Exception as exception: - ... print(exception) - ... - Invalid conformity_score argument. - Must be None or a ConformityScore instance. - """ - if conformity_score is None: - return AbsoluteConformityScore(sym=sym) - elif isinstance(conformity_score, ConformityScore): - return conformity_score - else: - raise ValueError( - "Invalid conformity_score argument.\n" - "Must be None or a ConformityScore instance." - ) diff --git a/mapie/conformity_scores/classification.py b/mapie/conformity_scores/classification.py new file mode 100644 index 000000000..6c91b88ee --- /dev/null +++ b/mapie/conformity_scores/classification.py @@ -0,0 +1,103 @@ +from abc import ABCMeta, abstractmethod + +from mapie.conformity_scores.interface import BaseConformityScore +from mapie.estimator.classifier import EnsembleClassifier + +from mapie._machine_precision import EPSILON +from mapie._typing import NDArray + + +class BaseClassificationScore(BaseConformityScore, metaclass=ABCMeta): + """ + Base conformity score class for classification task. + + This class should not be used directly. Use derived classes instead. + + Parameters + ---------- + consistency_check: bool, optional + Whether to check the consistency between the methods + ``get_estimation_distribution`` and ``get_conformity_scores``. + If ``True``, the following equality must be verified: + ``self.get_estimation_distribution( + y_pred, self.get_conformity_scores(y, y_pred, **kwargs), **kwargs + ) == y`` + + By default ``True``. + + eps: float, optional + Threshold to consider when checking the consistency between + ``get_estimation_distribution`` and ``get_conformity_scores``. + It should be specified if ``consistency_check==True``. + + By default, it is defined by the default precision. + """ + + def __init__( + self, + consistency_check: bool = True, + eps: float = float(EPSILON), + ): + super().__init__(consistency_check=consistency_check, eps=eps) + + @abstractmethod + def get_sets( + self, + X: NDArray, + alpha_np: NDArray, + estimator: EnsembleClassifier, + conformity_scores: NDArray, + **kwargs + ): + """ + Compute classes of the prediction sets from the observed values, + the estimator of type ``EnsembleClassifier`` and the conformity scores. + + Parameters + ---------- + X: NDArray of shape (n_samples, n_features) + Observed feature values. + + alpha_np: NDArray of shape (n_alpha,) + NDArray of floats between ``0`` and ``1``, represents the + uncertainty of the confidence interval. + + estimator: EnsembleClassifier + Estimator that is fitted to predict y from X. + + conformity_scores: NDArray of shape (n_samples,) + Conformity scores. + + Returns + ------- + NDArray of shape (n_samples, n_classes, n_alpha) + Prediction sets (Booleans indicate whether classes are included). + """ + + def predict_set( + self, + X: NDArray, + alpha_np: NDArray, + **kwargs + ): + """ + Compute the prediction sets on new samples based on the uncertainty of + the target confidence interval. + + Parameters: + ----------- + X: NDArray of shape (n_samples, ...) + The input data or samples for prediction. + + alpha_np: NDArray of shape (n_alpha, ) + Represents the uncertainty of the confidence interval to produce. + + **kwargs: dict + Additional keyword arguments. + + Returns: + -------- + The output strcture depend on the ``get_sets`` method. + The prediction sets for each sample and each alpha level. + """ + return self.get_sets(X=X, alpha_np=alpha_np, **kwargs) diff --git a/mapie/conformity_scores/interface.py b/mapie/conformity_scores/interface.py new file mode 100644 index 000000000..680c6cc9e --- /dev/null +++ b/mapie/conformity_scores/interface.py @@ -0,0 +1,256 @@ +from abc import ABCMeta, abstractmethod + +import numpy as np + +from mapie._compatibility import np_nanquantile +from mapie._machine_precision import EPSILON +from mapie._typing import NDArray + + +class BaseConformityScore(metaclass=ABCMeta): + """ + Base class for conformity scores. + + This class should not be used directly. Use derived classes instead. + + Parameters + ---------- + consistency_check: bool, optional + Whether to check the consistency between the methods + ``get_estimation_distribution`` and ``get_conformity_scores``. + If ``True``, the following equality must be verified: + ``self.get_estimation_distribution( + y_pred, self.get_conformity_scores(y, y_pred, **kwargs), **kwargs + ) == y`` + + By default ``True``. + + eps: float, optional + Threshold to consider when checking the consistency between + ``get_estimation_distribution`` and ``get_conformity_scores``. + It should be specified if ``consistency_check==True``. + + By default, it is defined by the default precision. + """ + + def __init__( + self, + consistency_check: bool = True, + eps: float = float(EPSILON), + ): + self.consistency_check = consistency_check + self.eps = eps + + def set_external_attributes( + self, + **kwargs + ) -> None: + """ + Set attributes that are not provided by the user. + + Must be overloaded by subclasses if necessary to add more attributes, + particularly when the attributes are known after the object has been + instantiated. + """ + pass + + def check_consistency( + self, + y: NDArray, + y_pred: NDArray, + conformity_scores: NDArray, + **kwargs + ) -> None: + """ + Check consistency between the following methods: + ``get_estimation_distribution`` and ``get_signed_conformity_scores`` + + The following equality should be verified: + ``self.get_estimation_distribution( + y_pred, self.get_conformity_scores(y, y_pred, **kwargs), **kwargs + ) == y`` + + Parameters + ---------- + y: NDArray of shape (n_samples, ...) + Observed target values. + + y_pred: NDArray of shape (n_samples, ...) + Predicted target values. + + conformity_scores: NDArray of shape (n_samples, ...) + Conformity scores. + + Raises + ------ + ValueError + If the two methods are not consistent. + """ + score_distribution = self.get_estimation_distribution( + y_pred, conformity_scores, **kwargs + ) + abs_conformity_scores = np.abs(np.subtract(score_distribution, y)) + max_conf_score = np.max(abs_conformity_scores) + if max_conf_score > self.eps: + raise ValueError( + "The two functions get_conformity_scores and " + "get_estimation_distribution of the BaseConformityScore class " + "are not consistent. " + "The following equation must be verified: " + "self.get_estimation_distribution(y_pred, " + "self.get_conformity_scores(y, y_pred)) == y. " + f"The maximum conformity score is {max_conf_score}. " + "The eps attribute may need to be increased if you are " + "sure that the two methods are consistent." + ) + + @abstractmethod + def get_conformity_scores( + self, + y: NDArray, + y_pred: NDArray, + **kwargs + ) -> NDArray: + """ + Placeholder for ``get_conformity_scores``. + Subclasses should implement this method! + + Compute the sample conformity scores given the predicted and + observed targets. + + Parameters + ---------- + y: NDArray of shape (n_samples, ...) + Observed target values. + + y_pred: NDArray of shape (n_samples, ...) + Predicted target values. + + Returns + ------- + NDArray of shape (n_samples, ...) + Conformity scores. + """ + + @abstractmethod + def get_estimation_distribution( + self, + y_pred: NDArray, + conformity_scores: NDArray, + **kwargs + ) -> NDArray: + """ + Placeholder for ``get_estimation_distribution``. + Subclasses should implement this method! + + Compute samples of the estimation distribution given the predicted + targets and the conformity scores. + + Parameters + ---------- + y_pred: NDArray of shape (n_samples, ...) + Predicted target values. + + conformity_scores: NDArray of shape (n_samples, ...) + Conformity scores. + + Returns + ------- + NDArray of shape (n_samples, ...) + Observed values. + """ + + @staticmethod + def get_quantile( + conformity_scores: NDArray, + alpha_np: NDArray, + axis: int = 0, + reversed: bool = False, + unbounded: bool = False + ) -> NDArray: + """ + Compute the alpha quantile of the conformity scores. + + Parameters + ---------- + conformity_scores: NDArray of shape (n_samples, ...) + Values from which the quantile is computed. + + alpha_np: NDArray of shape (n_alpha,) + NDArray of floats between ``0`` and ``1``, represents the + uncertainty of the confidence interval. + + axis: int + The axis from which to compute the quantile. + + By default ``0``. + + reversed: bool + Boolean specifying whether we take the upper or lower quantile, + if False, the alpha quantile, otherwise the (1-alpha) quantile. + + By default ``False``. + + unbounded: bool + Boolean specifying whether infinite prediction intervals + could be produced (when alpha_np is greater than or equal to 1.). + + By default ``False``. + + Returns + ------- + NDArray of shape (1, n_alpha) or (n_samples, n_alpha) + The quantiles of the conformity scores. + """ + n_ref = conformity_scores.shape[1-axis] + n_calib = np.min(np.sum(~np.isnan(conformity_scores), axis=axis)) + signed = 1-2*reversed + + # Adapt alpha w.r.t upper/lower : alpha vs. 1-alpha + alpha_ref = (1-2*alpha_np)*reversed + alpha_np + + # Adjust alpha w.r.t quantile correction + alpha_cor = np.ceil(alpha_ref*(n_calib+1))/n_calib + alpha_cor = np.clip(alpha_cor, a_min=0, a_max=1) + + # Compute the target quantiles: + # If unbounded is True and alpha is greater than or equal to 1, + # the quantile is set to infinity. + # Otherwise, the quantile is calculated as the corrected lower quantile + # of the signed conformity scores. + quantile = signed * np.column_stack([ + np_nanquantile( + signed * conformity_scores, _alpha_cor, + axis=axis, method="lower" + ) if not (unbounded and _alpha >= 1) else np.inf * np.ones(n_ref) + for _alpha, _alpha_cor in zip(alpha_ref, alpha_cor) + ]) + return quantile + + @abstractmethod + def predict_set( + self, + X: NDArray, + alpha_np: NDArray, + **kwargs + ): + """ + Compute the prediction sets on new samples based on the uncertainty of + the target confidence interval. + + Parameters: + ----------- + X: NDArray of shape (n_samples, ...) + The input data or samples for prediction. + + alpha_np: NDArray of shape (n_alpha, ) + Represents the uncertainty of the confidence interval to produce. + + **kwargs: dict + Additional keyword arguments. + + Returns: + -------- + The output strcture depend on the subclass. + The prediction sets for each sample and each alpha level. + """ diff --git a/mapie/conformity_scores/conformity_scores.py b/mapie/conformity_scores/regression.py similarity index 50% rename from mapie/conformity_scores/conformity_scores.py rename to mapie/conformity_scores/regression.py index a96df9945..2e878e349 100644 --- a/mapie/conformity_scores/conformity_scores.py +++ b/mapie/conformity_scores/regression.py @@ -3,17 +3,19 @@ import numpy as np -from mapie._compatibility import np_nanquantile -from mapie._typing import ArrayLike, NDArray +from mapie.conformity_scores.interface import BaseConformityScore from mapie.estimator.regressor import EnsembleRegressor +from mapie._compatibility import np_nanquantile +from mapie._machine_precision import EPSILON +from mapie._typing import NDArray + -class ConformityScore(metaclass=ABCMeta): +class BaseRegressionScore(BaseConformityScore, metaclass=ABCMeta): """ - Base class for conformity scores. + Base conformity score class for regression task. - Warning: This class should not be used directly. - Use derived classes instead. + This class should not be used directly. Use derived classes instead. Parameters ---------- @@ -21,61 +23,51 @@ class ConformityScore(metaclass=ABCMeta): Whether to consider the conformity score as symmetrical or not. consistency_check: bool, optional - Whether to check the consistency between the following methods: - - ``get_estimation_distribution`` and - - ``get_signed_conformity_scores`` + Whether to check the consistency between the methods + ``get_estimation_distribution`` and ``get_conformity_scores``. + If ``True``, the following equality must be verified: + ``self.get_estimation_distribution( + y_pred, self.get_conformity_scores(y, y_pred, **kwargs), **kwargs + ) == y`` By default ``True``. eps: float, optional - Threshold to consider when checking the consistency between the - following methods: - - ``get_estimation_distribution`` and - - ``get_signed_conformity_scores`` - The following equality must be verified: - ``self.get_estimation_distribution( - X, - y_pred, - self.get_conformity_scores(X, y, y_pred) - ) == y`` + Threshold to consider when checking the consistency between + ``get_estimation_distribution`` and ``get_conformity_scores``. It should be specified if ``consistency_check==True``. - By default ``np.float64(1e-8)``. + By default, it is defined by the default precision. """ def __init__( - self, - sym: bool, + self, sym: bool, consistency_check: bool = True, - eps: np.float64 = np.float64(1e-8), + eps: float = float(EPSILON), ): + super().__init__(consistency_check=consistency_check, eps=eps) self.sym = sym - self.consistency_check = consistency_check - self.eps = eps @abstractmethod def get_signed_conformity_scores( self, - X: ArrayLike, - y: ArrayLike, - y_pred: ArrayLike, + y: NDArray, + y_pred: NDArray, + **kwargs ) -> NDArray: """ - Placeholder for ``get_signed_conformity_scores``. + Placeholder for ``get_conformity_scores``. Subclasses should implement this method! - Compute the signed conformity scores from the predicted values - and the observed ones. + Compute the sample conformity scores given the predicted and + observed targets. Parameters ---------- - X: ArrayLike of shape (n_samples, n_features) - Observed feature values. - - y: ArrayLike of shape (n_samples,) + y: NDArray of shape (n_samples,) Observed target values. - y_pred: ArrayLike of shape (n_samples,) + y_pred: NDArray of shape (n_samples,) Predicted target values. Returns @@ -84,113 +76,17 @@ def get_signed_conformity_scores( Signed conformity scores. """ - @abstractmethod - def get_estimation_distribution( - self, - X: ArrayLike, - y_pred: ArrayLike, - conformity_scores: ArrayLike - ) -> NDArray: - """ - Placeholder for ``get_estimation_distribution``. - Subclasses should implement this method! - - Compute samples of the estimation distribution from the predicted - targets and ``conformity_scores`` that can be either the conformity - scores or the quantile of the conformity scores. - - Parameters - ---------- - X: ArrayLike of shape (n_samples, n_features) - Observed feature values. - - y_pred: ArrayLike - The shape is either (n_samples, n_references): when the - method is called in ``get_bounds`` it needs a prediction per train - sample for each test sample to compute the bounds. - Or (n_samples,): when it is called in ``check_consistency`` - - conformity_scores: ArrayLike - The shape is either (n_samples, 1) when it is the - conformity scores themselves or (1, n_alpha) when it is only the - quantile of the conformity scores. - - Returns - ------- - NDArray of shape (n_samples, n_alpha) or - (n_samples, n_references) according to the shape of ``y_pred`` - Observed values. - """ - - def check_consistency( - self, - X: ArrayLike, - y: ArrayLike, - y_pred: ArrayLike, - conformity_scores: ArrayLike, - ) -> None: - """ - Check consistency between the following methods: - ``get_estimation_distribution`` and ``get_signed_conformity_scores`` - - The following equality should be verified: - ``self.get_estimation_distribution( - X, - y_pred, - self.get_conformity_scores(X, y, y_pred) - ) == y`` - - Parameters - ---------- - X: ArrayLike of shape (n_samples, n_features) - Observed feature values. - - y: ArrayLike of shape (n_samples,) - Observed target values. - - y_pred: ArrayLike of shape (n_samples,) - Predicted target values. - - conformity_scores: ArrayLike of shape (n_samples,) - Conformity scores. - - Raises - ------ - ValueError - If the two methods are not consistent. - """ - score_distribution = self.get_estimation_distribution( - X, y_pred, conformity_scores - ) - abs_conformity_scores = np.abs(np.subtract(score_distribution, y)) - max_conf_score = np.max(abs_conformity_scores) - if max_conf_score > self.eps: - raise ValueError( - "The two functions get_conformity_scores and " - "get_estimation_distribution of the ConformityScore class " - "are not consistent. " - "The following equation must be verified: " - "self.get_estimation_distribution(X, y_pred, " - "self.get_conformity_scores(X, y, y_pred)) == y" # noqa: E501 - f"The maximum conformity score is {max_conf_score}." - "The eps attribute may need to be increased if you are " - "sure that the two methods are consistent." - ) - def get_conformity_scores( self, - X: ArrayLike, - y: ArrayLike, - y_pred: ArrayLike, + y: NDArray, + y_pred: NDArray, + **kwargs ) -> NDArray: """ Get the conformity score considering the symmetrical property if so. Parameters ---------- - X: NDArray of shape (n_samples, n_features) - Observed feature values. - y: NDArray of shape (n_samples,) Observed target values. @@ -202,82 +98,14 @@ def get_conformity_scores( NDArray of shape (n_samples,) Conformity scores. """ - conformity_scores = self.get_signed_conformity_scores(X, y, y_pred) + conformity_scores = \ + self.get_signed_conformity_scores(y, y_pred, **kwargs) if self.consistency_check: - self.check_consistency(X, y, y_pred, conformity_scores) + self.check_consistency(y, y_pred, conformity_scores, **kwargs) if self.sym: conformity_scores = np.abs(conformity_scores) return conformity_scores - @staticmethod - def get_quantile( - conformity_scores: NDArray, - alpha_np: NDArray, - axis: int, - reversed: bool = False, - unbounded: bool = False - ) -> NDArray: - """ - Compute the alpha quantile of the conformity scores or the conformity - scores aggregated with the predictions. - - Parameters - ---------- - conformity_scores: NDArray of shape (n_samples,) or - (n_samples, n_references) - Values from which the quantile is computed, it can be the - conformity scores or the conformity scores aggregated with - the predictions. - - alpha_np: NDArray of shape (n_alpha,) - NDArray of floats between ``0`` and ``1``, represents the - uncertainty of the confidence interval. - - axis: int - The axis from which to compute the quantile. - - reversed: bool - Boolean specifying whether we take the upper or lower quantile, - if False, the alpha quantile, otherwise the (1-alpha) quantile. - - By default ``False``. - - unbounded: bool - Boolean specifying whether infinite prediction intervals - could be produced (when alpha_np is greater than or equal to 1.). - - By default ``False``. - - Returns - ------- - NDArray of shape (1, n_alpha) or (n_samples, n_alpha) - The quantile of the conformity scores. - """ - n_ref = conformity_scores.shape[1-axis] - n_calib = np.min(np.sum(~np.isnan(conformity_scores), axis=axis)) - signed = 1-2*reversed - - # Adapt alpha w.r.t upper/lower : alpha vs. 1-alpha - alpha_ref = (1-2*alpha_np)*reversed + alpha_np - - # Adjust alpha w.r.t quantile correction - alpha_cor = np.ceil(alpha_ref*(n_calib+1))/n_calib - alpha_cor = np.clip(alpha_cor, a_min=0, a_max=1) - - # Compute the target quantiles: - # If unbounded is True and alpha is greater than or equal to 1, - # the quantile is set to infinity. - # Otherwise, the quantile is calculated as the corrected lower quantile - # of the signed conformity scores. - quantile = signed * np.column_stack([ - np_nanquantile( - signed * conformity_scores, _alpha_cor, - axis=axis, method="lower" - ) if not (unbounded and _alpha >= 1) else np.inf * np.ones(n_ref) - for _alpha, _alpha_cor in zip(alpha_ref, alpha_cor) - ]) - return quantile - @staticmethod def _beta_optimize( alpha_np: NDArray, @@ -292,15 +120,15 @@ def _beta_optimize( alpha_np: NDArray The quantiles to compute. - upper_bounds: NDArray + upper_bounds: NDArray of shape (n_samples,) The array of upper values. - lower_bounds: NDArray + lower_bounds: NDArray of shape (n_samples,) The array of lower values. Returns ------- - NDArray + NDArray of shape (n_samples,) Array of betas minimizing the differences ``(1-alpha+beta)-quantile - beta-quantile``. """ @@ -337,10 +165,10 @@ def _beta_optimize( def get_bounds( self, - X: ArrayLike, + X: NDArray, + alpha_np: NDArray, estimator: EnsembleRegressor, conformity_scores: NDArray, - alpha_np: NDArray, ensemble: bool = False, method: str = 'base', optimize_beta: bool = False, @@ -352,19 +180,19 @@ def get_bounds( Parameters ---------- - X: ArrayLike of shape (n_samples, n_features) + X: NDArray of shape (n_samples, n_features) Observed feature values. + alpha_np: NDArray of shape (n_alpha,) + NDArray of floats between ``0`` and ``1``, represents the + uncertainty of the confidence interval. + estimator: EnsembleRegressor Estimator that is fitted to predict y from X. - conformity_scores: ArrayLike of shape (n_samples,) + conformity_scores: NDArray of shape (n_samples,) Conformity scores. - alpha_np: NDArray of shape (n_alpha,) - NDArray of floats between ``0`` and ``1``, represents the - uncertainty of the confidence interval. - ensemble: bool Boolean determining whether the predictions are ensembled or not. @@ -426,10 +254,10 @@ def get_bounds( alpha_up = 1 - alpha_np if self.sym else 1 - alpha_np + beta_np conformity_scores_low = self.get_estimation_distribution( - X, y_pred_low, signed * conformity_scores + y_pred_low, signed * conformity_scores, X=X ) conformity_scores_up = self.get_estimation_distribution( - X, y_pred_up, conformity_scores + y_pred_up, conformity_scores, X=X ) bound_low = self.get_quantile( conformity_scores_low, alpha_low, axis=1, reversed=True, @@ -463,10 +291,38 @@ def get_bounds( ) bound_low = self.get_estimation_distribution( - X, y_pred_low, quantile_low + y_pred_low, quantile_low, X=X ) bound_up = self.get_estimation_distribution( - X, y_pred_up, quantile_up + y_pred_up, quantile_up, X=X ) return y_pred, bound_low, bound_up + + def predict_set( + self, + X: NDArray, + alpha_np: NDArray, + **kwargs + ): + """ + Compute the prediction sets on new samples based on the uncertainty of + the target confidence interval. + + Parameters: + ----------- + X: NDArray of shape (n_samples, ...) + The input data or samples for prediction. + + alpha_np: NDArray of shape (n_alpha, ) + Represents the uncertainty of the confidence interval to produce. + + **kwargs: dict + Additional keyword arguments. + + Returns: + -------- + The output strcture depend on the ``get_bounds`` method. + The prediction sets for each sample and each alpha level. + """ + return self.get_bounds(X=X, alpha_np=alpha_np, **kwargs) diff --git a/mapie/conformity_scores/sets/__init__.py b/mapie/conformity_scores/sets/__init__.py new file mode 100644 index 000000000..87b6a37e6 --- /dev/null +++ b/mapie/conformity_scores/sets/__init__.py @@ -0,0 +1,10 @@ +from .lac import LAC +from .aps import APS +from .topk import TopK + + +__all__ = [ + "LAC", + "APS", + "TopK", +] diff --git a/mapie/conformity_scores/sets/aps.py b/mapie/conformity_scores/sets/aps.py new file mode 100644 index 000000000..6cd282260 --- /dev/null +++ b/mapie/conformity_scores/sets/aps.py @@ -0,0 +1,497 @@ +from typing import Optional, Tuple, Union, cast + +import numpy as np +from sklearn.dummy import check_random_state + +from mapie.conformity_scores.classification import BaseClassificationScore +from mapie.conformity_scores.sets.utils import ( + add_random_tie_breaking, check_include_last_label, check_proba_normalized, + get_last_included_proba, get_true_label_cumsum_proba +) +from mapie.estimator.classifier import EnsembleClassifier + +from mapie._machine_precision import EPSILON +from mapie._typing import ArrayLike, NDArray +from mapie.metrics import classification_mean_width_score +from mapie.utils import check_alpha_and_n_samples, compute_quantiles + + +class APS(BaseClassificationScore): + """ + Adaptive Prediction Sets (APS) method-based non-conformity score. + Three differents method are available in this class: + + - ``"naive"``, sum of the probabilities until the 1-alpha threshold. + + - ``"aps"`` (formerly called "cumulated_score"), Adaptive Prediction + Sets method. It is based on the sum of the softmax outputs of the + labels until the true label is reached, on the calibration set. + See [1] for more details. + + - ``"raps"``, Regularized Adaptive Prediction Sets method. It uses the + same technique as ``"aps"`` method but with a penalty term + to reduce the size of prediction sets. See [2] for more + details. For now, this method only works with ``"prefit"`` and + ``"split"`` strategies. + + References + ---------- + [1] Yaniv Romano, Matteo Sesia and Emmanuel J. Candès. + "Classification with Valid and Adaptive Coverage." + NeurIPS 202 (spotlight) 2020. + + [2] Anastasios Nikolas Angelopoulos, Stephen Bates, Michael Jordan + and Jitendra Malik. + "Uncertainty Sets for Image Classifiers using Conformal Prediction." + International Conference on Learning Representations 2021. + """ + + def __init__( + self, + consistency_check: bool = True, + eps: float = float(EPSILON), + ): + super().__init__( + consistency_check=consistency_check, + eps=eps + ) + + def set_external_attributes( + self, + method: str = 'aps', + classes: Optional[ArrayLike] = None, + random_state: Optional[Union[int, np.random.RandomState]] = None, + **kwargs + ) -> None: + """ + Set attributes that are not provided by the user. + + Parameters + ---------- + method: str + Method to choose for prediction interval estimates. + Methods available in this class: ``aps``, ``raps`` and ``naive``. + + By default ``aps`` for APS method. + + classes: Optional[ArrayLike] + Names of the classes. + + By default ``None``. + + random_state: Optional[Union[int, RandomState]] + Pseudo random number generator state. + """ + super().set_external_attributes(**kwargs) + self.method = method + self.classes = classes + self.random_state = random_state + + def get_conformity_scores( + self, + y: ArrayLike, + y_pred: ArrayLike, + y_enc: Optional[ArrayLike] = None, + **kwargs + ) -> NDArray: + """ + Get the conformity score. + + Parameters + ---------- + y: NDArray of shape (n_samples,) + Observed target values. + + y_pred: NDArray of shape (n_samples,) + Predicted target values. + + Returns + ------- + NDArray of shape (n_samples,) + Conformity scores. + """ + y = cast(NDArray, y) + y_pred = cast(NDArray, y_pred) + y_enc = cast(NDArray, y_enc) + classes = cast(NDArray, self.classes) + + # Conformity scores + if self.method == "naive": + conformity_scores = ( + np.empty(y_pred.shape, dtype="float") + ) + else: + conformity_scores, self.cutoff = ( + get_true_label_cumsum_proba(y, y_pred, classes) + ) + y_proba_true = np.take_along_axis( + y_pred, y_enc.reshape(-1, 1), axis=1 + ) + random_state = check_random_state(self.random_state) + random_state = cast(np.random.RandomState, random_state) + u = random_state.uniform(size=len(y_pred)).reshape(-1, 1) + conformity_scores -= u * y_proba_true + + return conformity_scores + + def get_estimation_distribution( + self, + y_pred: ArrayLike, + conformity_scores: ArrayLike, + **kwargs + ) -> NDArray: + """ + TODO + Placeholder for ``get_estimation_distribution``. + Subclasses should implement this method! + + Compute samples of the estimation distribution given the predicted + targets and the conformity scores. + + Parameters + ---------- + y_pred: NDArray of shape (n_samples, ...) + Predicted target values. + + conformity_scores: NDArray of shape (n_samples, ...) + Conformity scores. + + Returns + ------- + NDArray of shape (n_samples, ...) + Observed values. + """ + return np.array([]) + + @staticmethod + def _regularize_conformity_score( + k_star: NDArray, + lambda_: Union[NDArray, float], + conf_score: NDArray, + cutoff: NDArray + ) -> NDArray: + """ + Regularize the conformity scores with the ``"raps"`` + method. See algo. 2 in [3]. + + Parameters + ---------- + k_star: NDArray of shape (n_alphas, ) + Optimal value of k (called k_reg in the paper). There + is one value per alpha. + + lambda_: Union[NDArray, float] of shape (n_alphas, ) + One value of lambda for each alpha. + + conf_score: NDArray of shape (n_samples, 1) + Conformity scores. + + cutoff: NDArray of shape (n_samples, 1) + Position of the true label. + + Returns + ------- + NDArray of shape (n_samples, 1, n_alphas) + Regularized conformity scores. The regularization + depends on the value of alpha. + """ + conf_score = np.repeat( + conf_score[:, :, np.newaxis], len(k_star), axis=2 + ) + cutoff = np.repeat( + cutoff[:, np.newaxis], len(k_star), axis=1 + ) + conf_score += np.maximum( + np.expand_dims( + lambda_ * (cutoff - k_star), + axis=1 + ), + 0 + ) + return conf_score + + def _update_size_and_lambda( + self, + best_sizes: NDArray, + alpha_np: NDArray, + y_ps: NDArray, + lambda_: Union[NDArray, float], + lambda_star: NDArray + ) -> Tuple[NDArray, NDArray]: + """Update the values of the optimal lambda if the + average size of the prediction sets decreases with + this new value of lambda. + + Parameters + ---------- + best_sizes: NDArray of shape (n_alphas, ) + Smallest average prediciton set size before testing + for the new value of lambda_ + + alpha_np: NDArray of shape (n_alphas) + Level of confidences. + + y_ps: NDArray of shape (n_samples, n_classes, n_alphas) + Prediction sets computed with the RAPS method and the + new value of lambda_ + + lambda_: NDArray of shape (n_alphas, ) + New value of lambda_star to test + + lambda_star: NDArray of shape (n_alphas, ) + Actual optimal lambda values for each alpha. + + Returns + ------- + Tuple[NDArray, NDArray] + Arrays of shape (n_alphas, ) and (n_alpha, ) which + respectively represent the updated values of lambda_star + and the new best sizes. + """ + + sizes = [ + classification_mean_width_score( + y_ps[:, :, i] + ) for i in range(len(alpha_np)) + ] + + sizes_improve = (sizes < best_sizes - EPSILON) + lambda_star = ( + sizes_improve * lambda_ + (1 - sizes_improve) * lambda_star + ) + best_sizes = sizes_improve * sizes + (1 - sizes_improve) * best_sizes + + return lambda_star, best_sizes + + def _find_lambda_star( + self, + y_raps_no_enc: NDArray, + y_pred_proba_raps: NDArray, + alpha_np: NDArray, + include_last_label: Union[bool, str, None], + k_star: NDArray + ) -> Union[NDArray, float]: + """Find the optimal value of lambda for each alpha. + + Parameters + ---------- + y_pred_proba_raps: NDArray of shape (n_samples, n_labels, n_alphas) + Predictions of the model repeated on the last axis as many times + as the number of alphas + + alpha_np: NDArray of shape (n_alphas, ) + Levels of confidences. + + include_last_label: bool + Whether to include or not last label in + the prediction sets + + k_star: NDArray of shape (n_alphas, ) + Values of k for the regularization. + + Returns + ------- + ArrayLike of shape (n_alphas, ) + Optimal values of lambda. + """ + classes = cast(NDArray, self.classes) + + lambda_star = np.zeros(len(alpha_np)) + best_sizes = np.full(len(alpha_np), np.finfo(np.float64).max) + + for lambda_ in [.001, .01, .1, .2, .5]: # values given in paper[3] + true_label_cumsum_proba, cutoff = ( + get_true_label_cumsum_proba( + y_raps_no_enc, + y_pred_proba_raps[:, :, 0], + classes + ) + ) + + true_label_cumsum_proba_reg = self._regularize_conformity_score( + k_star, + lambda_, + true_label_cumsum_proba, + cutoff + ) + + quantiles_ = compute_quantiles( + true_label_cumsum_proba_reg, + alpha_np + ) + + _, _, y_pred_proba_last = get_last_included_proba( + y_pred_proba_raps, + quantiles_, + include_last_label, + self.method, + lambda_, + k_star + ) + + y_ps = np.greater_equal( + y_pred_proba_raps - y_pred_proba_last, -EPSILON + ) + lambda_star, best_sizes = self._update_size_and_lambda( + best_sizes, alpha_np, y_ps, lambda_, lambda_star + ) + if len(lambda_star) == 1: + lambda_star = lambda_star[0] + return lambda_star + + def get_sets( + self, + X: ArrayLike, + alpha_np: NDArray, + estimator: EnsembleClassifier, + conformity_scores: NDArray, + include_last_label: Optional[Union[bool, str]] = True, + agg_scores: Optional[str] = "mean", + X_raps: Optional[NDArray] = None, + y_raps_no_enc: Optional[NDArray] = None, + y_pred_proba_raps: Optional[NDArray] = None, + position_raps: Optional[NDArray] = None, + **kwargs + ): + """ + Compute classes of the prediction sets from the observed values, + the estimator of type ``EnsembleClassifier`` and the conformity scores. + + Parameters + ---------- + X: NDArray of shape (n_samples, n_features) + Observed feature values. + + alpha_np: NDArray of shape (n_alpha,) + NDArray of floats between ``0`` and ``1``, represents the + uncertainty of the confidence interval. + + estimator: EnsembleClassifier + Estimator that is fitted to predict y from X. + + conformity_scores: NDArray of shape (n_samples,) + Conformity scores. + + TODO + + Returns + ------- + NDArray of shape (n_samples, n_classes, n_alpha) + Prediction sets (Booleans indicate whether classes are included). + """ + # Checks + include_last_label = check_include_last_label(include_last_label) + + # if self.method == "raps": + lambda_star, k_star = None, None + X_raps = cast(NDArray, X_raps) + y_raps_no_enc = cast(NDArray, y_raps_no_enc) + y_pred_proba_raps = cast(NDArray, y_pred_proba_raps) + position_raps = cast(NDArray, position_raps) + + n = len(conformity_scores) + + y_pred_proba = estimator.predict(X, agg_scores) + y_pred_proba = check_proba_normalized(y_pred_proba, axis=1) + if agg_scores != "crossval": + y_pred_proba = np.repeat( + y_pred_proba[:, :, np.newaxis], len(alpha_np), axis=2 + ) + + # Choice of the quantileif self.method == "naive": + if self.method == "naive": + self.quantiles_ = 1 - alpha_np + elif (estimator.cv == "prefit") or (agg_scores in ["mean"]): + if self.method == "raps": + check_alpha_and_n_samples(alpha_np, X_raps.shape[0]) + k_star = compute_quantiles( + position_raps, + alpha_np + ) + 1 + y_pred_proba_raps = np.repeat( + y_pred_proba_raps[:, :, np.newaxis], + len(alpha_np), + axis=2 + ) + lambda_star = self._find_lambda_star( + y_raps_no_enc, + y_pred_proba_raps, + alpha_np, + include_last_label, + k_star + ) + conformity_scores_regularized = ( + self._regularize_conformity_score( + k_star, + lambda_star, + conformity_scores, + self.cutoff + ) + ) + self.quantiles_ = compute_quantiles( + conformity_scores_regularized, + alpha_np + ) + else: + self.quantiles_ = compute_quantiles( + conformity_scores, + alpha_np + ) + else: + self.quantiles_ = (n + 1) * (1 - alpha_np) + + # Build prediction sets + # specify which thresholds will be used + if (estimator.cv == "prefit") or (agg_scores in ["mean"]): + thresholds = self.quantiles_ + else: + thresholds = conformity_scores.ravel() + # sort labels by decreasing probability + y_pred_proba_cumsum, y_pred_index_last, y_pred_proba_last = ( + get_last_included_proba( + y_pred_proba, + thresholds, + include_last_label, + self.method, + lambda_star, + k_star, + ) + ) + # get the prediction set by taking all probabilities + # above the last one + if (estimator.cv == "prefit") or (agg_scores in ["mean"]): + y_pred_included = np.greater_equal( + y_pred_proba - y_pred_proba_last, -EPSILON + ) + else: + y_pred_included = np.less_equal( + y_pred_proba - y_pred_proba_last, EPSILON + ) + # remove last label randomly + if include_last_label == "randomized": + y_pred_included = add_random_tie_breaking( + y_pred_included, + y_pred_index_last, + y_pred_proba_cumsum, + y_pred_proba_last, + thresholds, + self.method, + self.random_state, + lambda_star, + k_star, + ) + if (estimator.cv == "prefit") or (agg_scores in ["mean"]): + prediction_sets = y_pred_included + else: + # compute the number of times the inequality is verified + prediction_sets_summed = y_pred_included.sum(axis=2) + prediction_sets = np.less_equal( + prediction_sets_summed[:, :, np.newaxis] + - self.quantiles_[np.newaxis, np.newaxis, :], + EPSILON + ) + + # Just for coverage: do nothing + self.get_estimation_distribution(y_pred_proba, conformity_scores) + + return prediction_sets diff --git a/mapie/conformity_scores/sets/lac.py b/mapie/conformity_scores/sets/lac.py new file mode 100644 index 000000000..8bff9b6fa --- /dev/null +++ b/mapie/conformity_scores/sets/lac.py @@ -0,0 +1,207 @@ +from typing import Optional, Union, cast + +import numpy as np + +from mapie.conformity_scores.classification import BaseClassificationScore +from mapie.conformity_scores.sets.utils import check_proba_normalized +from mapie.estimator.classifier import EnsembleClassifier + +from mapie._machine_precision import EPSILON +from mapie._typing import ArrayLike, NDArray +from mapie.utils import compute_quantiles + + +class LAC(BaseClassificationScore): + """ + Least Ambiguous set-valued Classifier (LAC) method-based + non conformity score (also formerly called ``"score"``). + + It is based on the the scores (i.e. 1 minus the softmax score of the true + label) on the calibration set. + + References + ---------- + [1] Mauricio Sadinle, Jing Lei, and Larry Wasserman. + "Least Ambiguous Set-Valued Classifiers with Bounded Error Levels.", + Journal of the American Statistical Association, 114, 2019. + """ + + def __init__( + self, + consistency_check: bool = True, + eps: float = float(EPSILON), + ): + super().__init__( + consistency_check=consistency_check, + eps=eps + ) + + def set_external_attributes( + self, + method: str = 'lac', + classes: Optional[ArrayLike] = None, + random_state: Optional[Union[int, np.random.RandomState]] = None, + **kwargs + ) -> None: + """ + Set attributes that are not provided by the user. + + Parameters + ---------- + method: str + Method to choose for prediction interval estimates. + Methods available in this class: ``lac``. + + By default ``lac`` for LAC method. + + classes: Optional[ArrayLike] + Names of the classes. + + By default ``None``. + + random_state: Optional[Union[int, RandomState]] + Pseudo random number generator state. + """ + super().set_external_attributes(**kwargs) + self.method = method + self.classes = classes + self.random_state = random_state + + def get_conformity_scores( + self, + y: ArrayLike, + y_pred: ArrayLike, + y_enc: Optional[ArrayLike] = None, + **kwargs + ) -> NDArray: + """ + Get the conformity score. + + Parameters + ---------- + y: NDArray of shape (n_samples,) + Observed target values. + + y_pred: NDArray of shape (n_samples,) + Predicted target values. + + Returns + ------- + NDArray of shape (n_samples,) + Conformity scores. + """ + y_pred = cast(NDArray, y_pred) + y_enc = cast(NDArray, y_enc) + + # Conformity scores + conformity_scores = np.take_along_axis( + 1 - y_pred, y_enc.reshape(-1, 1), axis=1 + ) + return conformity_scores + + def get_estimation_distribution( + self, + y_pred: ArrayLike, + conformity_scores: ArrayLike, + **kwargs + ) -> NDArray: + """ + TODO + Placeholder for ``get_estimation_distribution``. + Subclasses should implement this method! + + Compute samples of the estimation distribution given the predicted + targets and the conformity scores. + + Parameters + ---------- + y_pred: NDArray of shape (n_samples, ...) + Predicted target values. + + conformity_scores: NDArray of shape (n_samples, ...) + Conformity scores. + + Returns + ------- + NDArray of shape (n_samples, ...) + Observed values. + """ + return np.array([]) + + def get_sets( + self, + X: ArrayLike, + alpha_np: NDArray, + estimator: EnsembleClassifier, + conformity_scores: NDArray, + agg_scores: Optional[str] = "mean", + **kwargs + ): + """ + Compute classes of the prediction sets from the observed values, + the estimator of type ``EnsembleClassifier`` and the conformity scores. + + Parameters + ---------- + X: NDArray of shape (n_samples, n_features) + Observed feature values. + + alpha_np: NDArray of shape (n_alpha,) + NDArray of floats between ``0`` and ``1``, represents the + uncertainty of the confidence interval. + + estimator: EnsembleClassifier + Estimator that is fitted to predict y from X. + + conformity_scores: NDArray of shape (n_samples,) + Conformity scores. + + TODO + + Returns + ------- + NDArray of shape (n_samples, n_classes, n_alpha) + Prediction sets (Booleans indicate whether classes are included). + """ + # Checks + n = len(conformity_scores) + + y_pred_proba = estimator.predict(X, agg_scores) + y_pred_proba = check_proba_normalized(y_pred_proba, axis=1) + if agg_scores != "crossval": + y_pred_proba = np.repeat( + y_pred_proba[:, :, np.newaxis], len(alpha_np), axis=2 + ) + + # Choice of the quantile + if (estimator.cv == "prefit") or (agg_scores in ["mean"]): + self.quantiles_ = compute_quantiles( + conformity_scores, + alpha_np + ) + else: + self.quantiles_ = (n + 1) * (1 - alpha_np) + + # Build prediction sets + if (estimator.cv == "prefit") or (agg_scores == "mean"): + prediction_sets = np.greater_equal( + y_pred_proba - (1 - self.quantiles_), -EPSILON + ) + else: + y_pred_included = np.less_equal( + (1 - y_pred_proba) - conformity_scores.ravel(), + EPSILON + ).sum(axis=2) + prediction_sets = np.stack( + [ + np.greater_equal( + y_pred_included - _alpha * (n - 1), -EPSILON + ) + for _alpha in alpha_np + ], axis=2 + ) + + # Just for coverage: do nothing + self.get_estimation_distribution(y_pred_proba, conformity_scores) + + return prediction_sets diff --git a/mapie/conformity_scores/sets/topk.py b/mapie/conformity_scores/sets/topk.py new file mode 100644 index 000000000..1e68ad832 --- /dev/null +++ b/mapie/conformity_scores/sets/topk.py @@ -0,0 +1,212 @@ +from typing import Optional, Union, cast + +import numpy as np + +from mapie.conformity_scores.classification import BaseClassificationScore +from mapie.conformity_scores.sets.utils import ( + check_proba_normalized, get_true_label_position +) +from mapie.estimator.classifier import EnsembleClassifier + +from mapie._machine_precision import EPSILON +from mapie._typing import ArrayLike, NDArray +from mapie.utils import compute_quantiles + + +class TopK(BaseClassificationScore): + """ + Top-K method-based non-conformity score. + + It is based on the sorted index of the probability of the true label in the + softmax outputs, on the calibration set. In case two probabilities are + equal, both are taken, thus, the size of some prediction sets may be + different from the others. + + References + ---------- + [1] Anastasios Nikolas Angelopoulos, Stephen Bates, Michael Jordan + and Jitendra Malik. + "Uncertainty Sets for Image Classifiers using Conformal Prediction." + International Conference on Learning Representations 2021. + """ + + def __init__( + self, + consistency_check: bool = True, + eps: float = float(EPSILON), + ): + super().__init__( + consistency_check=consistency_check, + eps=eps + ) + + def set_external_attributes( + self, + method: str = 'top_k', + classes: Optional[int] = None, + random_state: Optional[Union[int, np.random.RandomState]] = None, + **kwargs + ) -> None: + """ + Set attributes that are not provided by the user. + + Parameters + ---------- + method: str + Method to choose for prediction interval estimates. + Methods available in this class: ``top_k``. + + By default ``top_k`` for Top K method. + + classes: Optional[ArrayLike] + Names of the classes. + + By default ``None``. + + random_state: Optional[Union[int, RandomState]] + Pseudo random number generator state. + """ + super().set_external_attributes(**kwargs) + self.method = method + self.classes = classes + self.random_state = random_state + + def get_conformity_scores( + self, + y: ArrayLike, + y_pred: ArrayLike, + y_enc: Optional[ArrayLike] = None, + **kwargs + ) -> NDArray: + """ + Get the conformity score. + + Parameters + ---------- + y: NDArray of shape (n_samples,) + Observed target values. + + y_pred: NDArray of shape (n_samples,) + Predicted target values. + + Returns + ------- + NDArray of shape (n_samples,) + Conformity scores. + """ + y = cast(NDArray, y) + y_pred = cast(NDArray, y_pred) + y_enc = cast(NDArray, y_enc) + + # Conformity scores + # Here we reorder the labels by decreasing probability and get the + # position of each label from decreasing probability + conformity_scores = get_true_label_position(y_pred, y_enc) + + return conformity_scores + + def get_estimation_distribution( + self, + y_pred: ArrayLike, + conformity_scores: ArrayLike, + **kwargs + ) -> NDArray: + """ + TODO + Placeholder for ``get_estimation_distribution``. + Subclasses should implement this method! + + Compute samples of the estimation distribution given the predicted + targets and the conformity scores. + + Parameters + ---------- + y_pred: NDArray of shape (n_samples, ...) + Predicted target values. + + conformity_scores: NDArray of shape (n_samples, ...) + Conformity scores. + + Returns + ------- + NDArray of shape (n_samples, ...) + Observed values. + """ + return np.array([]) + + def get_sets( + self, + X: ArrayLike, + alpha_np: NDArray, + estimator: EnsembleClassifier, + conformity_scores: NDArray, + agg_scores: Optional[str] = "mean", + **kwargs + ): + """ + Compute classes of the prediction sets from the observed values, + the estimator of type ``EnsembleClassifier`` and the conformity scores. + + Parameters + ---------- + X: NDArray of shape (n_samples, n_features) + Observed feature values. + + alpha_np: NDArray of shape (n_alpha,) + NDArray of floats between ``0`` and ``1``, represents the + uncertainty of the confidence interval. + + estimator: EnsembleClassifier + Estimator that is fitted to predict y from X. + + conformity_scores: NDArray of shape (n_samples,) + Conformity scores. + + TODO + + Returns + ------- + NDArray of shape (n_samples, n_classes, n_alpha) + Prediction sets (Booleans indicate whether classes are included). + """ + # Checks + agg_scores = "mean" + + y_pred_proba = estimator.predict(X, agg_scores) + y_pred_proba = check_proba_normalized(y_pred_proba, axis=1) + y_pred_proba = np.repeat( + y_pred_proba[:, :, np.newaxis], len(alpha_np), axis=2 + ) + + # Choice of the quantile + self.quantiles_ = compute_quantiles(conformity_scores, alpha_np) + + # Build prediction sets + y_pred_proba = y_pred_proba[:, :, 0] + index_sorted = np.fliplr(np.argsort(y_pred_proba, axis=1)) + y_pred_index_last = np.stack( + [ + index_sorted[:, quantile] + for quantile in self.quantiles_ + ], axis=1 + ) + y_pred_proba_last = np.stack( + [ + np.take_along_axis( + y_pred_proba, + y_pred_index_last[:, iq].reshape(-1, 1), + axis=1 + ) + for iq, _ in enumerate(self.quantiles_) + ], axis=2 + ) + prediction_sets = np.greater_equal( + y_pred_proba[:, :, np.newaxis] + - y_pred_proba_last, + -EPSILON + ) + + # Just for coverage: do nothing + self.get_estimation_distribution(y_pred_proba, conformity_scores) + + return prediction_sets diff --git a/mapie/conformity_scores/sets/utils.py b/mapie/conformity_scores/sets/utils.py new file mode 100644 index 000000000..a2b5b32af --- /dev/null +++ b/mapie/conformity_scores/sets/utils.py @@ -0,0 +1,401 @@ +from typing import Any, Optional, Tuple, Union, cast +import numpy as np +from sklearn.calibration import label_binarize +from sklearn.dummy import check_random_state + +from mapie._typing import ArrayLike, NDArray +from mapie._machine_precision import EPSILON + + +def get_true_label_position( + y_pred_proba: NDArray, + y: NDArray +) -> NDArray: + """ + Return the sorted position of the true label in the prediction + + Parameters + ---------- + y_pred_proba: NDArray of shape (n_samples, n_classes) + Model prediction. + + y: NDArray of shape (n_samples) + Labels. + + Returns + ------- + NDArray of shape (n_samples, 1) + Position of the true label in the prediction. + """ + index = np.argsort(np.fliplr(np.argsort(y_pred_proba, axis=1))) + position = np.take_along_axis(index, y.reshape(-1, 1), axis=1) + + return position + + +def get_true_label_cumsum_proba( + y: ArrayLike, + y_pred_proba: NDArray, + classes: ArrayLike +) -> Tuple[NDArray, NDArray]: + """ + Compute the cumsumed probability of the true label. + + Parameters + ---------- + y: NDArray of shape (n_samples, ) + Array with the labels. + + y_pred_proba: NDArray of shape (n_samples, n_classes) + Predictions of the model. + + classes: NDArray of shape (n_classes, ) + Array with the classes. + + Returns + ------- + Tuple[NDArray, NDArray] of shapes (n_samples, 1) and (n_samples, ). + The first element is the cumsum probability of the true label. + The second is the sorted position of the true label. + """ + y_true = label_binarize(y=y, classes=classes) + index_sorted = np.fliplr(np.argsort(y_pred_proba, axis=1)) + y_pred_sorted = np.take_along_axis(y_pred_proba, index_sorted, axis=1) + y_true_sorted = np.take_along_axis(y_true, index_sorted, axis=1) + y_pred_sorted_cumsum = np.cumsum(y_pred_sorted, axis=1) + cutoff = np.argmax(y_true_sorted, axis=1) + true_label_cumsum_proba = np.take_along_axis( + y_pred_sorted_cumsum, cutoff.reshape(-1, 1), axis=1 + ) + + return true_label_cumsum_proba, cutoff + 1 + + +def check_include_last_label( + include_last_label: Optional[Union[bool, str]] +) -> Optional[Union[bool, str]]: + """ + Check if ``include_last_label`` is a boolean or a string. + Else raise error. + + Parameters + ---------- + include_last_label: Optional[Union[bool, str]] + Whether or not to include last label in + prediction sets for the ``"aps"`` method. Choose among: + + - ``False``, does not include label whose cumulated score is just + over the quantile. + + - ``True``, includes label whose cumulated score is just over the + quantile, unless there is only one label in the prediction set. + + - ``"randomized"``, randomly includes label whose cumulated score + is just over the quantile based on the comparison of a uniform + number and the difference between the cumulated score of the last + label and the quantile. + + Returns + ------- + Optional[Union[bool, str]] + + Raises + ------ + ValueError + "Invalid include_last_label argument. " + "Should be a boolean or 'randomized'." + """ + if ( + (not isinstance(include_last_label, bool)) and + (not include_last_label == "randomized") + ): + raise ValueError( + "Invalid include_last_label argument. " + "Should be a boolean or 'randomized'." + ) + else: + return include_last_label + + +def check_proba_normalized( + y_pred_proba: ArrayLike, + axis: int = 1 +) -> NDArray: + """ + Check if for all the samples the sum of the probabilities is equal to one. + + Parameters + ---------- + y_pred_proba: ArrayLike of shape (n_samples, n_classes) or + (n_samples, n_train_samples, n_classes) + Softmax output of a model. + + Returns + ------- + ArrayLike of shape (n_samples, n_classes) + Softmax output of a model if the scores all sum to one. + + Raises + ------ + ValueError + If the sum of the scores is not equal to one. + """ + sum_proba = np.sum(y_pred_proba, axis=axis) + err_msg = "The sum of the scores is not equal to one." + np.testing.assert_allclose(sum_proba, 1, err_msg=err_msg, rtol=1e-5) + y_pred_proba = cast(NDArray, y_pred_proba).astype(np.float64) + + return y_pred_proba + + +def get_last_index_included( + y_pred_proba_cumsum: NDArray, + threshold: NDArray, + include_last_label: Optional[Union[bool, str]] +) -> NDArray: + """ + Return the index of the last included sorted probability + depending if we included the first label over the quantile + or not. + + Parameters + ---------- + y_pred_proba_cumsum: NDArray of shape (n_samples, n_classes) + Cumsumed probabilities in the original order. + + threshold: NDArray of shape (n_alpha,) or shape (n_samples_train,) + Threshold to compare with y_proba_last_cumsum, can be either: + + - the quantiles associated with alpha values when + ``cv`` == "prefit", ``cv`` == "split" + or ``agg_scores`` is "mean" + + - the conformity score from training samples otherwise + (i.e., when ``cv`` is a CV splitter and + ``agg_scores`` is "crossval") + + include_last_label: Union[bool, str] + Whether or not include the last label. If 'randomized', + the last label is included. + + Returns + ------- + NDArray of shape (n_samples, n_alpha) + Index of the last included sorted probability. + """ + if include_last_label or include_last_label == 'randomized': + y_pred_index_last = ( + np.ma.masked_less( + y_pred_proba_cumsum + - threshold[np.newaxis, :], + -EPSILON + ).argmin(axis=1) + ) + else: + max_threshold = np.maximum( + threshold[np.newaxis, :], + np.min(y_pred_proba_cumsum, axis=1) + ) + y_pred_index_last = np.argmax( + np.ma.masked_greater( + y_pred_proba_cumsum - max_threshold[:, np.newaxis, :], + EPSILON + ), axis=1 + ) + return y_pred_index_last[:, np.newaxis, :] + + +def get_last_included_proba( + y_pred_proba: NDArray, + thresholds: NDArray, + include_last_label: Union[bool, str, None], + method: str, + lambda_: Union[NDArray, float, None], + k_star: Union[NDArray, Any] +) -> Tuple[NDArray, NDArray, NDArray]: + """ + Function that returns the smallest score + among those which are included in the prediciton set. + + Parameters + ---------- + y_pred_proba: NDArray of shape (n_samples, n_classes) + Predictions of the model. + + thresholds: NDArray of shape (n_alphas, ) + Quantiles that have been computed from the conformity scores. + + include_last_label: Union[bool, str, None] + Whether to include or not the label whose score exceeds the threshold. + + lambda_: Union[NDArray, float, None] of shape (n_alphas) + Values of lambda for the regularization. + + k_star: Union[NDArray, Any] + Values of k for the regularization. + + Returns + ------- + Tuple[ArrayLike, ArrayLike, ArrayLike] + Arrays of shape (n_samples, n_classes, n_alphas), + (n_samples, 1, n_alphas) and (n_samples, 1, n_alphas). + They are respectively the cumsumed scores in the original + order which can be different according to the value of alpha + with the RAPS method, the index of the last included score + and the value of the last included score. + """ + index_sorted = np.flip( + np.argsort(y_pred_proba, axis=1), axis=1 + ) + # sort probabilities by decreasing order + y_pred_proba_sorted = np.take_along_axis( + y_pred_proba, index_sorted, axis=1 + ) + # get sorted cumulated score + y_pred_proba_sorted_cumsum = np.cumsum( + y_pred_proba_sorted, axis=1 + ) + + if method == "raps": + y_pred_proba_sorted_cumsum += lambda_ * np.maximum( + 0, + np.cumsum( + np.ones(y_pred_proba_sorted_cumsum.shape), axis=1 + ) - k_star + ) + # get cumulated score at their original position + y_pred_proba_cumsum = np.take_along_axis( + y_pred_proba_sorted_cumsum, + np.argsort(index_sorted, axis=1), + axis=1 + ) + # get index of the last included label + y_pred_index_last = get_last_index_included( + y_pred_proba_cumsum, + thresholds, + include_last_label + ) + # get the probability of the last included label + y_pred_proba_last = np.take_along_axis( + y_pred_proba, + y_pred_index_last, + axis=1 + ) + + zeros_scores_proba_last = (y_pred_proba_last <= EPSILON) + + # If the last included proba is zero, change it to the + # smallest non-zero value to avoid inluding them in the + # prediction sets. + if np.sum(zeros_scores_proba_last) > 0: + y_pred_proba_last[zeros_scores_proba_last] = np.expand_dims( + np.min( + np.ma.masked_less( + y_pred_proba, + EPSILON + ).filled(fill_value=np.inf), + axis=1 + ), axis=1 + )[zeros_scores_proba_last] + + return y_pred_proba_cumsum, y_pred_index_last, y_pred_proba_last + + +def add_random_tie_breaking( + prediction_sets: NDArray, + y_pred_index_last: NDArray, + y_pred_proba_cumsum: NDArray, + y_pred_proba_last: NDArray, + threshold: NDArray, + method: str, + random_state: Optional[Union[int, np.random.RandomState]] = None, + lambda_star: Optional[Union[NDArray, float]] = None, + k_star: Optional[Union[NDArray, None]] = None +) -> NDArray: + """ + Randomly remove last label from prediction set based on the + comparison between a random number and the difference between + cumulated score of the last included label and the quantile. + + Parameters + ---------- + prediction_sets: NDArray of shape + (n_samples, n_classes, n_threshold) + Prediction set for each observation and each alpha. + + y_pred_index_last: NDArray of shape (n_samples, threshold) + Index of the last included label. + + y_pred_proba_cumsum: NDArray of shape (n_samples, n_classes) + Cumsumed probability of the model in the original order. + + y_pred_proba_last: NDArray of shape (n_samples, 1, threshold) + Last included probability. + + threshold: NDArray of shape (n_alpha,) or shape (n_samples_train,) + Threshold to compare with y_proba_last_cumsum, can be either: + + - the quantiles associated with alpha values when ``cv`` == "prefit", + ``cv`` == "split" or ``agg_scores`` is "mean" + + - the conformity score from training samples otherwise + (i.e., when ``cv`` is CV splitter and ``agg_scores`` is "crossval") + + method: str + Method that determines how to remove last label in the prediction set. + + - if "cumulated_score" or "aps", compute V parameter from Romano+(2020) + + - else compute V parameter from Angelopoulos+(2020) + + lambda_star: Union[NDArray, float, None] of shape (n_alpha): + Optimal value of the regulizer lambda. + + k_star: Union[NDArray, None] of shape (n_alpha): + Optimal value of the regulizer k. + + Returns + ------- + NDArray of shape (n_samples, n_classes, n_alpha) + Updated version of prediction_sets with randomly removed labels. + """ + # get cumsumed probabilities up to last retained label + y_proba_last_cumsumed = np.squeeze( + np.take_along_axis( + y_pred_proba_cumsum, + y_pred_index_last, + axis=1 + ), axis=1 + ) + + if method in ["cumulated_score", "aps"]: + # compute V parameter from Romano+(2020) + vs = ( + (y_proba_last_cumsumed - threshold.reshape(1, -1)) / + y_pred_proba_last[:, 0, :] + ) + else: + # compute V parameter from Angelopoulos+(2020) + L = np.sum(prediction_sets, axis=1) + vs = ( + (y_proba_last_cumsumed - threshold.reshape(1, -1)) / + ( + y_pred_proba_last[:, 0, :] - + lambda_star * np.maximum(0, L - k_star) + + lambda_star * (L > k_star) + ) + ) + + # get random numbers for each observation and alpha value + random_state = check_random_state(random_state) + random_state = cast(np.random.RandomState, random_state) + us = random_state.uniform(size=(prediction_sets.shape[0], 1)) + # remove last label from comparison between uniform number and V + vs_less_than_us = np.less_equal(vs - us, EPSILON) + np.put_along_axis( + prediction_sets, + y_pred_index_last, + vs_less_than_us[:, np.newaxis, :], + axis=1 + ) + return prediction_sets diff --git a/mapie/conformity_scores/utils.py b/mapie/conformity_scores/utils.py index 8cc3bf9d4..3206f90ca 100644 --- a/mapie/conformity_scores/utils.py +++ b/mapie/conformity_scores/utils.py @@ -1,26 +1,92 @@ -import numpy as np -from mapie._typing import NDArray +from typing import Optional +from .regression import BaseRegressionScore +from .classification import BaseClassificationScore +from .bounds import AbsoluteConformityScore +from .sets import APS, LAC, TopK -def get_true_label_position(y_pred_proba: NDArray, y: NDArray) -> NDArray: + +def check_regression_conformity_score( + conformity_score: Optional[BaseRegressionScore], + sym: bool = True, +) -> BaseRegressionScore: """ - Return the sorted position of the true label in the - prediction + Check parameter ``conformity_score`` for regression task. + + Raises + ------ + ValueError + If parameters are not valid. - Parameters - ---------- - y_pred_proba: NDArray of shape (n_samples, n_classes) - Model prediction. + Examples + -------- + >>> from mapie.conformity_scores.checks import ( + ... check_regression_conformity_score + ... ) + >>> try: + ... check_regression_conformity_score(1) + ... except Exception as exception: + ... print(exception) + ... + Invalid conformity_score argument. + Must be None or a ConformityScore instance. + """ + if conformity_score is None: + return AbsoluteConformityScore(sym=sym) + elif isinstance(conformity_score, BaseRegressionScore): + return conformity_score + else: + raise ValueError( + "Invalid conformity_score argument.\n" + "Must be None or a ConformityScore instance." + ) - y: NDArray of shape (n_samples) - Labels. - Returns - ------- - NDArray of shape (n_samples, 1) - Position of the true label in the prediction. +def check_classification_conformity_score( + conformity_score: Optional[BaseClassificationScore] = None, + method: Optional[str] = None, +) -> BaseClassificationScore: """ - index = np.argsort(np.fliplr(np.argsort(y_pred_proba, axis=1))) - position = np.take_along_axis(index, y.reshape(-1, 1), axis=1) + Check parameter ``conformity_score`` for classification task. - return position + Raises + ------ + ValueError + If parameters are not valid. + + Examples + -------- + >>> from mapie.conformity_scores.checks import ( + ... check_classification_conformity_score + ... ) + >>> try: + ... check_classification_conformity_score(1) + ... except Exception as exception: + ... print(exception) + ... + Invalid conformity_score argument. + Must be None or a ConformityScore instance. + """ + allowed_methods = ['lac', 'naive', 'aps', 'raps', 'top_k'] + deprecated_methods = ['score', 'cumulated_score'] + if method is not None: + if method in ['score', 'lac']: + return LAC() + if method in ['naive', 'cumulated_score', 'aps', 'raps']: + return APS() + if method == 'top_k': + return TopK() + else: + raise ValueError( + f"Invalid method. Allowed values are {allowed_methods}. " + f"Deprecated values are {deprecated_methods}. " + ) + elif isinstance(conformity_score, BaseClassificationScore): + return conformity_score + elif conformity_score is None: + return LAC() + else: + raise ValueError( + "Invalid conformity_score argument.\n" + "Must be None or a ConformityScore instance." + ) diff --git a/mapie/regression/regression.py b/mapie/regression/regression.py index 3085ce82d..018c30677 100644 --- a/mapie/regression/regression.py +++ b/mapie/regression/regression.py @@ -12,13 +12,14 @@ from sklearn.utils.validation import _check_y, check_is_fitted, indexable from mapie._typing import ArrayLike, NDArray -from mapie.conformity_scores import ConformityScore, ResidualNormalisedScore +from mapie.conformity_scores import (BaseRegressionScore, + ResidualNormalisedScore) +from mapie.conformity_scores.utils import check_regression_conformity_score from mapie.estimator.regressor import EnsembleRegressor from mapie.utils import (check_alpha, check_alpha_and_n_samples, check_cv, check_estimator_fit_predict, check_n_features_in, check_n_jobs, check_null_weight, check_verbose, get_effective_calibration_samples) -from mapie.conformity_scores.checks import check_conformity_score class MapieRegressor(BaseEstimator, RegressorMixin): @@ -226,7 +227,7 @@ def __init__( n_jobs: Optional[int] = None, agg_function: Optional[str] = "mean", verbose: int = 0, - conformity_score: Optional[ConformityScore] = None, + conformity_score: Optional[BaseRegressionScore] = None, random_state: Optional[Union[int, np.random.RandomState]] = None, ) -> None: self.estimator = estimator @@ -431,7 +432,7 @@ def _check_fit_parameters( self.method = "base" estimator = self._check_estimator(self.estimator) agg_function = self._check_agg_function(self.agg_function) - cs_estimator = check_conformity_score( + cs_estimator = check_regression_conformity_score( self.conformity_score, self.default_sym_ ) if isinstance(cs_estimator, ResidualNormalisedScore) and \ @@ -449,7 +450,7 @@ def _check_fit_parameters( # Casting cv = cast(BaseCrossValidator, cv) estimator = cast(RegressorMixin, estimator) - cs_estimator = cast(ConformityScore, cs_estimator) + cs_estimator = cast(BaseRegressionScore, cs_estimator) agg_function = cast(Optional[str], agg_function) X = cast(NDArray, X) y = cast(NDArray, y) @@ -539,7 +540,7 @@ def fit( # Compute the conformity scores (manage jk-ab case) self.conformity_scores_ = \ self.conformity_score_function_.get_conformity_scores( - X, y, y_pred + y, y_pred, X=X ) return self @@ -639,16 +640,15 @@ def predict( check_alpha_and_n_samples(alpha_np, n) # Predict the target with confidence intervals - y_pred, y_pred_low, y_pred_up = \ - self.conformity_score_function_.get_bounds( - X, - self.estimator_, - self.conformity_scores_, - alpha_np, - ensemble=ensemble, - method=self.method, - optimize_beta=optimize_beta, - allow_infinite_bounds=allow_infinite_bounds - ) + outputs = self.conformity_score_function_.predict_set( + X, alpha_np, + estimator=self.estimator_, + conformity_scores=self.conformity_scores_, + ensemble=ensemble, + method=self.method, + optimize_beta=optimize_beta, + allow_infinite_bounds=allow_infinite_bounds + ) + y_pred, y_pred_low, y_pred_up = outputs return np.array(y_pred), np.stack([y_pred_low, y_pred_up], axis=1) diff --git a/mapie/regression/time_series_regression.py b/mapie/regression/time_series_regression.py index b4bf0cc03..b96dc17dc 100644 --- a/mapie/regression/time_series_regression.py +++ b/mapie/regression/time_series_regression.py @@ -9,7 +9,7 @@ from sklearn.utils.validation import check_is_fitted from mapie._typing import ArrayLike, NDArray -from mapie.conformity_scores import ConformityScore +from mapie.conformity_scores import BaseRegressionScore from mapie.regression import MapieRegressor from mapie.utils import check_alpha, check_gamma @@ -66,7 +66,7 @@ def __init__( n_jobs: Optional[int] = None, agg_function: Optional[str] = "mean", verbose: int = 0, - conformity_score: Optional[ConformityScore] = None, + conformity_score: Optional[BaseRegressionScore] = None, random_state: Optional[Union[int, np.random.RandomState]] = None, ) -> None: super().__init__( @@ -114,7 +114,9 @@ def _relative_conformity_scores( """ y_pred = super().predict(X, ensemble=ensemble) scores = np.array( - self.conformity_score_function_.get_conformity_scores(X, y, y_pred) + self.conformity_score_function_.get_conformity_scores( + y, y_pred, X=X + ) ) return scores diff --git a/mapie/tests/test_classification.py b/mapie/tests/test_classification.py index 740c4df6b..1b6bf6a12 100644 --- a/mapie/tests/test_classification.py +++ b/mapie/tests/test_classification.py @@ -23,6 +23,11 @@ from mapie._typing import ArrayLike, NDArray from mapie.classification import MapieClassifier +from mapie.conformity_scores.sets.aps import APS +from mapie.conformity_scores.sets.utils import ( + check_proba_normalized, get_last_included_proba, + get_true_label_cumsum_proba +) from mapie.metrics import classification_coverage_score from mapie.utils import check_alpha @@ -1028,14 +1033,14 @@ def test_too_large_cv(cv: Any) -> None: ) def test_invalid_include_last_label(include_last_label: Any) -> None: """Test that invalid include_last_label raise errors.""" - mapie_clf = MapieClassifier(random_state=random_state) + mapie_clf = MapieClassifier(method='aps', random_state=random_state) mapie_clf.fit(X_toy, y_toy) with pytest.raises( ValueError, match=r".*Invalid include_last_label argument.*" ): mapie_clf.predict( X_toy, - y_toy, + alpha=0.5, include_last_label=include_last_label ) @@ -1504,7 +1509,8 @@ def test_cumulated_scores() -> None: include_last_label=True, alpha=alpha ) - np.testing.assert_allclose(mapie_clf.quantiles_, quantile) + computed_quantile = mapie_clf.conformity_score_function_.quantiles_ + np.testing.assert_allclose(computed_quantile, quantile) np.testing.assert_allclose(y_ps[:, :, 0], cumclf.y_pred_sets) @@ -1532,7 +1538,8 @@ def test_image_cumulated_scores(X: Dict[str, ArrayLike]) -> None: include_last_label=True, alpha=alpha ) - np.testing.assert_allclose(mapie.quantiles_, quantile) + computed_quantile = mapie.conformity_score_function_.quantiles_ + np.testing.assert_allclose(computed_quantile, quantile) np.testing.assert_allclose(y_ps[:, :, 0], cumclf.y_pred_sets) @@ -1606,28 +1613,16 @@ def test_method_error_in_fit(monkeypatch: Any, method: str) -> None: mapie_clf.fit(X_toy, y_toy) -@pytest.mark.parametrize("method", WRONG_METHODS) -@pytest.mark.parametrize("alpha", [0.2, [0.2, 0.3], (0.2, 0.3)]) -def test_method_error_in_predict(method: Any, alpha: float) -> None: - """Test else condition for the method in .predict""" - mapie_clf = MapieClassifier( - method="lac", random_state=random_state - ) - mapie_clf.fit(X_toy, y_toy) - mapie_clf.method = method - with pytest.raises(ValueError, match=r".*Invalid method.*"): - mapie_clf.predict(X_toy, alpha=alpha) - - @pytest.mark.parametrize("include_labels", WRONG_INCLUDE_LABELS) @pytest.mark.parametrize("alpha", [0.2, [0.2, 0.3], (0.2, 0.3)]) def test_include_label_error_in_predict( monkeypatch: Any, include_labels: Union[bool, str], alpha: float ) -> None: """Test else condition for include_label parameter in .predict""" + from mapie.conformity_scores.sets import utils monkeypatch.setattr( - MapieClassifier, - "_check_include_last_label", + utils, + "check_include_last_label", do_nothing ) mapie_clf = MapieClassifier( @@ -1694,8 +1689,7 @@ def test_pred_proba_float64() -> None: y_pred_proba = np.random.random((1000, 10)).astype(np.float32) sum_of_rows = y_pred_proba.sum(axis=1) normalized_array = y_pred_proba / sum_of_rows[:, np.newaxis] - mapie = MapieClassifier(random_state=random_state) - checked_normalized_array = mapie._check_proba_normalized(normalized_array) + checked_normalized_array = check_proba_normalized(normalized_array) assert checked_normalized_array.dtype == "float64" @@ -1744,12 +1738,9 @@ def test_regularize_conf_scores_shape(k_lambda) -> None: Test that the conformity scores have the correct shape. """ lambda_, k = k_lambda[0], k_lambda[1] - args_init, _ = STRATEGIES["raps"] - clf = LogisticRegression().fit(X, y) - mapie_clf = MapieClassifier(estimator=clf, **args_init) conf_scores = np.random.rand(100, 1) cutoff = np.cumsum(np.ones(conf_scores.shape)) - 1 - reg_conf_scores = mapie_clf._regularize_conformity_score( + reg_conf_scores = APS._regularize_conformity_score( k, lambda_, conf_scores, cutoff ) @@ -1768,9 +1759,8 @@ def test_get_true_label_cumsum_proba_shape() -> None: estimator=clf, random_state=random_state ) mapie_clf.fit(X, y) - cumsum_proba, cutoff = mapie_clf._get_true_label_cumsum_proba( - y, y_pred - ) + classes = mapie_clf.classes_ + cumsum_proba, cutoff = get_true_label_cumsum_proba(y, y_pred, classes) assert cumsum_proba.shape == (len(X), 1) assert cutoff.shape == (len(X), ) @@ -1787,9 +1777,8 @@ def test_get_true_label_cumsum_proba_result() -> None: estimator=clf, random_state=random_state ) mapie_clf.fit(X_toy, y_toy) - cumsum_proba, cutoff = mapie_clf._get_true_label_cumsum_proba( - y_toy, y_pred - ) + classes = mapie_clf.classes_ + cumsum_proba, cutoff = get_true_label_cumsum_proba(y_toy, y_pred, classes) np.testing.assert_allclose( cumsum_proba, np.array( @@ -1829,10 +1818,11 @@ def test_get_last_included_proba_shape(k_lambda, strategy): mapie = MapieClassifier(estimator=clf, **STRATEGIES[strategy][0]) include_last_label = STRATEGIES[strategy][1]["include_last_label"] - y_p_p_c, y_p_i_l, y_p_p_i_l = mapie._get_last_included_proba( - y_pred_proba, thresholds, - include_last_label, lambda_, k - ) + y_p_p_c, y_p_i_l, y_p_p_i_l = \ + get_last_included_proba( + y_pred_proba, thresholds, include_last_label, + mapie.method, lambda_, k + ) assert y_p_p_c.shape == (len(X), len(np.unique(y)), len(thresholds)) assert y_p_i_l.shape == (len(X), 1, len(thresholds)) diff --git a/mapie/tests/test_conformity_scores.py b/mapie/tests/test_conformity_scores.py index 4d4a32722..4ade1f354 100644 --- a/mapie/tests/test_conformity_scores.py +++ b/mapie/tests/test_conformity_scores.py @@ -5,31 +5,34 @@ from sklearn.preprocessing import PolynomialFeatures from mapie._typing import ArrayLike, NDArray -from mapie.conformity_scores import (AbsoluteConformityScore, ConformityScore, - GammaConformityScore, - ResidualNormalisedScore) +from mapie.conformity_scores import ( + AbsoluteConformityScore, BaseRegressionScore, GammaConformityScore, + ResidualNormalisedScore +) from mapie.regression import MapieRegressor X_toy = np.array([0, 1, 2, 3, 4, 5]).reshape(-1, 1) y_toy = np.array([5, 7, 9, 11, 13, 15]) -y_pred_list = [4, 7, 10, 12, 13, 12] -conf_scores_list = [1, 0, -1, -1, 0, 3] -conf_scores_gamma_list = [1 / 4, 0, -1 / 10, -1 / 12, 0, 3 / 12] -conf_scores_residual_norm_list = [0.2, 0., 0.11111111, 0.09090909, 0., 0.2] +y_pred_list = np.array([4, 7, 10, 12, 13, 12]) +conf_scores_list = np.array([1, 0, -1, -1, 0, 3]) +conf_scores_gamma_list = np.array([1 / 4, 0, -1 / 10, -1 / 12, 0, 3 / 12]) +conf_scores_residual_norm_list = np.array( + [0.2, 0., 0.11111111, 0.09090909, 0., 0.2] +) random_state = 42 -class DummyConformityScore(ConformityScore): +class DummyConformityScore(BaseRegressionScore): def __init__(self) -> None: super().__init__(sym=True, consistency_check=True) def get_signed_conformity_scores( - self, X: ArrayLike, y: ArrayLike, y_pred: ArrayLike, + self, y: ArrayLike, y_pred: ArrayLike, **kwargs ) -> NDArray: return np.subtract(y, y_pred) def get_estimation_distribution( - self, X: ArrayLike, y_pred: ArrayLike, conformity_scores: ArrayLike + self, y_pred: ArrayLike, conformity_scores: ArrayLike, **kwargs ) -> NDArray: """ A positive constant is added to the sum between predictions and @@ -42,7 +45,7 @@ def get_estimation_distribution( @pytest.mark.parametrize("sym", [False, True]) def test_error_mother_class_initialization(sym: bool) -> None: with pytest.raises(TypeError): - ConformityScore(sym) # type: ignore + BaseRegressionScore(sym) # type: ignore @pytest.mark.parametrize("y_pred", [np.array(y_pred_list), y_pred_list]) @@ -52,10 +55,10 @@ def test_absolute_conformity_score_get_conformity_scores( """Test conformity score computation for AbsoluteConformityScore.""" abs_conf_score = AbsoluteConformityScore() signed_conf_scores = abs_conf_score.get_signed_conformity_scores( - X_toy, y_toy, y_pred + y_toy, y_pred, X=X_toy ) conf_scores = abs_conf_score.get_conformity_scores( - X_toy, y_toy, y_pred + y_toy, y_pred, X=X_toy ) expected_signed_conf_scores = np.array(conf_scores_list) expected_conf_scores = np.abs(expected_signed_conf_scores) @@ -73,7 +76,7 @@ def test_absolute_conformity_score_get_estimation_distribution( """Test conformity observed value computation for AbsoluteConformityScore.""" # noqa: E501 abs_conf_score = AbsoluteConformityScore() y_obs = abs_conf_score.get_estimation_distribution( - X_toy, y_pred, conf_scores + y_pred, conf_scores, X=X_toy ) np.testing.assert_allclose(y_obs, y_toy) @@ -83,10 +86,10 @@ def test_absolute_conformity_score_consistency(y_pred: NDArray) -> None: """Test methods consistency for AbsoluteConformityScore.""" abs_conf_score = AbsoluteConformityScore() signed_conf_scores = abs_conf_score.get_signed_conformity_scores( - X_toy, y_toy, y_pred + y_toy, y_pred, X=X_toy, ) y_obs = abs_conf_score.get_estimation_distribution( - X_toy, y_pred, signed_conf_scores + y_pred, signed_conf_scores, X=X_toy, ) np.testing.assert_allclose(y_obs, y_toy) @@ -98,7 +101,7 @@ def test_gamma_conformity_score_get_conformity_scores( """Test conformity score computation for GammaConformityScore.""" gamma_conf_score = GammaConformityScore() conf_scores = gamma_conf_score.get_conformity_scores( - X_toy, y_toy, y_pred + y_toy, y_pred, X=X_toy ) expected_signed_conf_scores = np.array(conf_scores_gamma_list) np.testing.assert_allclose(conf_scores, expected_signed_conf_scores) @@ -118,7 +121,7 @@ def test_gamma_conformity_score_get_estimation_distribution( """Test conformity observed value computation for GammaConformityScore.""" # noqa: E501 gamma_conf_score = GammaConformityScore() y_obs = gamma_conf_score.get_estimation_distribution( - X_toy, y_pred, conf_scores + y_pred, conf_scores, X=X_toy ) np.testing.assert_allclose(y_obs, y_toy) @@ -128,10 +131,10 @@ def test_gamma_conformity_score_consistency(y_pred: NDArray) -> None: """Test methods consistency for GammaConformityScore.""" gamma_conf_score = GammaConformityScore() signed_conf_scores = gamma_conf_score.get_signed_conformity_scores( - X_toy, y_toy, y_pred + y_toy, y_pred, X=X_toy ) y_obs = gamma_conf_score.get_estimation_distribution( - X_toy, y_pred, signed_conf_scores + y_pred, signed_conf_scores, X=X_toy, ) np.testing.assert_allclose(y_obs, y_toy) @@ -152,7 +155,7 @@ def test_gamma_conformity_score_check_oberved_value( gamma_conf_score = GammaConformityScore() with pytest.raises(ValueError): gamma_conf_score.get_signed_conformity_scores( - [], y_toy, y_pred + y_toy, y_pred, X=[] ) @@ -189,14 +192,14 @@ def test_gamma_conformity_score_check_predicted_value( match=r".*At least one of the predicted target is negative.*" ): gamma_conf_score.get_signed_conformity_scores( - X_toy, y_toy, y_pred + y_toy, y_pred, X=X_toy ) with pytest.raises( ValueError, match=r".*At least one of the predicted target is negative.*" ): gamma_conf_score.get_estimation_distribution( - X_toy, y_pred, conf_scores + y_pred, conf_scores, X=X_toy ) @@ -207,14 +210,14 @@ def test_check_consistency() -> None: """ dummy_conf_score = DummyConformityScore() conformity_scores = dummy_conf_score.get_signed_conformity_scores( - X_toy, y_toy, y_pred_list + y_toy, y_pred_list ) with pytest.raises( ValueError, match=r".*The two functions get_conformity_scores.*" ): dummy_conf_score.check_consistency( - X_toy, y_toy, y_pred_list, conformity_scores + y_toy, y_pred_list, conformity_scores ) @@ -233,7 +236,7 @@ def test_residual_normalised_prefit_conformity_score_get_conformity_scores( random_state=random_state ) conf_scores = residual_norm_conf_score.get_conformity_scores( - X_toy, y_toy, y_pred + y_toy, y_pred, X=X_toy ) expected_signed_conf_scores = np.array(conf_scores_residual_norm_list) np.testing.assert_allclose(conf_scores, expected_signed_conf_scores) @@ -249,7 +252,7 @@ def test_residual_normalised_conformity_score_get_conformity_scores( """ residual_norm_score = ResidualNormalisedScore(random_state=random_state) conf_scores = residual_norm_score.get_conformity_scores( - X_toy, y_toy, y_pred + y_toy, y_pred, X=X_toy ) expected_signed_conf_scores = np.array( [np.nan, np.nan, 1.e+08, 1.e+08, 0.e+00, 3.e+08] @@ -264,7 +267,7 @@ def test_residual_normalised_score_prefit_with_notfitted_estim() -> None: ) with pytest.raises(ValueError): residual_norm_conf_score.get_conformity_scores( - X_toy, y_toy, y_pred_list + y_toy, y_pred_list, X=X_toy ) @@ -272,9 +275,11 @@ def test_residual_normalised_score_with_default_params() -> None: """Test that no error is raised with default parameters.""" residual_norm_score = ResidualNormalisedScore() conf_scores = residual_norm_score.get_conformity_scores( - X_toy, y_toy, y_pred_list + y_toy, y_pred_list, X=X_toy + ) + residual_norm_score.get_estimation_distribution( + y_toy, conf_scores, X=X_toy ) - residual_norm_score.get_estimation_distribution(X_toy, y_toy, conf_scores) def test_invalid_estimator() -> None: @@ -288,7 +293,7 @@ def __init__(self): ) with pytest.raises(ValueError): residual_norm_conf_score.get_conformity_scores( - X_toy, y_toy, y_pred_list + y_toy, y_pred_list, X=X_toy ) @@ -356,7 +361,7 @@ def predict(self, X): ) with pytest.warns(UserWarning): residual_norm_conf_score.get_conformity_scores( - X_toy, y_toy, y_pred_list + y_toy, y_pred_list, X=X_toy ) @@ -370,10 +375,10 @@ def test_residual_normalised_prefit_get_estimation_distribution() -> None: residual_estimator=estim, prefit=True ) conf_scores = residual_normalised_conf_score.get_conformity_scores( - X_toy, y_toy, y_pred_list + y_toy, y_pred_list, X=X_toy ) residual_normalised_conf_score.get_estimation_distribution( - X_toy, y_pred_list, conf_scores + y_pred_list, conf_scores, X=X_toy ) @@ -382,7 +387,7 @@ def test_residual_normalised_prefit_get_estimation_distribution() -> None: ResidualNormalisedScore()]) @pytest.mark.parametrize("alpha", [[0.5], [0.5, 0.6]]) def test_intervals_shape_with_every_score( - score: ConformityScore, + score: BaseRegressionScore, alpha: NDArray ) -> None: estim = LinearRegression().fit(X_toy, y_toy) diff --git a/mapie/tests/test_conformity_scores_sets.py b/mapie/tests/test_conformity_scores_sets.py new file mode 100644 index 000000000..b6349b4fc --- /dev/null +++ b/mapie/tests/test_conformity_scores_sets.py @@ -0,0 +1,37 @@ +from typing import Optional + +import pytest + +# from mapie._typing import ArrayLike, NDArray +from mapie.conformity_scores import BaseClassificationScore +from mapie.conformity_scores.sets import APS, LAC, TopK +from mapie.conformity_scores.utils import check_classification_conformity_score + + +cs_list = [None, LAC(), APS(), TopK()] +method_list = [None, 'naive', 'aps', 'raps', 'lac', 'top_k'] + + +def test_error_mother_class_initialization() -> None: + with pytest.raises(TypeError): + BaseClassificationScore() # type: ignore + + +@pytest.mark.parametrize("conformity_score", cs_list) +def test_check_classification_conformity_score( + conformity_score: Optional[BaseClassificationScore] +) -> None: + assert isinstance( + check_classification_conformity_score(conformity_score), + BaseClassificationScore + ) + + +@pytest.mark.parametrize("method", method_list) +def test_check_classification_method( + method: Optional[str] +) -> None: + assert isinstance( + check_classification_conformity_score(method=method), + BaseClassificationScore + ) diff --git a/mapie/tests/test_regression.py b/mapie/tests/test_regression.py index 1dad0776e..c35ebec34 100644 --- a/mapie/tests/test_regression.py +++ b/mapie/tests/test_regression.py @@ -6,15 +6,17 @@ import numpy as np import pandas as pd import pytest + from sklearn.compose import ColumnTransformer from sklearn.datasets import make_regression from sklearn.dummy import DummyRegressor from sklearn.ensemble import GradientBoostingRegressor from sklearn.impute import SimpleImputer from sklearn.linear_model import LinearRegression -from sklearn.model_selection import (GroupKFold, KFold, LeaveOneOut, - PredefinedSplit, ShuffleSplit, - train_test_split) +from sklearn.model_selection import ( + GroupKFold, KFold, LeaveOneOut, PredefinedSplit, ShuffleSplit, + train_test_split +) from sklearn.pipeline import Pipeline, make_pipeline from sklearn.preprocessing import OneHotEncoder from sklearn.utils.validation import check_is_fitted @@ -23,9 +25,10 @@ from mapie._typing import NDArray from mapie.aggregation_functions import aggregate_all -from mapie.conformity_scores import (AbsoluteConformityScore, ConformityScore, - GammaConformityScore, - ResidualNormalisedScore) +from mapie.conformity_scores import ( + AbsoluteConformityScore, BaseRegressionScore, GammaConformityScore, + ResidualNormalisedScore +) from mapie.estimator.regressor import EnsembleRegressor from mapie.metrics import regression_coverage_score from mapie.regression import MapieRegressor @@ -784,7 +787,7 @@ def test_pipeline_compatibility() -> None: "conformity_score", [AbsoluteConformityScore(), GammaConformityScore()] ) def test_conformity_score( - strategy: str, conformity_score: ConformityScore + strategy: str, conformity_score: BaseRegressionScore ) -> None: """Test that any conformity score function with MAPIE raises no error.""" mapie_reg = MapieRegressor( @@ -799,7 +802,7 @@ def test_conformity_score( "conformity_score", [ResidualNormalisedScore()] ) def test_conformity_score_with_split_strategies( - conformity_score: ConformityScore + conformity_score: BaseRegressionScore ) -> None: """ Test that any conformity score function that handle only split strategies diff --git a/mapie/tests/test_utils_classification_conformity_scores.py b/mapie/tests/test_utils_classification_conformity_scores.py index a74a6892a..9d07fa8bc 100644 --- a/mapie/tests/test_utils_classification_conformity_scores.py +++ b/mapie/tests/test_utils_classification_conformity_scores.py @@ -3,9 +3,7 @@ import numpy as np import pytest -from mapie.conformity_scores.utils import ( - get_true_label_position, -) +from mapie.conformity_scores.sets.utils import get_true_label_position from mapie._typing import NDArray Y_TRUE_PROBA_PLACE = [ From 0eb720356dc5ff2ed32cd9caa64807dc3df76adc Mon Sep 17 00:00:00 2001 From: Thibault Cordier Date: Mon, 1 Jul 2024 16:58:22 +0200 Subject: [PATCH 02/46] FIX: path access in test doctring --- mapie/conformity_scores/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mapie/conformity_scores/utils.py b/mapie/conformity_scores/utils.py index 3206f90ca..a6b3283c7 100644 --- a/mapie/conformity_scores/utils.py +++ b/mapie/conformity_scores/utils.py @@ -20,7 +20,7 @@ def check_regression_conformity_score( Examples -------- - >>> from mapie.conformity_scores.checks import ( + >>> from mapie.conformity_scores.utils import ( ... check_regression_conformity_score ... ) >>> try: @@ -56,7 +56,7 @@ def check_classification_conformity_score( Examples -------- - >>> from mapie.conformity_scores.checks import ( + >>> from mapie.conformity_scores.utils import ( ... check_classification_conformity_score ... ) >>> try: From fc0f46a18475ff3f6d6cb1513f46215b16d78bc9 Mon Sep 17 00:00:00 2001 From: Thibault Cordier Date: Mon, 1 Jul 2024 17:32:23 +0200 Subject: [PATCH 03/46] FIX: adapt exemple code with new signatures --- .../1-quickstart/plot_comp_methods_on_2d_dataset.py | 5 +++-- examples/classification/4-tutorials/plot_crossconformal.py | 5 +++-- .../4-tutorials/plot_main-tutorial-binary-classification.py | 5 +++-- .../4-tutorials/plot_main-tutorial-classification.py | 2 +- .../plot_conformal_predictive_distribution.py | 2 +- mapie/classification.py | 3 --- mapie/conformity_scores/classification.py | 5 +++++ 7 files changed, 16 insertions(+), 11 deletions(-) diff --git a/examples/classification/1-quickstart/plot_comp_methods_on_2d_dataset.py b/examples/classification/1-quickstart/plot_comp_methods_on_2d_dataset.py index b03e8cb97..f156233a4 100644 --- a/examples/classification/1-quickstart/plot_comp_methods_on_2d_dataset.py +++ b/examples/classification/1-quickstart/plot_comp_methods_on_2d_dataset.py @@ -170,7 +170,7 @@ def plot_scores( for i, method in enumerate(methods): conformity_scores = mapie[method].conformity_scores_ n = mapie[method].n_samples_ - quantiles = mapie[method].quantiles_ + quantiles = mapie[method].conformity_score_function_.quantiles_ plot_scores(alpha, conformity_scores, quantiles, method, axs[i]) plt.show() @@ -270,7 +270,8 @@ def plot_results( axs[0].set_xlabel("1 - alpha") axs[0].set_ylabel("Quantile") for method in methods: - axs[0].scatter(1 - alpha_, mapie[method].quantiles_, label=method) + quantiles = mapie[method].conformity_score_function_.quantiles_ + axs[0].scatter(1 - alpha_, quantiles, label=method) axs[0].legend() for method in methods: axs[1].scatter(1 - alpha_, coverage[method], label=method) diff --git a/examples/classification/4-tutorials/plot_crossconformal.py b/examples/classification/4-tutorials/plot_crossconformal.py index 8200e6c26..7fe8bbac5 100644 --- a/examples/classification/4-tutorials/plot_crossconformal.py +++ b/examples/classification/4-tutorials/plot_crossconformal.py @@ -134,10 +134,11 @@ fig, axs = plt.subplots(1, len(mapies["lac"]), figsize=(20, 4)) for i, (key, mapie) in enumerate(mapies["lac"].items()): + quantiles = mapie.conformity_score_function_.quantiles_[9] axs[i].set_xlabel("Conformity scores") axs[i].hist(mapie.conformity_scores_) - axs[i].axvline(mapie.quantiles_[9], ls="--", color="k") - axs[i].set_title(f"split={key}\nquantile={mapie.quantiles_[9]:.3f}") + axs[i].axvline(quantiles, ls="--", color="k") + axs[i].set_title(f"split={key}\nquantile={quantiles:.3f}") plt.suptitle( "Distribution of scores on each calibration fold for the " f"{methods[0]} method" diff --git a/examples/classification/4-tutorials/plot_main-tutorial-binary-classification.py b/examples/classification/4-tutorials/plot_main-tutorial-binary-classification.py index d7469f46b..24d20369a 100644 --- a/examples/classification/4-tutorials/plot_main-tutorial-binary-classification.py +++ b/examples/classification/4-tutorials/plot_main-tutorial-binary-classification.py @@ -188,7 +188,7 @@ def plot_scores( fig, axs = plt.subplots(1, 1, figsize=(10, 5)) conformity_scores = mapie_clf.conformity_scores_ -quantiles = mapie_clf.quantiles_ +quantiles = mapie_clf.conformity_score_function_.quantiles_ plot_scores(alpha, conformity_scores, quantiles, 'lac', axs) plt.show() @@ -309,10 +309,11 @@ def plot_results( def plot_coverages_widths(alpha, coverage, width, method): + quantiles = mapie_clf.conformity_score_function_.quantiles_ _, axs = plt.subplots(1, 3, figsize=(15, 5)) axs[0].set_xlabel("1 - alpha") axs[0].set_ylabel("Quantile") - axs[0].scatter(1 - alpha, mapie_clf.quantiles_, label=method) + axs[0].scatter(1 - alpha, quantiles, label=method) axs[0].legend() axs[1].scatter(1 - alpha, coverage, label=method) axs[1].set_xlabel("1 - alpha") diff --git a/examples/classification/4-tutorials/plot_main-tutorial-classification.py b/examples/classification/4-tutorials/plot_main-tutorial-classification.py index a7905cfe0..1003141d2 100644 --- a/examples/classification/4-tutorials/plot_main-tutorial-classification.py +++ b/examples/classification/4-tutorials/plot_main-tutorial-classification.py @@ -148,7 +148,7 @@ def plot_scores(n, alphas, scores, quantiles): scores = mapie_score.conformity_scores_ n = len(mapie_score.conformity_scores_) -quantiles = mapie_score.quantiles_ +quantiles = mapie_score.conformity_score_function_.quantiles_ plot_scores(n, alpha, scores, quantiles) ############################################################################## diff --git a/examples/regression/2-advanced-analysis/plot_conformal_predictive_distribution.py b/examples/regression/2-advanced-analysis/plot_conformal_predictive_distribution.py index 293404ca1..c0737c7ae 100644 --- a/examples/regression/2-advanced-analysis/plot_conformal_predictive_distribution.py +++ b/examples/regression/2-advanced-analysis/plot_conformal_predictive_distribution.py @@ -71,7 +71,7 @@ def get_cumulative_distribution_function(self, X): y_pred = self.predict(X) cs = self.conformity_scores_[~np.isnan(self.conformity_scores_)] res = self.conformity_score_function_.get_estimation_distribution( - X, y_pred.reshape((-1, 1)), cs + y_pred.reshape((-1, 1)), cs, X=X ) return res diff --git a/mapie/classification.py b/mapie/classification.py index 232d76251..626149add 100644 --- a/mapie/classification.py +++ b/mapie/classification.py @@ -146,9 +146,6 @@ class MapieClassifier(BaseEstimator, ClassifierMixin): conformity_scores_: ArrayLike of shape (n_samples_train) The conformity scores used to calibrate the prediction sets. - quantiles_: ArrayLike of shape (n_alpha) - The quantiles estimated from ``conformity_scores_`` and alpha values. - References ---------- [1] Mauricio Sadinle, Jing Lei, and Larry Wasserman. diff --git a/mapie/conformity_scores/classification.py b/mapie/conformity_scores/classification.py index 6c91b88ee..db9df2c05 100644 --- a/mapie/conformity_scores/classification.py +++ b/mapie/conformity_scores/classification.py @@ -31,6 +31,11 @@ class BaseClassificationScore(BaseConformityScore, metaclass=ABCMeta): It should be specified if ``consistency_check==True``. By default, it is defined by the default precision. + + Attributes + ---------- + quantiles_: ArrayLike of shape (n_alpha) + The quantiles estimated from ``conformity_scores_`` and alpha values. """ def __init__( From f064cc9f2039fdb1da83a93f551d80fd599c1974 Mon Sep 17 00:00:00 2001 From: Thibault Cordier Date: Tue, 2 Jul 2024 09:45:53 +0200 Subject: [PATCH 04/46] UPD: check tests for additional parameters in residual normalized score --- mapie/conformity_scores/bounds/residuals.py | 16 ++++++-- mapie/tests/test_conformity_scores.py | 42 +++++++++++++++++++++ 2 files changed, 55 insertions(+), 3 deletions(-) diff --git a/mapie/conformity_scores/bounds/residuals.py b/mapie/conformity_scores/bounds/residuals.py index f6bc9c7f3..f59084455 100644 --- a/mapie/conformity_scores/bounds/residuals.py +++ b/mapie/conformity_scores/bounds/residuals.py @@ -1,5 +1,5 @@ import warnings -from typing import Optional, Tuple, Union +from typing import Optional, Tuple, Union, cast import numpy as np from sklearn.base import RegressorMixin, clone @@ -245,7 +245,12 @@ def get_signed_conformity_scores( The learning is done with the log of the residual and later we use the exponential of the prediction to avoid negative values. """ - assert not (X is None) # TODO + if X is None: + raise ValueError( + "Additional parameters must be provided for the method to " + + "work (here `X` is missing)." + ) + X = cast(ArrayLike, X) (X, y, y_pred, self.residual_estimator_, @@ -307,7 +312,12 @@ def get_estimation_distribution( ``conformity_scores`` can be either the conformity scores or the quantile of the conformity scores. """ - assert not (X is None) # TODO + if X is None: + raise ValueError( + "Additional parameters must be provided for the method to " + + "work (here `X` is missing)." + ) + X = cast(ArrayLike, X) r_pred = self._predict_residual_estimator(X).reshape((-1, 1)) if not self.prefit: diff --git a/mapie/tests/test_conformity_scores.py b/mapie/tests/test_conformity_scores.py index 4ade1f354..06dfca94b 100644 --- a/mapie/tests/test_conformity_scores.py +++ b/mapie/tests/test_conformity_scores.py @@ -382,6 +382,48 @@ def test_residual_normalised_prefit_get_estimation_distribution() -> None: ) +def test_residual_normalised_additional_parameters() -> None: + """ + Test that residual normalised score raises no error with additional + parameters. + """ + residual_normalised_conf_score = ResidualNormalisedScore( + residual_estimator=LinearRegression(), + split_size=0.2, + random_state=random_state + ) + # Test for get_conformity_scores + # 1) Test that no error is raised + residual_normalised_conf_score.get_conformity_scores( + y_toy, y_pred_list, X=X_toy + ) + # 2) Test that an error is raised when X is not provided + with pytest.raises( + ValueError, + match=r"Additional parameters must be provided*" + ): + residual_normalised_conf_score.get_conformity_scores( + y_toy, y_pred_list + ) + + # Test for get_estimation_distribution + conf_scores = residual_normalised_conf_score.get_conformity_scores( + y_toy, y_pred_list, X=X_toy + ) + # 1) Test that no error is raised + residual_normalised_conf_score.get_estimation_distribution( + y_pred_list, conf_scores, X=X_toy + ) + # 2) Test that an error is raised when X is not provided + with pytest.raises( + ValueError, + match=r"Additional parameters must be provided*" + ): + residual_normalised_conf_score.get_estimation_distribution( + y_pred_list, conf_scores + ) + + @pytest.mark.parametrize("score", [AbsoluteConformityScore(), GammaConformityScore(), ResidualNormalisedScore()]) From a9c03fd1fd16717130f6c1ca5e04ea90212a0395 Mon Sep 17 00:00:00 2001 From: Thibault Cordier Date: Tue, 2 Jul 2024 10:17:16 +0200 Subject: [PATCH 05/46] UPD: typing and doctring in score classes --- mapie/classification.py | 2 +- mapie/conformity_scores/classification.py | 2 +- mapie/conformity_scores/sets/aps.py | 52 +++++++++++++++++++--- mapie/conformity_scores/sets/lac.py | 51 +++++++++++++++++++--- mapie/conformity_scores/sets/topk.py | 53 ++++++++++++++++++++--- 5 files changed, 142 insertions(+), 18 deletions(-) diff --git a/mapie/classification.py b/mapie/classification.py index 626149add..ab7eef9af 100644 --- a/mapie/classification.py +++ b/mapie/classification.py @@ -629,7 +629,7 @@ def fit( # Compute the conformity scores self.conformity_scores_ = \ self.conformity_score_function_.get_conformity_scores( - y, y_pred_proba, y_enc=y_enc, X=X + y_enc, y_pred_proba, X=X ) return self diff --git a/mapie/conformity_scores/classification.py b/mapie/conformity_scores/classification.py index db9df2c05..e23950b27 100644 --- a/mapie/conformity_scores/classification.py +++ b/mapie/conformity_scores/classification.py @@ -35,7 +35,7 @@ class BaseClassificationScore(BaseConformityScore, metaclass=ABCMeta): Attributes ---------- quantiles_: ArrayLike of shape (n_alpha) - The quantiles estimated from ``conformity_scores_`` and alpha values. + The quantiles estimated from ``get_sets`` method. """ def __init__( diff --git a/mapie/conformity_scores/sets/aps.py b/mapie/conformity_scores/sets/aps.py index 6cd282260..6918c94ea 100644 --- a/mapie/conformity_scores/sets/aps.py +++ b/mapie/conformity_scores/sets/aps.py @@ -44,6 +44,44 @@ class APS(BaseClassificationScore): and Jitendra Malik. "Uncertainty Sets for Image Classifiers using Conformal Prediction." International Conference on Learning Representations 2021. + + Parameters + ---------- + consistency_check: bool, optional + Whether to check the consistency between the methods + ``get_estimation_distribution`` and ``get_conformity_scores``. + If ``True``, the following equality must be verified: + ``self.get_estimation_distribution( + y_pred, self.get_conformity_scores(y, y_pred, **kwargs), **kwargs + ) == y`` + + By default ``True``. + + eps: float, optional + Threshold to consider when checking the consistency between + ``get_estimation_distribution`` and ``get_conformity_scores``. + It should be specified if ``consistency_check==True``. + + By default, it is defined by the default precision. + + Attributes + ---------- + method: str + Method to choose for prediction interval estimates. + This attribute is for compatibility with ``MapieClassifier`` + which previously used a string instead of a score class. + Methods available in this class: ``aps``, ``raps`` and ``naive``. + + By default, ``aps`` for APS method. + + classes: Optional[ArrayLike] + Names of the classes. + + random_state: Optional[Union[int, RandomState]] + Pseudo random number generator state. + + quantiles_: ArrayLike of shape (n_alpha) + The quantiles estimated from ``get_sets`` method. """ def __init__( @@ -89,9 +127,9 @@ def set_external_attributes( def get_conformity_scores( self, - y: ArrayLike, - y_pred: ArrayLike, - y_enc: Optional[ArrayLike] = None, + y: NDArray, + y_pred: NDArray, + y_enc: Optional[NDArray] = None, **kwargs ) -> NDArray: """ @@ -105,11 +143,15 @@ def get_conformity_scores( y_pred: NDArray of shape (n_samples,) Predicted target values. + y_enc: NDArray of shape (n_samples,) + Target values as normalized encodings. + Returns ------- NDArray of shape (n_samples,) Conformity scores. """ + # Casting y = cast(NDArray, y) y_pred = cast(NDArray, y_pred) y_enc = cast(NDArray, y_enc) @@ -136,8 +178,8 @@ def get_conformity_scores( def get_estimation_distribution( self, - y_pred: ArrayLike, - conformity_scores: ArrayLike, + y_pred: NDArray, + conformity_scores: NDArray, **kwargs ) -> NDArray: """ diff --git a/mapie/conformity_scores/sets/lac.py b/mapie/conformity_scores/sets/lac.py index 8bff9b6fa..23d2b255f 100644 --- a/mapie/conformity_scores/sets/lac.py +++ b/mapie/conformity_scores/sets/lac.py @@ -24,6 +24,43 @@ class LAC(BaseClassificationScore): [1] Mauricio Sadinle, Jing Lei, and Larry Wasserman. "Least Ambiguous Set-Valued Classifiers with Bounded Error Levels.", Journal of the American Statistical Association, 114, 2019. + + Parameters + ---------- + consistency_check: bool, optional + Whether to check the consistency between the methods + ``get_estimation_distribution`` and ``get_conformity_scores``. + If ``True``, the following equality must be verified: + ``self.get_estimation_distribution( + y_pred, self.get_conformity_scores(y, y_pred, **kwargs), **kwargs + ) == y`` + + By default ``True``. + + eps: float, optional + Threshold to consider when checking the consistency between + ``get_estimation_distribution`` and ``get_conformity_scores``. + It should be specified if ``consistency_check==True``. + + By default, it is defined by the default precision. + + Attributes + ---------- + method: str + Method to choose for prediction interval estimates. + This attribute is for compatibility with ``MapieClassifier`` + which previously used a string instead of a score class. + + By default, ``lac`` for LAC method. + + classes: Optional[ArrayLike] + Names of the classes. + + random_state: Optional[Union[int, RandomState]] + Pseudo random number generator state. + + quantiles_: ArrayLike of shape (n_alpha) + The quantiles estimated from ``get_sets`` method. """ def __init__( @@ -69,9 +106,9 @@ def set_external_attributes( def get_conformity_scores( self, - y: ArrayLike, - y_pred: ArrayLike, - y_enc: Optional[ArrayLike] = None, + y: NDArray, + y_pred: NDArray, + y_enc: Optional[NDArray] = None, **kwargs ) -> NDArray: """ @@ -85,11 +122,15 @@ def get_conformity_scores( y_pred: NDArray of shape (n_samples,) Predicted target values. + y_enc: NDArray of shape (n_samples,) + Target values as normalized encodings. + Returns ------- NDArray of shape (n_samples,) Conformity scores. """ + # Casting y_pred = cast(NDArray, y_pred) y_enc = cast(NDArray, y_enc) @@ -101,8 +142,8 @@ def get_conformity_scores( def get_estimation_distribution( self, - y_pred: ArrayLike, - conformity_scores: ArrayLike, + y_pred: NDArray, + conformity_scores: NDArray, **kwargs ) -> NDArray: """ diff --git a/mapie/conformity_scores/sets/topk.py b/mapie/conformity_scores/sets/topk.py index 1e68ad832..0df5faabb 100644 --- a/mapie/conformity_scores/sets/topk.py +++ b/mapie/conformity_scores/sets/topk.py @@ -28,6 +28,43 @@ class TopK(BaseClassificationScore): and Jitendra Malik. "Uncertainty Sets for Image Classifiers using Conformal Prediction." International Conference on Learning Representations 2021. + + Parameters + ---------- + consistency_check: bool, optional + Whether to check the consistency between the methods + ``get_estimation_distribution`` and ``get_conformity_scores``. + If ``True``, the following equality must be verified: + ``self.get_estimation_distribution( + y_pred, self.get_conformity_scores(y, y_pred, **kwargs), **kwargs + ) == y`` + + By default ``True``. + + eps: float, optional + Threshold to consider when checking the consistency between + ``get_estimation_distribution`` and ``get_conformity_scores``. + It should be specified if ``consistency_check==True``. + + By default, it is defined by the default precision. + + Attributes + ---------- + method: str + Method to choose for prediction interval estimates. + This attribute is for compatibility with ``MapieClassifier`` + which previously used a string instead of a score class. + + By default, ``top_k`` for Top-K method. + + classes: Optional[ArrayLike] + Names of the classes. + + random_state: Optional[Union[int, RandomState]] + Pseudo random number generator state. + + quantiles_: ArrayLike of shape (n_alpha) + The quantiles estimated from ``get_sets`` method. """ def __init__( @@ -56,7 +93,7 @@ def set_external_attributes( Method to choose for prediction interval estimates. Methods available in this class: ``top_k``. - By default ``top_k`` for Top K method. + By default ``top_k`` for Top-K method. classes: Optional[ArrayLike] Names of the classes. @@ -73,9 +110,9 @@ def set_external_attributes( def get_conformity_scores( self, - y: ArrayLike, - y_pred: ArrayLike, - y_enc: Optional[ArrayLike] = None, + y: NDArray, + y_pred: NDArray, + y_enc: Optional[NDArray] = None, **kwargs ) -> NDArray: """ @@ -89,11 +126,15 @@ def get_conformity_scores( y_pred: NDArray of shape (n_samples,) Predicted target values. + y_enc: NDArray of shape (n_samples,) + Target values as normalized encodings. + Returns ------- NDArray of shape (n_samples,) Conformity scores. """ + # Casting y = cast(NDArray, y) y_pred = cast(NDArray, y_pred) y_enc = cast(NDArray, y_enc) @@ -107,8 +148,8 @@ def get_conformity_scores( def get_estimation_distribution( self, - y_pred: ArrayLike, - conformity_scores: ArrayLike, + y_pred: NDArray, + conformity_scores: NDArray, **kwargs ) -> NDArray: """ From 7f791dd1494c5ff003413e967b4d66e01d7af286 Mon Sep 17 00:00:00 2001 From: Thibault Cordier Date: Tue, 2 Jul 2024 10:45:30 +0200 Subject: [PATCH 06/46] UPD: use y_enc as additional parameters to conserve label encoding --- mapie/classification.py | 2 +- mapie/conformity_scores/sets/aps.py | 2 -- mapie/conformity_scores/sets/lac.py | 2 +- mapie/conformity_scores/sets/topk.py | 2 -- 4 files changed, 2 insertions(+), 6 deletions(-) diff --git a/mapie/classification.py b/mapie/classification.py index ab7eef9af..626149add 100644 --- a/mapie/classification.py +++ b/mapie/classification.py @@ -629,7 +629,7 @@ def fit( # Compute the conformity scores self.conformity_scores_ = \ self.conformity_score_function_.get_conformity_scores( - y_enc, y_pred_proba, X=X + y, y_pred_proba, y_enc=y_enc, X=X ) return self diff --git a/mapie/conformity_scores/sets/aps.py b/mapie/conformity_scores/sets/aps.py index 6918c94ea..06ac5dfa9 100644 --- a/mapie/conformity_scores/sets/aps.py +++ b/mapie/conformity_scores/sets/aps.py @@ -152,8 +152,6 @@ def get_conformity_scores( Conformity scores. """ # Casting - y = cast(NDArray, y) - y_pred = cast(NDArray, y_pred) y_enc = cast(NDArray, y_enc) classes = cast(NDArray, self.classes) diff --git a/mapie/conformity_scores/sets/lac.py b/mapie/conformity_scores/sets/lac.py index 23d2b255f..718456c72 100644 --- a/mapie/conformity_scores/sets/lac.py +++ b/mapie/conformity_scores/sets/lac.py @@ -131,13 +131,13 @@ def get_conformity_scores( Conformity scores. """ # Casting - y_pred = cast(NDArray, y_pred) y_enc = cast(NDArray, y_enc) # Conformity scores conformity_scores = np.take_along_axis( 1 - y_pred, y_enc.reshape(-1, 1), axis=1 ) + return conformity_scores def get_estimation_distribution( diff --git a/mapie/conformity_scores/sets/topk.py b/mapie/conformity_scores/sets/topk.py index 0df5faabb..4a2cd8992 100644 --- a/mapie/conformity_scores/sets/topk.py +++ b/mapie/conformity_scores/sets/topk.py @@ -135,8 +135,6 @@ def get_conformity_scores( Conformity scores. """ # Casting - y = cast(NDArray, y) - y_pred = cast(NDArray, y_pred) y_enc = cast(NDArray, y_enc) # Conformity scores From 1b99529bf48368c9eab70cd6272accdab5a2f958 Mon Sep 17 00:00:00 2001 From: Thibault Cordier Date: Tue, 2 Jul 2024 11:25:20 +0200 Subject: [PATCH 07/46] UPD: change greater equal to less equal --- mapie/conformity_scores/sets/lac.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/mapie/conformity_scores/sets/lac.py b/mapie/conformity_scores/sets/lac.py index 718456c72..e4d43eee9 100644 --- a/mapie/conformity_scores/sets/lac.py +++ b/mapie/conformity_scores/sets/lac.py @@ -225,13 +225,12 @@ def get_sets( # Build prediction sets if (estimator.cv == "prefit") or (agg_scores == "mean"): - prediction_sets = np.greater_equal( - y_pred_proba - (1 - self.quantiles_), -EPSILON + prediction_sets = np.less_equal( + (1 - y_pred_proba) - self.quantiles_, EPSILON ) else: y_pred_included = np.less_equal( - (1 - y_pred_proba) - conformity_scores.ravel(), - EPSILON + (1 - y_pred_proba) - conformity_scores.ravel(), EPSILON ).sum(axis=2) prediction_sets = np.stack( [ From e588a3e8c28392a6d27f7fd9b941c94371085673 Mon Sep 17 00:00:00 2001 From: Thibault Cordier Date: Tue, 2 Jul 2024 12:07:23 +0200 Subject: [PATCH 08/46] UPD: docstring to explain parameters of get_sets --- mapie/conformity_scores/sets/aps.py | 34 +++++++++++++++++++++++++++- mapie/conformity_scores/sets/lac.py | 12 +++++++++- mapie/conformity_scores/sets/topk.py | 7 +----- 3 files changed, 45 insertions(+), 8 deletions(-) diff --git a/mapie/conformity_scores/sets/aps.py b/mapie/conformity_scores/sets/aps.py index 06ac5dfa9..3798e7139 100644 --- a/mapie/conformity_scores/sets/aps.py +++ b/mapie/conformity_scores/sets/aps.py @@ -412,7 +412,39 @@ def get_sets( conformity_scores: NDArray of shape (n_samples,) Conformity scores. - TODO + agg_scores: Optional[str] + How to aggregate the scores output by the estimators on test data + if a cross-validation strategy is used. Choose among: + + - "mean", take the mean of scores. + - "crossval", compare the scores between all training data and each + test point for each label to estimate if the label must be + included in the prediction set. Follows algorithm 2 of + Romano+2020. + + By default, "mean". + + X_raps: NDArray of shape (n_samples, n_features) + Observed feature values for the RAPS method (split data). + + By default, "None" but must be set to work. + + y_raps_no_enc: NDArray of shape (n_samples,) + Observed labels for the RAPS method (split data). + + By default, "None" but must be set to work. + + y_pred_proba_raps: NDArray of shape (n_samples, n_classes) + Predicted probabilities for the RAPS method (split data). + + By default, "None" but must be set to work. + + position_raps: NDArray of shape (n_samples,) + Position of the points in the split set for the RAPS method + (split data). These positions are returned by the function + ``get_true_label_position``. + + By default, "None" but must be set to work. Returns ------- diff --git a/mapie/conformity_scores/sets/lac.py b/mapie/conformity_scores/sets/lac.py index e4d43eee9..976291add 100644 --- a/mapie/conformity_scores/sets/lac.py +++ b/mapie/conformity_scores/sets/lac.py @@ -197,7 +197,17 @@ def get_sets( conformity_scores: NDArray of shape (n_samples,) Conformity scores. - TODO + agg_scores: Optional[str] + How to aggregate the scores output by the estimators on test data + if a cross-validation strategy is used. Choose among: + + - "mean", take the mean of scores. + - "crossval", compare the scores between all training data and each + test point for each label to estimate if the label must be + included in the prediction set. Follows algorithm 2 of + Romano+2020. + + By default, "mean". Returns ------- diff --git a/mapie/conformity_scores/sets/topk.py b/mapie/conformity_scores/sets/topk.py index 4a2cd8992..2769ed144 100644 --- a/mapie/conformity_scores/sets/topk.py +++ b/mapie/conformity_scores/sets/topk.py @@ -179,7 +179,6 @@ def get_sets( alpha_np: NDArray, estimator: EnsembleClassifier, conformity_scores: NDArray, - agg_scores: Optional[str] = "mean", **kwargs ): """ @@ -201,17 +200,13 @@ def get_sets( conformity_scores: NDArray of shape (n_samples,) Conformity scores. - TODO - Returns ------- NDArray of shape (n_samples, n_classes, n_alpha) Prediction sets (Booleans indicate whether classes are included). """ # Checks - agg_scores = "mean" - - y_pred_proba = estimator.predict(X, agg_scores) + y_pred_proba = estimator.predict(X, agg_scores="mean") y_pred_proba = check_proba_normalized(y_pred_proba, axis=1) y_pred_proba = np.repeat( y_pred_proba[:, :, np.newaxis], len(alpha_np), axis=2 From 8832a2e3c1d8beb21afff4d1dc5203b664b4ab3b Mon Sep 17 00:00:00 2001 From: Thibault Cordier Date: Wed, 3 Jul 2024 11:29:23 +0200 Subject: [PATCH 09/46] UPD: remove useless methods et attributes (estimation distribution) for classification score --- mapie/conformity_scores/classification.py | 28 +----- mapie/conformity_scores/interface.py | 107 +--------------------- mapie/conformity_scores/regression.py | 85 ++++++++++++++++- mapie/conformity_scores/sets/aps.py | 62 +------------ mapie/conformity_scores/sets/lac.py | 62 +------------ mapie/conformity_scores/sets/topk.py | 62 +------------ 6 files changed, 93 insertions(+), 313 deletions(-) diff --git a/mapie/conformity_scores/classification.py b/mapie/conformity_scores/classification.py index e23950b27..b2670c5d9 100644 --- a/mapie/conformity_scores/classification.py +++ b/mapie/conformity_scores/classification.py @@ -3,7 +3,6 @@ from mapie.conformity_scores.interface import BaseConformityScore from mapie.estimator.classifier import EnsembleClassifier -from mapie._machine_precision import EPSILON from mapie._typing import NDArray @@ -13,37 +12,14 @@ class BaseClassificationScore(BaseConformityScore, metaclass=ABCMeta): This class should not be used directly. Use derived classes instead. - Parameters - ---------- - consistency_check: bool, optional - Whether to check the consistency between the methods - ``get_estimation_distribution`` and ``get_conformity_scores``. - If ``True``, the following equality must be verified: - ``self.get_estimation_distribution( - y_pred, self.get_conformity_scores(y, y_pred, **kwargs), **kwargs - ) == y`` - - By default ``True``. - - eps: float, optional - Threshold to consider when checking the consistency between - ``get_estimation_distribution`` and ``get_conformity_scores``. - It should be specified if ``consistency_check==True``. - - By default, it is defined by the default precision. - Attributes ---------- quantiles_: ArrayLike of shape (n_alpha) The quantiles estimated from ``get_sets`` method. """ - def __init__( - self, - consistency_check: bool = True, - eps: float = float(EPSILON), - ): - super().__init__(consistency_check=consistency_check, eps=eps) + def __init__(self) -> None: + super().__init__() @abstractmethod def get_sets( diff --git a/mapie/conformity_scores/interface.py b/mapie/conformity_scores/interface.py index 680c6cc9e..c8e163844 100644 --- a/mapie/conformity_scores/interface.py +++ b/mapie/conformity_scores/interface.py @@ -3,7 +3,6 @@ import numpy as np from mapie._compatibility import np_nanquantile -from mapie._machine_precision import EPSILON from mapie._typing import NDArray @@ -12,34 +11,10 @@ class BaseConformityScore(metaclass=ABCMeta): Base class for conformity scores. This class should not be used directly. Use derived classes instead. - - Parameters - ---------- - consistency_check: bool, optional - Whether to check the consistency between the methods - ``get_estimation_distribution`` and ``get_conformity_scores``. - If ``True``, the following equality must be verified: - ``self.get_estimation_distribution( - y_pred, self.get_conformity_scores(y, y_pred, **kwargs), **kwargs - ) == y`` - - By default ``True``. - - eps: float, optional - Threshold to consider when checking the consistency between - ``get_estimation_distribution`` and ``get_conformity_scores``. - It should be specified if ``consistency_check==True``. - - By default, it is defined by the default precision. """ - def __init__( - self, - consistency_check: bool = True, - eps: float = float(EPSILON), - ): - self.consistency_check = consistency_check - self.eps = eps + def __init__(self) -> None: + pass def set_external_attributes( self, @@ -54,56 +29,6 @@ def set_external_attributes( """ pass - def check_consistency( - self, - y: NDArray, - y_pred: NDArray, - conformity_scores: NDArray, - **kwargs - ) -> None: - """ - Check consistency between the following methods: - ``get_estimation_distribution`` and ``get_signed_conformity_scores`` - - The following equality should be verified: - ``self.get_estimation_distribution( - y_pred, self.get_conformity_scores(y, y_pred, **kwargs), **kwargs - ) == y`` - - Parameters - ---------- - y: NDArray of shape (n_samples, ...) - Observed target values. - - y_pred: NDArray of shape (n_samples, ...) - Predicted target values. - - conformity_scores: NDArray of shape (n_samples, ...) - Conformity scores. - - Raises - ------ - ValueError - If the two methods are not consistent. - """ - score_distribution = self.get_estimation_distribution( - y_pred, conformity_scores, **kwargs - ) - abs_conformity_scores = np.abs(np.subtract(score_distribution, y)) - max_conf_score = np.max(abs_conformity_scores) - if max_conf_score > self.eps: - raise ValueError( - "The two functions get_conformity_scores and " - "get_estimation_distribution of the BaseConformityScore class " - "are not consistent. " - "The following equation must be verified: " - "self.get_estimation_distribution(y_pred, " - "self.get_conformity_scores(y, y_pred)) == y. " - f"The maximum conformity score is {max_conf_score}. " - "The eps attribute may need to be increased if you are " - "sure that the two methods are consistent." - ) - @abstractmethod def get_conformity_scores( self, @@ -132,34 +57,6 @@ def get_conformity_scores( Conformity scores. """ - @abstractmethod - def get_estimation_distribution( - self, - y_pred: NDArray, - conformity_scores: NDArray, - **kwargs - ) -> NDArray: - """ - Placeholder for ``get_estimation_distribution``. - Subclasses should implement this method! - - Compute samples of the estimation distribution given the predicted - targets and the conformity scores. - - Parameters - ---------- - y_pred: NDArray of shape (n_samples, ...) - Predicted target values. - - conformity_scores: NDArray of shape (n_samples, ...) - Conformity scores. - - Returns - ------- - NDArray of shape (n_samples, ...) - Observed values. - """ - @staticmethod def get_quantile( conformity_scores: NDArray, diff --git a/mapie/conformity_scores/regression.py b/mapie/conformity_scores/regression.py index 2e878e349..fa151d5e5 100644 --- a/mapie/conformity_scores/regression.py +++ b/mapie/conformity_scores/regression.py @@ -41,12 +41,15 @@ class BaseRegressionScore(BaseConformityScore, metaclass=ABCMeta): """ def __init__( - self, sym: bool, + self, + sym: bool, consistency_check: bool = True, eps: float = float(EPSILON), ): - super().__init__(consistency_check=consistency_check, eps=eps) + super().__init__() self.sym = sym + self.consistency_check = consistency_check + self.eps = eps @abstractmethod def get_signed_conformity_scores( @@ -106,6 +109,84 @@ def get_conformity_scores( conformity_scores = np.abs(conformity_scores) return conformity_scores + def check_consistency( + self, + y: NDArray, + y_pred: NDArray, + conformity_scores: NDArray, + **kwargs + ) -> None: + """ + Check consistency between the following methods: + ``get_estimation_distribution`` and ``get_signed_conformity_scores`` + + The following equality should be verified: + ``self.get_estimation_distribution( + y_pred, self.get_conformity_scores(y, y_pred, **kwargs), **kwargs + ) == y`` + + Parameters + ---------- + y: NDArray of shape (n_samples, ...) + Observed target values. + + y_pred: NDArray of shape (n_samples, ...) + Predicted target values. + + conformity_scores: NDArray of shape (n_samples, ...) + Conformity scores. + + Raises + ------ + ValueError + If the two methods are not consistent. + """ + score_distribution = self.get_estimation_distribution( + y_pred, conformity_scores, **kwargs + ) + abs_conformity_scores = np.abs(np.subtract(score_distribution, y)) + max_conf_score = np.max(abs_conformity_scores) + if max_conf_score > self.eps: + raise ValueError( + "The two functions get_conformity_scores and " + "get_estimation_distribution of the BaseConformityScore class " + "are not consistent. " + "The following equation must be verified: " + "self.get_estimation_distribution(y_pred, " + "self.get_conformity_scores(y, y_pred)) == y. " + f"The maximum conformity score is {max_conf_score}. " + "The eps attribute may need to be increased if you are " + "sure that the two methods are consistent." + ) + + @abstractmethod + def get_estimation_distribution( + self, + y_pred: NDArray, + conformity_scores: NDArray, + **kwargs + ) -> NDArray: + """ + Placeholder for ``get_estimation_distribution``. + Subclasses should implement this method! + + Compute samples of the estimation distribution given the predicted + targets and the conformity scores. + + Parameters + ---------- + y_pred: NDArray of shape (n_samples, ...) + Predicted target values. + + conformity_scores: NDArray of shape (n_samples, ...) + Conformity scores. + + Returns + ------- + NDArray of shape (n_samples, ...) + Observed values. + """ + @staticmethod def _beta_optimize( alpha_np: NDArray, diff --git a/mapie/conformity_scores/sets/aps.py b/mapie/conformity_scores/sets/aps.py index 3798e7139..b1a2fe142 100644 --- a/mapie/conformity_scores/sets/aps.py +++ b/mapie/conformity_scores/sets/aps.py @@ -45,25 +45,6 @@ class APS(BaseClassificationScore): "Uncertainty Sets for Image Classifiers using Conformal Prediction." International Conference on Learning Representations 2021. - Parameters - ---------- - consistency_check: bool, optional - Whether to check the consistency between the methods - ``get_estimation_distribution`` and ``get_conformity_scores``. - If ``True``, the following equality must be verified: - ``self.get_estimation_distribution( - y_pred, self.get_conformity_scores(y, y_pred, **kwargs), **kwargs - ) == y`` - - By default ``True``. - - eps: float, optional - Threshold to consider when checking the consistency between - ``get_estimation_distribution`` and ``get_conformity_scores``. - It should be specified if ``consistency_check==True``. - - By default, it is defined by the default precision. - Attributes ---------- method: str @@ -84,15 +65,8 @@ class APS(BaseClassificationScore): The quantiles estimated from ``get_sets`` method. """ - def __init__( - self, - consistency_check: bool = True, - eps: float = float(EPSILON), - ): - super().__init__( - consistency_check=consistency_check, - eps=eps - ) + def __init__(self) -> None: + super().__init__() def set_external_attributes( self, @@ -174,35 +148,6 @@ def get_conformity_scores( return conformity_scores - def get_estimation_distribution( - self, - y_pred: NDArray, - conformity_scores: NDArray, - **kwargs - ) -> NDArray: - """ - TODO - Placeholder for ``get_estimation_distribution``. - Subclasses should implement this method! - - Compute samples of the estimation distribution given the predicted - targets and the conformity scores. - - Parameters - ---------- - y_pred: NDArray of shape (n_samples, ...) - Predicted target values. - - conformity_scores: NDArray of shape (n_samples, ...) - Conformity scores. - - Returns - ------- - NDArray of shape (n_samples, ...) - Observed values. - """ - return np.array([]) - @staticmethod def _regularize_conformity_score( k_star: NDArray, @@ -563,7 +508,4 @@ def get_sets( EPSILON ) - # Just for coverage: do nothing - self.get_estimation_distribution(y_pred_proba, conformity_scores) - return prediction_sets diff --git a/mapie/conformity_scores/sets/lac.py b/mapie/conformity_scores/sets/lac.py index 976291add..5edf9d45c 100644 --- a/mapie/conformity_scores/sets/lac.py +++ b/mapie/conformity_scores/sets/lac.py @@ -25,25 +25,6 @@ class LAC(BaseClassificationScore): "Least Ambiguous Set-Valued Classifiers with Bounded Error Levels.", Journal of the American Statistical Association, 114, 2019. - Parameters - ---------- - consistency_check: bool, optional - Whether to check the consistency between the methods - ``get_estimation_distribution`` and ``get_conformity_scores``. - If ``True``, the following equality must be verified: - ``self.get_estimation_distribution( - y_pred, self.get_conformity_scores(y, y_pred, **kwargs), **kwargs - ) == y`` - - By default ``True``. - - eps: float, optional - Threshold to consider when checking the consistency between - ``get_estimation_distribution`` and ``get_conformity_scores``. - It should be specified if ``consistency_check==True``. - - By default, it is defined by the default precision. - Attributes ---------- method: str @@ -63,15 +44,8 @@ class LAC(BaseClassificationScore): The quantiles estimated from ``get_sets`` method. """ - def __init__( - self, - consistency_check: bool = True, - eps: float = float(EPSILON), - ): - super().__init__( - consistency_check=consistency_check, - eps=eps - ) + def __init__(self) -> None: + super().__init__() def set_external_attributes( self, @@ -140,35 +114,6 @@ def get_conformity_scores( return conformity_scores - def get_estimation_distribution( - self, - y_pred: NDArray, - conformity_scores: NDArray, - **kwargs - ) -> NDArray: - """ - TODO - Placeholder for ``get_estimation_distribution``. - Subclasses should implement this method! - - Compute samples of the estimation distribution given the predicted - targets and the conformity scores. - - Parameters - ---------- - y_pred: NDArray of shape (n_samples, ...) - Predicted target values. - - conformity_scores: NDArray of shape (n_samples, ...) - Conformity scores. - - Returns - ------- - NDArray of shape (n_samples, ...) - Observed values. - """ - return np.array([]) - def get_sets( self, X: ArrayLike, @@ -251,7 +196,4 @@ def get_sets( ], axis=2 ) - # Just for coverage: do nothing - self.get_estimation_distribution(y_pred_proba, conformity_scores) - return prediction_sets diff --git a/mapie/conformity_scores/sets/topk.py b/mapie/conformity_scores/sets/topk.py index 2769ed144..fb0e7836f 100644 --- a/mapie/conformity_scores/sets/topk.py +++ b/mapie/conformity_scores/sets/topk.py @@ -29,25 +29,6 @@ class TopK(BaseClassificationScore): "Uncertainty Sets for Image Classifiers using Conformal Prediction." International Conference on Learning Representations 2021. - Parameters - ---------- - consistency_check: bool, optional - Whether to check the consistency between the methods - ``get_estimation_distribution`` and ``get_conformity_scores``. - If ``True``, the following equality must be verified: - ``self.get_estimation_distribution( - y_pred, self.get_conformity_scores(y, y_pred, **kwargs), **kwargs - ) == y`` - - By default ``True``. - - eps: float, optional - Threshold to consider when checking the consistency between - ``get_estimation_distribution`` and ``get_conformity_scores``. - It should be specified if ``consistency_check==True``. - - By default, it is defined by the default precision. - Attributes ---------- method: str @@ -67,15 +48,8 @@ class TopK(BaseClassificationScore): The quantiles estimated from ``get_sets`` method. """ - def __init__( - self, - consistency_check: bool = True, - eps: float = float(EPSILON), - ): - super().__init__( - consistency_check=consistency_check, - eps=eps - ) + def __init__(self) -> None: + super().__init__() def set_external_attributes( self, @@ -144,35 +118,6 @@ def get_conformity_scores( return conformity_scores - def get_estimation_distribution( - self, - y_pred: NDArray, - conformity_scores: NDArray, - **kwargs - ) -> NDArray: - """ - TODO - Placeholder for ``get_estimation_distribution``. - Subclasses should implement this method! - - Compute samples of the estimation distribution given the predicted - targets and the conformity scores. - - Parameters - ---------- - y_pred: NDArray of shape (n_samples, ...) - Predicted target values. - - conformity_scores: NDArray of shape (n_samples, ...) - Conformity scores. - - Returns - ------- - NDArray of shape (n_samples, ...) - Observed values. - """ - return np.array([]) - def get_sets( self, X: ArrayLike, @@ -240,7 +185,4 @@ def get_sets( -EPSILON ) - # Just for coverage: do nothing - self.get_estimation_distribution(y_pred_proba, conformity_scores) - return prediction_sets From 8fa2474b074e978c88017de143413341ac904e4e Mon Sep 17 00:00:00 2001 From: Thibault Cordier Date: Wed, 3 Jul 2024 16:24:09 +0200 Subject: [PATCH 10/46] UPD: decompose APS into Naive, APS and RAPS + new abtract methods for classification --- mapie/conformity_scores/__init__.py | 4 +- mapie/conformity_scores/classification.py | 55 +++ mapie/conformity_scores/sets/__init__.py | 4 + mapie/conformity_scores/sets/aps.py | 431 ++-------------------- mapie/conformity_scores/sets/lac.py | 84 ++--- mapie/conformity_scores/sets/naive.py | 417 +++++++++++++++++++++ mapie/conformity_scores/sets/raps.py | 374 +++++++++++++++++++ mapie/conformity_scores/sets/topk.py | 59 ++- mapie/conformity_scores/sets/utils.py | 199 +--------- mapie/conformity_scores/utils.py | 10 +- mapie/tests/test_classification.py | 14 +- 11 files changed, 957 insertions(+), 694 deletions(-) create mode 100644 mapie/conformity_scores/sets/naive.py create mode 100644 mapie/conformity_scores/sets/raps.py diff --git a/mapie/conformity_scores/__init__.py b/mapie/conformity_scores/__init__.py index 3b47311da..88a3530be 100644 --- a/mapie/conformity_scores/__init__.py +++ b/mapie/conformity_scores/__init__.py @@ -3,7 +3,7 @@ from .bounds import ( AbsoluteConformityScore, GammaConformityScore, ResidualNormalisedScore ) -from .sets import APS, LAC, TopK +from .sets import APS, LAC, Naive, RAPS, TopK __all__ = [ @@ -12,7 +12,9 @@ "AbsoluteConformityScore", "GammaConformityScore", "ResidualNormalisedScore", + "Naive", "LAC", "APS", + "RAPS", "TopK" ] diff --git a/mapie/conformity_scores/classification.py b/mapie/conformity_scores/classification.py index b2670c5d9..4e4925cea 100644 --- a/mapie/conformity_scores/classification.py +++ b/mapie/conformity_scores/classification.py @@ -22,6 +22,42 @@ def __init__(self) -> None: super().__init__() @abstractmethod + def get_predictions( + self, + X: NDArray, + alpha_np: NDArray, + estimator: EnsembleClassifier, + **kwargs + ) -> NDArray: + """ + TODO: Compute the predictions. + """ + + @abstractmethod + def get_conformity_quantiles( + self, + conformity_scores: NDArray, + alpha_np: NDArray, + estimator: EnsembleClassifier, + **kwargs + ) -> NDArray: + """ + TODO: Compute the quantiles. + """ + + @abstractmethod + def get_prediction_sets( + self, + y_pred_proba: NDArray, + conformity_scores: NDArray, + alpha_np: NDArray, + estimator: EnsembleClassifier, + **kwargs + ): + """ + TODO: Compute the prediction sets. + """ + def get_sets( self, X: NDArray, @@ -54,6 +90,25 @@ def get_sets( NDArray of shape (n_samples, n_classes, n_alpha) Prediction sets (Booleans indicate whether classes are included). """ + # Checks + () + + # Predict probabilities + y_pred_proba = self.get_predictions( + X, alpha_np, estimator, **kwargs + ) + + # Choice of the quantile + self.quantiles_ = self.get_conformity_quantiles( + conformity_scores, alpha_np, estimator, **kwargs + ) + + # Build prediction sets + prediction_sets = self.get_prediction_sets( + y_pred_proba, conformity_scores, alpha_np, estimator, **kwargs + ) + + return prediction_sets def predict_set( self, diff --git a/mapie/conformity_scores/sets/__init__.py b/mapie/conformity_scores/sets/__init__.py index 87b6a37e6..36f203cc5 100644 --- a/mapie/conformity_scores/sets/__init__.py +++ b/mapie/conformity_scores/sets/__init__.py @@ -1,10 +1,14 @@ +from .naive import Naive from .lac import LAC from .aps import APS +from .raps import RAPS from .topk import TopK __all__ = [ + "Naive", "LAC", "APS", + "RAPS", "TopK", ] diff --git a/mapie/conformity_scores/sets/aps.py b/mapie/conformity_scores/sets/aps.py index b1a2fe142..16c6a7b98 100644 --- a/mapie/conformity_scores/sets/aps.py +++ b/mapie/conformity_scores/sets/aps.py @@ -1,23 +1,18 @@ -from typing import Optional, Tuple, Union, cast +from typing import Optional, cast import numpy as np from sklearn.dummy import check_random_state -from mapie.conformity_scores.classification import BaseClassificationScore -from mapie.conformity_scores.sets.utils import ( - add_random_tie_breaking, check_include_last_label, check_proba_normalized, - get_last_included_proba, get_true_label_cumsum_proba -) +from mapie.conformity_scores.sets.naive import Naive +from mapie.conformity_scores.sets.utils import get_true_label_cumsum_proba from mapie.estimator.classifier import EnsembleClassifier -from mapie._machine_precision import EPSILON -from mapie._typing import ArrayLike, NDArray -from mapie.metrics import classification_mean_width_score -from mapie.utils import check_alpha_and_n_samples, compute_quantiles +from mapie._typing import NDArray +from mapie.utils import compute_quantiles -class APS(BaseClassificationScore): - """ +class APS(Naive): + """TODO: Adaptive Prediction Sets (APS) method-based non-conformity score. Three differents method are available in this class: @@ -68,37 +63,6 @@ class APS(BaseClassificationScore): def __init__(self) -> None: super().__init__() - def set_external_attributes( - self, - method: str = 'aps', - classes: Optional[ArrayLike] = None, - random_state: Optional[Union[int, np.random.RandomState]] = None, - **kwargs - ) -> None: - """ - Set attributes that are not provided by the user. - - Parameters - ---------- - method: str - Method to choose for prediction interval estimates. - Methods available in this class: ``aps``, ``raps`` and ``naive``. - - By default ``aps`` for APS method. - - classes: Optional[ArrayLike] - Names of the classes. - - By default ``None``. - - random_state: Optional[Union[int, RandomState]] - Pseudo random number generator state. - """ - super().set_external_attributes(**kwargs) - self.method = method - self.classes = classes - self.random_state = random_state - def get_conformity_scores( self, y: NDArray, @@ -130,382 +94,35 @@ def get_conformity_scores( classes = cast(NDArray, self.classes) # Conformity scores - if self.method == "naive": - conformity_scores = ( - np.empty(y_pred.shape, dtype="float") - ) - else: - conformity_scores, self.cutoff = ( - get_true_label_cumsum_proba(y, y_pred, classes) - ) - y_proba_true = np.take_along_axis( - y_pred, y_enc.reshape(-1, 1), axis=1 - ) - random_state = check_random_state(self.random_state) - random_state = cast(np.random.RandomState, random_state) - u = random_state.uniform(size=len(y_pred)).reshape(-1, 1) - conformity_scores -= u * y_proba_true - - return conformity_scores - - @staticmethod - def _regularize_conformity_score( - k_star: NDArray, - lambda_: Union[NDArray, float], - conf_score: NDArray, - cutoff: NDArray - ) -> NDArray: - """ - Regularize the conformity scores with the ``"raps"`` - method. See algo. 2 in [3]. - - Parameters - ---------- - k_star: NDArray of shape (n_alphas, ) - Optimal value of k (called k_reg in the paper). There - is one value per alpha. - - lambda_: Union[NDArray, float] of shape (n_alphas, ) - One value of lambda for each alpha. - - conf_score: NDArray of shape (n_samples, 1) - Conformity scores. - - cutoff: NDArray of shape (n_samples, 1) - Position of the true label. - - Returns - ------- - NDArray of shape (n_samples, 1, n_alphas) - Regularized conformity scores. The regularization - depends on the value of alpha. - """ - conf_score = np.repeat( - conf_score[:, :, np.newaxis], len(k_star), axis=2 + conformity_scores, self.cutoff = ( + get_true_label_cumsum_proba(y, y_pred, classes) ) - cutoff = np.repeat( - cutoff[:, np.newaxis], len(k_star), axis=1 + y_proba_true = np.take_along_axis( + y_pred, y_enc.reshape(-1, 1), axis=1 ) - conf_score += np.maximum( - np.expand_dims( - lambda_ * (cutoff - k_star), - axis=1 - ), - 0 - ) - return conf_score - - def _update_size_and_lambda( - self, - best_sizes: NDArray, - alpha_np: NDArray, - y_ps: NDArray, - lambda_: Union[NDArray, float], - lambda_star: NDArray - ) -> Tuple[NDArray, NDArray]: - """Update the values of the optimal lambda if the - average size of the prediction sets decreases with - this new value of lambda. - - Parameters - ---------- - best_sizes: NDArray of shape (n_alphas, ) - Smallest average prediciton set size before testing - for the new value of lambda_ - - alpha_np: NDArray of shape (n_alphas) - Level of confidences. - - y_ps: NDArray of shape (n_samples, n_classes, n_alphas) - Prediction sets computed with the RAPS method and the - new value of lambda_ + random_state = check_random_state(self.random_state) + random_state = cast(np.random.RandomState, random_state) + u = random_state.uniform(size=len(y_pred)).reshape(-1, 1) + conformity_scores -= u * y_proba_true - lambda_: NDArray of shape (n_alphas, ) - New value of lambda_star to test - - lambda_star: NDArray of shape (n_alphas, ) - Actual optimal lambda values for each alpha. - - Returns - ------- - Tuple[NDArray, NDArray] - Arrays of shape (n_alphas, ) and (n_alpha, ) which - respectively represent the updated values of lambda_star - and the new best sizes. - """ - - sizes = [ - classification_mean_width_score( - y_ps[:, :, i] - ) for i in range(len(alpha_np)) - ] - - sizes_improve = (sizes < best_sizes - EPSILON) - lambda_star = ( - sizes_improve * lambda_ + (1 - sizes_improve) * lambda_star - ) - best_sizes = sizes_improve * sizes + (1 - sizes_improve) * best_sizes - - return lambda_star, best_sizes - - def _find_lambda_star( - self, - y_raps_no_enc: NDArray, - y_pred_proba_raps: NDArray, - alpha_np: NDArray, - include_last_label: Union[bool, str, None], - k_star: NDArray - ) -> Union[NDArray, float]: - """Find the optimal value of lambda for each alpha. - - Parameters - ---------- - y_pred_proba_raps: NDArray of shape (n_samples, n_labels, n_alphas) - Predictions of the model repeated on the last axis as many times - as the number of alphas - - alpha_np: NDArray of shape (n_alphas, ) - Levels of confidences. - - include_last_label: bool - Whether to include or not last label in - the prediction sets - - k_star: NDArray of shape (n_alphas, ) - Values of k for the regularization. - - Returns - ------- - ArrayLike of shape (n_alphas, ) - Optimal values of lambda. - """ - classes = cast(NDArray, self.classes) - - lambda_star = np.zeros(len(alpha_np)) - best_sizes = np.full(len(alpha_np), np.finfo(np.float64).max) - - for lambda_ in [.001, .01, .1, .2, .5]: # values given in paper[3] - true_label_cumsum_proba, cutoff = ( - get_true_label_cumsum_proba( - y_raps_no_enc, - y_pred_proba_raps[:, :, 0], - classes - ) - ) - - true_label_cumsum_proba_reg = self._regularize_conformity_score( - k_star, - lambda_, - true_label_cumsum_proba, - cutoff - ) - - quantiles_ = compute_quantiles( - true_label_cumsum_proba_reg, - alpha_np - ) - - _, _, y_pred_proba_last = get_last_included_proba( - y_pred_proba_raps, - quantiles_, - include_last_label, - self.method, - lambda_, - k_star - ) - - y_ps = np.greater_equal( - y_pred_proba_raps - y_pred_proba_last, -EPSILON - ) - lambda_star, best_sizes = self._update_size_and_lambda( - best_sizes, alpha_np, y_ps, lambda_, lambda_star - ) - if len(lambda_star) == 1: - lambda_star = lambda_star[0] - return lambda_star + return conformity_scores - def get_sets( + def get_conformity_quantiles( self, - X: ArrayLike, + conformity_scores: NDArray, alpha_np: NDArray, estimator: EnsembleClassifier, - conformity_scores: NDArray, - include_last_label: Optional[Union[bool, str]] = True, agg_scores: Optional[str] = "mean", - X_raps: Optional[NDArray] = None, - y_raps_no_enc: Optional[NDArray] = None, - y_pred_proba_raps: Optional[NDArray] = None, - position_raps: Optional[NDArray] = None, **kwargs - ): + ) -> NDArray: """ - Compute classes of the prediction sets from the observed values, - the estimator of type ``EnsembleClassifier`` and the conformity scores. - - Parameters - ---------- - X: NDArray of shape (n_samples, n_features) - Observed feature values. - - alpha_np: NDArray of shape (n_alpha,) - NDArray of floats between ``0`` and ``1``, represents the - uncertainty of the confidence interval. - - estimator: EnsembleClassifier - Estimator that is fitted to predict y from X. - - conformity_scores: NDArray of shape (n_samples,) - Conformity scores. - - agg_scores: Optional[str] - How to aggregate the scores output by the estimators on test data - if a cross-validation strategy is used. Choose among: - - - "mean", take the mean of scores. - - "crossval", compare the scores between all training data and each - test point for each label to estimate if the label must be - included in the prediction set. Follows algorithm 2 of - Romano+2020. - - By default, "mean". - - X_raps: NDArray of shape (n_samples, n_features) - Observed feature values for the RAPS method (split data). - - By default, "None" but must be set to work. - - y_raps_no_enc: NDArray of shape (n_samples,) - Observed labels for the RAPS method (split data). - - By default, "None" but must be set to work. - - y_pred_proba_raps: NDArray of shape (n_samples, n_classes) - Predicted probabilities for the RAPS method (split data). - - By default, "None" but must be set to work. - - position_raps: NDArray of shape (n_samples,) - Position of the points in the split set for the RAPS method - (split data). These positions are returned by the function - ``get_true_label_position``. - - By default, "None" but must be set to work. - - Returns - ------- - NDArray of shape (n_samples, n_classes, n_alpha) - Prediction sets (Booleans indicate whether classes are included). + TODO: Compute the quantiles. """ - # Checks - include_last_label = check_include_last_label(include_last_label) - - # if self.method == "raps": - lambda_star, k_star = None, None - X_raps = cast(NDArray, X_raps) - y_raps_no_enc = cast(NDArray, y_raps_no_enc) - y_pred_proba_raps = cast(NDArray, y_pred_proba_raps) - position_raps = cast(NDArray, position_raps) - n = len(conformity_scores) - y_pred_proba = estimator.predict(X, agg_scores) - y_pred_proba = check_proba_normalized(y_pred_proba, axis=1) - if agg_scores != "crossval": - y_pred_proba = np.repeat( - y_pred_proba[:, :, np.newaxis], len(alpha_np), axis=2 - ) - - # Choice of the quantileif self.method == "naive": - if self.method == "naive": - self.quantiles_ = 1 - alpha_np - elif (estimator.cv == "prefit") or (agg_scores in ["mean"]): - if self.method == "raps": - check_alpha_and_n_samples(alpha_np, X_raps.shape[0]) - k_star = compute_quantiles( - position_raps, - alpha_np - ) + 1 - y_pred_proba_raps = np.repeat( - y_pred_proba_raps[:, :, np.newaxis], - len(alpha_np), - axis=2 - ) - lambda_star = self._find_lambda_star( - y_raps_no_enc, - y_pred_proba_raps, - alpha_np, - include_last_label, - k_star - ) - conformity_scores_regularized = ( - self._regularize_conformity_score( - k_star, - lambda_star, - conformity_scores, - self.cutoff - ) - ) - self.quantiles_ = compute_quantiles( - conformity_scores_regularized, - alpha_np - ) - else: - self.quantiles_ = compute_quantiles( - conformity_scores, - alpha_np - ) - else: - self.quantiles_ = (n + 1) * (1 - alpha_np) - - # Build prediction sets - # specify which thresholds will be used - if (estimator.cv == "prefit") or (agg_scores in ["mean"]): - thresholds = self.quantiles_ - else: - thresholds = conformity_scores.ravel() - # sort labels by decreasing probability - y_pred_proba_cumsum, y_pred_index_last, y_pred_proba_last = ( - get_last_included_proba( - y_pred_proba, - thresholds, - include_last_label, - self.method, - lambda_star, - k_star, - ) - ) - # get the prediction set by taking all probabilities - # above the last one - if (estimator.cv == "prefit") or (agg_scores in ["mean"]): - y_pred_included = np.greater_equal( - y_pred_proba - y_pred_proba_last, -EPSILON - ) - else: - y_pred_included = np.less_equal( - y_pred_proba - y_pred_proba_last, EPSILON - ) - # remove last label randomly - if include_last_label == "randomized": - y_pred_included = add_random_tie_breaking( - y_pred_included, - y_pred_index_last, - y_pred_proba_cumsum, - y_pred_proba_last, - thresholds, - self.method, - self.random_state, - lambda_star, - k_star, - ) - if (estimator.cv == "prefit") or (agg_scores in ["mean"]): - prediction_sets = y_pred_included + if estimator.cv == "prefit" or agg_scores in ["mean"]: + quantiles_ = compute_quantiles(conformity_scores, alpha_np) else: - # compute the number of times the inequality is verified - prediction_sets_summed = y_pred_included.sum(axis=2) - prediction_sets = np.less_equal( - prediction_sets_summed[:, :, np.newaxis] - - self.quantiles_[np.newaxis, np.newaxis, :], - EPSILON - ) + quantiles_ = (n + 1) * (1 - alpha_np) - return prediction_sets + return quantiles_ diff --git a/mapie/conformity_scores/sets/lac.py b/mapie/conformity_scores/sets/lac.py index 5edf9d45c..48d7c04d7 100644 --- a/mapie/conformity_scores/sets/lac.py +++ b/mapie/conformity_scores/sets/lac.py @@ -114,54 +114,17 @@ def get_conformity_scores( return conformity_scores - def get_sets( + def get_predictions( self, - X: ArrayLike, + X: NDArray, alpha_np: NDArray, estimator: EnsembleClassifier, - conformity_scores: NDArray, agg_scores: Optional[str] = "mean", **kwargs - ): + ) -> NDArray: """ - Compute classes of the prediction sets from the observed values, - the estimator of type ``EnsembleClassifier`` and the conformity scores. - - Parameters - ---------- - X: NDArray of shape (n_samples, n_features) - Observed feature values. - - alpha_np: NDArray of shape (n_alpha,) - NDArray of floats between ``0`` and ``1``, represents the - uncertainty of the confidence interval. - - estimator: EnsembleClassifier - Estimator that is fitted to predict y from X. - - conformity_scores: NDArray of shape (n_samples,) - Conformity scores. - - agg_scores: Optional[str] - How to aggregate the scores output by the estimators on test data - if a cross-validation strategy is used. Choose among: - - - "mean", take the mean of scores. - - "crossval", compare the scores between all training data and each - test point for each label to estimate if the label must be - included in the prediction set. Follows algorithm 2 of - Romano+2020. - - By default, "mean". - - Returns - ------- - NDArray of shape (n_samples, n_classes, n_alpha) - Prediction sets (Booleans indicate whether classes are included). + TODO: Compute the predictions. """ - # Checks - n = len(conformity_scores) - y_pred_proba = estimator.predict(X, agg_scores) y_pred_proba = check_proba_normalized(y_pred_proba, axis=1) if agg_scores != "crossval": @@ -169,16 +132,45 @@ def get_sets( y_pred_proba[:, :, np.newaxis], len(alpha_np), axis=2 ) - # Choice of the quantile - if (estimator.cv == "prefit") or (agg_scores in ["mean"]): - self.quantiles_ = compute_quantiles( + return y_pred_proba + + def get_conformity_quantiles( + self, + conformity_scores: NDArray, + alpha_np: NDArray, + estimator: EnsembleClassifier, + agg_scores: Optional[str] = "mean", + **kwargs + ) -> NDArray: + """ + TODO: Compute the quantiles. + """ + n = len(conformity_scores) + + if estimator.cv == "prefit" or agg_scores in ["mean"]: + quantiles_ = compute_quantiles( conformity_scores, alpha_np ) else: - self.quantiles_ = (n + 1) * (1 - alpha_np) + quantiles_ = (n + 1) * (1 - alpha_np) + + return quantiles_ + + def get_prediction_sets( + self, + y_pred_proba: NDArray, + conformity_scores: NDArray, + alpha_np: NDArray, + estimator: EnsembleClassifier, + agg_scores: Optional[str] = "mean", + **kwargs + ): + """ + TODO: Compute the prediction sets. + """ + n = len(conformity_scores) - # Build prediction sets if (estimator.cv == "prefit") or (agg_scores == "mean"): prediction_sets = np.less_equal( (1 - y_pred_proba) - self.quantiles_, EPSILON diff --git a/mapie/conformity_scores/sets/naive.py b/mapie/conformity_scores/sets/naive.py new file mode 100644 index 000000000..cb2df4157 --- /dev/null +++ b/mapie/conformity_scores/sets/naive.py @@ -0,0 +1,417 @@ +from typing import Optional, Tuple, Union, cast + +import numpy as np +from sklearn.dummy import check_random_state + +from mapie.conformity_scores.classification import BaseClassificationScore +from mapie.conformity_scores.sets.utils import ( + check_include_last_label, check_proba_normalized, get_last_index_included +) +from mapie.estimator.classifier import EnsembleClassifier + +from mapie._machine_precision import EPSILON +from mapie._typing import ArrayLike, NDArray + + +class Naive(BaseClassificationScore): + """TODO: + Adaptive Prediction Sets (APS) method-based non-conformity score. + Three differents method are available in this class: + + - ``"naive"``, sum of the probabilities until the 1-alpha threshold. + + - ``"aps"`` (formerly called "cumulated_score"), Adaptive Prediction + Sets method. It is based on the sum of the softmax outputs of the + labels until the true label is reached, on the calibration set. + See [1] for more details. + + - ``"raps"``, Regularized Adaptive Prediction Sets method. It uses the + same technique as ``"aps"`` method but with a penalty term + to reduce the size of prediction sets. See [2] for more + details. For now, this method only works with ``"prefit"`` and + ``"split"`` strategies. + + References + ---------- + [1] Yaniv Romano, Matteo Sesia and Emmanuel J. Candès. + "Classification with Valid and Adaptive Coverage." + NeurIPS 202 (spotlight) 2020. + + [2] Anastasios Nikolas Angelopoulos, Stephen Bates, Michael Jordan + and Jitendra Malik. + "Uncertainty Sets for Image Classifiers using Conformal Prediction." + International Conference on Learning Representations 2021. + + Attributes + ---------- + method: str + Method to choose for prediction interval estimates. + This attribute is for compatibility with ``MapieClassifier`` + which previously used a string instead of a score class. + Methods available in this class: ``aps``, ``raps`` and ``naive``. + + By default, ``aps`` for APS method. + + classes: Optional[ArrayLike] + Names of the classes. + + random_state: Optional[Union[int, RandomState]] + Pseudo random number generator state. + + quantiles_: ArrayLike of shape (n_alpha) + The quantiles estimated from ``get_sets`` method. + """ + + def __init__(self) -> None: + super().__init__() + + def set_external_attributes( + self, + method: str = 'naive', + classes: Optional[ArrayLike] = None, + random_state: Optional[Union[int, np.random.RandomState]] = None, + **kwargs + ) -> None: + """ + Set attributes that are not provided by the user. + + Parameters + ---------- + method: str + Method to choose for prediction interval estimates. + Methods available in this class: ``aps``, ``raps`` and ``naive``. + + By default ``aps`` for APS method. + + classes: Optional[ArrayLike] + Names of the classes. + + By default ``None``. + + random_state: Optional[Union[int, RandomState]] + Pseudo random number generator state. + """ + super().set_external_attributes(**kwargs) + self.method = method + self.classes = classes + self.random_state = random_state + + def get_conformity_scores( + self, + y: NDArray, + y_pred: NDArray, + **kwargs + ) -> NDArray: + """ + Get the conformity score. + + Parameters + ---------- + y: NDArray of shape (n_samples,) + Observed target values. + + y_pred: NDArray of shape (n_samples,) + Predicted target values. + + Returns + ------- + NDArray of shape (n_samples,) + Conformity scores. + """ + conformity_scores = np.empty(y_pred.shape, dtype="float") + return conformity_scores + + def get_predictions( + self, + X: NDArray, + alpha_np: NDArray, + estimator: EnsembleClassifier, + agg_scores: Optional[str] = "mean", + **kwargs + ) -> NDArray: + """ + TODO: Compute the predictions. + """ + y_pred_proba = estimator.predict(X, agg_scores) + y_pred_proba = check_proba_normalized(y_pred_proba, axis=1) + if agg_scores != "crossval": + y_pred_proba = np.repeat( + y_pred_proba[:, :, np.newaxis], len(alpha_np), axis=2 + ) + return y_pred_proba + + def get_conformity_quantiles( + self, + conformity_scores: NDArray, + alpha_np: NDArray, + estimator: EnsembleClassifier, + **kwargs + ) -> NDArray: + """ + TODO: Compute the quantiles. + """ + quantiles_ = 1 - alpha_np + return quantiles_ + + def _add_regualization( + self, + y_pred_proba_sorted_cumsum, + **kwargs + ): + return y_pred_proba_sorted_cumsum + + def _get_last_included_proba( + self, + y_pred_proba: NDArray, + thresholds: NDArray, + include_last_label: Union[bool, str, None], + **kwargs + ) -> Tuple[NDArray, NDArray, NDArray]: + """ + Function that returns the smallest score + among those which are included in the prediciton set. + + Parameters + ---------- + y_pred_proba: NDArray of shape (n_samples, n_classes) + Predictions of the model. + + thresholds: NDArray of shape (n_alphas, ) + Quantiles that have been computed from the conformity scores. + + include_last_label: Union[bool, str, None] + Whether to include or not the label whose score exceeds threshold. + + Returns + ------- + Tuple[ArrayLike, ArrayLike, ArrayLike] + Arrays of shape (n_samples, n_classes, n_alphas), + (n_samples, 1, n_alphas) and (n_samples, 1, n_alphas). + They are respectively the cumsumed scores in the original + order which can be different according to the value of alpha + with the RAPS method, the index of the last included score + and the value of the last included score. + """ + index_sorted = np.flip( + np.argsort(y_pred_proba, axis=1), axis=1 + ) + # sort probabilities by decreasing order + y_pred_proba_sorted = np.take_along_axis( + y_pred_proba, index_sorted, axis=1 + ) + # get sorted cumulated score + y_pred_proba_sorted_cumsum = np.cumsum(y_pred_proba_sorted, axis=1) + y_pred_proba_sorted_cumsum = self._add_regualization( + y_pred_proba_sorted_cumsum, **kwargs + ) + + # get cumulated score at their original position + y_pred_proba_cumsum = np.take_along_axis( + y_pred_proba_sorted_cumsum, + np.argsort(index_sorted, axis=1), + axis=1 + ) + # get index of the last included label + y_pred_index_last = get_last_index_included( + y_pred_proba_cumsum, + thresholds, + include_last_label + ) + # get the probability of the last included label + y_pred_proba_last = np.take_along_axis( + y_pred_proba, + y_pred_index_last, + axis=1 + ) + + zeros_scores_proba_last = (y_pred_proba_last <= EPSILON) + + # If the last included proba is zero, change it to the + # smallest non-zero value to avoid inluding them in the + # prediction sets. + if np.sum(zeros_scores_proba_last) > 0: + y_pred_proba_last[zeros_scores_proba_last] = np.expand_dims( + np.min( + np.ma.masked_less( + y_pred_proba, + EPSILON + ).filled(fill_value=np.inf), + axis=1 + ), axis=1 + )[zeros_scores_proba_last] + + return y_pred_proba_cumsum, y_pred_index_last, y_pred_proba_last + + def _compute_vs_parameter( + self, + y_proba_last_cumsumed, + threshold, + y_pred_proba_last, + prediction_sets, + *kwargs + ): + """ + TODO + """ + # compute V parameter from Romano+(2020) + vs = ( + (y_proba_last_cumsumed - threshold.reshape(1, -1)) / + y_pred_proba_last[:, 0, :] + ) + return vs + + def _add_random_tie_breaking( + self, + prediction_sets: NDArray, + y_pred_index_last: NDArray, + y_pred_proba_cumsum: NDArray, + y_pred_proba_last: NDArray, + threshold: NDArray, + random_state: Optional[Union[int, np.random.RandomState]] = None, + **kwargs + ) -> NDArray: + """ + Randomly remove last label from prediction set based on the + comparison between a random number and the difference between + cumulated score of the last included label and the quantile. + + Parameters + ---------- + prediction_sets: NDArray of shape + (n_samples, n_classes, n_threshold) + Prediction set for each observation and each alpha. + + y_pred_index_last: NDArray of shape (n_samples, threshold) + Index of the last included label. + + y_pred_proba_cumsum: NDArray of shape (n_samples, n_classes) + Cumsumed probability of the model in the original order. + + y_pred_proba_last: NDArray of shape (n_samples, 1, threshold) + Last included probability. + + threshold: NDArray of shape (n_alpha,) or shape (n_samples_train,) + Threshold to compare with y_proba_last_cumsum, can be either: + + - the quantiles associated with alpha values when + ``cv`` == "prefit", ``cv`` == "split" + or ``agg_scores`` is "mean" + + - the conformity score from training samples otherwise (i.e., when + ``cv`` is CV splitter and ``agg_scores`` is "crossval") + + method: str + Method that determines how to remove last label in the prediction + set. + + - if "cumulated_score" or "aps", compute V parameter + from Romano+(2020) + + - else compute V parameter from Angelopoulos+(2020) + + lambda_star: Optional[Union[NDArray, float]] of shape (n_alpha): + Optimal value of the regulizer lambda. + + k_star: Optional[NDArray] of shape (n_alpha): + Optimal value of the regulizer k. + + Returns + ------- + NDArray of shape (n_samples, n_classes, n_alpha) + Updated version of prediction_sets with randomly removed labels. + """ + # get cumsumed probabilities up to last retained label + y_proba_last_cumsumed = np.squeeze( + np.take_along_axis( + y_pred_proba_cumsum, + y_pred_index_last, + axis=1 + ), axis=1 + ) + + # TODO + vs = self._compute_vs_parameter( + y_proba_last_cumsumed, + threshold, + y_pred_proba_last, + prediction_sets + ) + + # get random numbers for each observation and alpha value + random_state = check_random_state(random_state) + random_state = cast(np.random.RandomState, random_state) + us = random_state.uniform(size=(prediction_sets.shape[0], 1)) + # remove last label from comparison between uniform number and V + vs_less_than_us = np.less_equal(vs - us, EPSILON) + np.put_along_axis( + prediction_sets, + y_pred_index_last, + vs_less_than_us[:, np.newaxis, :], + axis=1 + ) + return prediction_sets + + def get_prediction_sets( + self, + y_pred_proba: NDArray, + conformity_scores: NDArray, + alpha_np: NDArray, + estimator: EnsembleClassifier, + agg_scores: Optional[str] = "mean", + include_last_label: Optional[Union[bool, str]] = True, + **kwargs + ): + """ + TODO: Compute the prediction sets. + """ + include_last_label = check_include_last_label(include_last_label) + + # specify which thresholds will be used + if estimator.cv == "prefit" or agg_scores in ["mean"]: + thresholds = self.quantiles_ + else: + thresholds = conformity_scores.ravel() + + # sort labels by decreasing probability + y_pred_proba_cumsum, y_pred_index_last, y_pred_proba_last = ( + self._get_last_included_proba( + y_pred_proba, + thresholds, + include_last_label, + prediction_phase=True, + **kwargs + ) + ) + # get the prediction set by taking all probabilities + # above the last one + if estimator.cv == "prefit" or agg_scores in ["mean"]: + y_pred_included = np.greater_equal( + y_pred_proba - y_pred_proba_last, -EPSILON + ) + else: + y_pred_included = np.less_equal( + y_pred_proba - y_pred_proba_last, EPSILON + ) + # remove last label randomly + if include_last_label == "randomized": + y_pred_included = self._add_random_tie_breaking( + y_pred_included, + y_pred_index_last, + y_pred_proba_cumsum, + y_pred_proba_last, + thresholds, + self.random_state, + **kwargs + ) + if estimator.cv == "prefit" or agg_scores in ["mean"]: + prediction_sets = y_pred_included + else: + # compute the number of times the inequality is verified + prediction_sets_summed = y_pred_included.sum(axis=2) + prediction_sets = np.less_equal( + prediction_sets_summed[:, :, np.newaxis] + - self.quantiles_[np.newaxis, np.newaxis, :], + EPSILON + ) + + return prediction_sets diff --git a/mapie/conformity_scores/sets/raps.py b/mapie/conformity_scores/sets/raps.py new file mode 100644 index 000000000..8ee2bdea1 --- /dev/null +++ b/mapie/conformity_scores/sets/raps.py @@ -0,0 +1,374 @@ +from typing import Optional, Tuple, Union, cast + +import numpy as np + +from mapie.conformity_scores.sets.aps import APS +from mapie.conformity_scores.sets.utils import get_true_label_cumsum_proba +from mapie.estimator.classifier import EnsembleClassifier + +from mapie._machine_precision import EPSILON +from mapie._typing import ArrayLike, NDArray +from mapie.metrics import classification_mean_width_score +from mapie.utils import check_alpha_and_n_samples, compute_quantiles + + +class RAPS(APS): + """TODO: + Adaptive Prediction Sets (APS) method-based non-conformity score. + Three differents method are available in this class: + + - ``"naive"``, sum of the probabilities until the 1-alpha threshold. + + - ``"aps"`` (formerly called "cumulated_score"), Adaptive Prediction + Sets method. It is based on the sum of the softmax outputs of the + labels until the true label is reached, on the calibration set. + See [1] for more details. + + - ``"raps"``, Regularized Adaptive Prediction Sets method. It uses the + same technique as ``"aps"`` method but with a penalty term + to reduce the size of prediction sets. See [2] for more + details. For now, this method only works with ``"prefit"`` and + ``"split"`` strategies. + + References + ---------- + [1] Yaniv Romano, Matteo Sesia and Emmanuel J. Candès. + "Classification with Valid and Adaptive Coverage." + NeurIPS 202 (spotlight) 2020. + + [2] Anastasios Nikolas Angelopoulos, Stephen Bates, Michael Jordan + and Jitendra Malik. + "Uncertainty Sets for Image Classifiers using Conformal Prediction." + International Conference on Learning Representations 2021. + + Attributes + ---------- + method: str + Method to choose for prediction interval estimates. + This attribute is for compatibility with ``MapieClassifier`` + which previously used a string instead of a score class. + Methods available in this class: ``aps``, ``raps`` and ``naive``. + + By default, ``aps`` for APS method. + + classes: Optional[ArrayLike] + Names of the classes. + + random_state: Optional[Union[int, RandomState]] + Pseudo random number generator state. + + quantiles_: ArrayLike of shape (n_alpha) + The quantiles estimated from ``get_sets`` method. + """ + + def __init__(self) -> None: + super().__init__() + + def set_external_attributes( + self, + method: str = 'raps', + classes: Optional[ArrayLike] = None, + random_state: Optional[Union[int, np.random.RandomState]] = None, + **kwargs + ) -> None: + """ + Set attributes that are not provided by the user. + + Parameters + ---------- + method: str + Method to choose for prediction interval estimates. + Methods available in this class: ``aps``, ``raps`` and ``naive``. + + By default ``aps`` for APS method. + + classes: Optional[ArrayLike] + Names of the classes. + + By default ``None``. + + random_state: Optional[Union[int, RandomState]] + Pseudo random number generator state. + """ + super().set_external_attributes(**kwargs) + self.method = method + self.classes = classes + self.random_state = random_state + + @staticmethod + def _regularize_conformity_score( + k_star: NDArray, + lambda_: Union[NDArray, float], + conf_score: NDArray, + cutoff: NDArray + ) -> NDArray: + """ + Regularize the conformity scores with the ``"raps"`` + method. See algo. 2 in [3]. TODO: add ref. + + Parameters + ---------- + k_star: NDArray of shape (n_alphas, ) + Optimal value of k (called k_reg in the paper). There + is one value per alpha. + + lambda_: Union[NDArray, float] of shape (n_alphas, ) + One value of lambda for each alpha. + + conf_score: NDArray of shape (n_samples, 1) + Conformity scores. + + cutoff: NDArray of shape (n_samples, 1) + Position of the true label. + + Returns + ------- + NDArray of shape (n_samples, 1, n_alphas) + Regularized conformity scores. The regularization + depends on the value of alpha. + """ + conf_score = np.repeat( + conf_score[:, :, np.newaxis], len(k_star), axis=2 + ) + cutoff = np.repeat( + cutoff[:, np.newaxis], len(k_star), axis=1 + ) + conf_score += np.maximum( + np.expand_dims( + lambda_ * (cutoff - k_star), + axis=1 + ), + 0 + ) + return conf_score + + def _update_size_and_lambda( + self, + best_sizes: NDArray, + alpha_np: NDArray, + y_ps: NDArray, + lambda_: Union[NDArray, float], + lambda_star: NDArray + ) -> Tuple[NDArray, NDArray]: + """ + Update the values of the optimal lambda if the average size of the + prediction sets decreases with this new value of lambda. + + Parameters + ---------- + best_sizes: NDArray of shape (n_alphas, ) + Smallest average prediciton set size before testing + for the new value of lambda_ + + alpha_np: NDArray of shape (n_alphas) + Level of confidences. + + y_ps: NDArray of shape (n_samples, n_classes, n_alphas) + Prediction sets computed with the RAPS method and the + new value of lambda_ + + lambda_: NDArray of shape (n_alphas, ) + New value of lambda_star to test + + lambda_star: NDArray of shape (n_alphas, ) + Actual optimal lambda values for each alpha. + + Returns + ------- + Tuple[NDArray, NDArray] + Arrays of shape (n_alphas, ) and (n_alpha, ) which + respectively represent the updated values of lambda_star + and the new best sizes. + """ + sizes = [ + classification_mean_width_score( + y_ps[:, :, i] + ) for i in range(len(alpha_np)) + ] + + sizes_improve = (sizes < best_sizes - EPSILON) + lambda_star = ( + sizes_improve * lambda_ + (1 - sizes_improve) * lambda_star + ) + best_sizes = sizes_improve * sizes + (1 - sizes_improve) * best_sizes + + return lambda_star, best_sizes + + def _find_lambda_star( + self, + y_raps_no_enc: NDArray, + y_pred_proba_raps: NDArray, + alpha_np: NDArray, + include_last_label: Union[bool, str, None], + k_star: NDArray + ) -> Union[NDArray, float]: + """ + Find the optimal value of lambda for each alpha. + + Parameters + ---------- + y_pred_proba_raps: NDArray of shape (n_samples, n_labels, n_alphas) + Predictions of the model repeated on the last axis as many times + as the number of alphas + + alpha_np: NDArray of shape (n_alphas, ) + Levels of confidences. + + include_last_label: bool + Whether to include or not last label in + the prediction sets + + k_star: NDArray of shape (n_alphas, ) + Values of k for the regularization. + + Returns + ------- + ArrayLike of shape (n_alphas, ) + Optimal values of lambda. + """ + classes = cast(NDArray, self.classes) + + lambda_star = np.zeros(len(alpha_np)) + best_sizes = np.full(len(alpha_np), np.finfo(np.float64).max) + + for lambda_ in [.001, .01, .1, .2, .5]: # values given in paper[3]TODO + true_label_cumsum_proba, cutoff = ( + get_true_label_cumsum_proba( + y_raps_no_enc, + y_pred_proba_raps[:, :, 0], + classes + ) + ) + + true_label_cumsum_proba_reg = self._regularize_conformity_score( + k_star, + lambda_, + true_label_cumsum_proba, + cutoff + ) + + quantiles_ = compute_quantiles( + true_label_cumsum_proba_reg, + alpha_np + ) + + _, _, y_pred_proba_last = self._get_last_included_proba( + y_pred_proba_raps, + quantiles_, + include_last_label, + lambda_=lambda_, + k_star=k_star + ) + + y_ps = np.greater_equal( + y_pred_proba_raps - y_pred_proba_last, -EPSILON + ) + lambda_star, best_sizes = self._update_size_and_lambda( + best_sizes, alpha_np, y_ps, lambda_, lambda_star + ) + + if len(lambda_star) == 1: + lambda_star = lambda_star[0] + + return lambda_star + + def get_conformity_quantiles( + self, + conformity_scores: NDArray, + alpha_np: NDArray, + estimator: EnsembleClassifier, + agg_scores: Optional[str] = "mean", + include_last_label: Optional[Union[bool, str]] = True, + X_raps: Optional[NDArray] = None, + y_raps_no_enc: Optional[NDArray] = None, + y_pred_proba_raps: Optional[NDArray] = None, + position_raps: Optional[NDArray] = None, + **kwargs + ) -> NDArray: + """ + TODO: Compute the quantiles. + """ + # Casting to NDArray to avoid mypy errors + X_raps = cast(NDArray, X_raps) + y_raps_no_enc = cast(NDArray, y_raps_no_enc) + y_pred_proba_raps = cast(NDArray, y_pred_proba_raps) + position_raps = cast(NDArray, position_raps) + + check_alpha_and_n_samples(alpha_np, X_raps.shape[0]) + self.k_star = compute_quantiles( + position_raps, + alpha_np + ) + 1 + y_pred_proba_raps = np.repeat( + y_pred_proba_raps[:, :, np.newaxis], + len(alpha_np), + axis=2 + ) + self.lambda_star = self._find_lambda_star( + y_raps_no_enc, + y_pred_proba_raps, + alpha_np, + include_last_label, + self.k_star + ) + conformity_scores_regularized = ( + self._regularize_conformity_score( + self.k_star, + self.lambda_star, + conformity_scores, + self.cutoff + ) + ) + quantiles_ = compute_quantiles( + conformity_scores_regularized, + alpha_np + ) + + return quantiles_ + + def _add_regualization( + self, + y_pred_proba_sorted_cumsum, + lambda_=None, + k_star=None, + prediction_phase=False, + **kwargs + ): + """ + TODO + """ + if prediction_phase: + lambda_ = self.lambda_star + k_star = self.k_star + + y_pred_proba_sorted_cumsum += lambda_ * np.maximum( + 0, + np.cumsum( + np.ones(y_pred_proba_sorted_cumsum.shape), axis=1 + ) - k_star + ) + + return y_pred_proba_sorted_cumsum + + def _compute_vs_parameter( + self, + y_proba_last_cumsumed, + threshold, + y_pred_proba_last, + prediction_sets, + *kwargs + ): + """ + TODO + """ + # compute V parameter from Angelopoulos+(2020) + L = np.sum(prediction_sets, axis=1) + vs = ( + (y_proba_last_cumsumed - threshold.reshape(1, -1)) / + ( + y_pred_proba_last[:, 0, :] - + self.lambda_star * np.maximum(0, L - self.k_star) + + self.lambda_star * (L > self.k_star) + ) + ) + return vs diff --git a/mapie/conformity_scores/sets/topk.py b/mapie/conformity_scores/sets/topk.py index fb0e7836f..91667b802 100644 --- a/mapie/conformity_scores/sets/topk.py +++ b/mapie/conformity_scores/sets/topk.py @@ -9,7 +9,7 @@ from mapie.estimator.classifier import EnsembleClassifier from mapie._machine_precision import EPSILON -from mapie._typing import ArrayLike, NDArray +from mapie._typing import NDArray from mapie.utils import compute_quantiles @@ -118,49 +118,46 @@ def get_conformity_scores( return conformity_scores - def get_sets( + def get_predictions( self, - X: ArrayLike, + X: NDArray, alpha_np: NDArray, estimator: EnsembleClassifier, - conformity_scores: NDArray, **kwargs - ): + ) -> NDArray: """ - Compute classes of the prediction sets from the observed values, - the estimator of type ``EnsembleClassifier`` and the conformity scores. - - Parameters - ---------- - X: NDArray of shape (n_samples, n_features) - Observed feature values. - - alpha_np: NDArray of shape (n_alpha,) - NDArray of floats between ``0`` and ``1``, represents the - uncertainty of the confidence interval. - - estimator: EnsembleClassifier - Estimator that is fitted to predict y from X. - - conformity_scores: NDArray of shape (n_samples,) - Conformity scores. - - Returns - ------- - NDArray of shape (n_samples, n_classes, n_alpha) - Prediction sets (Booleans indicate whether classes are included). + TODO: Compute the predictions. """ - # Checks y_pred_proba = estimator.predict(X, agg_scores="mean") y_pred_proba = check_proba_normalized(y_pred_proba, axis=1) y_pred_proba = np.repeat( y_pred_proba[:, :, np.newaxis], len(alpha_np), axis=2 ) + return y_pred_proba - # Choice of the quantile - self.quantiles_ = compute_quantiles(conformity_scores, alpha_np) + def get_conformity_quantiles( + self, + conformity_scores: NDArray, + alpha_np: NDArray, + estimator: EnsembleClassifier, + **kwargs + ) -> NDArray: + """ + TODO: Compute the quantiles. + """ + return compute_quantiles(conformity_scores, alpha_np) - # Build prediction sets + def get_prediction_sets( + self, + y_pred_proba: NDArray, + conformity_scores: NDArray, + alpha_np: NDArray, + estimator: EnsembleClassifier, + **kwargs + ): + """ + TODO: Compute the prediction sets. + """ y_pred_proba = y_pred_proba[:, :, 0] index_sorted = np.fliplr(np.argsort(y_pred_proba, axis=1)) y_pred_index_last = np.stack( diff --git a/mapie/conformity_scores/sets/utils.py b/mapie/conformity_scores/sets/utils.py index a2b5b32af..5917a6cb7 100644 --- a/mapie/conformity_scores/sets/utils.py +++ b/mapie/conformity_scores/sets/utils.py @@ -1,7 +1,6 @@ -from typing import Any, Optional, Tuple, Union, cast +from typing import Optional, Tuple, Union, cast import numpy as np from sklearn.calibration import label_binarize -from sklearn.dummy import check_random_state from mapie._typing import ArrayLike, NDArray from mapie._machine_precision import EPSILON @@ -203,199 +202,3 @@ def get_last_index_included( ), axis=1 ) return y_pred_index_last[:, np.newaxis, :] - - -def get_last_included_proba( - y_pred_proba: NDArray, - thresholds: NDArray, - include_last_label: Union[bool, str, None], - method: str, - lambda_: Union[NDArray, float, None], - k_star: Union[NDArray, Any] -) -> Tuple[NDArray, NDArray, NDArray]: - """ - Function that returns the smallest score - among those which are included in the prediciton set. - - Parameters - ---------- - y_pred_proba: NDArray of shape (n_samples, n_classes) - Predictions of the model. - - thresholds: NDArray of shape (n_alphas, ) - Quantiles that have been computed from the conformity scores. - - include_last_label: Union[bool, str, None] - Whether to include or not the label whose score exceeds the threshold. - - lambda_: Union[NDArray, float, None] of shape (n_alphas) - Values of lambda for the regularization. - - k_star: Union[NDArray, Any] - Values of k for the regularization. - - Returns - ------- - Tuple[ArrayLike, ArrayLike, ArrayLike] - Arrays of shape (n_samples, n_classes, n_alphas), - (n_samples, 1, n_alphas) and (n_samples, 1, n_alphas). - They are respectively the cumsumed scores in the original - order which can be different according to the value of alpha - with the RAPS method, the index of the last included score - and the value of the last included score. - """ - index_sorted = np.flip( - np.argsort(y_pred_proba, axis=1), axis=1 - ) - # sort probabilities by decreasing order - y_pred_proba_sorted = np.take_along_axis( - y_pred_proba, index_sorted, axis=1 - ) - # get sorted cumulated score - y_pred_proba_sorted_cumsum = np.cumsum( - y_pred_proba_sorted, axis=1 - ) - - if method == "raps": - y_pred_proba_sorted_cumsum += lambda_ * np.maximum( - 0, - np.cumsum( - np.ones(y_pred_proba_sorted_cumsum.shape), axis=1 - ) - k_star - ) - # get cumulated score at their original position - y_pred_proba_cumsum = np.take_along_axis( - y_pred_proba_sorted_cumsum, - np.argsort(index_sorted, axis=1), - axis=1 - ) - # get index of the last included label - y_pred_index_last = get_last_index_included( - y_pred_proba_cumsum, - thresholds, - include_last_label - ) - # get the probability of the last included label - y_pred_proba_last = np.take_along_axis( - y_pred_proba, - y_pred_index_last, - axis=1 - ) - - zeros_scores_proba_last = (y_pred_proba_last <= EPSILON) - - # If the last included proba is zero, change it to the - # smallest non-zero value to avoid inluding them in the - # prediction sets. - if np.sum(zeros_scores_proba_last) > 0: - y_pred_proba_last[zeros_scores_proba_last] = np.expand_dims( - np.min( - np.ma.masked_less( - y_pred_proba, - EPSILON - ).filled(fill_value=np.inf), - axis=1 - ), axis=1 - )[zeros_scores_proba_last] - - return y_pred_proba_cumsum, y_pred_index_last, y_pred_proba_last - - -def add_random_tie_breaking( - prediction_sets: NDArray, - y_pred_index_last: NDArray, - y_pred_proba_cumsum: NDArray, - y_pred_proba_last: NDArray, - threshold: NDArray, - method: str, - random_state: Optional[Union[int, np.random.RandomState]] = None, - lambda_star: Optional[Union[NDArray, float]] = None, - k_star: Optional[Union[NDArray, None]] = None -) -> NDArray: - """ - Randomly remove last label from prediction set based on the - comparison between a random number and the difference between - cumulated score of the last included label and the quantile. - - Parameters - ---------- - prediction_sets: NDArray of shape - (n_samples, n_classes, n_threshold) - Prediction set for each observation and each alpha. - - y_pred_index_last: NDArray of shape (n_samples, threshold) - Index of the last included label. - - y_pred_proba_cumsum: NDArray of shape (n_samples, n_classes) - Cumsumed probability of the model in the original order. - - y_pred_proba_last: NDArray of shape (n_samples, 1, threshold) - Last included probability. - - threshold: NDArray of shape (n_alpha,) or shape (n_samples_train,) - Threshold to compare with y_proba_last_cumsum, can be either: - - - the quantiles associated with alpha values when ``cv`` == "prefit", - ``cv`` == "split" or ``agg_scores`` is "mean" - - - the conformity score from training samples otherwise - (i.e., when ``cv`` is CV splitter and ``agg_scores`` is "crossval") - - method: str - Method that determines how to remove last label in the prediction set. - - - if "cumulated_score" or "aps", compute V parameter from Romano+(2020) - - - else compute V parameter from Angelopoulos+(2020) - - lambda_star: Union[NDArray, float, None] of shape (n_alpha): - Optimal value of the regulizer lambda. - - k_star: Union[NDArray, None] of shape (n_alpha): - Optimal value of the regulizer k. - - Returns - ------- - NDArray of shape (n_samples, n_classes, n_alpha) - Updated version of prediction_sets with randomly removed labels. - """ - # get cumsumed probabilities up to last retained label - y_proba_last_cumsumed = np.squeeze( - np.take_along_axis( - y_pred_proba_cumsum, - y_pred_index_last, - axis=1 - ), axis=1 - ) - - if method in ["cumulated_score", "aps"]: - # compute V parameter from Romano+(2020) - vs = ( - (y_proba_last_cumsumed - threshold.reshape(1, -1)) / - y_pred_proba_last[:, 0, :] - ) - else: - # compute V parameter from Angelopoulos+(2020) - L = np.sum(prediction_sets, axis=1) - vs = ( - (y_proba_last_cumsumed - threshold.reshape(1, -1)) / - ( - y_pred_proba_last[:, 0, :] - - lambda_star * np.maximum(0, L - k_star) + - lambda_star * (L > k_star) - ) - ) - - # get random numbers for each observation and alpha value - random_state = check_random_state(random_state) - random_state = cast(np.random.RandomState, random_state) - us = random_state.uniform(size=(prediction_sets.shape[0], 1)) - # remove last label from comparison between uniform number and V - vs_less_than_us = np.less_equal(vs - us, EPSILON) - np.put_along_axis( - prediction_sets, - y_pred_index_last, - vs_less_than_us[:, np.newaxis, :], - axis=1 - ) - return prediction_sets diff --git a/mapie/conformity_scores/utils.py b/mapie/conformity_scores/utils.py index a6b3283c7..d2b0c6cc9 100644 --- a/mapie/conformity_scores/utils.py +++ b/mapie/conformity_scores/utils.py @@ -3,7 +3,7 @@ from .regression import BaseRegressionScore from .classification import BaseClassificationScore from .bounds import AbsoluteConformityScore -from .sets import APS, LAC, TopK +from .sets import APS, LAC, Naive, RAPS, TopK def check_regression_conformity_score( @@ -72,9 +72,13 @@ def check_classification_conformity_score( if method is not None: if method in ['score', 'lac']: return LAC() - if method in ['naive', 'cumulated_score', 'aps', 'raps']: + if method in ['cumulated_score', 'aps']: return APS() - if method == 'top_k': + if method in ['naive']: + return Naive() + if method in ['raps']: + return RAPS() + if method in ['top_k']: return TopK() else: raise ValueError( diff --git a/mapie/tests/test_classification.py b/mapie/tests/test_classification.py index 1b6bf6a12..c0ad000f4 100644 --- a/mapie/tests/test_classification.py +++ b/mapie/tests/test_classification.py @@ -23,10 +23,9 @@ from mapie._typing import ArrayLike, NDArray from mapie.classification import MapieClassifier -from mapie.conformity_scores.sets.aps import APS +from mapie.conformity_scores.sets.raps import RAPS from mapie.conformity_scores.sets.utils import ( - check_proba_normalized, get_last_included_proba, - get_true_label_cumsum_proba + check_proba_normalized, get_true_label_cumsum_proba ) from mapie.metrics import classification_coverage_score from mapie.utils import check_alpha @@ -1740,7 +1739,7 @@ def test_regularize_conf_scores_shape(k_lambda) -> None: lambda_, k = k_lambda[0], k_lambda[1] conf_scores = np.random.rand(100, 1) cutoff = np.cumsum(np.ones(conf_scores.shape)) - 1 - reg_conf_scores = APS._regularize_conformity_score( + reg_conf_scores = RAPS._regularize_conformity_score( k, lambda_, conf_scores, cutoff ) @@ -1816,12 +1815,11 @@ def test_get_last_included_proba_shape(k_lambda, strategy): y_pred_proba[:, :, np.newaxis], len(thresholds), axis=2 ) - mapie = MapieClassifier(estimator=clf, **STRATEGIES[strategy][0]) include_last_label = STRATEGIES[strategy][1]["include_last_label"] y_p_p_c, y_p_i_l, y_p_p_i_l = \ - get_last_included_proba( - y_pred_proba, thresholds, include_last_label, - mapie.method, lambda_, k + RAPS._get_last_included_proba( + RAPS(), y_pred_proba, thresholds, include_last_label, + lambda_=lambda_, k_star=k ) assert y_p_p_c.shape == (len(X), len(np.unique(y)), len(thresholds)) From c79d6e5a35847e2754c87ac9d2eb798900875a06 Mon Sep 17 00:00:00 2001 From: Thibault Cordier Date: Wed, 3 Jul 2024 18:12:07 +0200 Subject: [PATCH 11/46] UPD: improve docstring of score classes --- mapie/conformity_scores/classification.py | 78 ++++++++-- mapie/conformity_scores/sets/aps.py | 43 ++++-- mapie/conformity_scores/sets/lac.py | 84 ++++++++++- mapie/conformity_scores/sets/naive.py | 169 +++++++++++++++++----- mapie/conformity_scores/sets/raps.py | 167 +++++++++++++++++---- mapie/conformity_scores/sets/topk.py | 65 ++++++++- 6 files changed, 504 insertions(+), 102 deletions(-) diff --git a/mapie/conformity_scores/classification.py b/mapie/conformity_scores/classification.py index 4e4925cea..ace093661 100644 --- a/mapie/conformity_scores/classification.py +++ b/mapie/conformity_scores/classification.py @@ -30,7 +30,26 @@ def get_predictions( **kwargs ) -> NDArray: """ - TODO: Compute the predictions. + Abstract method to get predictions from an EnsembleClassifier. + + This method should be implemented by any subclass of the current class. + + Parameters: + ----------- + X: NDArray of shape (n_samples, n_features) + Observed feature values. + + alpha_np: NDArray of shape (n_alpha,) + NDArray of floats between ``0`` and ``1``, represents the + uncertainty of the confidence interval. + + estimator: EnsembleClassifier + Estimator that is fitted to predict y from X. + + Returns: + -------- + NDArray + Array of predictions. """ @abstractmethod @@ -42,7 +61,26 @@ def get_conformity_quantiles( **kwargs ) -> NDArray: """ - TODO: Compute the quantiles. + Abstract method to get quantiles of the conformity scores. + + This method should be implemented by any subclass of the current class. + + Parameters: + ----------- + conformity_scores: NDArray of shape (n_samples,) + Conformity scores for each sample. + + alpha_np: NDArray of shape (n_alpha,) + NDArray of floats between 0 and 1, representing the uncertainty + of the confidence interval. + + estimator: EnsembleClassifier + Estimator that is fitted to predict y from X. + + Returns: + -------- + NDArray + Array of quantiles with respect to alpha_np. """ @abstractmethod @@ -53,9 +91,32 @@ def get_prediction_sets( alpha_np: NDArray, estimator: EnsembleClassifier, **kwargs - ): + ) -> NDArray: """ - TODO: Compute the prediction sets. + Abstract method to generate prediction sets based on the probability + predictions, the conformity scores and the uncertainty level. + + This method should be implemented by any subclass of the current class. + + Parameters: + ----------- + y_pred_proba: NDArray of shape (n_samples, n_classes) + Target prediction. + + conformity_scores: NDArray of shape (n_samples,) + Conformity scores for each sample. + + alpha_np: NDArray of shape (n_alpha,) + NDArray of floats between 0 and 1, representing the uncertainty + of the confidence interval. + + estimator: EnsembleClassifier + Estimator that is fitted to predict y from X. + + Returns: + -------- + NDArray + Array of quantiles with respect to alpha_np. """ def get_sets( @@ -65,7 +126,7 @@ def get_sets( estimator: EnsembleClassifier, conformity_scores: NDArray, **kwargs - ): + ) -> NDArray: """ Compute classes of the prediction sets from the observed values, the estimator of type ``EnsembleClassifier`` and the conformity scores. @@ -76,8 +137,8 @@ def get_sets( Observed feature values. alpha_np: NDArray of shape (n_alpha,) - NDArray of floats between ``0`` and ``1``, represents the - uncertainty of the confidence interval. + NDArray of floats between 0 and 1, representing the uncertainty + of the confidence interval. estimator: EnsembleClassifier Estimator that is fitted to predict y from X. @@ -90,9 +151,6 @@ def get_sets( NDArray of shape (n_samples, n_classes, n_alpha) Prediction sets (Booleans indicate whether classes are included). """ - # Checks - () - # Predict probabilities y_pred_proba = self.get_predictions( X, alpha_np, estimator, **kwargs diff --git a/mapie/conformity_scores/sets/aps.py b/mapie/conformity_scores/sets/aps.py index 16c6a7b98..29402b64c 100644 --- a/mapie/conformity_scores/sets/aps.py +++ b/mapie/conformity_scores/sets/aps.py @@ -12,11 +12,12 @@ class APS(Naive): - """TODO: + """ Adaptive Prediction Sets (APS) method-based non-conformity score. - Three differents method are available in this class: + Three differents method are available: - - ``"naive"``, sum of the probabilities until the 1-alpha threshold. + - ``"naive"``, that is based on the sum of the probabilities until the + 1-alpha threshold. See ``"Naive"`` class for more details. - ``"aps"`` (formerly called "cumulated_score"), Adaptive Prediction Sets method. It is based on the sum of the softmax outputs of the @@ -25,9 +26,8 @@ class APS(Naive): - ``"raps"``, Regularized Adaptive Prediction Sets method. It uses the same technique as ``"aps"`` method but with a penalty term - to reduce the size of prediction sets. See [2] for more - details. For now, this method only works with ``"prefit"`` and - ``"split"`` strategies. + to reduce the size of prediction sets. + See ``"RAPS"`` class for more details. References ---------- @@ -35,11 +35,6 @@ class APS(Naive): "Classification with Valid and Adaptive Coverage." NeurIPS 202 (spotlight) 2020. - [2] Anastasios Nikolas Angelopoulos, Stephen Bates, Michael Jordan - and Jitendra Malik. - "Uncertainty Sets for Image Classifiers using Conformal Prediction." - International Conference on Learning Representations 2021. - Attributes ---------- method: str @@ -116,7 +111,31 @@ def get_conformity_quantiles( **kwargs ) -> NDArray: """ - TODO: Compute the quantiles. + Get the quantiles of the conformity scores for each uncertainty level. + + Parameters: + ----------- + conformity_scores: NDArray of shape (n_samples,) + Conformity scores for each sample. + + alpha_np: NDArray of shape (n_alpha,) + NDArray of floats between 0 and 1, representing the uncertainty + of the confidence interval. + + estimator: EnsembleClassifier + Estimator that is fitted to predict y from X. + + agg_scores: Optional[str] + Method to aggregate the scores from the base estimators. + If "mean", the scores are averaged. If "crossval", the scores are + obtained from cross-validation. + + By default ``"mean"``. + + Returns: + -------- + NDArray + Array of quantiles with respect to alpha_np. """ n = len(conformity_scores) diff --git a/mapie/conformity_scores/sets/lac.py b/mapie/conformity_scores/sets/lac.py index 48d7c04d7..2e32de7c2 100644 --- a/mapie/conformity_scores/sets/lac.py +++ b/mapie/conformity_scores/sets/lac.py @@ -123,7 +123,31 @@ def get_predictions( **kwargs ) -> NDArray: """ - TODO: Compute the predictions. + Get predictions from an EnsembleClassifier. + + Parameters: + ----------- + X: NDArray of shape (n_samples, n_features) + Observed feature values. + + alpha_np: NDArray of shape (n_alpha,) + NDArray of floats between ``0`` and ``1``, represents the + uncertainty of the confidence interval. + + estimator: EnsembleClassifier + Estimator that is fitted to predict y from X. + + agg_scores: Optional[str] + Method to aggregate the scores from the base estimators. + If "mean", the scores are averaged. If "crossval", the scores are + obtained from cross-validation. + + By default ``"mean"``. + + Returns: + -------- + NDArray + Array of predictions. """ y_pred_proba = estimator.predict(X, agg_scores) y_pred_proba = check_proba_normalized(y_pred_proba, axis=1) @@ -143,7 +167,31 @@ def get_conformity_quantiles( **kwargs ) -> NDArray: """ - TODO: Compute the quantiles. + Get the quantiles of the conformity scores for each uncertainty level. + + Parameters: + ----------- + conformity_scores: NDArray of shape (n_samples,) + Conformity scores for each sample. + + alpha_np: NDArray of shape (n_alpha,) + NDArray of floats between 0 and 1, representing the uncertainty + of the confidence interval. + + estimator: EnsembleClassifier + Estimator that is fitted to predict y from X. + + agg_scores: Optional[str] + Method to aggregate the scores from the base estimators. + If "mean", the scores are averaged. If "crossval", the scores are + obtained from cross-validation. + + By default ``"mean"``. + + Returns: + -------- + NDArray + Array of quantiles with respect to alpha_np. """ n = len(conformity_scores) @@ -165,9 +213,37 @@ def get_prediction_sets( estimator: EnsembleClassifier, agg_scores: Optional[str] = "mean", **kwargs - ): + ) -> NDArray: """ - TODO: Compute the prediction sets. + Generate prediction sets based on the probability predictions, + the conformity scores and the uncertainty level. + + Parameters: + ----------- + y_pred_proba: NDArray of shape (n_samples, n_classes) + Target prediction. + + conformity_scores: NDArray of shape (n_samples,) + Conformity scores for each sample. + + alpha_np: NDArray of shape (n_alpha,) + NDArray of floats between 0 and 1, representing the uncertainty + of the confidence interval. + + estimator: EnsembleClassifier + Estimator that is fitted to predict y from X. + + agg_scores: Optional[str] + Method to aggregate the scores from the base estimators. + If "mean", the scores are averaged. If "crossval", the scores are + obtained from cross-validation. + + By default ``"mean"``. + + Returns: + -------- + NDArray + Array of quantiles with respect to alpha_np. """ n = len(conformity_scores) diff --git a/mapie/conformity_scores/sets/naive.py b/mapie/conformity_scores/sets/naive.py index cb2df4157..eba65a604 100644 --- a/mapie/conformity_scores/sets/naive.py +++ b/mapie/conformity_scores/sets/naive.py @@ -14,33 +14,9 @@ class Naive(BaseClassificationScore): - """TODO: - Adaptive Prediction Sets (APS) method-based non-conformity score. - Three differents method are available in this class: - - - ``"naive"``, sum of the probabilities until the 1-alpha threshold. - - - ``"aps"`` (formerly called "cumulated_score"), Adaptive Prediction - Sets method. It is based on the sum of the softmax outputs of the - labels until the true label is reached, on the calibration set. - See [1] for more details. - - - ``"raps"``, Regularized Adaptive Prediction Sets method. It uses the - same technique as ``"aps"`` method but with a penalty term - to reduce the size of prediction sets. See [2] for more - details. For now, this method only works with ``"prefit"`` and - ``"split"`` strategies. - - References - ---------- - [1] Yaniv Romano, Matteo Sesia and Emmanuel J. Candès. - "Classification with Valid and Adaptive Coverage." - NeurIPS 202 (spotlight) 2020. - - [2] Anastasios Nikolas Angelopoulos, Stephen Bates, Michael Jordan - and Jitendra Malik. - "Uncertainty Sets for Image Classifiers using Conformal Prediction." - International Conference on Learning Representations 2021. + """ + Naive classification non-conformity score method that is based on the + cumulative sum of probabilities until the 1-alpha threshold. Attributes ---------- @@ -130,7 +106,31 @@ def get_predictions( **kwargs ) -> NDArray: """ - TODO: Compute the predictions. + Get predictions from an EnsembleClassifier. + + Parameters: + ----------- + X: NDArray of shape (n_samples, n_features) + Observed feature values. + + alpha_np: NDArray of shape (n_alpha,) + NDArray of floats between ``0`` and ``1``, represents the + uncertainty of the confidence interval. + + estimator: EnsembleClassifier + Estimator that is fitted to predict y from X. + + agg_scores: Optional[str] + Method to aggregate the scores from the base estimators. + If "mean", the scores are averaged. If "crossval", the scores are + obtained from cross-validation. + + By default ``"mean"``. + + Returns: + -------- + NDArray + Array of predictions. """ y_pred_proba = estimator.predict(X, agg_scores) y_pred_proba = check_proba_normalized(y_pred_proba, axis=1) @@ -148,16 +148,52 @@ def get_conformity_quantiles( **kwargs ) -> NDArray: """ - TODO: Compute the quantiles. + Get the quantiles of the conformity scores for each uncertainty level. + + Parameters: + ----------- + conformity_scores: NDArray of shape (n_samples,) + Conformity scores for each sample. + + alpha_np: NDArray of shape (n_alpha,) + NDArray of floats between 0 and 1, representing the uncertainty + of the confidence interval. + + estimator: EnsembleClassifier + Estimator that is fitted to predict y from X. + + Returns: + -------- + NDArray + Array of quantiles with respect to alpha_np. """ quantiles_ = 1 - alpha_np return quantiles_ def _add_regualization( self, - y_pred_proba_sorted_cumsum, + y_pred_proba_sorted_cumsum: NDArray, **kwargs ): + """ + Add regularization to the sorted cumulative sum of predicted + probabilities. + + Parameters + ---------- + y_pred_proba_sorted_cumsum: NDArray of shape (n_samples, n_classes) + The sorted cumulative sum of predicted probabilities. + + **kwargs: dict, optional + Additional keyword arguments that might be used. + The current implementation does not use any. + + Returns + ------- + NDArray + The adjusted cumulative sum of predicted probabilities after + applying the regularization technique. + """ return y_pred_proba_sorted_cumsum def _get_last_included_proba( @@ -244,14 +280,33 @@ def _get_last_included_proba( def _compute_vs_parameter( self, - y_proba_last_cumsumed, - threshold, - y_pred_proba_last, - prediction_sets, - *kwargs - ): + y_proba_last_cumsumed: NDArray, + threshold: NDArray, + y_pred_proba_last: NDArray, + prediction_sets: NDArray, + **kwargs + ) -> NDArray: """ - TODO + Compute the V parameters from Romano+(2020). + + Parameters: + ----------- + y_proba_last_cumsumed: NDArray of shape (n_samples, n_alpha) + Cumulated score of the last included label. + + threshold: NDArray of shape (n_alpha,) or shape (n_samples_train,) + Threshold to compare with y_proba_last_cumsum. + + y_pred_proba_last: NDArray of shape (n_samples, 1, n_alpha) + Last included probability. + + predicition_sets: NDArray of shape (n_samples, n_alpha) + Prediction sets. + + Returns: + -------- + NDArray of shape (n_samples, n_alpha) + Vs parameters. """ # compute V parameter from Romano+(2020) vs = ( @@ -329,7 +384,7 @@ def _add_random_tie_breaking( ), axis=1 ) - # TODO + # get the V parameter from Romano+(2020) or Angelopoulos+(2020) vs = self._compute_vs_parameter( y_proba_last_cumsumed, threshold, @@ -360,9 +415,43 @@ def get_prediction_sets( agg_scores: Optional[str] = "mean", include_last_label: Optional[Union[bool, str]] = True, **kwargs - ): + ) -> NDArray: """ - TODO: Compute the prediction sets. + Generate prediction sets based on the probability predictions, + the conformity scores and the uncertainty level. + + Parameters: + ----------- + y_pred_proba: NDArray of shape (n_samples, n_classes) + Target prediction. + + conformity_scores: NDArray of shape (n_samples,) + Conformity scores for each sample. + + alpha_np: NDArray of shape (n_alpha,) + NDArray of floats between 0 and 1, representing the uncertainty + of the confidence interval. + + estimator: EnsembleClassifier + Estimator that is fitted to predict y from X. + + agg_scores: Optional[str] + Method to aggregate the scores from the base estimators. + If "mean", the scores are averaged. If "crossval", the scores are + obtained from cross-validation. + + By default ``"mean"``. + + include_last_label: Optional[Union[bool, str]] + Whether or not to include last label in prediction sets. + Choose among ``False``, ``True`` or ``"randomized"``. + + By default, ``True``. + + Returns: + -------- + NDArray + Array of quantiles with respect to alpha_np. """ include_last_label = check_include_last_label(include_last_label) diff --git a/mapie/conformity_scores/sets/raps.py b/mapie/conformity_scores/sets/raps.py index 8ee2bdea1..e52da6271 100644 --- a/mapie/conformity_scores/sets/raps.py +++ b/mapie/conformity_scores/sets/raps.py @@ -13,30 +13,26 @@ class RAPS(APS): - """TODO: - Adaptive Prediction Sets (APS) method-based non-conformity score. - Three differents method are available in this class: + """ + Regularized Adaptive Prediction Sets (RAPS) method-based non-conformity + score. Three differents method are available: - - ``"naive"``, sum of the probabilities until the 1-alpha threshold. + - ``"naive"``, that is based on the sum of the probabilities until the + 1-alpha threshold. See ``"Naive"`` class for more details. - ``"aps"`` (formerly called "cumulated_score"), Adaptive Prediction Sets method. It is based on the sum of the softmax outputs of the labels until the true label is reached, on the calibration set. - See [1] for more details. + See ``"APS"`` class for more details. - ``"raps"``, Regularized Adaptive Prediction Sets method. It uses the - same technique as ``"aps"`` method but with a penalty term - to reduce the size of prediction sets. See [2] for more - details. For now, this method only works with ``"prefit"`` and - ``"split"`` strategies. + same technique as ``"aps"`` method but with a penalty term to reduce + the size of prediction sets. See [1] for more details. For now, this + method only works with ``"prefit"`` and ``"split"`` strategies. References ---------- - [1] Yaniv Romano, Matteo Sesia and Emmanuel J. Candès. - "Classification with Valid and Adaptive Coverage." - NeurIPS 202 (spotlight) 2020. - - [2] Anastasios Nikolas Angelopoulos, Stephen Bates, Michael Jordan + [1] Anastasios Nikolas Angelopoulos, Stephen Bates, Michael Jordan and Jitendra Malik. "Uncertainty Sets for Image Classifiers using Conformal Prediction." International Conference on Learning Representations 2021. @@ -104,7 +100,7 @@ def _regularize_conformity_score( ) -> NDArray: """ Regularize the conformity scores with the ``"raps"`` - method. See algo. 2 in [3]. TODO: add ref. + method. See algo. 2 in [1]. Parameters ---------- @@ -231,7 +227,7 @@ def _find_lambda_star( lambda_star = np.zeros(len(alpha_np)) best_sizes = np.full(len(alpha_np), np.finfo(np.float64).max) - for lambda_ in [.001, .01, .1, .2, .5]: # values given in paper[3]TODO + for lambda_ in [.001, .01, .1, .2, .5]: # values given in paper[1] true_label_cumsum_proba, cutoff = ( get_true_label_cumsum_proba( y_raps_no_enc, @@ -286,7 +282,59 @@ def get_conformity_quantiles( **kwargs ) -> NDArray: """ - TODO: Compute the quantiles. + Get the quantiles of the conformity scores for each uncertainty level. + + Parameters: + ----------- + conformity_scores: NDArray of shape (n_samples,) + Conformity scores for each sample. + + alpha_np: NDArray of shape (n_alpha,) + NDArray of floats between 0 and 1, representing the uncertainty + of the confidence interval. + + estimator: EnsembleClassifier + Estimator that is fitted to predict y from X. + + agg_scores: Optional[str] + Method to aggregate the scores from the base estimators. + If "mean", the scores are averaged. If "crossval", the scores are + obtained from cross-validation. + + By default, ``"mean"``. + + include_last_label: Optional[Union[bool, str]] + Whether or not to include last label in prediction sets. + Choose among ``False``, ``True`` or ``"randomized"``. + + By default, ``True``. + + X_raps: NDArray of shape (n_samples, n_features) + Observed feature values for the RAPS method (split data). + + By default, "None" but must be set to work. + + y_raps_no_enc: NDArray of shape (n_samples,) + Observed labels for the RAPS method (split data). + + By default, "None" but must be set to work. + + y_pred_proba_raps: NDArray of shape (n_samples, n_classes) + Predicted probabilities for the RAPS method (split data). + + By default, "None" but must be set to work. + + position_raps: NDArray of shape (n_samples,) + Position of the points in the split set for the RAPS method + (split data). These positions are returned by the function + ``get_true_label_position``. + + By default, "None" but must be set to work. + + Returns: + -------- + NDArray + Array of quantiles with respect to alpha_np. """ # Casting to NDArray to avoid mypy errors X_raps = cast(NDArray, X_raps) @@ -328,18 +376,54 @@ def get_conformity_quantiles( def _add_regualization( self, - y_pred_proba_sorted_cumsum, - lambda_=None, - k_star=None, - prediction_phase=False, + y_pred_proba_sorted_cumsum: NDArray, + lambda_: Optional[float] = None, + k_star: Optional[int] = None, + prediction_phase: bool = False, **kwargs - ): + ) -> NDArray: """ - TODO + Add regularization to the sorted cumulative sum of predicted + probabilities. + + Parameters + ---------- + y_pred_proba_sorted_cumsum: NDArray of shape (n_samples, n_classes) + The sorted cumulative sum of predicted probabilities. + + lambda_: float + The lambda value used in the paper [1]. + + By default, "None" but must be set to work. + + k_star: int + The optimal value of k (called k_reg in the paper [1]). + + By default, "None" but must be set to work. + + prediction_phase: bool, optional + Whether the function is called during the prediction phase. + If ``True``, the function will use the values of ``lambda_star`` + and ``k_star`` of the object. + + By default, ``False``. + + **kwargs: dict, optional + Additional keyword arguments that might be used. + The current implementation does not use any. + + Returns + ------- + NDArray + The adjusted cumulative sum of predicted probabilities after + applying the regularization technique. """ if prediction_phase: - lambda_ = self.lambda_star - k_star = self.k_star + lambda_ = cast(float, self.lambda_star) + k_star = cast(int, self.k_star) + else: + lambda_ = cast(float, lambda_) + k_star = cast(int, lambda_) y_pred_proba_sorted_cumsum += lambda_ * np.maximum( 0, @@ -352,14 +436,33 @@ def _add_regualization( def _compute_vs_parameter( self, - y_proba_last_cumsumed, - threshold, - y_pred_proba_last, - prediction_sets, - *kwargs - ): + y_proba_last_cumsumed: NDArray, + threshold: NDArray, + y_pred_proba_last: NDArray, + prediction_sets: NDArray, + **kwargs + ) -> NDArray: """ - TODO + Compute the V parameters from Angelopoulos+(2020). + + Parameters: + ----------- + y_proba_last_cumsumed: NDArray of shape (n_samples, n_alpha) + Cumulated score of the last included label. + + threshold: NDArray of shape (n_alpha,) or shape (n_samples_train,) + Threshold to compare with y_proba_last_cumsum. + + y_pred_proba_last: NDArray of shape (n_samples, 1, n_alpha) + Last included probability. + + predicition_sets: NDArray of shape (n_samples, n_alpha) + Prediction sets. + + Returns: + -------- + NDArray of shape (n_samples, n_alpha) + Vs parameters. """ # compute V parameter from Angelopoulos+(2020) L = np.sum(prediction_sets, axis=1) diff --git a/mapie/conformity_scores/sets/topk.py b/mapie/conformity_scores/sets/topk.py index 91667b802..94303563d 100644 --- a/mapie/conformity_scores/sets/topk.py +++ b/mapie/conformity_scores/sets/topk.py @@ -126,7 +126,26 @@ def get_predictions( **kwargs ) -> NDArray: """ - TODO: Compute the predictions. + Get predictions from an EnsembleClassifier. + + This method should be implemented by any subclass of the current class. + + Parameters: + ----------- + X: NDArray of shape (n_samples, n_features) + Observed feature values. + + alpha_np: NDArray of shape (n_alpha,) + NDArray of floats between ``0`` and ``1``, represents the + uncertainty of the confidence interval. + + estimator: EnsembleClassifier + Estimator that is fitted to predict y from X. + + Returns: + -------- + NDArray + Array of predictions. """ y_pred_proba = estimator.predict(X, agg_scores="mean") y_pred_proba = check_proba_normalized(y_pred_proba, axis=1) @@ -143,7 +162,24 @@ def get_conformity_quantiles( **kwargs ) -> NDArray: """ - TODO: Compute the quantiles. + Get the quantiles of the conformity scores for each uncertainty level. + + Parameters: + ----------- + conformity_scores: NDArray of shape (n_samples,) + Conformity scores for each sample. + + alpha_np: NDArray of shape (n_alpha,) + NDArray of floats between 0 and 1, representing the uncertainty + of the confidence interval. + + estimator: EnsembleClassifier + Estimator that is fitted to predict y from X. + + Returns: + -------- + NDArray + Array of quantiles with respect to alpha_np. """ return compute_quantiles(conformity_scores, alpha_np) @@ -154,9 +190,30 @@ def get_prediction_sets( alpha_np: NDArray, estimator: EnsembleClassifier, **kwargs - ): + ) -> NDArray: """ - TODO: Compute the prediction sets. + Generate prediction sets based on the probability predictions, + the conformity scores and the uncertainty level. + + Parameters: + ----------- + y_pred_proba: NDArray of shape (n_samples, n_classes) + Target prediction. + + conformity_scores: NDArray of shape (n_samples,) + Conformity scores for each sample. + + alpha_np: NDArray of shape (n_alpha,) + NDArray of floats between 0 and 1, representing the uncertainty + of the confidence interval. + + estimator: EnsembleClassifier + Estimator that is fitted to predict y from X. + + Returns: + -------- + NDArray + Array of quantiles with respect to alpha_np. """ y_pred_proba = y_pred_proba[:, :, 0] index_sorted = np.fliplr(np.argsort(y_pred_proba, axis=1)) From 59586a9b30dd358e60a53e7c27a7e134d2b51306 Mon Sep 17 00:00:00 2001 From: Thibault Cordier Date: Wed, 3 Jul 2024 18:41:05 +0200 Subject: [PATCH 12/46] UPD: refacto changes in history file --- HISTORY.rst | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/HISTORY.rst b/HISTORY.rst index b88fc99dc..93db7febe 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -5,6 +5,13 @@ History 0.8.x (2024-xx-xx) ------------------ +* Extend `ConformityScore` to support regression (with `BaseRegressionScore`) and to support classification (with `BaseClassificationScore`) +* Extend `EnsembleEstimator` to support regression (with `EnsembleRegressor`) and to support classification (with `EnsembleClassifier`) +* Refactor `MapieClassifier` by separating the handling of the `MapieClassifier` estimator into a new class called `EnsembleClassifier` +* Refactor `MapieClassifier` by separating the handling of the `MapieClassifier` conformity score into a new class called `BaseClassificationScore` +* Add severals non-conformity scores for classification (`LAC`, `APS`, `RAPS`, `TopK`) based on `BaseClassificationScore` +* Transfer the logic of classification methods into the non-conformity score classes (`LAC`, `APS`, `RAPS`, `TopK`) +* Extend the classification strategy definition by supporting `method` and `conformity_score` attributes * Building unit tests for different `Subsample` and `BlockBooststrap` instances * Change the sign of C_k in the `Kolmogorov-Smirnov` test documentation * Building a training set with a fraction between 0 and 1 with `n_samples` attribute when using `split` method from `Subsample` class. From 7d053b76ad632164e83f286b56cebc93b8bcab9b Mon Sep 17 00:00:00 2001 From: Thibault Cordier Date: Thu, 4 Jul 2024 14:57:36 +0200 Subject: [PATCH 13/46] UPD: add missing attributes + keep quantiles_ attribute --- mapie/classification.py | 8 ++++++++ mapie/regression/regression.py | 3 +++ 2 files changed, 11 insertions(+) diff --git a/mapie/classification.py b/mapie/classification.py index 626149add..06c0a543d 100644 --- a/mapie/classification.py +++ b/mapie/classification.py @@ -140,12 +140,18 @@ class MapieClassifier(BaseEstimator, ClassifierMixin): estimator_: EnsembleClassifier Sklearn estimator that handle all that is related to the estimator. + conformity_score_function_: BaseClassificationScore + Score function that handle all that is related to conformity scores. + n_features_in_: int Number of features passed to the fit method. conformity_scores_: ArrayLike of shape (n_samples_train) The conformity scores used to calibrate the prediction sets. + quantiles_: ArrayLike of shape (n_alpha) + The quantiles estimated from ``conformity_scores_`` and alpha values. + References ---------- [1] Mauricio Sadinle, Jing Lei, and Larry Wasserman. @@ -741,4 +747,6 @@ def predict( **kwargs ) + self.quantiles_ = self.conformity_score_function_.quantiles_ + return y_pred, prediction_sets diff --git a/mapie/regression/regression.py b/mapie/regression/regression.py index 018c30677..88e827368 100644 --- a/mapie/regression/regression.py +++ b/mapie/regression/regression.py @@ -165,6 +165,9 @@ class MapieRegressor(BaseEstimator, RegressorMixin): estimator_: EnsembleRegressor Sklearn estimator that handle all that is related to the estimator. + conformity_score_function_: BaseRegressionScore + Score function that handle all that is related to conformity scores. + conformity_scores_: ArrayLike of shape (n_samples_train,) Conformity scores between ``y_train`` and ``y_pred``. From 15b31ff401acffcbb2da0813eb93ce5802909edc Mon Sep 17 00:00:00 2001 From: Thibault Cordier Date: Thu, 4 Jul 2024 16:22:05 +0200 Subject: [PATCH 14/46] UPD: remove obsolete 'method' attribute and methods in conformity score --- mapie/classification.py | 1 - mapie/conformity_scores/sets/aps.py | 8 ------ mapie/conformity_scores/sets/lac.py | 15 ---------- mapie/conformity_scores/sets/naive.py | 16 ----------- mapie/conformity_scores/sets/raps.py | 41 +-------------------------- mapie/conformity_scores/sets/topk.py | 15 ---------- 6 files changed, 1 insertion(+), 95 deletions(-) diff --git a/mapie/classification.py b/mapie/classification.py index 06c0a543d..4fe76bc2e 100644 --- a/mapie/classification.py +++ b/mapie/classification.py @@ -452,7 +452,6 @@ def _check_fit_parameter( method=self.method ) cs_estimator.set_external_attributes( - method=self.method, classes=self.classes_, random_state=self.random_state ) diff --git a/mapie/conformity_scores/sets/aps.py b/mapie/conformity_scores/sets/aps.py index 29402b64c..1885a1863 100644 --- a/mapie/conformity_scores/sets/aps.py +++ b/mapie/conformity_scores/sets/aps.py @@ -37,14 +37,6 @@ class APS(Naive): Attributes ---------- - method: str - Method to choose for prediction interval estimates. - This attribute is for compatibility with ``MapieClassifier`` - which previously used a string instead of a score class. - Methods available in this class: ``aps``, ``raps`` and ``naive``. - - By default, ``aps`` for APS method. - classes: Optional[ArrayLike] Names of the classes. diff --git a/mapie/conformity_scores/sets/lac.py b/mapie/conformity_scores/sets/lac.py index 2e32de7c2..3708587aa 100644 --- a/mapie/conformity_scores/sets/lac.py +++ b/mapie/conformity_scores/sets/lac.py @@ -27,13 +27,6 @@ class LAC(BaseClassificationScore): Attributes ---------- - method: str - Method to choose for prediction interval estimates. - This attribute is for compatibility with ``MapieClassifier`` - which previously used a string instead of a score class. - - By default, ``lac`` for LAC method. - classes: Optional[ArrayLike] Names of the classes. @@ -49,7 +42,6 @@ def __init__(self) -> None: def set_external_attributes( self, - method: str = 'lac', classes: Optional[ArrayLike] = None, random_state: Optional[Union[int, np.random.RandomState]] = None, **kwargs @@ -59,12 +51,6 @@ def set_external_attributes( Parameters ---------- - method: str - Method to choose for prediction interval estimates. - Methods available in this class: ``lac``. - - By default ``lac`` for LAC method. - classes: Optional[ArrayLike] Names of the classes. @@ -74,7 +60,6 @@ def set_external_attributes( Pseudo random number generator state. """ super().set_external_attributes(**kwargs) - self.method = method self.classes = classes self.random_state = random_state diff --git a/mapie/conformity_scores/sets/naive.py b/mapie/conformity_scores/sets/naive.py index eba65a604..91b2115ae 100644 --- a/mapie/conformity_scores/sets/naive.py +++ b/mapie/conformity_scores/sets/naive.py @@ -20,14 +20,6 @@ class Naive(BaseClassificationScore): Attributes ---------- - method: str - Method to choose for prediction interval estimates. - This attribute is for compatibility with ``MapieClassifier`` - which previously used a string instead of a score class. - Methods available in this class: ``aps``, ``raps`` and ``naive``. - - By default, ``aps`` for APS method. - classes: Optional[ArrayLike] Names of the classes. @@ -43,7 +35,6 @@ def __init__(self) -> None: def set_external_attributes( self, - method: str = 'naive', classes: Optional[ArrayLike] = None, random_state: Optional[Union[int, np.random.RandomState]] = None, **kwargs @@ -53,12 +44,6 @@ def set_external_attributes( Parameters ---------- - method: str - Method to choose for prediction interval estimates. - Methods available in this class: ``aps``, ``raps`` and ``naive``. - - By default ``aps`` for APS method. - classes: Optional[ArrayLike] Names of the classes. @@ -68,7 +53,6 @@ def set_external_attributes( Pseudo random number generator state. """ super().set_external_attributes(**kwargs) - self.method = method self.classes = classes self.random_state = random_state diff --git a/mapie/conformity_scores/sets/raps.py b/mapie/conformity_scores/sets/raps.py index e52da6271..ec39f8e2e 100644 --- a/mapie/conformity_scores/sets/raps.py +++ b/mapie/conformity_scores/sets/raps.py @@ -7,7 +7,7 @@ from mapie.estimator.classifier import EnsembleClassifier from mapie._machine_precision import EPSILON -from mapie._typing import ArrayLike, NDArray +from mapie._typing import NDArray from mapie.metrics import classification_mean_width_score from mapie.utils import check_alpha_and_n_samples, compute_quantiles @@ -39,14 +39,6 @@ class RAPS(APS): Attributes ---------- - method: str - Method to choose for prediction interval estimates. - This attribute is for compatibility with ``MapieClassifier`` - which previously used a string instead of a score class. - Methods available in this class: ``aps``, ``raps`` and ``naive``. - - By default, ``aps`` for APS method. - classes: Optional[ArrayLike] Names of the classes. @@ -60,37 +52,6 @@ class RAPS(APS): def __init__(self) -> None: super().__init__() - def set_external_attributes( - self, - method: str = 'raps', - classes: Optional[ArrayLike] = None, - random_state: Optional[Union[int, np.random.RandomState]] = None, - **kwargs - ) -> None: - """ - Set attributes that are not provided by the user. - - Parameters - ---------- - method: str - Method to choose for prediction interval estimates. - Methods available in this class: ``aps``, ``raps`` and ``naive``. - - By default ``aps`` for APS method. - - classes: Optional[ArrayLike] - Names of the classes. - - By default ``None``. - - random_state: Optional[Union[int, RandomState]] - Pseudo random number generator state. - """ - super().set_external_attributes(**kwargs) - self.method = method - self.classes = classes - self.random_state = random_state - @staticmethod def _regularize_conformity_score( k_star: NDArray, diff --git a/mapie/conformity_scores/sets/topk.py b/mapie/conformity_scores/sets/topk.py index 94303563d..9723b8a27 100644 --- a/mapie/conformity_scores/sets/topk.py +++ b/mapie/conformity_scores/sets/topk.py @@ -31,13 +31,6 @@ class TopK(BaseClassificationScore): Attributes ---------- - method: str - Method to choose for prediction interval estimates. - This attribute is for compatibility with ``MapieClassifier`` - which previously used a string instead of a score class. - - By default, ``top_k`` for Top-K method. - classes: Optional[ArrayLike] Names of the classes. @@ -53,7 +46,6 @@ def __init__(self) -> None: def set_external_attributes( self, - method: str = 'top_k', classes: Optional[int] = None, random_state: Optional[Union[int, np.random.RandomState]] = None, **kwargs @@ -63,12 +55,6 @@ def set_external_attributes( Parameters ---------- - method: str - Method to choose for prediction interval estimates. - Methods available in this class: ``top_k``. - - By default ``top_k`` for Top-K method. - classes: Optional[ArrayLike] Names of the classes. @@ -78,7 +64,6 @@ def set_external_attributes( Pseudo random number generator state. """ super().set_external_attributes(**kwargs) - self.method = method self.classes = classes self.random_state = random_state From ee89b53a3cbf3627099fc9544fc95c45e015fe59 Mon Sep 17 00:00:00 2001 From: Thibault Cordier Date: Thu, 4 Jul 2024 16:24:29 +0200 Subject: [PATCH 15/46] UPD: reduce doctring --- mapie/conformity_scores/sets/aps.py | 16 ++-------------- mapie/conformity_scores/sets/raps.py | 17 +++-------------- 2 files changed, 5 insertions(+), 28 deletions(-) diff --git a/mapie/conformity_scores/sets/aps.py b/mapie/conformity_scores/sets/aps.py index 1885a1863..af8c4ffcb 100644 --- a/mapie/conformity_scores/sets/aps.py +++ b/mapie/conformity_scores/sets/aps.py @@ -14,20 +14,8 @@ class APS(Naive): """ Adaptive Prediction Sets (APS) method-based non-conformity score. - Three differents method are available: - - - ``"naive"``, that is based on the sum of the probabilities until the - 1-alpha threshold. See ``"Naive"`` class for more details. - - - ``"aps"`` (formerly called "cumulated_score"), Adaptive Prediction - Sets method. It is based on the sum of the softmax outputs of the - labels until the true label is reached, on the calibration set. - See [1] for more details. - - - ``"raps"``, Regularized Adaptive Prediction Sets method. It uses the - same technique as ``"aps"`` method but with a penalty term - to reduce the size of prediction sets. - See ``"RAPS"`` class for more details. + It is based on the sum of the softmax outputs of the labels until the true + label is reached, on the calibration set. See [1] for more details. References ---------- diff --git a/mapie/conformity_scores/sets/raps.py b/mapie/conformity_scores/sets/raps.py index ec39f8e2e..cc0dd3e03 100644 --- a/mapie/conformity_scores/sets/raps.py +++ b/mapie/conformity_scores/sets/raps.py @@ -15,20 +15,9 @@ class RAPS(APS): """ Regularized Adaptive Prediction Sets (RAPS) method-based non-conformity - score. Three differents method are available: - - - ``"naive"``, that is based on the sum of the probabilities until the - 1-alpha threshold. See ``"Naive"`` class for more details. - - - ``"aps"`` (formerly called "cumulated_score"), Adaptive Prediction - Sets method. It is based on the sum of the softmax outputs of the - labels until the true label is reached, on the calibration set. - See ``"APS"`` class for more details. - - - ``"raps"``, Regularized Adaptive Prediction Sets method. It uses the - same technique as ``"aps"`` method but with a penalty term to reduce - the size of prediction sets. See [1] for more details. For now, this - method only works with ``"prefit"`` and ``"split"`` strategies. + score. It uses the same technique as ``APS`` class but with a penalty term + to reduce the size of prediction sets. See [1] for more details. For now, + this method only works with ``"prefit"`` and ``"split"`` strategies. References ---------- From d9e498903e18e9d5cd6c5f29922604e995e425b3 Mon Sep 17 00:00:00 2001 From: Thibault Cordier Date: Thu, 4 Jul 2024 16:26:17 +0200 Subject: [PATCH 16/46] UPD: change method name --- mapie/conformity_scores/classification.py | 4 ++-- mapie/conformity_scores/sets/aps.py | 2 +- mapie/conformity_scores/sets/lac.py | 2 +- mapie/conformity_scores/sets/naive.py | 2 +- mapie/conformity_scores/sets/raps.py | 2 +- mapie/conformity_scores/sets/topk.py | 2 +- 6 files changed, 7 insertions(+), 7 deletions(-) diff --git a/mapie/conformity_scores/classification.py b/mapie/conformity_scores/classification.py index ace093661..727cde104 100644 --- a/mapie/conformity_scores/classification.py +++ b/mapie/conformity_scores/classification.py @@ -53,7 +53,7 @@ def get_predictions( """ @abstractmethod - def get_conformity_quantiles( + def get_conformity_score_quantiles( self, conformity_scores: NDArray, alpha_np: NDArray, @@ -157,7 +157,7 @@ def get_sets( ) # Choice of the quantile - self.quantiles_ = self.get_conformity_quantiles( + self.quantiles_ = self.get_conformity_score_quantiles( conformity_scores, alpha_np, estimator, **kwargs ) diff --git a/mapie/conformity_scores/sets/aps.py b/mapie/conformity_scores/sets/aps.py index af8c4ffcb..fbb9186e2 100644 --- a/mapie/conformity_scores/sets/aps.py +++ b/mapie/conformity_scores/sets/aps.py @@ -82,7 +82,7 @@ def get_conformity_scores( return conformity_scores - def get_conformity_quantiles( + def get_conformity_score_quantiles( self, conformity_scores: NDArray, alpha_np: NDArray, diff --git a/mapie/conformity_scores/sets/lac.py b/mapie/conformity_scores/sets/lac.py index 3708587aa..464f6096d 100644 --- a/mapie/conformity_scores/sets/lac.py +++ b/mapie/conformity_scores/sets/lac.py @@ -143,7 +143,7 @@ def get_predictions( return y_pred_proba - def get_conformity_quantiles( + def get_conformity_score_quantiles( self, conformity_scores: NDArray, alpha_np: NDArray, diff --git a/mapie/conformity_scores/sets/naive.py b/mapie/conformity_scores/sets/naive.py index 91b2115ae..868aeecdf 100644 --- a/mapie/conformity_scores/sets/naive.py +++ b/mapie/conformity_scores/sets/naive.py @@ -124,7 +124,7 @@ def get_predictions( ) return y_pred_proba - def get_conformity_quantiles( + def get_conformity_score_quantiles( self, conformity_scores: NDArray, alpha_np: NDArray, diff --git a/mapie/conformity_scores/sets/raps.py b/mapie/conformity_scores/sets/raps.py index cc0dd3e03..4d3e95f2b 100644 --- a/mapie/conformity_scores/sets/raps.py +++ b/mapie/conformity_scores/sets/raps.py @@ -218,7 +218,7 @@ def _find_lambda_star( return lambda_star - def get_conformity_quantiles( + def get_conformity_score_quantiles( self, conformity_scores: NDArray, alpha_np: NDArray, diff --git a/mapie/conformity_scores/sets/topk.py b/mapie/conformity_scores/sets/topk.py index 9723b8a27..346592452 100644 --- a/mapie/conformity_scores/sets/topk.py +++ b/mapie/conformity_scores/sets/topk.py @@ -139,7 +139,7 @@ def get_predictions( ) return y_pred_proba - def get_conformity_quantiles( + def get_conformity_score_quantiles( self, conformity_scores: NDArray, alpha_np: NDArray, From 115c08086c007a55498075659b501cdfaaa4443e Mon Sep 17 00:00:00 2001 From: Thibault Cordier Date: Fri, 5 Jul 2024 10:30:14 +0200 Subject: [PATCH 17/46] DOC: change docstring and useless cast --- mapie/conformity_scores/classification.py | 2 +- mapie/conformity_scores/interface.py | 10 +++++----- mapie/conformity_scores/regression.py | 14 +++++++------- mapie/conformity_scores/sets/aps.py | 1 - 4 files changed, 13 insertions(+), 14 deletions(-) diff --git a/mapie/conformity_scores/classification.py b/mapie/conformity_scores/classification.py index 727cde104..f6e45d380 100644 --- a/mapie/conformity_scores/classification.py +++ b/mapie/conformity_scores/classification.py @@ -180,7 +180,7 @@ def predict_set( Parameters: ----------- - X: NDArray of shape (n_samples, ...) + X: NDArray of shape (n_samples,) The input data or samples for prediction. alpha_np: NDArray of shape (n_alpha, ) diff --git a/mapie/conformity_scores/interface.py b/mapie/conformity_scores/interface.py index c8e163844..3979149c0 100644 --- a/mapie/conformity_scores/interface.py +++ b/mapie/conformity_scores/interface.py @@ -45,15 +45,15 @@ def get_conformity_scores( Parameters ---------- - y: NDArray of shape (n_samples, ...) + y: NDArray of shape (n_samples,) Observed target values. - y_pred: NDArray of shape (n_samples, ...) + y_pred: NDArray of shape (n_samples,) Predicted target values. Returns ------- - NDArray of shape (n_samples, ...) + NDArray of shape (n_samples,) Conformity scores. """ @@ -70,7 +70,7 @@ def get_quantile( Parameters ---------- - conformity_scores: NDArray of shape (n_samples, ...) + conformity_scores: NDArray of shape (n_samples,) Values from which the quantile is computed. alpha_np: NDArray of shape (n_alpha,) @@ -137,7 +137,7 @@ def predict_set( Parameters: ----------- - X: NDArray of shape (n_samples, ...) + X: NDArray of shape (n_samples,) The input data or samples for prediction. alpha_np: NDArray of shape (n_alpha, ) diff --git a/mapie/conformity_scores/regression.py b/mapie/conformity_scores/regression.py index fa151d5e5..1e58cc163 100644 --- a/mapie/conformity_scores/regression.py +++ b/mapie/conformity_scores/regression.py @@ -127,13 +127,13 @@ def check_consistency( Parameters ---------- - y: NDArray of shape (n_samples, ...) + y: NDArray of shape (n_samples,) Observed target values. - y_pred: NDArray of shape (n_samples, ...) + y_pred: NDArray of shape (n_samples,) Predicted target values. - conformity_scores: NDArray of shape (n_samples, ...) + conformity_scores: NDArray of shape (n_samples,) Conformity scores. Raises @@ -175,15 +175,15 @@ def get_estimation_distribution( Parameters ---------- - y_pred: NDArray of shape (n_samples, ...) + y_pred: NDArray of shape (n_samples,) Predicted target values. - conformity_scores: NDArray of shape (n_samples, ...) + conformity_scores: NDArray of shape (n_samples,) Conformity scores. Returns ------- - NDArray of shape (n_samples, ...) + NDArray of shape (n_samples,) Observed values. """ @@ -392,7 +392,7 @@ def predict_set( Parameters: ----------- - X: NDArray of shape (n_samples, ...) + X: NDArray of shape (n_samples,) The input data or samples for prediction. alpha_np: NDArray of shape (n_alpha, ) diff --git a/mapie/conformity_scores/sets/aps.py b/mapie/conformity_scores/sets/aps.py index fbb9186e2..35d191836 100644 --- a/mapie/conformity_scores/sets/aps.py +++ b/mapie/conformity_scores/sets/aps.py @@ -76,7 +76,6 @@ def get_conformity_scores( y_pred, y_enc.reshape(-1, 1), axis=1 ) random_state = check_random_state(self.random_state) - random_state = cast(np.random.RandomState, random_state) u = random_state.uniform(size=len(y_pred)).reshape(-1, 1) conformity_scores -= u * y_proba_true From 64973c0b559d675fb33b49bff9b0249e6fe14b22 Mon Sep 17 00:00:00 2001 From: Thibault Cordier Date: Fri, 5 Jul 2024 11:38:08 +0200 Subject: [PATCH 18/46] UPD: remove useless cast + reformat typing --- mapie/conformity_scores/sets/utils.py | 22 ++++++++++++---------- mapie/estimator/classifier.py | 5 ++--- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/mapie/conformity_scores/sets/utils.py b/mapie/conformity_scores/sets/utils.py index 5917a6cb7..6ede57ea1 100644 --- a/mapie/conformity_scores/sets/utils.py +++ b/mapie/conformity_scores/sets/utils.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple, Union, cast +from typing import Optional, Tuple, Union import numpy as np from sklearn.calibration import label_binarize @@ -66,8 +66,9 @@ def get_true_label_cumsum_proba( true_label_cumsum_proba = np.take_along_axis( y_pred_sorted_cumsum, cutoff.reshape(-1, 1), axis=1 ) + cutoff += 1 - return true_label_cumsum_proba, cutoff + 1 + return true_label_cumsum_proba, cutoff def check_include_last_label( @@ -117,7 +118,7 @@ def check_include_last_label( def check_proba_normalized( - y_pred_proba: ArrayLike, + y_pred_proba: NDArray, axis: int = 1 ) -> NDArray: """ @@ -125,7 +126,7 @@ def check_proba_normalized( Parameters ---------- - y_pred_proba: ArrayLike of shape (n_samples, n_classes) or + y_pred_proba: NDArray of shape (n_samples, n_classes) or (n_samples, n_train_samples, n_classes) Softmax output of a model. @@ -139,12 +140,13 @@ def check_proba_normalized( ValueError If the sum of the scores is not equal to one. """ - sum_proba = np.sum(y_pred_proba, axis=axis) - err_msg = "The sum of the scores is not equal to one." - np.testing.assert_allclose(sum_proba, 1, err_msg=err_msg, rtol=1e-5) - y_pred_proba = cast(NDArray, y_pred_proba).astype(np.float64) - - return y_pred_proba + np.testing.assert_allclose( + np.sum(y_pred_proba, axis=axis), + 1, + err_msg="The sum of the scores is not equal to one.", + rtol=1e-5 + ) + return y_pred_proba.astype(np.float64) def get_last_index_included( diff --git a/mapie/estimator/classifier.py b/mapie/estimator/classifier.py index 16df810e2..fc0ad12ce 100644 --- a/mapie/estimator/classifier.py +++ b/mapie/estimator/classifier.py @@ -189,7 +189,7 @@ def _fit_oof_estimator( def _check_proba_normalized( y_pred_proba: ArrayLike, axis: int = 1 - ) -> NDArray: + ) -> ArrayLike: """ Check if, for all the observations, the sum of the probabilities is equal to one. @@ -216,8 +216,7 @@ def _check_proba_normalized( err_msg="The sum of the scores is not equal to one.", rtol=1e-5 ) - y_pred_proba = cast(NDArray, y_pred_proba).astype(np.float64) - return y_pred_proba + return y_pred_proba.astype(np.float64) def _predict_proba_oof_estimator( self, From cdc27166d4d9f8f41866dd31fc8060e0ea21e0b2 Mon Sep 17 00:00:00 2001 From: Thibault Cordier Date: Fri, 5 Jul 2024 11:39:02 +0200 Subject: [PATCH 19/46] FIX: float conversion removed --- mapie/estimator/classifier.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mapie/estimator/classifier.py b/mapie/estimator/classifier.py index fc0ad12ce..0c7fa16c1 100644 --- a/mapie/estimator/classifier.py +++ b/mapie/estimator/classifier.py @@ -216,7 +216,7 @@ def _check_proba_normalized( err_msg="The sum of the scores is not equal to one.", rtol=1e-5 ) - return y_pred_proba.astype(np.float64) + return y_pred_proba def _predict_proba_oof_estimator( self, From a8d47e53e0b13cbde55b2b406308a585e29bc6b8 Mon Sep 17 00:00:00 2001 From: Thibault Cordier Date: Fri, 5 Jul 2024 12:12:57 +0200 Subject: [PATCH 20/46] UPD: move methods relative to aps, from naive to aps --- mapie/conformity_scores/sets/aps.py | 277 +++++++++++++++++++++++++- mapie/conformity_scores/sets/naive.py | 216 ++------------------ 2 files changed, 288 insertions(+), 205 deletions(-) diff --git a/mapie/conformity_scores/sets/aps.py b/mapie/conformity_scores/sets/aps.py index 35d191836..ff0ebb6d7 100644 --- a/mapie/conformity_scores/sets/aps.py +++ b/mapie/conformity_scores/sets/aps.py @@ -1,12 +1,16 @@ -from typing import Optional, cast +from typing import Optional, Union, cast import numpy as np from sklearn.dummy import check_random_state from mapie.conformity_scores.sets.naive import Naive -from mapie.conformity_scores.sets.utils import get_true_label_cumsum_proba +from mapie.conformity_scores.sets.utils import ( + check_include_last_label, check_proba_normalized, + get_true_label_cumsum_proba +) from mapie.estimator.classifier import EnsembleClassifier +from mapie._machine_precision import EPSILON from mapie._typing import NDArray from mapie.utils import compute_quantiles @@ -38,6 +42,49 @@ class APS(Naive): def __init__(self) -> None: super().__init__() + def get_predictions( + self, + X: NDArray, + alpha_np: NDArray, + estimator: EnsembleClassifier, + agg_scores: Optional[str] = "mean", + **kwargs + ) -> NDArray: + """ + Get predictions from an EnsembleClassifier. + + Parameters: + ----------- + X: NDArray of shape (n_samples, n_features) + Observed feature values. + + alpha_np: NDArray of shape (n_alpha,) + NDArray of floats between ``0`` and ``1``, represents the + uncertainty of the confidence interval. + + estimator: EnsembleClassifier + Estimator that is fitted to predict y from X. + + agg_scores: Optional[str] + Method to aggregate the scores from the base estimators. + If "mean", the scores are averaged. If "crossval", the scores are + obtained from cross-validation. + + By default ``"mean"``. + + Returns: + -------- + NDArray + Array of predictions. + """ + y_pred_proba = estimator.predict(X, agg_scores) + y_pred_proba = check_proba_normalized(y_pred_proba, axis=1) + if agg_scores != "crossval": + y_pred_proba = np.repeat( + y_pred_proba[:, :, np.newaxis], len(alpha_np), axis=2 + ) + return y_pred_proba + def get_conformity_scores( self, y: NDArray, @@ -124,3 +171,229 @@ def get_conformity_score_quantiles( quantiles_ = (n + 1) * (1 - alpha_np) return quantiles_ + + def _compute_vs_parameter( + self, + y_proba_last_cumsumed: NDArray, + threshold: NDArray, + y_pred_proba_last: NDArray, + prediction_sets: NDArray, + **kwargs + ) -> NDArray: + """ + Compute the V parameters from Romano+(2020). + + Parameters: + ----------- + y_proba_last_cumsumed: NDArray of shape (n_samples, n_alpha) + Cumulated score of the last included label. + + threshold: NDArray of shape (n_alpha,) or shape (n_samples_train,) + Threshold to compare with y_proba_last_cumsum. + + y_pred_proba_last: NDArray of shape (n_samples, 1, n_alpha) + Last included probability. + + predicition_sets: NDArray of shape (n_samples, n_alpha) + Prediction sets. + + Returns: + -------- + NDArray of shape (n_samples, n_alpha) + Vs parameters. + """ + # compute V parameter from Romano+(2020) + vs = ( + (y_proba_last_cumsumed - threshold.reshape(1, -1)) / + y_pred_proba_last[:, 0, :] + ) + return vs + + def _add_random_tie_breaking( + self, + prediction_sets: NDArray, + y_pred_index_last: NDArray, + y_pred_proba_cumsum: NDArray, + y_pred_proba_last: NDArray, + threshold: NDArray, + random_state: Optional[Union[int, np.random.RandomState]] = None, + **kwargs + ) -> NDArray: + """ + Randomly remove last label from prediction set based on the + comparison between a random number and the difference between + cumulated score of the last included label and the quantile. + + Parameters + ---------- + prediction_sets: NDArray of shape + (n_samples, n_classes, n_threshold) + Prediction set for each observation and each alpha. + + y_pred_index_last: NDArray of shape (n_samples, threshold) + Index of the last included label. + + y_pred_proba_cumsum: NDArray of shape (n_samples, n_classes) + Cumsumed probability of the model in the original order. + + y_pred_proba_last: NDArray of shape (n_samples, 1, threshold) + Last included probability. + + threshold: NDArray of shape (n_alpha,) or shape (n_samples_train,) + Threshold to compare with y_proba_last_cumsum, can be either: + + - the quantiles associated with alpha values when + ``cv`` == "prefit", ``cv`` == "split" + or ``agg_scores`` is "mean" + + - the conformity score from training samples otherwise (i.e., when + ``cv`` is CV splitter and ``agg_scores`` is "crossval") + + method: str + Method that determines how to remove last label in the prediction + set. + + - if "cumulated_score" or "aps", compute V parameter + from Romano+(2020) + + - else compute V parameter from Angelopoulos+(2020) + + lambda_star: Optional[Union[NDArray, float]] of shape (n_alpha): + Optimal value of the regulizer lambda. + + k_star: Optional[NDArray] of shape (n_alpha): + Optimal value of the regulizer k. + + Returns + ------- + NDArray of shape (n_samples, n_classes, n_alpha) + Updated version of prediction_sets with randomly removed labels. + """ + # get cumsumed probabilities up to last retained label + y_proba_last_cumsumed = np.squeeze( + np.take_along_axis( + y_pred_proba_cumsum, + y_pred_index_last, + axis=1 + ), axis=1 + ) + + # get the V parameter from Romano+(2020) or Angelopoulos+(2020) + vs = self._compute_vs_parameter( + y_proba_last_cumsumed, + threshold, + y_pred_proba_last, + prediction_sets + ) + + # get random numbers for each observation and alpha value + random_state = check_random_state(random_state) + random_state = cast(np.random.RandomState, random_state) + us = random_state.uniform(size=(prediction_sets.shape[0], 1)) + # remove last label from comparison between uniform number and V + vs_less_than_us = np.less_equal(vs - us, EPSILON) + np.put_along_axis( + prediction_sets, + y_pred_index_last, + vs_less_than_us[:, np.newaxis, :], + axis=1 + ) + return prediction_sets + + def get_prediction_sets( + self, + y_pred_proba: NDArray, + conformity_scores: NDArray, + alpha_np: NDArray, + estimator: EnsembleClassifier, + agg_scores: Optional[str] = "mean", + include_last_label: Optional[Union[bool, str]] = True, + **kwargs + ) -> NDArray: + """ + Generate prediction sets based on the probability predictions, + the conformity scores and the uncertainty level. + + Parameters: + ----------- + y_pred_proba: NDArray of shape (n_samples, n_classes) + Target prediction. + + conformity_scores: NDArray of shape (n_samples,) + Conformity scores for each sample. + + alpha_np: NDArray of shape (n_alpha,) + NDArray of floats between 0 and 1, representing the uncertainty + of the confidence interval. + + estimator: EnsembleClassifier + Estimator that is fitted to predict y from X. + + agg_scores: Optional[str] + Method to aggregate the scores from the base estimators. + If "mean", the scores are averaged. If "crossval", the scores are + obtained from cross-validation. + + By default ``"mean"``. + + include_last_label: Optional[Union[bool, str]] + Whether or not to include last label in prediction sets. + Choose among ``False``, ``True`` or ``"randomized"``. + + By default, ``True``. + + Returns: + -------- + NDArray + Array of quantiles with respect to alpha_np. + """ + include_last_label = check_include_last_label(include_last_label) + + # specify which thresholds will be used + if estimator.cv == "prefit" or agg_scores in ["mean"]: + thresholds = self.quantiles_ + else: + thresholds = conformity_scores.ravel() + + # sort labels by decreasing probability + y_pred_proba_cumsum, y_pred_index_last, y_pred_proba_last = ( + self._get_last_included_proba( + y_pred_proba, + thresholds, + include_last_label, + prediction_phase=True, + **kwargs + ) + ) + # get the prediction set by taking all probabilities above the last one + if estimator.cv == "prefit" or agg_scores in ["mean"]: + y_pred_included = np.greater_equal( + y_pred_proba - y_pred_proba_last, -EPSILON + ) + else: + y_pred_included = np.less_equal( + y_pred_proba - y_pred_proba_last, EPSILON + ) + # remove last label randomly + if include_last_label == "randomized": + y_pred_included = self._add_random_tie_breaking( + y_pred_included, + y_pred_index_last, + y_pred_proba_cumsum, + y_pred_proba_last, + thresholds, + self.random_state, + **kwargs + ) + if estimator.cv == "prefit" or agg_scores in ["mean"]: + prediction_sets = y_pred_included + else: + # compute the number of times the inequality is verified + prediction_sets_summed = y_pred_included.sum(axis=2) + prediction_sets = np.less_equal( + prediction_sets_summed[:, :, np.newaxis] + - self.quantiles_[np.newaxis, np.newaxis, :], + EPSILON + ) + + return prediction_sets diff --git a/mapie/conformity_scores/sets/naive.py b/mapie/conformity_scores/sets/naive.py index 868aeecdf..259753021 100644 --- a/mapie/conformity_scores/sets/naive.py +++ b/mapie/conformity_scores/sets/naive.py @@ -1,11 +1,10 @@ -from typing import Optional, Tuple, Union, cast +from typing import Optional, Tuple, Union import numpy as np -from sklearn.dummy import check_random_state from mapie.conformity_scores.classification import BaseClassificationScore from mapie.conformity_scores.sets.utils import ( - check_include_last_label, check_proba_normalized, get_last_index_included + check_proba_normalized, get_last_index_included ) from mapie.estimator.classifier import EnsembleClassifier @@ -86,7 +85,6 @@ def get_predictions( X: NDArray, alpha_np: NDArray, estimator: EnsembleClassifier, - agg_scores: Optional[str] = "mean", **kwargs ) -> NDArray: """ @@ -104,24 +102,16 @@ def get_predictions( estimator: EnsembleClassifier Estimator that is fitted to predict y from X. - agg_scores: Optional[str] - Method to aggregate the scores from the base estimators. - If "mean", the scores are averaged. If "crossval", the scores are - obtained from cross-validation. - - By default ``"mean"``. - Returns: -------- NDArray Array of predictions. """ - y_pred_proba = estimator.predict(X, agg_scores) + y_pred_proba = estimator.predict(X, agg_scores='mean') y_pred_proba = check_proba_normalized(y_pred_proba, axis=1) - if agg_scores != "crossval": - y_pred_proba = np.repeat( - y_pred_proba[:, :, np.newaxis], len(alpha_np), axis=2 - ) + y_pred_proba = np.repeat( + y_pred_proba[:, :, np.newaxis], len(alpha_np), axis=2 + ) return y_pred_proba def get_conformity_score_quantiles( @@ -262,142 +252,12 @@ def _get_last_included_proba( return y_pred_proba_cumsum, y_pred_index_last, y_pred_proba_last - def _compute_vs_parameter( - self, - y_proba_last_cumsumed: NDArray, - threshold: NDArray, - y_pred_proba_last: NDArray, - prediction_sets: NDArray, - **kwargs - ) -> NDArray: - """ - Compute the V parameters from Romano+(2020). - - Parameters: - ----------- - y_proba_last_cumsumed: NDArray of shape (n_samples, n_alpha) - Cumulated score of the last included label. - - threshold: NDArray of shape (n_alpha,) or shape (n_samples_train,) - Threshold to compare with y_proba_last_cumsum. - - y_pred_proba_last: NDArray of shape (n_samples, 1, n_alpha) - Last included probability. - - predicition_sets: NDArray of shape (n_samples, n_alpha) - Prediction sets. - - Returns: - -------- - NDArray of shape (n_samples, n_alpha) - Vs parameters. - """ - # compute V parameter from Romano+(2020) - vs = ( - (y_proba_last_cumsumed - threshold.reshape(1, -1)) / - y_pred_proba_last[:, 0, :] - ) - return vs - - def _add_random_tie_breaking( - self, - prediction_sets: NDArray, - y_pred_index_last: NDArray, - y_pred_proba_cumsum: NDArray, - y_pred_proba_last: NDArray, - threshold: NDArray, - random_state: Optional[Union[int, np.random.RandomState]] = None, - **kwargs - ) -> NDArray: - """ - Randomly remove last label from prediction set based on the - comparison between a random number and the difference between - cumulated score of the last included label and the quantile. - - Parameters - ---------- - prediction_sets: NDArray of shape - (n_samples, n_classes, n_threshold) - Prediction set for each observation and each alpha. - - y_pred_index_last: NDArray of shape (n_samples, threshold) - Index of the last included label. - - y_pred_proba_cumsum: NDArray of shape (n_samples, n_classes) - Cumsumed probability of the model in the original order. - - y_pred_proba_last: NDArray of shape (n_samples, 1, threshold) - Last included probability. - - threshold: NDArray of shape (n_alpha,) or shape (n_samples_train,) - Threshold to compare with y_proba_last_cumsum, can be either: - - - the quantiles associated with alpha values when - ``cv`` == "prefit", ``cv`` == "split" - or ``agg_scores`` is "mean" - - - the conformity score from training samples otherwise (i.e., when - ``cv`` is CV splitter and ``agg_scores`` is "crossval") - - method: str - Method that determines how to remove last label in the prediction - set. - - - if "cumulated_score" or "aps", compute V parameter - from Romano+(2020) - - - else compute V parameter from Angelopoulos+(2020) - - lambda_star: Optional[Union[NDArray, float]] of shape (n_alpha): - Optimal value of the regulizer lambda. - - k_star: Optional[NDArray] of shape (n_alpha): - Optimal value of the regulizer k. - - Returns - ------- - NDArray of shape (n_samples, n_classes, n_alpha) - Updated version of prediction_sets with randomly removed labels. - """ - # get cumsumed probabilities up to last retained label - y_proba_last_cumsumed = np.squeeze( - np.take_along_axis( - y_pred_proba_cumsum, - y_pred_index_last, - axis=1 - ), axis=1 - ) - - # get the V parameter from Romano+(2020) or Angelopoulos+(2020) - vs = self._compute_vs_parameter( - y_proba_last_cumsumed, - threshold, - y_pred_proba_last, - prediction_sets - ) - - # get random numbers for each observation and alpha value - random_state = check_random_state(random_state) - random_state = cast(np.random.RandomState, random_state) - us = random_state.uniform(size=(prediction_sets.shape[0], 1)) - # remove last label from comparison between uniform number and V - vs_less_than_us = np.less_equal(vs - us, EPSILON) - np.put_along_axis( - prediction_sets, - y_pred_index_last, - vs_less_than_us[:, np.newaxis, :], - axis=1 - ) - return prediction_sets - def get_prediction_sets( self, y_pred_proba: NDArray, conformity_scores: NDArray, alpha_np: NDArray, estimator: EnsembleClassifier, - agg_scores: Optional[str] = "mean", - include_last_label: Optional[Union[bool, str]] = True, **kwargs ) -> NDArray: """ @@ -419,72 +279,22 @@ def get_prediction_sets( estimator: EnsembleClassifier Estimator that is fitted to predict y from X. - agg_scores: Optional[str] - Method to aggregate the scores from the base estimators. - If "mean", the scores are averaged. If "crossval", the scores are - obtained from cross-validation. - - By default ``"mean"``. - - include_last_label: Optional[Union[bool, str]] - Whether or not to include last label in prediction sets. - Choose among ``False``, ``True`` or ``"randomized"``. - - By default, ``True``. - Returns: -------- NDArray Array of quantiles with respect to alpha_np. """ - include_last_label = check_include_last_label(include_last_label) - - # specify which thresholds will be used - if estimator.cv == "prefit" or agg_scores in ["mean"]: - thresholds = self.quantiles_ - else: - thresholds = conformity_scores.ravel() - # sort labels by decreasing probability - y_pred_proba_cumsum, y_pred_index_last, y_pred_proba_last = ( + _, _, y_pred_proba_last = ( self._get_last_included_proba( y_pred_proba, - thresholds, - include_last_label, - prediction_phase=True, - **kwargs + thresholds=self.quantiles_, + include_last_label=True ) ) - # get the prediction set by taking all probabilities - # above the last one - if estimator.cv == "prefit" or agg_scores in ["mean"]: - y_pred_included = np.greater_equal( - y_pred_proba - y_pred_proba_last, -EPSILON - ) - else: - y_pred_included = np.less_equal( - y_pred_proba - y_pred_proba_last, EPSILON - ) - # remove last label randomly - if include_last_label == "randomized": - y_pred_included = self._add_random_tie_breaking( - y_pred_included, - y_pred_index_last, - y_pred_proba_cumsum, - y_pred_proba_last, - thresholds, - self.random_state, - **kwargs - ) - if estimator.cv == "prefit" or agg_scores in ["mean"]: - prediction_sets = y_pred_included - else: - # compute the number of times the inequality is verified - prediction_sets_summed = y_pred_included.sum(axis=2) - prediction_sets = np.less_equal( - prediction_sets_summed[:, :, np.newaxis] - - self.quantiles_[np.newaxis, np.newaxis, :], - EPSILON - ) + # get the prediction set by taking all probabilities above the last one + prediction_sets = np.greater_equal( + y_pred_proba - y_pred_proba_last, -EPSILON + ) return prediction_sets From 2f0ed146b4939e40ce0c0197ca8839173af17da3 Mon Sep 17 00:00:00 2001 From: Thibault Cordier Date: Fri, 5 Jul 2024 12:19:24 +0200 Subject: [PATCH 21/46] UPD: move get_true_label_cumsum_proba to class method --- mapie/conformity_scores/sets/aps.py | 49 ++++++++++++++++++++++++--- mapie/conformity_scores/sets/raps.py | 3 +- mapie/conformity_scores/sets/utils.py | 44 ++---------------------- mapie/tests/test_classification.py | 12 +++---- 4 files changed, 53 insertions(+), 55 deletions(-) diff --git a/mapie/conformity_scores/sets/aps.py b/mapie/conformity_scores/sets/aps.py index ff0ebb6d7..e8cf5c1c3 100644 --- a/mapie/conformity_scores/sets/aps.py +++ b/mapie/conformity_scores/sets/aps.py @@ -1,17 +1,17 @@ -from typing import Optional, Union, cast +from typing import Optional, Tuple, Union, cast import numpy as np from sklearn.dummy import check_random_state +from sklearn.calibration import label_binarize from mapie.conformity_scores.sets.naive import Naive from mapie.conformity_scores.sets.utils import ( - check_include_last_label, check_proba_normalized, - get_true_label_cumsum_proba + check_include_last_label, check_proba_normalized ) from mapie.estimator.classifier import EnsembleClassifier from mapie._machine_precision import EPSILON -from mapie._typing import NDArray +from mapie._typing import ArrayLike, NDArray from mapie.utils import compute_quantiles @@ -85,6 +85,45 @@ def get_predictions( ) return y_pred_proba + @staticmethod + def get_true_label_cumsum_proba( + y: ArrayLike, + y_pred_proba: NDArray, + classes: ArrayLike + ) -> Tuple[NDArray, NDArray]: + """ + Compute the cumsumed probability of the true label. + + Parameters + ---------- + y: NDArray of shape (n_samples, ) + Array with the labels. + + y_pred_proba: NDArray of shape (n_samples, n_classes) + Predictions of the model. + + classes: NDArray of shape (n_classes, ) + Array with the classes. + + Returns + ------- + Tuple[NDArray, NDArray] of shapes (n_samples, 1) and (n_samples, ). + The first element is the cumsum probability of the true label. + The second is the sorted position of the true label. + """ + y_true = label_binarize(y=y, classes=classes) + index_sorted = np.fliplr(np.argsort(y_pred_proba, axis=1)) + y_pred_sorted = np.take_along_axis(y_pred_proba, index_sorted, axis=1) + y_true_sorted = np.take_along_axis(y_true, index_sorted, axis=1) + y_pred_sorted_cumsum = np.cumsum(y_pred_sorted, axis=1) + cutoff = np.argmax(y_true_sorted, axis=1) + true_label_cumsum_proba = np.take_along_axis( + y_pred_sorted_cumsum, cutoff.reshape(-1, 1), axis=1 + ) + cutoff += 1 + + return true_label_cumsum_proba, cutoff + def get_conformity_scores( self, y: NDArray, @@ -117,7 +156,7 @@ def get_conformity_scores( # Conformity scores conformity_scores, self.cutoff = ( - get_true_label_cumsum_proba(y, y_pred, classes) + self.get_true_label_cumsum_proba(y, y_pred, classes) ) y_proba_true = np.take_along_axis( y_pred, y_enc.reshape(-1, 1), axis=1 diff --git a/mapie/conformity_scores/sets/raps.py b/mapie/conformity_scores/sets/raps.py index 4d3e95f2b..320b3bbf0 100644 --- a/mapie/conformity_scores/sets/raps.py +++ b/mapie/conformity_scores/sets/raps.py @@ -3,7 +3,6 @@ import numpy as np from mapie.conformity_scores.sets.aps import APS -from mapie.conformity_scores.sets.utils import get_true_label_cumsum_proba from mapie.estimator.classifier import EnsembleClassifier from mapie._machine_precision import EPSILON @@ -179,7 +178,7 @@ def _find_lambda_star( for lambda_ in [.001, .01, .1, .2, .5]: # values given in paper[1] true_label_cumsum_proba, cutoff = ( - get_true_label_cumsum_proba( + self.get_true_label_cumsum_proba( y_raps_no_enc, y_pred_proba_raps[:, :, 0], classes diff --git a/mapie/conformity_scores/sets/utils.py b/mapie/conformity_scores/sets/utils.py index 6ede57ea1..5912607fb 100644 --- a/mapie/conformity_scores/sets/utils.py +++ b/mapie/conformity_scores/sets/utils.py @@ -1,8 +1,7 @@ -from typing import Optional, Tuple, Union +from typing import Optional, Union import numpy as np -from sklearn.calibration import label_binarize -from mapie._typing import ArrayLike, NDArray +from mapie._typing import NDArray from mapie._machine_precision import EPSILON @@ -32,45 +31,6 @@ def get_true_label_position( return position -def get_true_label_cumsum_proba( - y: ArrayLike, - y_pred_proba: NDArray, - classes: ArrayLike -) -> Tuple[NDArray, NDArray]: - """ - Compute the cumsumed probability of the true label. - - Parameters - ---------- - y: NDArray of shape (n_samples, ) - Array with the labels. - - y_pred_proba: NDArray of shape (n_samples, n_classes) - Predictions of the model. - - classes: NDArray of shape (n_classes, ) - Array with the classes. - - Returns - ------- - Tuple[NDArray, NDArray] of shapes (n_samples, 1) and (n_samples, ). - The first element is the cumsum probability of the true label. - The second is the sorted position of the true label. - """ - y_true = label_binarize(y=y, classes=classes) - index_sorted = np.fliplr(np.argsort(y_pred_proba, axis=1)) - y_pred_sorted = np.take_along_axis(y_pred_proba, index_sorted, axis=1) - y_true_sorted = np.take_along_axis(y_true, index_sorted, axis=1) - y_pred_sorted_cumsum = np.cumsum(y_pred_sorted, axis=1) - cutoff = np.argmax(y_true_sorted, axis=1) - true_label_cumsum_proba = np.take_along_axis( - y_pred_sorted_cumsum, cutoff.reshape(-1, 1), axis=1 - ) - cutoff += 1 - - return true_label_cumsum_proba, cutoff - - def check_include_last_label( include_last_label: Optional[Union[bool, str]] ) -> Optional[Union[bool, str]]: diff --git a/mapie/tests/test_classification.py b/mapie/tests/test_classification.py index c0ad000f4..497b8cea5 100644 --- a/mapie/tests/test_classification.py +++ b/mapie/tests/test_classification.py @@ -23,10 +23,8 @@ from mapie._typing import ArrayLike, NDArray from mapie.classification import MapieClassifier -from mapie.conformity_scores.sets.raps import RAPS -from mapie.conformity_scores.sets.utils import ( - check_proba_normalized, get_true_label_cumsum_proba -) +from mapie.conformity_scores import APS, RAPS +from mapie.conformity_scores.sets.utils import check_proba_normalized from mapie.metrics import classification_coverage_score from mapie.utils import check_alpha @@ -1759,7 +1757,7 @@ def test_get_true_label_cumsum_proba_shape() -> None: ) mapie_clf.fit(X, y) classes = mapie_clf.classes_ - cumsum_proba, cutoff = get_true_label_cumsum_proba(y, y_pred, classes) + cumsum_proba, cutoff = APS.get_true_label_cumsum_proba(y, y_pred, classes) assert cumsum_proba.shape == (len(X), 1) assert cutoff.shape == (len(X), ) @@ -1777,7 +1775,9 @@ def test_get_true_label_cumsum_proba_result() -> None: ) mapie_clf.fit(X_toy, y_toy) classes = mapie_clf.classes_ - cumsum_proba, cutoff = get_true_label_cumsum_proba(y_toy, y_pred, classes) + cumsum_proba, cutoff = APS.get_true_label_cumsum_proba( + y_toy, y_pred, classes + ) np.testing.assert_allclose( cumsum_proba, np.array( From 0a5ac6e392d53ae1139ac837d5f1fbffcde17b5f Mon Sep 17 00:00:00 2001 From: Thibault Cordier Date: Fri, 5 Jul 2024 12:21:41 +0200 Subject: [PATCH 22/46] UPD: add test wrong method in conformity score --- mapie/tests/test_conformity_scores_sets.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/mapie/tests/test_conformity_scores_sets.py b/mapie/tests/test_conformity_scores_sets.py index b6349b4fc..b6d63fde6 100644 --- a/mapie/tests/test_conformity_scores_sets.py +++ b/mapie/tests/test_conformity_scores_sets.py @@ -10,6 +10,7 @@ cs_list = [None, LAC(), APS(), TopK()] method_list = [None, 'naive', 'aps', 'raps', 'lac', 'top_k'] +wrong_method_list = ['naive_', 'aps_', 'raps_', 'lac_', 'top_k_'] def test_error_mother_class_initialization() -> None: @@ -35,3 +36,11 @@ def test_check_classification_method( check_classification_conformity_score(method=method), BaseClassificationScore ) + + +@pytest.mark.parametrize("method", wrong_method_list) +def test_check_wrong_classification_method( + method: Optional[str] +) -> None: + with pytest.raises(ValueError, match="Invalid method.*"): + check_classification_conformity_score(method=method) From 1f916f3610c688a71d02b629c3fa4bd19c11eece Mon Sep 17 00:00:00 2001 From: Thibault Cordier Date: Fri, 5 Jul 2024 18:05:18 +0200 Subject: [PATCH 23/46] FIX: add missing docstring --- mapie/classification.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mapie/classification.py b/mapie/classification.py index 4fe76bc2e..1f3abf10d 100644 --- a/mapie/classification.py +++ b/mapie/classification.py @@ -116,6 +116,9 @@ class MapieClassifier(BaseEstimator, ClassifierMixin): By default ``None``. + conformity_score_function_: BaseClassificationScore + Score function that handle all that is related to conformity scores. + random_state: Optional[Union[int, RandomState]] Pseudo random number generator state used for random uniform sampling for evaluation quantiles and prediction sets. From 17cfbcadba3bd0e82afd06505fa34ff99feae65e Mon Sep 17 00:00:00 2001 From: Thibault Cordier Date: Fri, 5 Jul 2024 19:01:25 +0200 Subject: [PATCH 24/46] UPD: change default attribute as done in MapieRegressor --- mapie/classification.py | 86 ++++++------------------------ mapie/conformity_scores/utils.py | 73 +++++++++++++++++++++++-- mapie/tests/test_classification.py | 6 --- 3 files changed, 87 insertions(+), 78 deletions(-) diff --git a/mapie/classification.py b/mapie/classification.py index 1f3abf10d..153e859e8 100644 --- a/mapie/classification.py +++ b/mapie/classification.py @@ -9,14 +9,14 @@ StratifiedShuffleSplit) from sklearn.preprocessing import LabelEncoder from sklearn.utils import _safe_indexing, check_random_state -from sklearn.utils.multiclass import (check_classification_targets, - type_of_target) from sklearn.utils.validation import (_check_y, _num_samples, check_is_fitted, indexable) from mapie._typing import ArrayLike, NDArray from mapie.conformity_scores import BaseClassificationScore -from mapie.conformity_scores.utils import check_classification_conformity_score +from mapie.conformity_scores.utils import ( + check_classification_conformity_score, check_target +) from mapie.conformity_scores.sets.utils import get_true_label_position from mapie.estimator.classifier import EnsembleClassifier from mapie.utils import (check_alpha, check_alpha_and_n_samples, check_cv, @@ -68,7 +68,12 @@ class MapieClassifier(BaseEstimator, ClassifierMixin): prediction sets may be different from the others. See [3] for more details. - By default ``"lac"``. + - ``None``, that does not specify the method used. + + In any case, the `method` parameter does not take precedence over the + `conformity_score` parameter to define the method used. + + By default ``None``. cv: Optional[str] The cross-validation strategy for computing scores. @@ -119,6 +124,11 @@ class MapieClassifier(BaseEstimator, ClassifierMixin): conformity_score_function_: BaseClassificationScore Score function that handle all that is related to conformity scores. + In any case, the `conformity_score` parameter takes precedence over the + `method` parameter to define the method used. + + By default ``None``. + random_state: Optional[Union[int, RandomState]] Pseudo random number generator state used for random uniform sampling for evaluation quantiles and prediction sets. @@ -193,9 +203,6 @@ class MapieClassifier(BaseEstimator, ClassifierMixin): """ raps_valid_cv_ = ["prefit", "split"] - valid_methods_ = [ - "naive", "score", "lac", "cumulated_score", "aps", "top_k", "raps" - ] fit_attributes = [ "estimator_", "n_features_in_", @@ -208,7 +215,7 @@ class MapieClassifier(BaseEstimator, ClassifierMixin): def __init__( self, estimator: Optional[ClassifierMixin] = None, - method: str = "lac", + method: Optional[str] = None, cv: Optional[Union[int, str, BaseCrossValidator]] = None, test_size: Optional[Union[int, float]] = None, n_jobs: Optional[int] = None, @@ -234,70 +241,11 @@ def _check_parameters(self) -> None: ValueError If parameters are not valid. """ - if self.method not in self.valid_methods_: - raise ValueError( - "Invalid method. " - f"Allowed values are {self.valid_methods_}." - ) check_n_jobs(self.n_jobs) check_verbose(self.verbose) check_random_state(self.random_state) - self._check_depreciated() self._check_raps() - def _check_depreciated(self) -> None: - """ - Check if the chosen method is outdated. - - Raises - ------ - Warning - If method is ``"score"`` (not ``"lac"``) or - if method is ``"cumulated_score"`` (not ``"aps"``). - """ - if self.method == "score": - warnings.warn( - "WARNING: Deprecated method. " - + "The method \"score\" is outdated. " - + "Prefer to use \"lac\" instead to keep " - + "the same behavior in the next release.", - DeprecationWarning - ) - if self.method == "cumulated_score": - warnings.warn( - "WARNING: Deprecated method. " - + "The method \"cumulated_score\" is outdated. " - + "Prefer to use \"aps\" instead to keep " - + "the same behavior in the next release.", - DeprecationWarning - ) - - def _check_target(self, y: ArrayLike) -> None: - """ - Check that if the type of target is binary, - (then the method have to be ``"lac"``), or multi-class. - - Parameters - ---------- - y: NDArray of shape (n_samples,) - Training labels. - - Raises - ------ - ValueError - If type of target is binary and method is not ``"lac"`` - or ``"score"`` or if type of target is not multi-class. - """ - check_classification_targets(y) - if type_of_target(y) == "binary" and \ - self.method not in ["score", "lac"]: - raise ValueError( - "Invalid method for binary target. " - "Your target is not of type multiclass and " - "allowed values for binary type are " - f"{['score', 'lac']}." - ) - def _check_raps(self): """ Check that if the method used is ``"raps"``, then @@ -448,8 +396,6 @@ def _check_fit_parameter( self.label_encoder_ = self._get_label_encoder() y_enc = self.label_encoder_.transform(y) - self._check_target(y) - cs_estimator = check_classification_conformity_score( conformity_score=self.conformity_score, method=self.method @@ -459,6 +405,8 @@ def _check_fit_parameter( random_state=self.random_state ) + check_target(cs_estimator, y) + return ( estimator, cs_estimator, cv, X, y, y_enc, sample_weight, groups, n_samples diff --git a/mapie/conformity_scores/utils.py b/mapie/conformity_scores/utils.py index d2b0c6cc9..801f2a5fe 100644 --- a/mapie/conformity_scores/utils.py +++ b/mapie/conformity_scores/utils.py @@ -1,10 +1,16 @@ from typing import Optional +import warnings + +from sklearn.utils.multiclass import (check_classification_targets, + type_of_target) from .regression import BaseRegressionScore from .classification import BaseClassificationScore from .bounds import AbsoluteConformityScore from .sets import APS, LAC, Naive, RAPS, TopK +from mapie._typing import ArrayLike + def check_regression_conformity_score( conformity_score: Optional[BaseRegressionScore], @@ -42,6 +48,68 @@ def check_regression_conformity_score( ) +def _check_depreciated( + method: str +) -> None: + """ + Check if the chosen method is outdated. + + Raises + ------ + Warning + If method is ``"score"`` (not ``"lac"``) or + if method is ``"cumulated_score"`` (not ``"aps"``). + """ + if method == "score": + warnings.warn( + "WARNING: Deprecated method. " + + "The method \"score\" is outdated. " + + "Prefer to use \"lac\" instead to keep " + + "the same behavior in the next release.", + DeprecationWarning + ) + if method == "cumulated_score": + warnings.warn( + "WARNING: Deprecated method. " + + "The method \"cumulated_score\" is outdated. " + + "Prefer to use \"aps\" instead to keep " + + "the same behavior in the next release.", + DeprecationWarning + ) + + +def check_target( + conformity_score: BaseClassificationScore, + y: ArrayLike +) -> None: + """ + Check that if the type of target is binary, + (then the method have to be ``"lac"``), or multi-class. + + Parameters + ---------- + conformity_score: BaseClassificationScore + Conformity score function. + + y: NDArray of shape (n_samples,) + Training labels. + + Raises + ------ + ValueError + If type of target is binary and method is not ``"lac"`` + or ``"score"`` or if type of target is not multi-class. + """ + check_classification_targets(y) + if type_of_target(y) == "binary" and not isinstance(conformity_score, LAC): + raise ValueError( + "Invalid method for binary target. " + "Your target is not of type multiclass and " + "allowed values for binary type are " + f"{['score', 'lac']}." + ) + + def check_classification_conformity_score( conformity_score: Optional[BaseClassificationScore] = None, method: Optional[str] = None, @@ -68,8 +136,8 @@ def check_classification_conformity_score( Must be None or a ConformityScore instance. """ allowed_methods = ['lac', 'naive', 'aps', 'raps', 'top_k'] - deprecated_methods = ['score', 'cumulated_score'] if method is not None: + _check_depreciated(method) if method in ['score', 'lac']: return LAC() if method in ['cumulated_score', 'aps']: @@ -82,8 +150,7 @@ def check_classification_conformity_score( return TopK() else: raise ValueError( - f"Invalid method. Allowed values are {allowed_methods}. " - f"Deprecated values are {deprecated_methods}. " + f"Invalid method. Allowed values are {allowed_methods}." ) elif isinstance(conformity_score, BaseClassificationScore): return conformity_score diff --git a/mapie/tests/test_classification.py b/mapie/tests/test_classification.py index 497b8cea5..ec9366a3e 100644 --- a/mapie/tests/test_classification.py +++ b/mapie/tests/test_classification.py @@ -912,12 +912,6 @@ def test_initialized() -> None: MapieClassifier() -def test_default_parameters() -> None: - """Test default values of input parameters.""" - mapie_clf = MapieClassifier() - assert mapie_clf.method == "lac" - - @pytest.mark.parametrize("cv", ["prefit", "split"]) @pytest.mark.parametrize("method", ["aps", "raps"]) def test_warning_binary_classif(cv: str, method: str) -> None: From 031a8d492d06f52e4d8b1b929814188b73c035f6 Mon Sep 17 00:00:00 2001 From: Thibault Cordier Date: Thu, 11 Jul 2024 16:28:19 +0200 Subject: [PATCH 25/46] UPD: manage class method with conflict warning --- mapie/classification.py | 8 ++-- mapie/conformity_scores/utils.py | 44 +++++++++++++--------- mapie/tests/test_conformity_scores_sets.py | 17 +++++++++ 3 files changed, 48 insertions(+), 21 deletions(-) diff --git a/mapie/classification.py b/mapie/classification.py index 153e859e8..9d2d63b53 100644 --- a/mapie/classification.py +++ b/mapie/classification.py @@ -147,9 +147,6 @@ class MapieClassifier(BaseEstimator, ClassifierMixin): Attributes ---------- - valid_methods: List[str] - List of all valid methods. - estimator_: EnsembleClassifier Sklearn estimator that handle all that is related to the estimator. @@ -165,6 +162,9 @@ class MapieClassifier(BaseEstimator, ClassifierMixin): quantiles_: ArrayLike of shape (n_alpha) The quantiles estimated from ``conformity_scores_`` and alpha values. + label_encoder_: LabelEncoder + Label encoder used to encode the labels. + References ---------- [1] Mauricio Sadinle, Jing Lei, and Larry Wasserman. @@ -634,7 +634,7 @@ def predict( When set to ``True`` or ``False``, it may result in a coverage higher than ``1 - alpha`` (because contrary to the "randomized" - setting, none of this methods create empty prediction sets). See + setting, none of these methods create empty prediction sets). See [2] and [3] for more details. By default ``True``. diff --git a/mapie/conformity_scores/utils.py b/mapie/conformity_scores/utils.py index 801f2a5fe..30069aca3 100644 --- a/mapie/conformity_scores/utils.py +++ b/mapie/conformity_scores/utils.py @@ -110,6 +110,17 @@ def check_target( ) +method_score_map = { + 'score': lambda: LAC(), + 'lac': lambda: LAC(), + 'cumulated_score': lambda: APS(), + 'aps': lambda: APS(), + 'naive': lambda: Naive(), + 'raps': lambda: RAPS(), + 'top_k': lambda: TopK() +} + + def check_classification_conformity_score( conformity_score: Optional[BaseClassificationScore] = None, method: Optional[str] = None, @@ -135,27 +146,26 @@ def check_classification_conformity_score( Invalid conformity_score argument. Must be None or a ConformityScore instance. """ - allowed_methods = ['lac', 'naive', 'aps', 'raps', 'top_k'] + if method is None and conformity_score is None: + return LAC() + elif conformity_score is not None: + if method is not None: + warnings.warn( + "WARNING: the `conformity_score` parameter takes precedence " + "over the `method` parameter to define the method used.", + UserWarning + ) + if isinstance(conformity_score, BaseClassificationScore): + return conformity_score if method is not None: - _check_depreciated(method) - if method in ['score', 'lac']: - return LAC() - if method in ['cumulated_score', 'aps']: - return APS() - if method in ['naive']: - return Naive() - if method in ['raps']: - return RAPS() - if method in ['top_k']: - return TopK() + if isinstance(method, str) and method in method_score_map: + _check_depreciated(method) + return method_score_map[method]() else: raise ValueError( - f"Invalid method. Allowed values are {allowed_methods}." + "Invalid method. " + f"Allowed values are {list(method_score_map.keys())}." ) - elif isinstance(conformity_score, BaseClassificationScore): - return conformity_score - elif conformity_score is None: - return LAC() else: raise ValueError( "Invalid conformity_score argument.\n" diff --git a/mapie/tests/test_conformity_scores_sets.py b/mapie/tests/test_conformity_scores_sets.py index b6d63fde6..ac66a16e7 100644 --- a/mapie/tests/test_conformity_scores_sets.py +++ b/mapie/tests/test_conformity_scores_sets.py @@ -38,6 +38,23 @@ def test_check_classification_method( ) +@pytest.mark.parametrize("method", method_list) +@pytest.mark.parametrize("conformity_score", cs_list) +def test_check_conflict_parameters( + method: Optional[str], + conformity_score: Optional[BaseClassificationScore] +) -> None: + if method is None or conformity_score is None: + return + with pytest.warns( + UserWarning, + match="WARNING: the `conformity_score` parameter takes precedence*" + ): + check_classification_conformity_score( + method=method, conformity_score=conformity_score + ) + + @pytest.mark.parametrize("method", wrong_method_list) def test_check_wrong_classification_method( method: Optional[str] From b1b425ea2cac4ccb3d888e5d12dbdd4298d1f567 Mon Sep 17 00:00:00 2001 From: Thibault Cordier Date: Thu, 11 Jul 2024 17:13:59 +0200 Subject: [PATCH 26/46] UPD: reorganize conformity score tests --- mapie/tests/test_classification.py | 114 +--------------- ...es.py => test_conformity_scores_bounds.py} | 0 mapie/tests/test_conformity_scores_sets.py | 129 +++++++++++++++++- ...res.py => test_conformity_scores_utils.py} | 0 4 files changed, 129 insertions(+), 114 deletions(-) rename mapie/tests/{test_conformity_scores.py => test_conformity_scores_bounds.py} (100%) rename mapie/tests/{test_utils_classification_conformity_scores.py => test_conformity_scores_utils.py} (100%) diff --git a/mapie/tests/test_classification.py b/mapie/tests/test_classification.py index ec9366a3e..7e4a1e1e4 100644 --- a/mapie/tests/test_classification.py +++ b/mapie/tests/test_classification.py @@ -1,7 +1,7 @@ from __future__ import annotations from copy import deepcopy -from typing import Any, Dict, Iterable, Optional, Union, cast +from typing import Any, Dict, Iterable, Optional, Union import numpy as np import pandas as pd @@ -23,10 +23,8 @@ from mapie._typing import ArrayLike, NDArray from mapie.classification import MapieClassifier -from mapie.conformity_scores import APS, RAPS from mapie.conformity_scores.sets.utils import check_proba_normalized from mapie.metrics import classification_coverage_score -from mapie.utils import check_alpha random_state = 42 @@ -734,12 +732,6 @@ ] } -REGULARIZATION_PARAMETERS = [ - [.001, [1]], - [[.01, .2], [1, 3]], - [.1, [2, 4]] -] - IMAGE_INPUT = [ { "X_calib": np.zeros((3, 1024, 1024, 1)), @@ -1500,8 +1492,7 @@ def test_cumulated_scores() -> None: include_last_label=True, alpha=alpha ) - computed_quantile = mapie_clf.conformity_score_function_.quantiles_ - np.testing.assert_allclose(computed_quantile, quantile) + np.testing.assert_allclose(mapie_clf.quantiles_, quantile) np.testing.assert_allclose(y_ps[:, :, 0], cumclf.y_pred_sets) @@ -1529,8 +1520,7 @@ def test_image_cumulated_scores(X: Dict[str, ArrayLike]) -> None: include_last_label=True, alpha=alpha ) - computed_quantile = mapie.conformity_score_function_.quantiles_ - np.testing.assert_allclose(computed_quantile, quantile) + np.testing.assert_allclose(mapie.quantiles_, quantile) np.testing.assert_allclose(y_ps[:, :, 0], cumclf.y_pred_sets) @@ -1723,104 +1713,6 @@ def test_classif_float32(cv) -> None: ).all() -@pytest.mark.parametrize("k_lambda", REGULARIZATION_PARAMETERS) -def test_regularize_conf_scores_shape(k_lambda) -> None: - """ - Test that the conformity scores have the correct shape. - """ - lambda_, k = k_lambda[0], k_lambda[1] - conf_scores = np.random.rand(100, 1) - cutoff = np.cumsum(np.ones(conf_scores.shape)) - 1 - reg_conf_scores = RAPS._regularize_conformity_score( - k, lambda_, conf_scores, cutoff - ) - - assert reg_conf_scores.shape == (100, 1, len(k)) - - -def test_get_true_label_cumsum_proba_shape() -> None: - """ - Test that the true label cumsumed probabilities - have the correct shape. - """ - clf = LogisticRegression() - clf.fit(X, y) - y_pred = clf.predict_proba(X) - mapie_clf = MapieClassifier( - estimator=clf, random_state=random_state - ) - mapie_clf.fit(X, y) - classes = mapie_clf.classes_ - cumsum_proba, cutoff = APS.get_true_label_cumsum_proba(y, y_pred, classes) - assert cumsum_proba.shape == (len(X), 1) - assert cutoff.shape == (len(X), ) - - -def test_get_true_label_cumsum_proba_result() -> None: - """ - Test that the true label cumsumed probabilities - are the expected ones. - """ - clf = LogisticRegression() - clf.fit(X_toy, y_toy) - y_pred = clf.predict_proba(X_toy) - mapie_clf = MapieClassifier( - estimator=clf, random_state=random_state - ) - mapie_clf.fit(X_toy, y_toy) - classes = mapie_clf.classes_ - cumsum_proba, cutoff = APS.get_true_label_cumsum_proba( - y_toy, y_pred, classes - ) - np.testing.assert_allclose( - cumsum_proba, - np.array( - [ - y_pred[0, 0], y_pred[1, 0], - y_pred[2, 0] + y_pred[2, 1], - y_pred[3, 0] + y_pred[3, 1], - y_pred[4, 1], y_pred[5, 1], - y_pred[6, 1] + y_pred[6, 2], - y_pred[7, 1] + y_pred[7, 2], - y_pred[8, 2] - ] - )[:, np.newaxis] - ) - np.testing.assert_allclose(cutoff, np.array([1, 1, 2, 2, 1, 1, 2, 2, 1])) - - -@pytest.mark.parametrize("k_lambda", REGULARIZATION_PARAMETERS) -@pytest.mark.parametrize("strategy", [*STRATEGIES]) -def test_get_last_included_proba_shape(k_lambda, strategy): - """ - Test that the outputs of _get_last_included_proba method - have the correct shape. - """ - lambda_, k = k_lambda[0], k_lambda[1] - if len(k) == 1: - thresholds = .2 - else: - thresholds = np.random.rand(len(k)) - thresholds = cast(NDArray, check_alpha(thresholds)) - clf = LogisticRegression() - clf.fit(X, y) - y_pred_proba = clf.predict_proba(X) - y_pred_proba = np.repeat( - y_pred_proba[:, :, np.newaxis], len(thresholds), axis=2 - ) - - include_last_label = STRATEGIES[strategy][1]["include_last_label"] - y_p_p_c, y_p_i_l, y_p_p_i_l = \ - RAPS._get_last_included_proba( - RAPS(), y_pred_proba, thresholds, include_last_label, - lambda_=lambda_, k_star=k - ) - - assert y_p_p_c.shape == (len(X), len(np.unique(y)), len(thresholds)) - assert y_p_i_l.shape == (len(X), 1, len(thresholds)) - assert y_p_p_i_l.shape == (len(X), 1, len(thresholds)) - - @pytest.mark.parametrize("cv", [5, None]) def test_error_raps_cv_not_prefit(cv: Union[int, None]) -> None: """ diff --git a/mapie/tests/test_conformity_scores.py b/mapie/tests/test_conformity_scores_bounds.py similarity index 100% rename from mapie/tests/test_conformity_scores.py rename to mapie/tests/test_conformity_scores_bounds.py diff --git a/mapie/tests/test_conformity_scores_sets.py b/mapie/tests/test_conformity_scores_sets.py index ac66a16e7..2425c5409 100644 --- a/mapie/tests/test_conformity_scores_sets.py +++ b/mapie/tests/test_conformity_scores_sets.py @@ -1,17 +1,43 @@ -from typing import Optional +from typing import Optional, cast import pytest +import numpy as np +from sklearn.datasets import make_classification +from sklearn.linear_model import LogisticRegression -# from mapie._typing import ArrayLike, NDArray +from mapie._typing import NDArray +from mapie.classification import MapieClassifier from mapie.conformity_scores import BaseClassificationScore -from mapie.conformity_scores.sets import APS, LAC, TopK +from mapie.conformity_scores.sets import APS, LAC, RAPS, TopK from mapie.conformity_scores.utils import check_classification_conformity_score +from mapie.utils import check_alpha +random_state = 42 + cs_list = [None, LAC(), APS(), TopK()] method_list = [None, 'naive', 'aps', 'raps', 'lac', 'top_k'] wrong_method_list = ['naive_', 'aps_', 'raps_', 'lac_', 'top_k_'] +REGULARIZATION_PARAMETERS = [ + [.001, [1]], + [[.01, .2], [1, 3]], + [.1, [2, 4]] +] + +X_toy = np.arange(9).reshape(-1, 1) +y_toy = np.array([0, 0, 1, 0, 1, 1, 2, 1, 2]) +y_toy_string = np.array(["0", "0", "1", "0", "1", "1", "2", "1", "2"]) + +n_classes = 4 +X, y = make_classification( + n_samples=500, + n_features=10, + n_informative=3, + n_classes=n_classes, + random_state=random_state, +) + def test_error_mother_class_initialization() -> None: with pytest.raises(TypeError): @@ -61,3 +87,100 @@ def test_check_wrong_classification_method( ) -> None: with pytest.raises(ValueError, match="Invalid method.*"): check_classification_conformity_score(method=method) + + +@pytest.mark.parametrize("k_lambda", REGULARIZATION_PARAMETERS) +def test_regularize_conf_scores_shape(k_lambda) -> None: + """ + Test that the conformity scores have the correct shape. + """ + lambda_, k = k_lambda[0], k_lambda[1] + conf_scores = np.random.rand(100, 1) + cutoff = np.cumsum(np.ones(conf_scores.shape)) - 1 + reg_conf_scores = RAPS._regularize_conformity_score( + k, lambda_, conf_scores, cutoff + ) + + assert reg_conf_scores.shape == (100, 1, len(k)) + + +def test_get_true_label_cumsum_proba_shape() -> None: + """ + Test that the true label cumsumed probabilities + have the correct shape. + """ + clf = LogisticRegression() + clf.fit(X, y) + y_pred = clf.predict_proba(X) + mapie_clf = MapieClassifier( + estimator=clf, random_state=random_state + ) + mapie_clf.fit(X, y) + classes = mapie_clf.classes_ + cumsum_proba, cutoff = APS.get_true_label_cumsum_proba(y, y_pred, classes) + assert cumsum_proba.shape == (len(X), 1) + assert cutoff.shape == (len(X), ) + + +def test_get_true_label_cumsum_proba_result() -> None: + """ + Test that the true label cumsumed probabilities + are the expected ones. + """ + clf = LogisticRegression() + clf.fit(X_toy, y_toy) + y_pred = clf.predict_proba(X_toy) + mapie_clf = MapieClassifier( + estimator=clf, random_state=random_state + ) + mapie_clf.fit(X_toy, y_toy) + classes = mapie_clf.classes_ + cumsum_proba, cutoff = APS.get_true_label_cumsum_proba( + y_toy, y_pred, classes + ) + np.testing.assert_allclose( + cumsum_proba, + np.array( + [ + y_pred[0, 0], y_pred[1, 0], + y_pred[2, 0] + y_pred[2, 1], + y_pred[3, 0] + y_pred[3, 1], + y_pred[4, 1], y_pred[5, 1], + y_pred[6, 1] + y_pred[6, 2], + y_pred[7, 1] + y_pred[7, 2], + y_pred[8, 2] + ] + )[:, np.newaxis] + ) + np.testing.assert_allclose(cutoff, np.array([1, 1, 2, 2, 1, 1, 2, 2, 1])) + + +@pytest.mark.parametrize("k_lambda", REGULARIZATION_PARAMETERS) +@pytest.mark.parametrize("include_last_label", [True, False]) +def test_get_last_included_proba_shape(k_lambda, include_last_label): + """ + Test that the outputs of _get_last_included_proba method + have the correct shape. + """ + lambda_, k = k_lambda[0], k_lambda[1] + if len(k) == 1: + thresholds = .2 + else: + thresholds = np.random.rand(len(k)) + thresholds = cast(NDArray, check_alpha(thresholds)) + clf = LogisticRegression() + clf.fit(X, y) + y_pred_proba = clf.predict_proba(X) + y_pred_proba = np.repeat( + y_pred_proba[:, :, np.newaxis], len(thresholds), axis=2 + ) + + y_p_p_c, y_p_i_l, y_p_p_i_l = \ + RAPS._get_last_included_proba( + RAPS(), y_pred_proba, thresholds, include_last_label, + lambda_=lambda_, k_star=k + ) + + assert y_p_p_c.shape == (len(X), len(np.unique(y)), len(thresholds)) + assert y_p_i_l.shape == (len(X), 1, len(thresholds)) + assert y_p_p_i_l.shape == (len(X), 1, len(thresholds)) diff --git a/mapie/tests/test_utils_classification_conformity_scores.py b/mapie/tests/test_conformity_scores_utils.py similarity index 100% rename from mapie/tests/test_utils_classification_conformity_scores.py rename to mapie/tests/test_conformity_scores_utils.py From 5fa0fee06e81bdcdccc9344f3a1ff82165eeec18 Mon Sep 17 00:00:00 2001 From: Thibault Cordier Date: Thu, 11 Jul 2024 17:26:51 +0200 Subject: [PATCH 27/46] UPD: corrected doctring + comments + minor corrections --- mapie/conformity_scores/sets/aps.py | 19 +------------------ mapie/conformity_scores/sets/naive.py | 2 +- mapie/conformity_scores/sets/raps.py | 3 +++ 3 files changed, 5 insertions(+), 19 deletions(-) diff --git a/mapie/conformity_scores/sets/aps.py b/mapie/conformity_scores/sets/aps.py index e8cf5c1c3..9d3a6d9a2 100644 --- a/mapie/conformity_scores/sets/aps.py +++ b/mapie/conformity_scores/sets/aps.py @@ -255,7 +255,6 @@ def _add_random_tie_breaking( y_pred_proba_cumsum: NDArray, y_pred_proba_last: NDArray, threshold: NDArray, - random_state: Optional[Union[int, np.random.RandomState]] = None, **kwargs ) -> NDArray: """ @@ -288,21 +287,6 @@ def _add_random_tie_breaking( - the conformity score from training samples otherwise (i.e., when ``cv`` is CV splitter and ``agg_scores`` is "crossval") - method: str - Method that determines how to remove last label in the prediction - set. - - - if "cumulated_score" or "aps", compute V parameter - from Romano+(2020) - - - else compute V parameter from Angelopoulos+(2020) - - lambda_star: Optional[Union[NDArray, float]] of shape (n_alpha): - Optimal value of the regulizer lambda. - - k_star: Optional[NDArray] of shape (n_alpha): - Optimal value of the regulizer k. - Returns ------- NDArray of shape (n_samples, n_classes, n_alpha) @@ -326,7 +310,7 @@ def _add_random_tie_breaking( ) # get random numbers for each observation and alpha value - random_state = check_random_state(random_state) + random_state = check_random_state(self.random_state) random_state = cast(np.random.RandomState, random_state) us = random_state.uniform(size=(prediction_sets.shape[0], 1)) # remove last label from comparison between uniform number and V @@ -421,7 +405,6 @@ def get_prediction_sets( y_pred_proba_cumsum, y_pred_proba_last, thresholds, - self.random_state, **kwargs ) if estimator.cv == "prefit" or agg_scores in ["mean"]: diff --git a/mapie/conformity_scores/sets/naive.py b/mapie/conformity_scores/sets/naive.py index 259753021..9d25f3e9f 100644 --- a/mapie/conformity_scores/sets/naive.py +++ b/mapie/conformity_scores/sets/naive.py @@ -213,7 +213,7 @@ def _get_last_included_proba( y_pred_proba_sorted_cumsum = np.cumsum(y_pred_proba_sorted, axis=1) y_pred_proba_sorted_cumsum = self._add_regualization( y_pred_proba_sorted_cumsum, **kwargs - ) + ) # Do nothing as no regularization for the naive method # get cumulated score at their original position y_pred_proba_cumsum = np.take_along_axis( diff --git a/mapie/conformity_scores/sets/raps.py b/mapie/conformity_scores/sets/raps.py index 320b3bbf0..f2844fd5f 100644 --- a/mapie/conformity_scores/sets/raps.py +++ b/mapie/conformity_scores/sets/raps.py @@ -152,6 +152,9 @@ def _find_lambda_star( Parameters ---------- + y_raps_no_enc: NDArray of shape (n_samples, ) + True labels (after applying `label_encoder_.inverse_transform`). + y_pred_proba_raps: NDArray of shape (n_samples, n_labels, n_alphas) Predictions of the model repeated on the last axis as many times as the number of alphas From d2cf4434487759d06c5281ec0dec4c85ded1ce4a Mon Sep 17 00:00:00 2001 From: Thibault Cordier Date: Fri, 12 Jul 2024 11:47:36 +0200 Subject: [PATCH 28/46] UPD: move split data into conformity score + side effect changes --- mapie/classification.py | 161 ++++--------------- mapie/conformity_scores/interface.py | 41 ++++- mapie/conformity_scores/sets/lac.py | 1 + mapie/conformity_scores/sets/naive.py | 1 + mapie/conformity_scores/sets/raps.py | 216 +++++++++++++++++++++++--- mapie/conformity_scores/sets/topk.py | 1 + mapie/tests/test_classification.py | 2 +- 7 files changed, 266 insertions(+), 157 deletions(-) diff --git a/mapie/classification.py b/mapie/classification.py index 9d2d63b53..55e32f813 100644 --- a/mapie/classification.py +++ b/mapie/classification.py @@ -5,19 +5,16 @@ import numpy as np from sklearn.base import BaseEstimator, ClassifierMixin -from sklearn.model_selection import (BaseCrossValidator, BaseShuffleSplit, - StratifiedShuffleSplit) +from sklearn.model_selection import BaseCrossValidator from sklearn.preprocessing import LabelEncoder -from sklearn.utils import _safe_indexing, check_random_state -from sklearn.utils.validation import (_check_y, _num_samples, check_is_fitted, - indexable) +from sklearn.utils import check_random_state +from sklearn.utils.validation import (_check_y, check_is_fitted, indexable) from mapie._typing import ArrayLike, NDArray from mapie.conformity_scores import BaseClassificationScore from mapie.conformity_scores.utils import ( check_classification_conformity_score, check_target ) -from mapie.conformity_scores.sets.utils import get_true_label_position from mapie.estimator.classifier import EnsembleClassifier from mapie.utils import (check_alpha, check_alpha_and_n_samples, check_cv, check_estimator_classification, check_n_features_in, @@ -75,7 +72,7 @@ class MapieClassifier(BaseEstimator, ClassifierMixin): By default ``None``. - cv: Optional[str] + cv: Optional[Union[int, str, BaseCrossValidator]] The cross-validation strategy for computing scores. It directly drives the distinction between jackknife and cv variants. Choose among: @@ -202,7 +199,6 @@ class MapieClassifier(BaseEstimator, ClassifierMixin): [False False True]] """ - raps_valid_cv_ = ["prefit", "split"] fit_attributes = [ "estimator_", "n_features_in_", @@ -244,26 +240,6 @@ def _check_parameters(self) -> None: check_n_jobs(self.n_jobs) check_verbose(self.verbose) check_random_state(self.random_state) - self._check_raps() - - def _check_raps(self): - """ - Check that if the method used is ``"raps"``, then - the cross validation strategy is ``"prefit"``. - - Raises - ------ - ValueError - If ``method`` is ``"raps"`` and ``cv`` is not ``"prefit"``. - """ - if (self.method == "raps") and not ( - (self.cv in self.raps_valid_cv_) - or isinstance(self.cv, BaseShuffleSplit) - ): - raise ValueError( - "RAPS method can only be used " - f"with cv in {self.raps_valid_cv_}." - ) def _get_classes_info( self, estimator: ClassifierMixin, y: NDArray @@ -336,6 +312,7 @@ def _check_fit_parameter( y: ArrayLike, sample_weight: Optional[ArrayLike] = None, groups: Optional[ArrayLike] = None, + size_raps: Optional[float] = None, ): """ Perform several checks on class parameters. @@ -390,103 +367,44 @@ def _check_fit_parameter( estimator = check_estimator_classification(X, y, cv, self.estimator) self.n_features_in_ = check_n_features_in(X, cv, estimator) - n_samples = _num_samples(y) - self.n_classes_, self.classes_ = self._get_classes_info(estimator, y) self.label_encoder_ = self._get_label_encoder() y_enc = self.label_encoder_.transform(y) cs_estimator = check_classification_conformity_score( conformity_score=self.conformity_score, - method=self.method + method=self.method, ) + # TODO test size_raps depreciated cs_estimator.set_external_attributes( + cv=self.cv, classes=self.classes_, + label_encoder=self.label_encoder_, + size_raps=size_raps, random_state=self.random_state ) - check_target(cs_estimator, y) - - return ( - estimator, cs_estimator, cv, - X, y, y_enc, sample_weight, groups, n_samples - ) - - def _split_data( - self, - X: ArrayLike, - y_enc: ArrayLike, - sample_weight: Optional[ArrayLike] = None, - groups: Optional[ArrayLike] = None, - size_raps: Optional[float] = None, - ): - """Split data for raps method - - Parameters - ---------- - X: ArrayLike - Observed values. - - y_enc: ArrayLike - Target values as normalized encodings. - - sample_weight: Optional[ArrayLike] of shape (n_samples,) - Non-null sample weights. + # Cast + X, y_enc, y = cast(NDArray, X), cast(NDArray, y_enc), cast(NDArray, y) + sample_weight = cast(NDArray, sample_weight) + groups = cast(NDArray, groups) - groups: Optional[ArrayLike] of shape (n_samples,) - Group labels for the samples used while splitting the dataset into - train/test set. - By default ``None``. + X, y, y_enc, sample_weight, groups = \ + cs_estimator.split_data(X, y, y_enc, sample_weight, groups) + self.n_samples_ = cs_estimator.n_samples_ - size_raps: : Optional[float] - Percentage of the data to be used for choosing lambda_star and - k_star for the RAPS method. + check_target(cs_estimator, y) - Returns - ------- - Tuple[NDArray, NDArray, NDArray, NDArray, Optional[NDArray], - Optional[NDArray]] - - NDArray of shape (n_samples, n_features) - - NDArray of shape (n_samples,) - - NDArray of shape (n_samples,) - - NDArray of shape (n_samples,) - - NDArray of shape (n_samples,) - - NDArray of shape (n_samples,) - """ - # Split data for raps method - raps_split = StratifiedShuffleSplit( - n_splits=1, test_size=size_raps, random_state=self.random_state - ) - train_raps_index, val_raps_index = next(raps_split.split(X, y_enc)) - X, self.X_raps, y_enc, self.y_raps = ( - _safe_indexing(X, train_raps_index), - _safe_indexing(X, val_raps_index), - _safe_indexing(y_enc, train_raps_index), - _safe_indexing(y_enc, val_raps_index), + return ( + estimator, cs_estimator, cv, X, y, y_enc, sample_weight, groups ) - # Decode y_raps for use in the RAPS method - self.y_raps_no_enc = self.label_encoder_.inverse_transform(self.y_raps) - y = self.label_encoder_.inverse_transform(y_enc) - - # Cast to NDArray for type checking - y_enc = cast(NDArray, y_enc) - n_samples = _num_samples(y_enc) - if sample_weight is not None: - sample_weight = cast(NDArray, sample_weight) - sample_weight = sample_weight[train_raps_index] - if groups is not None: - groups = cast(NDArray, groups) - groups = groups[train_raps_index] - - return X, y_enc, y, n_samples, sample_weight, groups - def fit( self, X: ArrayLike, y: ArrayLike, sample_weight: Optional[ArrayLike] = None, - size_raps: Optional[float] = 0.2, + size_raps: Optional[float] = None, groups: Optional[ArrayLike] = None, **fit_params, ) -> MapieClassifier: @@ -514,7 +432,7 @@ def fit( Percentage of the data to be used for choosing lambda_star and k_star for the RAPS method. - By default ``0.2``. + By default ``None``. groups: Optional[ArrayLike] of shape (n_samples,) Group labels for the samples used while splitting the dataset into @@ -538,14 +456,9 @@ def fit( y, y_enc, sample_weight, - groups, - n_samples) = self._check_fit_parameter(X, y, sample_weight, groups) - self.n_samples_ = n_samples - - if self.method == "raps": - (X, y_enc, y, n_samples, sample_weight, groups) = self._split_data( - X, y_enc, sample_weight, groups, size_raps - ) + groups) = self._check_fit_parameter( + X, y, sample_weight, groups, size_raps + ) # Cast X, y_enc, y = cast(NDArray, X), cast(NDArray, y_enc), cast(NDArray, y) @@ -573,19 +486,12 @@ def fit( X, y, y_enc, groups ) - # RAPS: compute y_pred and position on the RAPS validation dataset - if self.method == "raps": - self.y_pred_proba_raps = ( - self.estimator_.single_estimator_.predict_proba(self.X_raps) - ) - self.position_raps = get_true_label_position( - self.y_pred_proba_raps, self.y_raps - ) - # Compute the conformity scores + self.conformity_score_function_.set_ref_predictor(self.estimator_) self.conformity_scores_ = \ self.conformity_score_function_.get_conformity_scores( - y, y_pred_proba, y_enc=y_enc, X=X + y, y_pred_proba, y_enc=y_enc, X=X, + sample_weight=sample_weight, groups=groups ) return self @@ -678,23 +584,12 @@ def predict( check_alpha_and_n_samples(alpha_np, n) # Estimate prediction sets - if self.method == "raps": - kwargs = { - 'X_raps': self.X_raps, - 'y_raps_no_enc': self.y_raps_no_enc, - 'y_pred_proba_raps': self.y_pred_proba_raps, - 'position_raps': self.position_raps, - } - else: - kwargs = {} - prediction_sets = self.conformity_score_function_.predict_set( X, alpha_np, estimator=self.estimator_, conformity_scores=self.conformity_scores_, include_last_label=include_last_label, agg_scores=agg_scores, - **kwargs ) self.quantiles_ = self.conformity_score_function_.quantiles_ diff --git a/mapie/conformity_scores/interface.py b/mapie/conformity_scores/interface.py index 3979149c0..e7eaa151c 100644 --- a/mapie/conformity_scores/interface.py +++ b/mapie/conformity_scores/interface.py @@ -1,6 +1,8 @@ from abc import ABCMeta, abstractmethod +from typing import Optional import numpy as np +from sklearn.base import BaseEstimator from mapie._compatibility import np_nanquantile from mapie._typing import NDArray @@ -27,7 +29,44 @@ def set_external_attributes( particularly when the attributes are known after the object has been instantiated. """ - pass + + def set_ref_predictor( + self, + predictor: BaseEstimator + ): + """ + Set the reference predictor. + + Parameters + ---------- + predictor: BaeEstimator + Reference predictor. + """ + self.predictor = predictor + + def split_data( + self, + X: NDArray, + y: NDArray, + y_enc: NDArray, + sample_weight: Optional[NDArray] = None, + groups: Optional[NDArray] = None, + ): + """ + Split data. Keeps part of the data for the calibration estimator + (separate from the calibration data). + + Parameters + ---------- + *args: Tuple of NDArray + + Returns + ------- + Tuple of NDArray + Split data for training and calibration. + """ + self.n_samples_ = len(X) + return X, y, y_enc, sample_weight, groups @abstractmethod def get_conformity_scores( diff --git a/mapie/conformity_scores/sets/lac.py b/mapie/conformity_scores/sets/lac.py index 464f6096d..cc7017ea4 100644 --- a/mapie/conformity_scores/sets/lac.py +++ b/mapie/conformity_scores/sets/lac.py @@ -42,6 +42,7 @@ def __init__(self) -> None: def set_external_attributes( self, + *, classes: Optional[ArrayLike] = None, random_state: Optional[Union[int, np.random.RandomState]] = None, **kwargs diff --git a/mapie/conformity_scores/sets/naive.py b/mapie/conformity_scores/sets/naive.py index 9d25f3e9f..6b512c675 100644 --- a/mapie/conformity_scores/sets/naive.py +++ b/mapie/conformity_scores/sets/naive.py @@ -34,6 +34,7 @@ def __init__(self) -> None: def set_external_attributes( self, + *, classes: Optional[ArrayLike] = None, random_state: Optional[Union[int, np.random.RandomState]] = None, **kwargs diff --git a/mapie/conformity_scores/sets/raps.py b/mapie/conformity_scores/sets/raps.py index f2844fd5f..2fcbd69cd 100644 --- a/mapie/conformity_scores/sets/raps.py +++ b/mapie/conformity_scores/sets/raps.py @@ -1,8 +1,14 @@ from typing import Optional, Tuple, Union, cast import numpy as np +from sklearn.calibration import LabelEncoder +from sklearn.model_selection import (BaseCrossValidator, BaseShuffleSplit, + StratifiedShuffleSplit) +from sklearn.utils import _safe_indexing +from sklearn.utils.validation import _num_samples from mapie.conformity_scores.sets.aps import APS +from mapie.conformity_scores.sets.utils import get_true_label_position from mapie.estimator.classifier import EnsembleClassifier from mapie._machine_precision import EPSILON @@ -25,6 +31,12 @@ class RAPS(APS): "Uncertainty Sets for Image Classifiers using Conformal Prediction." International Conference on Learning Representations 2021. + Parameters + ---------- + size_raps: Optional[float] + Percentage of the data to be used for choosing lambda_star and + k_star for the RAPS method. + Attributes ---------- classes: Optional[ArrayLike] @@ -37,8 +49,176 @@ class RAPS(APS): The quantiles estimated from ``get_sets`` method. """ - def __init__(self) -> None: + valid_cv_ = ["prefit", "split"] + + def __init__( + self, + size_raps: Optional[float] = 0.2 + ) -> None: super().__init__() + self.size_raps = size_raps + + def set_external_attributes( + self, + *, + cv: Union[str, BaseCrossValidator, BaseShuffleSplit] = None, + label_encoder: LabelEncoder = None, + size_raps: Optional[float] = None, + **kwargs + ) -> None: + """ + Set attributes that are not provided by the user. + + Parameters + ---------- + cv: Optional[Union[int, str, BaseCrossValidator]] + The cross-validation strategy for computing scores. + + label_encoder: Optional[LabelEncoder] + The label encoder used to encode the labels. + + By default ``None``. + + size_raps: Optional[float] + Percentage of the data to be used for choosing lambda_star and + k_star for the RAPS method. + + By default ``None``. + """ + super().set_external_attributes(**kwargs) + self.cv = cv + self.label_encoder_ = label_encoder + self.size_raps = size_raps + + def _check_cv(self): + """ + Check that if the method used is ``"raps"``, then + the cross validation strategy is ``"prefit"``. + + Raises + ------ + ValueError + If ``method`` is ``"raps"`` and ``cv`` is not ``"prefit"``. + """ + if not ( + self.cv in self.valid_cv_ or isinstance(self.cv, BaseShuffleSplit) + ): + raise ValueError( + "RAPS method can only be used " + f"with cv in {self.valid_cv_}." + ) + + def split_data( + self, + X: NDArray, + y: NDArray, + y_enc: NDArray, + sample_weight: Optional[NDArray] = None, + groups: Optional[NDArray] = None, + ): + """Split data + + Parameters + ---------- + X: ArrayLike + Observed values. + + y: ArrayLike + Target values. + + y_enc: ArrayLike + Target values as normalized encodings. + + sample_weight: Optional[ArrayLike] of shape (n_samples,) + Non-null sample weights. + + groups: Optional[ArrayLike] of shape (n_samples,) + Group labels for the samples used while splitting the dataset into + train/test set. + By default ``None``. + + Returns + ------- + Tuple[NDArray, NDArray, NDArray, NDArray, Optional[NDArray], + Optional[NDArray]] + - NDArray of shape (n_samples, n_features) + - NDArray of shape (n_samples,) + - NDArray of shape (n_samples,) + - NDArray of shape (n_samples,) + - NDArray of shape (n_samples,) + - NDArray of shape (n_samples,) + """ + # Checks + self._check_cv() + + # Split data for raps method + raps_split = StratifiedShuffleSplit( + n_splits=1, + test_size=self.size_raps, random_state=self.random_state + ) + train_raps_index, val_raps_index = next(raps_split.split(X, y_enc)) + X, self.X_raps, y_enc, self.y_raps = ( + _safe_indexing(X, train_raps_index), + _safe_indexing(X, val_raps_index), + _safe_indexing(y_enc, train_raps_index), + _safe_indexing(y_enc, val_raps_index), + ) + + # Decode y_raps for use in the RAPS method + self.y_raps_no_enc = self.label_encoder_.inverse_transform(self.y_raps) + y = self.label_encoder_.inverse_transform(y_enc) + + # Cast to NDArray for type checking + y_enc = cast(NDArray, y_enc) + if sample_weight is not None: + sample_weight = cast(NDArray, sample_weight) + sample_weight = sample_weight[train_raps_index] + if groups is not None: + groups = cast(NDArray, groups) + groups = groups[train_raps_index] + + # Keep sample data size for training and calibration + self.n_samples_ = _num_samples(y_enc) + + return X, y, y_enc, sample_weight, groups + + def get_conformity_scores( + self, + y: NDArray, + y_pred: NDArray, + y_enc: Optional[NDArray] = None, + **kwargs + ) -> NDArray: + """ + Get the conformity score. + + Parameters + ---------- + y: NDArray of shape (n_samples,) + Observed target values. + + y_pred: NDArray of shape (n_samples,) + Predicted target values. + + y_enc: NDArray of shape (n_samples,) + Target values as normalized encodings. + + Returns + ------- + NDArray of shape (n_samples,) + Conformity scores. + """ + # Compute y_pred and position on the RAPS validation dataset + self.y_pred_proba_raps = ( + self.predictor.single_estimator_.predict_proba(self.X_raps) + ) + self.position_raps = get_true_label_position( + self.y_pred_proba_raps, self.y_raps + ) + + return super().get_conformity_scores( + y, y_pred, y_enc=y_enc, **kwargs + ) @staticmethod def _regularize_conformity_score( @@ -79,11 +259,7 @@ def _regularize_conformity_score( cutoff[:, np.newaxis], len(k_star), axis=1 ) conf_score += np.maximum( - np.expand_dims( - lambda_ * (cutoff - k_star), - axis=1 - ), - 0 + np.expand_dims(lambda_ * (cutoff - k_star), axis=1), 0 ) return conf_score @@ -126,9 +302,8 @@ def _update_size_and_lambda( and the new best sizes. """ sizes = [ - classification_mean_width_score( - y_ps[:, :, i] - ) for i in range(len(alpha_np)) + classification_mean_width_score(y_ps[:, :, i]) + for i in range(len(alpha_np)) ] sizes_improve = (sizes < best_sizes - EPSILON) @@ -209,8 +384,9 @@ def _find_lambda_star( ) y_ps = np.greater_equal( - y_pred_proba_raps - y_pred_proba_last, -EPSILON + y_pred_proba_raps - y_pred_proba_last, -EPSILON ) + lambda_star, best_sizes = self._update_size_and_lambda( best_sizes, alpha_np, y_ps, lambda_, lambda_star ) @@ -227,10 +403,6 @@ def get_conformity_score_quantiles( estimator: EnsembleClassifier, agg_scores: Optional[str] = "mean", include_last_label: Optional[Union[bool, str]] = True, - X_raps: Optional[NDArray] = None, - y_raps_no_enc: Optional[NDArray] = None, - y_pred_proba_raps: Optional[NDArray] = None, - position_raps: Optional[NDArray] = None, **kwargs ) -> NDArray: """ @@ -289,23 +461,23 @@ def get_conformity_score_quantiles( Array of quantiles with respect to alpha_np. """ # Casting to NDArray to avoid mypy errors - X_raps = cast(NDArray, X_raps) - y_raps_no_enc = cast(NDArray, y_raps_no_enc) - y_pred_proba_raps = cast(NDArray, y_pred_proba_raps) - position_raps = cast(NDArray, position_raps) + # X_raps = cast(NDArray, X_raps) + # y_raps_no_enc = cast(NDArray, y_raps_no_enc) + # y_pred_proba_raps = cast(NDArray, y_pred_proba_raps) + # position_raps = cast(NDArray, position_raps) - check_alpha_and_n_samples(alpha_np, X_raps.shape[0]) + check_alpha_and_n_samples(alpha_np, self.X_raps.shape[0]) self.k_star = compute_quantiles( - position_raps, + self.position_raps, alpha_np ) + 1 y_pred_proba_raps = np.repeat( - y_pred_proba_raps[:, :, np.newaxis], + self.y_pred_proba_raps[:, :, np.newaxis], len(alpha_np), axis=2 ) self.lambda_star = self._find_lambda_star( - y_raps_no_enc, + self.y_raps_no_enc, y_pred_proba_raps, alpha_np, include_last_label, diff --git a/mapie/conformity_scores/sets/topk.py b/mapie/conformity_scores/sets/topk.py index 346592452..d46ee08e1 100644 --- a/mapie/conformity_scores/sets/topk.py +++ b/mapie/conformity_scores/sets/topk.py @@ -46,6 +46,7 @@ def __init__(self) -> None: def set_external_attributes( self, + *, classes: Optional[int] = None, random_state: Optional[Union[int, np.random.RandomState]] = None, **kwargs diff --git a/mapie/tests/test_classification.py b/mapie/tests/test_classification.py index 7e4a1e1e4..30b26a8fd 100644 --- a/mapie/tests/test_classification.py +++ b/mapie/tests/test_classification.py @@ -1408,7 +1408,7 @@ def test_toy_dataset_predictions(strategy: str) -> None: else: clf = LogisticRegression() mapie_clf = MapieClassifier(estimator=clf, **args_init) - mapie_clf.fit(X_toy, y_toy, size_raps=.5) + mapie_clf.fit(X_toy, y_toy, size_raps=0.5) _, y_ps = mapie_clf.predict( X_toy, alpha=0.5, From ba8021be0208cc2acd71ad9fb741497ee0de67cb Mon Sep 17 00:00:00 2001 From: Thibault Cordier Date: Fri, 12 Jul 2024 12:06:59 +0200 Subject: [PATCH 29/46] FIx: type-check casting --- mapie/conformity_scores/sets/raps.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/mapie/conformity_scores/sets/raps.py b/mapie/conformity_scores/sets/raps.py index 2fcbd69cd..dd0e3e433 100644 --- a/mapie/conformity_scores/sets/raps.py +++ b/mapie/conformity_scores/sets/raps.py @@ -61,8 +61,8 @@ def __init__( def set_external_attributes( self, *, - cv: Union[str, BaseCrossValidator, BaseShuffleSplit] = None, - label_encoder: LabelEncoder = None, + cv: Optional[Union[str, BaseCrossValidator, BaseShuffleSplit]] = None, + label_encoder: Optional[LabelEncoder] = None, size_raps: Optional[float] = None, **kwargs ) -> None: @@ -74,6 +74,8 @@ def set_external_attributes( cv: Optional[Union[int, str, BaseCrossValidator]] The cross-validation strategy for computing scores. + By default ``None``. + label_encoder: Optional[LabelEncoder] The label encoder used to encode the labels. @@ -86,8 +88,8 @@ def set_external_attributes( By default ``None``. """ super().set_external_attributes(**kwargs) - self.cv = cv - self.label_encoder_ = label_encoder + self.cv = cast(Union[str, BaseCrossValidator, BaseShuffleSplit], cv) + self.label_encoder_ = cast(LabelEncoder, label_encoder) self.size_raps = size_raps def _check_cv(self): From a2f022ce76a4f153e419b4ee4c4937ae4e5e0d09 Mon Sep 17 00:00:00 2001 From: Thibault Cordier Date: Fri, 12 Jul 2024 14:46:23 +0200 Subject: [PATCH 30/46] UPD: add attributes in doctring --- mapie/conformity_scores/sets/raps.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/mapie/conformity_scores/sets/raps.py b/mapie/conformity_scores/sets/raps.py index dd0e3e433..9894f991a 100644 --- a/mapie/conformity_scores/sets/raps.py +++ b/mapie/conformity_scores/sets/raps.py @@ -39,14 +39,24 @@ class RAPS(APS): Attributes ---------- - classes: Optional[ArrayLike] + classes: ArrayLike Names of the classes. - random_state: Optional[Union[int, RandomState]] + random_state: Union[int, RandomState] Pseudo random number generator state. quantiles_: ArrayLike of shape (n_alpha) The quantiles estimated from ``get_sets`` method. + + cv: Union[int, str, BaseCrossValidator] + The cross-validation strategy for computing scores. + + label_encoder: LabelEncoder + The label encoder used to encode the labels. + + size_raps: float + Percentage of the data to be used for choosing lambda_star and + k_star for the RAPS method. """ valid_cv_ = ["prefit", "split"] @@ -118,7 +128,9 @@ def split_data( sample_weight: Optional[NDArray] = None, groups: Optional[NDArray] = None, ): - """Split data + """ + Split data. Keeps part of the data for the calibration estimator + (separate from the calibration data). Parameters ---------- @@ -148,7 +160,6 @@ def split_data( - NDArray of shape (n_samples,) - NDArray of shape (n_samples,) - NDArray of shape (n_samples,) - - NDArray of shape (n_samples,) """ # Checks self._check_cv() From 3eb40eb09639325ee636e99563493cfbf58bfcda Mon Sep 17 00:00:00 2001 From: Thibault Cordier Date: Fri, 12 Jul 2024 15:19:10 +0200 Subject: [PATCH 31/46] UPD: add description in tests --- mapie/tests/test_conformity_scores_sets.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/mapie/tests/test_conformity_scores_sets.py b/mapie/tests/test_conformity_scores_sets.py index 2425c5409..26e02f43b 100644 --- a/mapie/tests/test_conformity_scores_sets.py +++ b/mapie/tests/test_conformity_scores_sets.py @@ -40,6 +40,9 @@ def test_error_mother_class_initialization() -> None: + """ + Test that the mother class BaseClassificationScore cannot be instantiated. + """ with pytest.raises(TypeError): BaseClassificationScore() # type: ignore @@ -48,6 +51,10 @@ def test_error_mother_class_initialization() -> None: def test_check_classification_conformity_score( conformity_score: Optional[BaseClassificationScore] ) -> None: + """ + Test that the function check_classification_conformity_score returns + an instance of BaseClassificationScore when using conformity_score. + """ assert isinstance( check_classification_conformity_score(conformity_score), BaseClassificationScore @@ -58,6 +65,10 @@ def test_check_classification_conformity_score( def test_check_classification_method( method: Optional[str] ) -> None: + """ + Test that the function check_classification_conformity_score returns + an instance of BaseClassificationScore when using method. + """ assert isinstance( check_classification_conformity_score(method=method), BaseClassificationScore @@ -70,6 +81,10 @@ def test_check_conflict_parameters( method: Optional[str], conformity_score: Optional[BaseClassificationScore] ) -> None: + """ + Test that the function check_classification_conformity_score raises + a warning when both method and conformity_score are provided. + """ if method is None or conformity_score is None: return with pytest.warns( @@ -85,6 +100,10 @@ def test_check_conflict_parameters( def test_check_wrong_classification_method( method: Optional[str] ) -> None: + """ + Test that the function check_classification_conformity_score raises + a ValueError when using a wrong method. + """ with pytest.raises(ValueError, match="Invalid method.*"): check_classification_conformity_score(method=method) From 2e0171b8575d53ffe5171153897c729908cf894d Mon Sep 17 00:00:00 2001 From: Thibault Cordier Date: Fri, 12 Jul 2024 15:41:55 +0200 Subject: [PATCH 32/46] UPD: add all conformity scores to test + test same results with method and score parameters --- mapie/conformity_scores/utils.py | 8 ++-- mapie/tests/test_classification.py | 43 +++++++++++++++++++++- mapie/tests/test_conformity_scores_sets.py | 13 ++++--- 3 files changed, 53 insertions(+), 11 deletions(-) diff --git a/mapie/conformity_scores/utils.py b/mapie/conformity_scores/utils.py index 30069aca3..ce8735d53 100644 --- a/mapie/conformity_scores/utils.py +++ b/mapie/conformity_scores/utils.py @@ -110,7 +110,7 @@ def check_target( ) -method_score_map = { +METHOD_SCORE_MAP = { 'score': lambda: LAC(), 'lac': lambda: LAC(), 'cumulated_score': lambda: APS(), @@ -158,13 +158,13 @@ def check_classification_conformity_score( if isinstance(conformity_score, BaseClassificationScore): return conformity_score if method is not None: - if isinstance(method, str) and method in method_score_map: + if isinstance(method, str) and method in METHOD_SCORE_MAP: _check_depreciated(method) - return method_score_map[method]() + return METHOD_SCORE_MAP[method]() else: raise ValueError( "Invalid method. " - f"Allowed values are {list(method_score_map.keys())}." + f"Allowed values are {list(METHOD_SCORE_MAP.keys())}." ) else: raise ValueError( diff --git a/mapie/tests/test_classification.py b/mapie/tests/test_classification.py index 30b26a8fd..24b37d612 100644 --- a/mapie/tests/test_classification.py +++ b/mapie/tests/test_classification.py @@ -1,7 +1,7 @@ from __future__ import annotations from copy import deepcopy -from typing import Any, Dict, Iterable, Optional, Union +from typing import Any, Dict, Iterable, Optional, Union, cast import numpy as np import pandas as pd @@ -23,6 +23,7 @@ from mapie._typing import ArrayLike, NDArray from mapie.classification import MapieClassifier +from mapie.conformity_scores.utils import METHOD_SCORE_MAP from mapie.conformity_scores.sets.utils import check_proba_normalized from mapie.metrics import classification_coverage_score @@ -1444,6 +1445,46 @@ def test_large_dataset_predictions(strategy: str) -> None: ) +@pytest.mark.parametrize("strategy", [*LARGE_COVERAGES]) +def test_same_result_with_score_and_method(strategy: str) -> None: + """ + Test that prediction sets estimated by MapieClassifier on a larger dataset + archive same coverage with conformity_score or method parameters. + """ + + def get_results(args_init, args_predict): + if "split" not in strategy: + clf = LogisticRegression().fit(X, y) + else: + clf = LogisticRegression() + mapie_clf = MapieClassifier(estimator=clf, **args_init) + mapie_clf.fit(X, y, size_raps=0.5) + _, y_ps = mapie_clf.predict( + X, + alpha=0.2, + include_last_label=args_predict["include_last_label"], + agg_scores=args_predict["agg_scores"] + ) + return classification_coverage_score(y, y_ps[:, :, 0]) + + # Take args of the strategy to test + args_init = cast(dict, deepcopy(STRATEGIES[strategy][0])) + args_predict = cast(dict, deepcopy(STRATEGIES[strategy][1])) + + # Test with method parameters + cov_method = get_results(args_init, args_predict) + + # Change method to conformity_score + method = args_init.pop('method', None) + args_init['conformity_score'] = METHOD_SCORE_MAP[method]() + + # Test with method parameters + cov_conformity_score = get_results(args_init, args_predict) + + # Test that results are the same + np.testing.assert_allclose(cov_method, cov_conformity_score, rtol=1e-2) + + @pytest.mark.parametrize("strategy", [*STRATEGIES_BINARY]) def test_toy_binary_dataset_predictions(strategy: str) -> None: """ diff --git a/mapie/tests/test_conformity_scores_sets.py b/mapie/tests/test_conformity_scores_sets.py index 26e02f43b..213ab9129 100644 --- a/mapie/tests/test_conformity_scores_sets.py +++ b/mapie/tests/test_conformity_scores_sets.py @@ -8,15 +8,16 @@ from mapie._typing import NDArray from mapie.classification import MapieClassifier from mapie.conformity_scores import BaseClassificationScore -from mapie.conformity_scores.sets import APS, LAC, RAPS, TopK +from mapie.conformity_scores.sets import APS, LAC, Naive, RAPS, TopK from mapie.conformity_scores.utils import check_classification_conformity_score from mapie.utils import check_alpha random_state = 42 -cs_list = [None, LAC(), APS(), TopK()] -method_list = [None, 'naive', 'aps', 'raps', 'lac', 'top_k'] +cs_list = [None, LAC(), APS(), RAPS(), Naive(), TopK()] +valid_method_list = ['naive', 'aps', 'raps', 'lac', 'top_k'] +all_method_list = valid_method_list + [None] wrong_method_list = ['naive_', 'aps_', 'raps_', 'lac_', 'top_k_'] REGULARIZATION_PARAMETERS = [ @@ -61,7 +62,7 @@ def test_check_classification_conformity_score( ) -@pytest.mark.parametrize("method", method_list) +@pytest.mark.parametrize("method", all_method_list) def test_check_classification_method( method: Optional[str] ) -> None: @@ -75,7 +76,7 @@ def test_check_classification_method( ) -@pytest.mark.parametrize("method", method_list) +@pytest.mark.parametrize("method", valid_method_list) @pytest.mark.parametrize("conformity_score", cs_list) def test_check_conflict_parameters( method: Optional[str], @@ -85,7 +86,7 @@ def test_check_conflict_parameters( Test that the function check_classification_conformity_score raises a warning when both method and conformity_score are provided. """ - if method is None or conformity_score is None: + if conformity_score is None: return with pytest.warns( UserWarning, From e92a71324d3da730c32d79280d0dcf7860149c6e Mon Sep 17 00:00:00 2001 From: Thibault Cordier Date: Fri, 12 Jul 2024 17:35:07 +0200 Subject: [PATCH 33/46] UPD: doctring parameters --- mapie/conformity_scores/utils.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/mapie/conformity_scores/utils.py b/mapie/conformity_scores/utils.py index ce8735d53..20f5886dd 100644 --- a/mapie/conformity_scores/utils.py +++ b/mapie/conformity_scores/utils.py @@ -19,6 +19,18 @@ def check_regression_conformity_score( """ Check parameter ``conformity_score`` for regression task. + Parameters + ---------- + conformity_score: BaseClassificationScore + Conformity score function. + + By default, `None`. + + sym: bool + Whether to use symmetric bounds. + + By default, `True`. + Raises ------ ValueError @@ -128,6 +140,18 @@ def check_classification_conformity_score( """ Check parameter ``conformity_score`` for classification task. + Parameters + ---------- + conformity_score: BaseClassificationScore + Conformity score function. + + By default, `None`. + + method: str + Method to compute the conformity score. + + By default, `None`. + Raises ------ ValueError From c3fee465429efdc683bcec67a221f2d00ef72a87 Mon Sep 17 00:00:00 2001 From: Thibault Cordier Date: Fri, 12 Jul 2024 17:46:15 +0200 Subject: [PATCH 34/46] UPD: move dict at the top of file --- mapie/conformity_scores/utils.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/mapie/conformity_scores/utils.py b/mapie/conformity_scores/utils.py index 20f5886dd..71b3eaed2 100644 --- a/mapie/conformity_scores/utils.py +++ b/mapie/conformity_scores/utils.py @@ -12,6 +12,17 @@ from mapie._typing import ArrayLike +METHOD_SCORE_MAP = { + 'score': lambda: LAC(), + 'lac': lambda: LAC(), + 'cumulated_score': lambda: APS(), + 'aps': lambda: APS(), + 'naive': lambda: Naive(), + 'raps': lambda: RAPS(), + 'top_k': lambda: TopK() +} + + def check_regression_conformity_score( conformity_score: Optional[BaseRegressionScore], sym: bool = True, @@ -122,17 +133,6 @@ def check_target( ) -METHOD_SCORE_MAP = { - 'score': lambda: LAC(), - 'lac': lambda: LAC(), - 'cumulated_score': lambda: APS(), - 'aps': lambda: APS(), - 'naive': lambda: Naive(), - 'raps': lambda: RAPS(), - 'top_k': lambda: TopK() -} - - def check_classification_conformity_score( conformity_score: Optional[BaseClassificationScore] = None, method: Optional[str] = None, From 933d4d9b3f3e858b2f20646601876ceaeacb54fd Mon Sep 17 00:00:00 2001 From: Thibault Cordier Date: Fri, 12 Jul 2024 18:07:30 +0200 Subject: [PATCH 35/46] UPD: add all check tests when parameters are wrong --- mapie/conformity_scores/utils.py | 24 +++++++++++--------- mapie/tests/test_conformity_scores_bounds.py | 17 ++++++++++++++ mapie/tests/test_conformity_scores_sets.py | 15 +++++++++++- 3 files changed, 44 insertions(+), 12 deletions(-) diff --git a/mapie/conformity_scores/utils.py b/mapie/conformity_scores/utils.py index 71b3eaed2..c6cfb91c9 100644 --- a/mapie/conformity_scores/utils.py +++ b/mapie/conformity_scores/utils.py @@ -29,6 +29,7 @@ def check_regression_conformity_score( ) -> BaseRegressionScore: """ Check parameter ``conformity_score`` for regression task. + By default, return a AbsoluteConformityScore instance. Parameters ---------- @@ -58,7 +59,7 @@ def check_regression_conformity_score( ... print(exception) ... Invalid conformity_score argument. - Must be None or a ConformityScore instance. + Must be None or a BaseRegressionScore instance. """ if conformity_score is None: return AbsoluteConformityScore(sym=sym) @@ -67,7 +68,7 @@ def check_regression_conformity_score( else: raise ValueError( "Invalid conformity_score argument.\n" - "Must be None or a ConformityScore instance." + "Must be None or a BaseRegressionScore instance." ) @@ -139,6 +140,7 @@ def check_classification_conformity_score( ) -> BaseClassificationScore: """ Check parameter ``conformity_score`` for classification task. + By default, return a LAC instance. Parameters ---------- @@ -168,11 +170,9 @@ def check_classification_conformity_score( ... print(exception) ... Invalid conformity_score argument. - Must be None or a ConformityScore instance. + Must be None or a BaseClassificationScore instance. """ - if method is None and conformity_score is None: - return LAC() - elif conformity_score is not None: + if conformity_score is not None: if method is not None: warnings.warn( "WARNING: the `conformity_score` parameter takes precedence " @@ -181,7 +181,12 @@ def check_classification_conformity_score( ) if isinstance(conformity_score, BaseClassificationScore): return conformity_score - if method is not None: + else: + raise ValueError( + "Invalid conformity_score argument.\n" + "Must be None or a BaseClassificationScore instance." + ) + elif method is not None: if isinstance(method, str) and method in METHOD_SCORE_MAP: _check_depreciated(method) return METHOD_SCORE_MAP[method]() @@ -191,7 +196,4 @@ def check_classification_conformity_score( f"Allowed values are {list(METHOD_SCORE_MAP.keys())}." ) else: - raise ValueError( - "Invalid conformity_score argument.\n" - "Must be None or a ConformityScore instance." - ) + return LAC() diff --git a/mapie/tests/test_conformity_scores_bounds.py b/mapie/tests/test_conformity_scores_bounds.py index 06dfca94b..345c33652 100644 --- a/mapie/tests/test_conformity_scores_bounds.py +++ b/mapie/tests/test_conformity_scores_bounds.py @@ -1,3 +1,4 @@ +from typing import Any import numpy as np import pytest from sklearn.linear_model import LinearRegression @@ -10,6 +11,8 @@ ResidualNormalisedScore ) from mapie.regression import MapieRegressor +from mapie.conformity_scores.utils import check_regression_conformity_score + X_toy = np.array([0, 1, 2, 3, 4, 5]).reshape(-1, 1) y_toy = np.array([5, 7, 9, 11, 13, 15]) @@ -21,6 +24,8 @@ ) random_state = 42 +wrong_cs_list = [object(), "AbsoluteConformityScore", 1] + class DummyConformityScore(BaseRegressionScore): def __init__(self) -> None: @@ -48,6 +53,18 @@ def test_error_mother_class_initialization(sym: bool) -> None: BaseRegressionScore(sym) # type: ignore +@pytest.mark.parametrize("score", wrong_cs_list) +def test_check_wrong_regression_score( + score: Any +) -> None: + """ + Test that the function check_regression_conformity_score raises + a ValueError when using a wrong score. + """ + with pytest.raises(ValueError, match="Invalid conformity_score argument*"): + check_regression_conformity_score(conformity_score=score) + + @pytest.mark.parametrize("y_pred", [np.array(y_pred_list), y_pred_list]) def test_absolute_conformity_score_get_conformity_scores( y_pred: NDArray, diff --git a/mapie/tests/test_conformity_scores_sets.py b/mapie/tests/test_conformity_scores_sets.py index 213ab9129..e6154602c 100644 --- a/mapie/tests/test_conformity_scores_sets.py +++ b/mapie/tests/test_conformity_scores_sets.py @@ -1,4 +1,4 @@ -from typing import Optional, cast +from typing import Any, Optional, cast import pytest import numpy as np @@ -16,6 +16,7 @@ random_state = 42 cs_list = [None, LAC(), APS(), RAPS(), Naive(), TopK()] +wrong_cs_list = [object(), "LAC", 1] valid_method_list = ['naive', 'aps', 'raps', 'lac', 'top_k'] all_method_list = valid_method_list + [None] wrong_method_list = ['naive_', 'aps_', 'raps_', 'lac_', 'top_k_'] @@ -109,6 +110,18 @@ def test_check_wrong_classification_method( check_classification_conformity_score(method=method) +@pytest.mark.parametrize("score", wrong_cs_list) +def test_check_wrong_classification_score( + score: Any +) -> None: + """ + Test that the function check_classification_conformity_score raises + a ValueError when using a wrong score. + """ + with pytest.raises(ValueError, match="Invalid conformity_score argument*"): + check_classification_conformity_score(conformity_score=score) + + @pytest.mark.parametrize("k_lambda", REGULARIZATION_PARAMETERS) def test_regularize_conf_scores_shape(k_lambda) -> None: """ From a096b51c8d02fc5e11bf10f42f9d5f5eec9e6922 Mon Sep 17 00:00:00 2001 From: Thibault Cordier Date: Mon, 15 Jul 2024 10:31:19 +0200 Subject: [PATCH 36/46] UPD: add deprecated value check --- mapie/classification.py | 5 +-- mapie/conformity_scores/utils.py | 39 +++++++++++++++++----- mapie/tests/test_conformity_scores_sets.py | 16 +++++++++ 3 files changed, 50 insertions(+), 10 deletions(-) diff --git a/mapie/classification.py b/mapie/classification.py index 55e32f813..d73085c0e 100644 --- a/mapie/classification.py +++ b/mapie/classification.py @@ -13,7 +13,8 @@ from mapie._typing import ArrayLike, NDArray from mapie.conformity_scores import BaseClassificationScore from mapie.conformity_scores.utils import ( - check_classification_conformity_score, check_target + check_depreciated_size_raps, check_classification_conformity_score, + check_target ) from mapie.estimator.classifier import EnsembleClassifier from mapie.utils import (check_alpha, check_alpha_and_n_samples, check_cv, @@ -375,7 +376,7 @@ def _check_fit_parameter( conformity_score=self.conformity_score, method=self.method, ) - # TODO test size_raps depreciated + check_depreciated_size_raps(size_raps) cs_estimator.set_external_attributes( cv=self.cv, classes=self.classes_, diff --git a/mapie/conformity_scores/utils.py b/mapie/conformity_scores/utils.py index c6cfb91c9..1cae0f1c8 100644 --- a/mapie/conformity_scores/utils.py +++ b/mapie/conformity_scores/utils.py @@ -72,7 +72,7 @@ def check_regression_conformity_score( ) -def _check_depreciated( +def check_depreciated_score( method: str ) -> None: """ @@ -87,17 +87,40 @@ def _check_depreciated( if method == "score": warnings.warn( "WARNING: Deprecated method. " - + "The method \"score\" is outdated. " - + "Prefer to use \"lac\" instead to keep " - + "the same behavior in the next release.", + "The method \"score\" is outdated. " + "Prefer to use \"lac\" instead to keep " + "the same behavior in the next release.", DeprecationWarning ) if method == "cumulated_score": warnings.warn( "WARNING: Deprecated method. " - + "The method \"cumulated_score\" is outdated. " - + "Prefer to use \"aps\" instead to keep " - + "the same behavior in the next release.", + "The method \"cumulated_score\" is outdated. " + "Prefer to use \"aps\" instead to keep " + "the same behavior in the next release.", + DeprecationWarning + ) + + +def check_depreciated_size_raps( + size_raps: Optional[float] +) -> None: + """ + Check if the parameter ``size_raps`` is used. If so, raise a warning. + + Raises + ------ + Warning + If ``size_raps`` is not ``None``. + """ + if not (size_raps is None): + warnings.warn( + "WARNING: Deprecated parameter. " + "The parameter `size_raps` is deprecated. " + "In the next release, `RAPS` takes precedence over " + "`MapieClassifier` for setting the size used. " + "Prefer to define `size_raps` in `RAPS` rather than " + "in the `fit` method of `MapieClassifier`.", DeprecationWarning ) @@ -188,7 +211,7 @@ def check_classification_conformity_score( ) elif method is not None: if isinstance(method, str) and method in METHOD_SCORE_MAP: - _check_depreciated(method) + check_depreciated_score(method) return METHOD_SCORE_MAP[method]() else: raise ValueError( diff --git a/mapie/tests/test_conformity_scores_sets.py b/mapie/tests/test_conformity_scores_sets.py index e6154602c..a5197f341 100644 --- a/mapie/tests/test_conformity_scores_sets.py +++ b/mapie/tests/test_conformity_scores_sets.py @@ -122,6 +122,22 @@ def test_check_wrong_classification_score( check_classification_conformity_score(conformity_score=score) +@pytest.mark.parametrize("cv", ['prefit', 'split']) +@pytest.mark.parametrize("size_raps", [0.2, 0.5, 0.8]) +def test_check_depreciated_size_raps(size_raps: float, cv: str) -> None: + """ + Test that the function check_classification_conformity_score raises + a DeprecationWarning when using size_raps. + """ + clf = LogisticRegression().fit(X, y) + mapie_clf = MapieClassifier(estimator=clf, conformity_score=RAPS(), cv=cv) + with pytest.warns( + DeprecationWarning, + match="The parameter `size_raps` is deprecated.*" + ): + mapie_clf.fit(X, y, size_raps=size_raps) + + @pytest.mark.parametrize("k_lambda", REGULARIZATION_PARAMETERS) def test_regularize_conf_scores_shape(k_lambda) -> None: """ From 7b64f6f4c134cba43e49ac776eba4f0050a37882 Mon Sep 17 00:00:00 2001 From: Thibault Cordier Date: Mon, 15 Jul 2024 10:32:35 +0200 Subject: [PATCH 37/46] UPD: short value check command --- mapie/conformity_scores/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mapie/conformity_scores/utils.py b/mapie/conformity_scores/utils.py index 1cae0f1c8..58ff6f05f 100644 --- a/mapie/conformity_scores/utils.py +++ b/mapie/conformity_scores/utils.py @@ -210,7 +210,7 @@ def check_classification_conformity_score( "Must be None or a BaseClassificationScore instance." ) elif method is not None: - if isinstance(method, str) and method in METHOD_SCORE_MAP: + if method in METHOD_SCORE_MAP: check_depreciated_score(method) return METHOD_SCORE_MAP[method]() else: From 262a96a64eaad535fbc63a0e96255eb8a3f351a8 Mon Sep 17 00:00:00 2001 From: Thibault Cordier Date: Mon, 15 Jul 2024 10:39:58 +0200 Subject: [PATCH 38/46] FIX: unhashable list --- mapie/conformity_scores/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mapie/conformity_scores/utils.py b/mapie/conformity_scores/utils.py index 58ff6f05f..1cae0f1c8 100644 --- a/mapie/conformity_scores/utils.py +++ b/mapie/conformity_scores/utils.py @@ -210,7 +210,7 @@ def check_classification_conformity_score( "Must be None or a BaseClassificationScore instance." ) elif method is not None: - if method in METHOD_SCORE_MAP: + if isinstance(method, str) and method in METHOD_SCORE_MAP: check_depreciated_score(method) return METHOD_SCORE_MAP[method]() else: From 1e0b66cce8be0e9b3714c508093505c01a38a866 Mon Sep 17 00:00:00 2001 From: Thibault Cordier Date: Mon, 15 Jul 2024 10:48:31 +0200 Subject: [PATCH 39/46] UPD: set_external_attributes common method --- mapie/conformity_scores/classification.py | 29 ++++++++++++++++++++++- mapie/conformity_scores/sets/lac.py | 28 ++-------------------- mapie/conformity_scores/sets/naive.py | 28 ++-------------------- mapie/conformity_scores/sets/topk.py | 26 +------------------- 4 files changed, 33 insertions(+), 78 deletions(-) diff --git a/mapie/conformity_scores/classification.py b/mapie/conformity_scores/classification.py index f6e45d380..2e2b010c1 100644 --- a/mapie/conformity_scores/classification.py +++ b/mapie/conformity_scores/classification.py @@ -1,9 +1,12 @@ from abc import ABCMeta, abstractmethod +from typing import Optional, Union + +import numpy as np from mapie.conformity_scores.interface import BaseConformityScore from mapie.estimator.classifier import EnsembleClassifier -from mapie._typing import NDArray +from mapie._typing import ArrayLike, NDArray class BaseClassificationScore(BaseConformityScore, metaclass=ABCMeta): @@ -21,6 +24,30 @@ class BaseClassificationScore(BaseConformityScore, metaclass=ABCMeta): def __init__(self) -> None: super().__init__() + def set_external_attributes( + self, + *, + classes: Optional[ArrayLike] = None, + random_state: Optional[Union[int, np.random.RandomState]] = None, + **kwargs + ) -> None: + """ + Set attributes that are not provided by the user. + + Parameters + ---------- + classes: Optional[ArrayLike] + Names of the classes. + + By default ``None``. + + random_state: Optional[Union[int, RandomState]] + Pseudo random number generator state. + """ + super().set_external_attributes(**kwargs) + self.classes = classes + self.random_state = random_state + @abstractmethod def get_predictions( self, diff --git a/mapie/conformity_scores/sets/lac.py b/mapie/conformity_scores/sets/lac.py index cc7017ea4..a2b48795c 100644 --- a/mapie/conformity_scores/sets/lac.py +++ b/mapie/conformity_scores/sets/lac.py @@ -1,4 +1,4 @@ -from typing import Optional, Union, cast +from typing import Optional, cast import numpy as np @@ -7,7 +7,7 @@ from mapie.estimator.classifier import EnsembleClassifier from mapie._machine_precision import EPSILON -from mapie._typing import ArrayLike, NDArray +from mapie._typing import NDArray from mapie.utils import compute_quantiles @@ -40,30 +40,6 @@ class LAC(BaseClassificationScore): def __init__(self) -> None: super().__init__() - def set_external_attributes( - self, - *, - classes: Optional[ArrayLike] = None, - random_state: Optional[Union[int, np.random.RandomState]] = None, - **kwargs - ) -> None: - """ - Set attributes that are not provided by the user. - - Parameters - ---------- - classes: Optional[ArrayLike] - Names of the classes. - - By default ``None``. - - random_state: Optional[Union[int, RandomState]] - Pseudo random number generator state. - """ - super().set_external_attributes(**kwargs) - self.classes = classes - self.random_state = random_state - def get_conformity_scores( self, y: NDArray, diff --git a/mapie/conformity_scores/sets/naive.py b/mapie/conformity_scores/sets/naive.py index 6b512c675..9ec6c2399 100644 --- a/mapie/conformity_scores/sets/naive.py +++ b/mapie/conformity_scores/sets/naive.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple, Union +from typing import Tuple, Union import numpy as np @@ -9,7 +9,7 @@ from mapie.estimator.classifier import EnsembleClassifier from mapie._machine_precision import EPSILON -from mapie._typing import ArrayLike, NDArray +from mapie._typing import NDArray class Naive(BaseClassificationScore): @@ -32,30 +32,6 @@ class Naive(BaseClassificationScore): def __init__(self) -> None: super().__init__() - def set_external_attributes( - self, - *, - classes: Optional[ArrayLike] = None, - random_state: Optional[Union[int, np.random.RandomState]] = None, - **kwargs - ) -> None: - """ - Set attributes that are not provided by the user. - - Parameters - ---------- - classes: Optional[ArrayLike] - Names of the classes. - - By default ``None``. - - random_state: Optional[Union[int, RandomState]] - Pseudo random number generator state. - """ - super().set_external_attributes(**kwargs) - self.classes = classes - self.random_state = random_state - def get_conformity_scores( self, y: NDArray, diff --git a/mapie/conformity_scores/sets/topk.py b/mapie/conformity_scores/sets/topk.py index d46ee08e1..2d5693cc1 100644 --- a/mapie/conformity_scores/sets/topk.py +++ b/mapie/conformity_scores/sets/topk.py @@ -1,4 +1,4 @@ -from typing import Optional, Union, cast +from typing import Optional, cast import numpy as np @@ -44,30 +44,6 @@ class TopK(BaseClassificationScore): def __init__(self) -> None: super().__init__() - def set_external_attributes( - self, - *, - classes: Optional[int] = None, - random_state: Optional[Union[int, np.random.RandomState]] = None, - **kwargs - ) -> None: - """ - Set attributes that are not provided by the user. - - Parameters - ---------- - classes: Optional[ArrayLike] - Names of the classes. - - By default ``None``. - - random_state: Optional[Union[int, RandomState]] - Pseudo random number generator state. - """ - super().set_external_attributes(**kwargs) - self.classes = classes - self.random_state = random_state - def get_conformity_scores( self, y: NDArray, From e505a2215aadfd1310c3bbea0258172701517e61 Mon Sep 17 00:00:00 2001 From: Thibault Cordier Date: Tue, 16 Jul 2024 14:55:31 +0200 Subject: [PATCH 40/46] UPD: change with correct conformity score name --- mapie/conformity_scores/regression.py | 2 +- mapie/regression/regression.py | 6 +++--- mapie/tests/test_conformity_scores_bounds.py | 2 +- mapie/tests/test_regression.py | 4 ++-- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/mapie/conformity_scores/regression.py b/mapie/conformity_scores/regression.py index 1e58cc163..a3dcb45e8 100644 --- a/mapie/conformity_scores/regression.py +++ b/mapie/conformity_scores/regression.py @@ -149,7 +149,7 @@ def check_consistency( if max_conf_score > self.eps: raise ValueError( "The two functions get_conformity_scores and " - "get_estimation_distribution of the BaseConformityScore class " + "get_estimation_distribution of the BaseRegressionScore class " "are not consistent. " "The following equation must be verified: " "self.get_estimation_distribution(y_pred, " diff --git a/mapie/regression/regression.py b/mapie/regression/regression.py index 88e827368..aeb68b5bf 100644 --- a/mapie/regression/regression.py +++ b/mapie/regression/regression.py @@ -138,8 +138,8 @@ class MapieRegressor(BaseEstimator, RegressorMixin): By default ``0``. - conformity_score: Optional[ConformityScore] - ConformityScore instance. + conformity_score: Optional[BaseRegressionScore] + BaseRegressionScore instance. It defines the link between the observed values, the predicted ones and the conformity scores. For instance, the default ``None`` value correspondonds to a conformity score which assumes @@ -147,7 +147,7 @@ class MapieRegressor(BaseEstimator, RegressorMixin): - ``None``, to use the default ``AbsoluteConformityScore`` conformity score - - ConformityScore: any ``ConformityScore`` class + - BaseRegressionScore: any ``BaseRegressionScore`` class By default ``None``. diff --git a/mapie/tests/test_conformity_scores_bounds.py b/mapie/tests/test_conformity_scores_bounds.py index 345c33652..bd7b9209d 100644 --- a/mapie/tests/test_conformity_scores_bounds.py +++ b/mapie/tests/test_conformity_scores_bounds.py @@ -222,7 +222,7 @@ def test_gamma_conformity_score_check_predicted_value( def test_check_consistency() -> None: """ - Test that a dummy ConformityScore class that gives inconsistent scores + Test that a dummy BaseRegressionScore class that gives inconsistent scores and distributions raises an error. """ dummy_conf_score = DummyConformityScore() diff --git a/mapie/tests/test_regression.py b/mapie/tests/test_regression.py index c35ebec34..da81798a2 100644 --- a/mapie/tests/test_regression.py +++ b/mapie/tests/test_regression.py @@ -367,7 +367,7 @@ def test_calibration_data_size_asymmetric_score(delta: float) -> None: # Define an asymmetric conformity score score = AbsoluteConformityScore(sym=False) - # Test when ConformityScore is asymmetric + # Test when BaseRegressionScore is asymmetric # and calibration data size is sufficient n_calib_sufficient = int(np.ceil(1/(1-delta) * 2)) + 1 Xc, Xt, yc, _ = train_test_split(Xct, yct, train_size=n_calib_sufficient) @@ -377,7 +377,7 @@ def test_calibration_data_size_asymmetric_score(delta: float) -> None: mapie_reg.fit(Xc, yc) mapie_reg.predict(Xt, alpha=1-delta) - # Test when ConformityScore is asymmetric + # Test when BaseRegressionScore is asymmetric # and calibration data size is too low with pytest.raises( ValueError, match=r"Number of samples of the score is too low*" From 04e52d42f863785db25fcea70996b5b6d3a3c254 Mon Sep 17 00:00:00 2001 From: Thibault Cordier Date: Tue, 16 Jul 2024 15:23:36 +0200 Subject: [PATCH 41/46] UPD: change class and method names --- mapie/conformity_scores/__init__.py | 15 ++++++----- mapie/conformity_scores/sets/__init__.py | 20 +++++++------- mapie/conformity_scores/sets/aps.py | 8 +++--- mapie/conformity_scores/sets/lac.py | 2 +- mapie/conformity_scores/sets/naive.py | 6 ++--- mapie/conformity_scores/sets/raps.py | 15 ++++++----- mapie/conformity_scores/sets/topk.py | 2 +- mapie/conformity_scores/utils.py | 31 ++++++++++++---------- mapie/tests/test_conformity_scores_sets.py | 28 ++++++++++++------- 9 files changed, 72 insertions(+), 55 deletions(-) diff --git a/mapie/conformity_scores/__init__.py b/mapie/conformity_scores/__init__.py index 88a3530be..d8f6b1f5b 100644 --- a/mapie/conformity_scores/__init__.py +++ b/mapie/conformity_scores/__init__.py @@ -3,7 +3,10 @@ from .bounds import ( AbsoluteConformityScore, GammaConformityScore, ResidualNormalisedScore ) -from .sets import APS, LAC, Naive, RAPS, TopK +from .sets import ( + APSConformityScore, LACConformityScore, NaiveConformityScore, + RAPSConformityScore, TopKConformityScore +) __all__ = [ @@ -12,9 +15,9 @@ "AbsoluteConformityScore", "GammaConformityScore", "ResidualNormalisedScore", - "Naive", - "LAC", - "APS", - "RAPS", - "TopK" + "NaiveConformityScore", + "LACConformityScore", + "APSConformityScore", + "RAPSConformityScore", + "TopKConformityScore" ] diff --git a/mapie/conformity_scores/sets/__init__.py b/mapie/conformity_scores/sets/__init__.py index 36f203cc5..9db834634 100644 --- a/mapie/conformity_scores/sets/__init__.py +++ b/mapie/conformity_scores/sets/__init__.py @@ -1,14 +1,14 @@ -from .naive import Naive -from .lac import LAC -from .aps import APS -from .raps import RAPS -from .topk import TopK +from .naive import NaiveConformityScore +from .lac import LACConformityScore +from .aps import APSConformityScore +from .raps import RAPSConformityScore +from .topk import TopKConformityScore __all__ = [ - "Naive", - "LAC", - "APS", - "RAPS", - "TopK", + "NaiveConformityScore", + "LACConformityScore", + "APSConformityScore", + "RAPSConformityScore", + "TopKConformityScore", ] diff --git a/mapie/conformity_scores/sets/aps.py b/mapie/conformity_scores/sets/aps.py index 9d3a6d9a2..9c7affd0b 100644 --- a/mapie/conformity_scores/sets/aps.py +++ b/mapie/conformity_scores/sets/aps.py @@ -4,7 +4,7 @@ from sklearn.dummy import check_random_state from sklearn.calibration import label_binarize -from mapie.conformity_scores.sets.naive import Naive +from mapie.conformity_scores.sets.naive import NaiveConformityScore from mapie.conformity_scores.sets.utils import ( check_include_last_label, check_proba_normalized ) @@ -15,7 +15,7 @@ from mapie.utils import compute_quantiles -class APS(Naive): +class APSConformityScore(NaiveConformityScore): """ Adaptive Prediction Sets (APS) method-based non-conformity score. It is based on the sum of the softmax outputs of the labels until the true @@ -211,7 +211,7 @@ def get_conformity_score_quantiles( return quantiles_ - def _compute_vs_parameter( + def _compute_v_parameter( self, y_proba_last_cumsumed: NDArray, threshold: NDArray, @@ -302,7 +302,7 @@ def _add_random_tie_breaking( ) # get the V parameter from Romano+(2020) or Angelopoulos+(2020) - vs = self._compute_vs_parameter( + vs = self._compute_v_parameter( y_proba_last_cumsumed, threshold, y_pred_proba_last, diff --git a/mapie/conformity_scores/sets/lac.py b/mapie/conformity_scores/sets/lac.py index a2b48795c..a81d39240 100644 --- a/mapie/conformity_scores/sets/lac.py +++ b/mapie/conformity_scores/sets/lac.py @@ -11,7 +11,7 @@ from mapie.utils import compute_quantiles -class LAC(BaseClassificationScore): +class LACConformityScore(BaseClassificationScore): """ Least Ambiguous set-valued Classifier (LAC) method-based non conformity score (also formerly called ``"score"``). diff --git a/mapie/conformity_scores/sets/naive.py b/mapie/conformity_scores/sets/naive.py index 9ec6c2399..79ba4407c 100644 --- a/mapie/conformity_scores/sets/naive.py +++ b/mapie/conformity_scores/sets/naive.py @@ -12,7 +12,7 @@ from mapie._typing import NDArray -class Naive(BaseClassificationScore): +class NaiveConformityScore(BaseClassificationScore): """ Naive classification non-conformity score method that is based on the cumulative sum of probabilities until the 1-alpha threshold. @@ -121,7 +121,7 @@ def get_conformity_score_quantiles( quantiles_ = 1 - alpha_np return quantiles_ - def _add_regualization( + def _add_regularization( self, y_pred_proba_sorted_cumsum: NDArray, **kwargs @@ -188,7 +188,7 @@ def _get_last_included_proba( ) # get sorted cumulated score y_pred_proba_sorted_cumsum = np.cumsum(y_pred_proba_sorted, axis=1) - y_pred_proba_sorted_cumsum = self._add_regualization( + y_pred_proba_sorted_cumsum = self._add_regularization( y_pred_proba_sorted_cumsum, **kwargs ) # Do nothing as no regularization for the naive method diff --git a/mapie/conformity_scores/sets/raps.py b/mapie/conformity_scores/sets/raps.py index 9894f991a..070cf4b2a 100644 --- a/mapie/conformity_scores/sets/raps.py +++ b/mapie/conformity_scores/sets/raps.py @@ -7,7 +7,7 @@ from sklearn.utils import _safe_indexing from sklearn.utils.validation import _num_samples -from mapie.conformity_scores.sets.aps import APS +from mapie.conformity_scores.sets.aps import APSConformityScore from mapie.conformity_scores.sets.utils import get_true_label_position from mapie.estimator.classifier import EnsembleClassifier @@ -17,12 +17,13 @@ from mapie.utils import check_alpha_and_n_samples, compute_quantiles -class RAPS(APS): +class RAPSConformityScore(APSConformityScore): """ Regularized Adaptive Prediction Sets (RAPS) method-based non-conformity - score. It uses the same technique as ``APS`` class but with a penalty term - to reduce the size of prediction sets. See [1] for more details. For now, - this method only works with ``"prefit"`` and ``"split"`` strategies. + score. It uses the same technique as ``APSConformityScore`` class but with + a penalty term to reduce the size of prediction sets. See [1] for more + details. For now, this method only works with ``"prefit"`` and ``"split"`` + strategies. References ---------- @@ -511,7 +512,7 @@ def get_conformity_score_quantiles( return quantiles_ - def _add_regualization( + def _add_regularization( self, y_pred_proba_sorted_cumsum: NDArray, lambda_: Optional[float] = None, @@ -571,7 +572,7 @@ def _add_regualization( return y_pred_proba_sorted_cumsum - def _compute_vs_parameter( + def _compute_v_parameter( self, y_proba_last_cumsumed: NDArray, threshold: NDArray, diff --git a/mapie/conformity_scores/sets/topk.py b/mapie/conformity_scores/sets/topk.py index 2d5693cc1..4e86a2671 100644 --- a/mapie/conformity_scores/sets/topk.py +++ b/mapie/conformity_scores/sets/topk.py @@ -13,7 +13,7 @@ from mapie.utils import compute_quantiles -class TopK(BaseClassificationScore): +class TopKConformityScore(BaseClassificationScore): """ Top-K method-based non-conformity score. diff --git a/mapie/conformity_scores/utils.py b/mapie/conformity_scores/utils.py index 1cae0f1c8..04295e794 100644 --- a/mapie/conformity_scores/utils.py +++ b/mapie/conformity_scores/utils.py @@ -7,19 +7,22 @@ from .regression import BaseRegressionScore from .classification import BaseClassificationScore from .bounds import AbsoluteConformityScore -from .sets import APS, LAC, Naive, RAPS, TopK +from .sets import ( + APSConformityScore, LACConformityScore, NaiveConformityScore, + RAPSConformityScore, TopKConformityScore +) from mapie._typing import ArrayLike METHOD_SCORE_MAP = { - 'score': lambda: LAC(), - 'lac': lambda: LAC(), - 'cumulated_score': lambda: APS(), - 'aps': lambda: APS(), - 'naive': lambda: Naive(), - 'raps': lambda: RAPS(), - 'top_k': lambda: TopK() + 'score': lambda: LACConformityScore(), + 'lac': lambda: LACConformityScore(), + 'cumulated_score': lambda: APSConformityScore(), + 'aps': lambda: APSConformityScore(), + 'naive': lambda: NaiveConformityScore(), + 'raps': lambda: RAPSConformityScore(), + 'top_k': lambda: TopKConformityScore() } @@ -117,10 +120,10 @@ def check_depreciated_size_raps( warnings.warn( "WARNING: Deprecated parameter. " "The parameter `size_raps` is deprecated. " - "In the next release, `RAPS` takes precedence over " + "In the next release, `RAPSConformityScore` takes precedence over " "`MapieClassifier` for setting the size used. " - "Prefer to define `size_raps` in `RAPS` rather than " - "in the `fit` method of `MapieClassifier`.", + "Prefer to define `size_raps` in `RAPSConformityScore` rather " + "than in the `fit` method of `MapieClassifier`.", DeprecationWarning ) @@ -148,7 +151,7 @@ def check_target( or ``"score"`` or if type of target is not multi-class. """ check_classification_targets(y) - if type_of_target(y) == "binary" and not isinstance(conformity_score, LAC): + if type_of_target(y) == "binary" and not isinstance(conformity_score, LACConformityScore): raise ValueError( "Invalid method for binary target. " "Your target is not of type multiclass and " @@ -163,7 +166,7 @@ def check_classification_conformity_score( ) -> BaseClassificationScore: """ Check parameter ``conformity_score`` for classification task. - By default, return a LAC instance. + By default, return a LACConformityScore instance. Parameters ---------- @@ -219,4 +222,4 @@ def check_classification_conformity_score( f"Allowed values are {list(METHOD_SCORE_MAP.keys())}." ) else: - return LAC() + return LACConformityScore() diff --git a/mapie/tests/test_conformity_scores_sets.py b/mapie/tests/test_conformity_scores_sets.py index a5197f341..2e258a160 100644 --- a/mapie/tests/test_conformity_scores_sets.py +++ b/mapie/tests/test_conformity_scores_sets.py @@ -8,14 +8,20 @@ from mapie._typing import NDArray from mapie.classification import MapieClassifier from mapie.conformity_scores import BaseClassificationScore -from mapie.conformity_scores.sets import APS, LAC, Naive, RAPS, TopK +from mapie.conformity_scores.sets import ( + APSConformityScore, LACConformityScore, NaiveConformityScore, + RAPSConformityScore, TopKConformityScore +) from mapie.conformity_scores.utils import check_classification_conformity_score from mapie.utils import check_alpha random_state = 42 -cs_list = [None, LAC(), APS(), RAPS(), Naive(), TopK()] +cs_list = [ + None, LACConformityScore(), APSConformityScore(), RAPSConformityScore(), + NaiveConformityScore(), TopKConformityScore() +] wrong_cs_list = [object(), "LAC", 1] valid_method_list = ['naive', 'aps', 'raps', 'lac', 'top_k'] all_method_list = valid_method_list + [None] @@ -130,7 +136,9 @@ def test_check_depreciated_size_raps(size_raps: float, cv: str) -> None: a DeprecationWarning when using size_raps. """ clf = LogisticRegression().fit(X, y) - mapie_clf = MapieClassifier(estimator=clf, conformity_score=RAPS(), cv=cv) + mapie_clf = MapieClassifier( + estimator=clf, conformity_score=RAPSConformityScore(), cv=cv + ) with pytest.warns( DeprecationWarning, match="The parameter `size_raps` is deprecated.*" @@ -146,7 +154,7 @@ def test_regularize_conf_scores_shape(k_lambda) -> None: lambda_, k = k_lambda[0], k_lambda[1] conf_scores = np.random.rand(100, 1) cutoff = np.cumsum(np.ones(conf_scores.shape)) - 1 - reg_conf_scores = RAPS._regularize_conformity_score( + reg_conf_scores = RAPSConformityScore._regularize_conformity_score( k, lambda_, conf_scores, cutoff ) @@ -166,7 +174,9 @@ def test_get_true_label_cumsum_proba_shape() -> None: ) mapie_clf.fit(X, y) classes = mapie_clf.classes_ - cumsum_proba, cutoff = APS.get_true_label_cumsum_proba(y, y_pred, classes) + cumsum_proba, cutoff = APSConformityScore.get_true_label_cumsum_proba( + y, y_pred, classes + ) assert cumsum_proba.shape == (len(X), 1) assert cutoff.shape == (len(X), ) @@ -184,7 +194,7 @@ def test_get_true_label_cumsum_proba_result() -> None: ) mapie_clf.fit(X_toy, y_toy) classes = mapie_clf.classes_ - cumsum_proba, cutoff = APS.get_true_label_cumsum_proba( + cumsum_proba, cutoff = APSConformityScore.get_true_label_cumsum_proba( y_toy, y_pred, classes ) np.testing.assert_allclose( @@ -225,9 +235,9 @@ def test_get_last_included_proba_shape(k_lambda, include_last_label): ) y_p_p_c, y_p_i_l, y_p_p_i_l = \ - RAPS._get_last_included_proba( - RAPS(), y_pred_proba, thresholds, include_last_label, - lambda_=lambda_, k_star=k + RAPSConformityScore._get_last_included_proba( + RAPSConformityScore(), y_pred_proba, thresholds, + include_last_label, lambda_=lambda_, k_star=k ) assert y_p_p_c.shape == (len(X), len(np.unique(y)), len(thresholds)) From b724c35e4ffc70952412e11df9989cdaa8ed6590 Mon Sep 17 00:00:00 2001 From: Thibault Cordier Date: Tue, 16 Jul 2024 15:27:24 +0200 Subject: [PATCH 42/46] FIX: line too long --- mapie/conformity_scores/utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/mapie/conformity_scores/utils.py b/mapie/conformity_scores/utils.py index 04295e794..b995926c9 100644 --- a/mapie/conformity_scores/utils.py +++ b/mapie/conformity_scores/utils.py @@ -151,7 +151,10 @@ def check_target( or ``"score"`` or if type of target is not multi-class. """ check_classification_targets(y) - if type_of_target(y) == "binary" and not isinstance(conformity_score, LACConformityScore): + if ( + type_of_target(y) == "binary" and + not isinstance(conformity_score, LACConformityScore) + ): raise ValueError( "Invalid method for binary target. " "Your target is not of type multiclass and " From ebf107a3cd9b7300335f301fb84289dd4953a6a3 Mon Sep 17 00:00:00 2001 From: Thibault Cordier Date: Tue, 16 Jul 2024 15:53:45 +0200 Subject: [PATCH 43/46] UPD: move check cv - cs function --- mapie/classification.py | 18 +++++++++++--- mapie/conformity_scores/sets/raps.py | 36 +--------------------------- 2 files changed, 16 insertions(+), 38 deletions(-) diff --git a/mapie/classification.py b/mapie/classification.py index d73085c0e..aa1f321b8 100644 --- a/mapie/classification.py +++ b/mapie/classification.py @@ -5,13 +5,14 @@ import numpy as np from sklearn.base import BaseEstimator, ClassifierMixin -from sklearn.model_selection import BaseCrossValidator +from sklearn.model_selection import BaseCrossValidator, BaseShuffleSplit from sklearn.preprocessing import LabelEncoder from sklearn.utils import check_random_state from sklearn.utils.validation import (_check_y, check_is_fitted, indexable) from mapie._typing import ArrayLike, NDArray from mapie.conformity_scores import BaseClassificationScore +from mapie.conformity_scores.sets.raps import RAPSConformityScore from mapie.conformity_scores.utils import ( check_depreciated_size_raps, check_classification_conformity_score, check_target @@ -39,6 +40,7 @@ class MapieClassifier(BaseEstimator, ClassifierMixin): If ``None``, estimator defaults to a ``LogisticRegression`` instance. method: Optional[str] + [DEPRECIATED see instead conformity_score] Method to choose for prediction interval estimates. Choose among: @@ -119,7 +121,7 @@ class MapieClassifier(BaseEstimator, ClassifierMixin): By default ``None``. - conformity_score_function_: BaseClassificationScore + conformity_score: BaseClassificationScore Score function that handle all that is related to conformity scores. In any case, the `conformity_score` parameter takes precedence over the @@ -378,12 +380,22 @@ def _check_fit_parameter( ) check_depreciated_size_raps(size_raps) cs_estimator.set_external_attributes( - cv=self.cv, classes=self.classes_, label_encoder=self.label_encoder_, size_raps=size_raps, random_state=self.random_state ) + if ( + isinstance(cs_estimator, RAPSConformityScore) and + not ( + self.cv in ["split", "prefit"] or + isinstance(self.cv, BaseShuffleSplit) + ) + ): + raise ValueError( + "RAPS method can only be used " + "with ``cv='split'`` and ``cv='prefit'``." + ) # Cast X, y_enc, y = cast(NDArray, X), cast(NDArray, y_enc), cast(NDArray, y) diff --git a/mapie/conformity_scores/sets/raps.py b/mapie/conformity_scores/sets/raps.py index 070cf4b2a..c03c2b48e 100644 --- a/mapie/conformity_scores/sets/raps.py +++ b/mapie/conformity_scores/sets/raps.py @@ -2,8 +2,7 @@ import numpy as np from sklearn.calibration import LabelEncoder -from sklearn.model_selection import (BaseCrossValidator, BaseShuffleSplit, - StratifiedShuffleSplit) +from sklearn.model_selection import StratifiedShuffleSplit from sklearn.utils import _safe_indexing from sklearn.utils.validation import _num_samples @@ -49,9 +48,6 @@ class RAPSConformityScore(APSConformityScore): quantiles_: ArrayLike of shape (n_alpha) The quantiles estimated from ``get_sets`` method. - cv: Union[int, str, BaseCrossValidator] - The cross-validation strategy for computing scores. - label_encoder: LabelEncoder The label encoder used to encode the labels. @@ -60,8 +56,6 @@ class RAPSConformityScore(APSConformityScore): k_star for the RAPS method. """ - valid_cv_ = ["prefit", "split"] - def __init__( self, size_raps: Optional[float] = 0.2 @@ -72,7 +66,6 @@ def __init__( def set_external_attributes( self, *, - cv: Optional[Union[str, BaseCrossValidator, BaseShuffleSplit]] = None, label_encoder: Optional[LabelEncoder] = None, size_raps: Optional[float] = None, **kwargs @@ -82,11 +75,6 @@ def set_external_attributes( Parameters ---------- - cv: Optional[Union[int, str, BaseCrossValidator]] - The cross-validation strategy for computing scores. - - By default ``None``. - label_encoder: Optional[LabelEncoder] The label encoder used to encode the labels. @@ -99,28 +87,9 @@ def set_external_attributes( By default ``None``. """ super().set_external_attributes(**kwargs) - self.cv = cast(Union[str, BaseCrossValidator, BaseShuffleSplit], cv) self.label_encoder_ = cast(LabelEncoder, label_encoder) self.size_raps = size_raps - def _check_cv(self): - """ - Check that if the method used is ``"raps"``, then - the cross validation strategy is ``"prefit"``. - - Raises - ------ - ValueError - If ``method`` is ``"raps"`` and ``cv`` is not ``"prefit"``. - """ - if not ( - self.cv in self.valid_cv_ or isinstance(self.cv, BaseShuffleSplit) - ): - raise ValueError( - "RAPS method can only be used " - f"with cv in {self.valid_cv_}." - ) - def split_data( self, X: NDArray, @@ -162,9 +131,6 @@ def split_data( - NDArray of shape (n_samples,) - NDArray of shape (n_samples,) """ - # Checks - self._check_cv() - # Split data for raps method raps_split = StratifiedShuffleSplit( n_splits=1, From c3d9025fe8c37e4be2ae5c2b4d8152121e1993d1 Mon Sep 17 00:00:00 2001 From: Thibault Cordier Date: Tue, 16 Jul 2024 17:39:38 +0200 Subject: [PATCH 44/46] FIX: typo in docstring and variable names --- mapie/classification.py | 2 +- mapie/conformity_scores/classification.py | 6 +++--- mapie/conformity_scores/interface.py | 8 ++++---- mapie/conformity_scores/regression.py | 6 +++--- mapie/conformity_scores/sets/aps.py | 12 ++++++------ mapie/conformity_scores/sets/lac.py | 2 +- mapie/conformity_scores/sets/naive.py | 2 +- mapie/conformity_scores/sets/raps.py | 16 ++++++++-------- 8 files changed, 27 insertions(+), 27 deletions(-) diff --git a/mapie/classification.py b/mapie/classification.py index aa1f321b8..f4e19ba45 100644 --- a/mapie/classification.py +++ b/mapie/classification.py @@ -47,7 +47,7 @@ class MapieClassifier(BaseEstimator, ClassifierMixin): - ``"naive"``, sum of the probabilities until the 1-alpha threshold. - ``"lac"`` (formerly called ``"score"``), Least Ambiguous set-valued - Classifier. It is based on the the scores + Classifier. It is based on the scores (i.e. 1 minus the softmax score of the true label) on the calibration set. See [1] for more details. diff --git a/mapie/conformity_scores/classification.py b/mapie/conformity_scores/classification.py index 2e2b010c1..e450514fd 100644 --- a/mapie/conformity_scores/classification.py +++ b/mapie/conformity_scores/classification.py @@ -203,7 +203,7 @@ def predict_set( ): """ Compute the prediction sets on new samples based on the uncertainty of - the target confidence interval. + the target confidence set. Parameters: ----------- @@ -211,14 +211,14 @@ def predict_set( The input data or samples for prediction. alpha_np: NDArray of shape (n_alpha, ) - Represents the uncertainty of the confidence interval to produce. + Represents the uncertainty of the confidence set to produce. **kwargs: dict Additional keyword arguments. Returns: -------- - The output strcture depend on the ``get_sets`` method. + The output structure depend on the ``get_sets`` method. The prediction sets for each sample and each alpha level. """ return self.get_sets(X=X, alpha_np=alpha_np, **kwargs) diff --git a/mapie/conformity_scores/interface.py b/mapie/conformity_scores/interface.py index e7eaa151c..29f8ff282 100644 --- a/mapie/conformity_scores/interface.py +++ b/mapie/conformity_scores/interface.py @@ -39,7 +39,7 @@ def set_ref_predictor( Parameters ---------- - predictor: BaeEstimator + predictor: BaseEstimator Reference predictor. """ self.predictor = predictor @@ -172,7 +172,7 @@ def predict_set( ): """ Compute the prediction sets on new samples based on the uncertainty of - the target confidence interval. + the target confidence set. Parameters: ----------- @@ -180,13 +180,13 @@ def predict_set( The input data or samples for prediction. alpha_np: NDArray of shape (n_alpha, ) - Represents the uncertainty of the confidence interval to produce. + Represents the uncertainty of the confidence set to produce. **kwargs: dict Additional keyword arguments. Returns: -------- - The output strcture depend on the subclass. + The output structure depend on the subclass. The prediction sets for each sample and each alpha level. """ diff --git a/mapie/conformity_scores/regression.py b/mapie/conformity_scores/regression.py index a3dcb45e8..e6e098464 100644 --- a/mapie/conformity_scores/regression.py +++ b/mapie/conformity_scores/regression.py @@ -388,7 +388,7 @@ def predict_set( ): """ Compute the prediction sets on new samples based on the uncertainty of - the target confidence interval. + the target confidence set. Parameters: ----------- @@ -396,14 +396,14 @@ def predict_set( The input data or samples for prediction. alpha_np: NDArray of shape (n_alpha, ) - Represents the uncertainty of the confidence interval to produce. + Represents the uncertainty of the confidence set to produce. **kwargs: dict Additional keyword arguments. Returns: -------- - The output strcture depend on the ``get_bounds`` method. + The output structure depend on the ``get_bounds`` method. The prediction sets for each sample and each alpha level. """ return self.get_bounds(X=X, alpha_np=alpha_np, **kwargs) diff --git a/mapie/conformity_scores/sets/aps.py b/mapie/conformity_scores/sets/aps.py index 9c7affd0b..8e5cb7d27 100644 --- a/mapie/conformity_scores/sets/aps.py +++ b/mapie/conformity_scores/sets/aps.py @@ -242,11 +242,11 @@ def _compute_v_parameter( Vs parameters. """ # compute V parameter from Romano+(2020) - vs = ( + v_param = ( (y_proba_last_cumsumed - threshold.reshape(1, -1)) / y_pred_proba_last[:, 0, :] ) - return vs + return v_param def _add_random_tie_breaking( self, @@ -302,7 +302,7 @@ def _add_random_tie_breaking( ) # get the V parameter from Romano+(2020) or Angelopoulos+(2020) - vs = self._compute_v_parameter( + v_param = self._compute_v_parameter( y_proba_last_cumsumed, threshold, y_pred_proba_last, @@ -312,13 +312,13 @@ def _add_random_tie_breaking( # get random numbers for each observation and alpha value random_state = check_random_state(self.random_state) random_state = cast(np.random.RandomState, random_state) - us = random_state.uniform(size=(prediction_sets.shape[0], 1)) + u_param = random_state.uniform(size=(prediction_sets.shape[0], 1)) # remove last label from comparison between uniform number and V - vs_less_than_us = np.less_equal(vs - us, EPSILON) + label_to_keep = np.less_equal(v_param - u_param, EPSILON) np.put_along_axis( prediction_sets, y_pred_index_last, - vs_less_than_us[:, np.newaxis, :], + label_to_keep[:, np.newaxis, :], axis=1 ) return prediction_sets diff --git a/mapie/conformity_scores/sets/lac.py b/mapie/conformity_scores/sets/lac.py index a81d39240..bf5bcbd01 100644 --- a/mapie/conformity_scores/sets/lac.py +++ b/mapie/conformity_scores/sets/lac.py @@ -16,7 +16,7 @@ class LACConformityScore(BaseClassificationScore): Least Ambiguous set-valued Classifier (LAC) method-based non conformity score (also formerly called ``"score"``). - It is based on the the scores (i.e. 1 minus the softmax score of the true + It is based on the scores (i.e. 1 minus the softmax score of the true label) on the calibration set. References diff --git a/mapie/conformity_scores/sets/naive.py b/mapie/conformity_scores/sets/naive.py index 79ba4407c..19b0e42c9 100644 --- a/mapie/conformity_scores/sets/naive.py +++ b/mapie/conformity_scores/sets/naive.py @@ -156,7 +156,7 @@ def _get_last_included_proba( ) -> Tuple[NDArray, NDArray, NDArray]: """ Function that returns the smallest score - among those which are included in the prediciton set. + among those which are included in the prediction set. Parameters ---------- diff --git a/mapie/conformity_scores/sets/raps.py b/mapie/conformity_scores/sets/raps.py index c03c2b48e..1c39aed8f 100644 --- a/mapie/conformity_scores/sets/raps.py +++ b/mapie/conformity_scores/sets/raps.py @@ -125,11 +125,11 @@ def split_data( ------- Tuple[NDArray, NDArray, NDArray, NDArray, Optional[NDArray], Optional[NDArray]] - - NDArray of shape (n_samples, n_features) - - NDArray of shape (n_samples,) - - NDArray of shape (n_samples,) - - NDArray of shape (n_samples,) - - NDArray of shape (n_samples,) + - X: NDArray of shape (n_samples, n_features) + - y: NDArray of shape (n_samples,) + - y_enc: NDArray of shape (n_samples,) + - sample_weight: Optional[NDArray] of shape (n_samples,) + - groups: Optional[NDArray] of shape (n_samples,) """ # Split data for raps method raps_split = StratifiedShuffleSplit( @@ -258,7 +258,7 @@ def _update_size_and_lambda( Parameters ---------- best_sizes: NDArray of shape (n_alphas, ) - Smallest average prediciton set size before testing + Smallest average prediction set size before testing for the new value of lambda_ alpha_np: NDArray of shape (n_alphas) @@ -570,7 +570,7 @@ def _compute_v_parameter( """ # compute V parameter from Angelopoulos+(2020) L = np.sum(prediction_sets, axis=1) - vs = ( + v_param = ( (y_proba_last_cumsumed - threshold.reshape(1, -1)) / ( y_pred_proba_last[:, 0, :] - @@ -578,4 +578,4 @@ def _compute_v_parameter( self.lambda_star * (L > self.k_star) ) ) - return vs + return v_param From d7b484757060e359da5e5a462514b1634ad183f8 Mon Sep 17 00:00:00 2001 From: Thibault Cordier Date: Tue, 16 Jul 2024 17:42:20 +0200 Subject: [PATCH 45/46] UPD: change interval to set --- mapie/conformity_scores/classification.py | 8 ++++---- mapie/conformity_scores/interface.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/mapie/conformity_scores/classification.py b/mapie/conformity_scores/classification.py index e450514fd..00e397128 100644 --- a/mapie/conformity_scores/classification.py +++ b/mapie/conformity_scores/classification.py @@ -68,7 +68,7 @@ def get_predictions( alpha_np: NDArray of shape (n_alpha,) NDArray of floats between ``0`` and ``1``, represents the - uncertainty of the confidence interval. + uncertainty of the confidence set. estimator: EnsembleClassifier Estimator that is fitted to predict y from X. @@ -99,7 +99,7 @@ def get_conformity_score_quantiles( alpha_np: NDArray of shape (n_alpha,) NDArray of floats between 0 and 1, representing the uncertainty - of the confidence interval. + of the confidence set. estimator: EnsembleClassifier Estimator that is fitted to predict y from X. @@ -135,7 +135,7 @@ def get_prediction_sets( alpha_np: NDArray of shape (n_alpha,) NDArray of floats between 0 and 1, representing the uncertainty - of the confidence interval. + of the confidence set. estimator: EnsembleClassifier Estimator that is fitted to predict y from X. @@ -165,7 +165,7 @@ def get_sets( alpha_np: NDArray of shape (n_alpha,) NDArray of floats between 0 and 1, representing the uncertainty - of the confidence interval. + of the confidence set. estimator: EnsembleClassifier Estimator that is fitted to predict y from X. diff --git a/mapie/conformity_scores/interface.py b/mapie/conformity_scores/interface.py index 29f8ff282..07345d3e4 100644 --- a/mapie/conformity_scores/interface.py +++ b/mapie/conformity_scores/interface.py @@ -114,7 +114,7 @@ def get_quantile( alpha_np: NDArray of shape (n_alpha,) NDArray of floats between ``0`` and ``1``, represents the - uncertainty of the confidence interval. + uncertainty of the confidence set. axis: int The axis from which to compute the quantile. @@ -128,7 +128,7 @@ def get_quantile( By default ``False``. unbounded: bool - Boolean specifying whether infinite prediction intervals + Boolean specifying whether infinite prediction sets could be produced (when alpha_np is greater than or equal to 1.). By default ``False``. From 4c97a005f3a164eb0a2db6b891498a03e3812c51 Mon Sep 17 00:00:00 2001 From: Thibault Cordier Date: Wed, 17 Jul 2024 11:00:41 +0200 Subject: [PATCH 46/46] UPD: documentation with score api --- doc/api.rst | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/doc/api.rst b/doc/api.rst index a36957f36..411221efd 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -73,8 +73,8 @@ Metrics metrics.spiegelhalter_statistic metrics.top_label_ece -Conformity scores -================= +Conformity scores (regression) +============================== .. autosummary:: :toctree: generated/ @@ -84,10 +84,20 @@ Conformity scores conformity_scores.AbsoluteConformityScore conformity_scores.GammaConformityScore conformity_scores.ResidualNormalisedScore + +Conformity scores (classification) +================================== + +.. autosummary:: + :toctree: generated/ + :template: class.rst + conformity_scores.BaseClassificationScore - conformity_scores.LAC - conformity_scores.APS - conformity_scores.TopK + conformity_scores.NaiveConformityScore + conformity_scores.LACConformityScore + conformity_scores.APSConformityScore + conformity_scores.RAPSConformityScore + conformity_scores.TopKConformityScore Resampling ==========