diff --git a/imblearn/base.py b/imblearn/base.py index 39141029a..eeabf8322 100644 --- a/imblearn/base.py +++ b/imblearn/base.py @@ -22,6 +22,14 @@ METHODS.append("fit_transform") METHODS.append("fit_resample") +try: + from sklearn.utils._metadata_requests import SIMPLE_METHODS + + SIMPLE_METHODS.append("fit_resample") +except ImportError: + # in older versions of scikit-learn, only METHODS is used + pass + class SamplerMixin(metaclass=ABCMeta): """Mixin class for samplers with abstract method. @@ -33,7 +41,7 @@ class SamplerMixin(metaclass=ABCMeta): _estimator_type = "sampler" @_fit_context(prefer_skip_nested_validation=True) - def fit(self, X, y): + def fit(self, X, y, **params): """Check inputs and statistics of the sampler. You should use ``fit_resample`` in all cases. @@ -47,6 +55,9 @@ def fit(self, X, y): y : array-like of shape (n_samples,) Target array. + **params : dict + Extra parameters to use by the sampler. + Returns ------- self : object @@ -58,7 +69,8 @@ def fit(self, X, y): ) return self - def fit_resample(self, X, y): + @_fit_context(prefer_skip_nested_validation=True) + def fit_resample(self, X, y, **params): """Resample the dataset. Parameters @@ -70,6 +82,9 @@ def fit_resample(self, X, y): y : array-like of shape (n_samples,) Corresponding label for each sample in X. + **params : dict + Extra parameters to use by the sampler. + Returns ------- X_resampled : {array-like, dataframe, sparse matrix} of shape \ @@ -87,7 +102,7 @@ def fit_resample(self, X, y): self.sampling_strategy, y, self._sampling_type ) - output = self._fit_resample(X, y) + output = self._fit_resample(X, y, **params) y_ = ( label_binarize(output[1], classes=np.unique(y)) if binarize_y else output[1] @@ -97,7 +112,7 @@ def fit_resample(self, X, y): return (X_, y_) if len(output) == 2 else (X_, y_, output[2]) @abstractmethod - def _fit_resample(self, X, y): + def _fit_resample(self, X, y, **params): """Base method defined in each sampler to defined the sampling strategy. @@ -109,6 +124,9 @@ def _fit_resample(self, X, y): y : array-like of shape (n_samples,) Corresponding label for each sample in X. + **params : dict + Extra parameters to use by the sampler. + Returns ------- X_resampled : {ndarray, sparse matrix} of shape \ @@ -139,7 +157,7 @@ def _check_X_y(self, X, y, accept_sparse=None): X, y = validate_data(self, X=X, y=y, reset=True, accept_sparse=accept_sparse) return X, y, binarize_y - def fit(self, X, y): + def fit(self, X, y, **params): """Check inputs and statistics of the sampler. You should use ``fit_resample`` in all cases. @@ -158,10 +176,9 @@ def fit(self, X, y): self : object Return the instance itself. """ - self._validate_params() - return super().fit(X, y) + return super().fit(X, y, **params) - def fit_resample(self, X, y): + def fit_resample(self, X, y, **params): """Resample the dataset. Parameters @@ -182,8 +199,7 @@ def fit_resample(self, X, y): y_resampled : array-like of shape (n_samples_new,) The corresponding label of `X_resampled`. """ - self._validate_params() - return super().fit_resample(X, y) + return super().fit_resample(X, y, **params) def _more_tags(self): return {"X_types": ["2darray", "sparse", "dataframe"]} diff --git a/imblearn/pipeline.py b/imblearn/pipeline.py index b82b9a543..0449fd3f5 100644 --- a/imblearn/pipeline.py +++ b/imblearn/pipeline.py @@ -1168,34 +1168,45 @@ def get_metadata_routing(self): router = MetadataRouter(owner=self.__class__.__name__) # first we add all steps except the last one - for _, name, trans in self._iter(with_final=False, filter_passthrough=True): + for _, name, trans in self._iter( + with_final=False, filter_passthrough=True, filter_resample=False + ): method_mapping = MethodMapping() # fit, fit_predict, and fit_transform call fit_transform if it # exists, or else fit and transform if hasattr(trans, "fit_transform"): - method_mapping.add(caller="fit", callee="fit_transform") - method_mapping.add(caller="fit_transform", callee="fit_transform") - method_mapping.add(caller="fit_predict", callee="fit_transform") - method_mapping.add(caller="fit_resample", callee="fit_transform") + ( + method_mapping.add(caller="fit", callee="fit_transform") + .add(caller="fit_transform", callee="fit_transform") + .add(caller="fit_predict", callee="fit_transform") + ) else: - method_mapping.add(caller="fit", callee="fit") - method_mapping.add(caller="fit", callee="transform") - method_mapping.add(caller="fit_transform", callee="fit") - method_mapping.add(caller="fit_transform", callee="transform") - method_mapping.add(caller="fit_predict", callee="fit") - method_mapping.add(caller="fit_predict", callee="transform") - method_mapping.add(caller="fit_resample", callee="fit") - method_mapping.add(caller="fit_resample", callee="transform") - - method_mapping.add(caller="predict", callee="transform") - method_mapping.add(caller="predict", callee="transform") - method_mapping.add(caller="predict_proba", callee="transform") - method_mapping.add(caller="decision_function", callee="transform") - method_mapping.add(caller="predict_log_proba", callee="transform") - method_mapping.add(caller="transform", callee="transform") - method_mapping.add(caller="inverse_transform", callee="inverse_transform") - method_mapping.add(caller="score", callee="transform") - method_mapping.add(caller="fit_resample", callee="transform") + ( + method_mapping.add(caller="fit", callee="fit") + .add(caller="fit", callee="transform") + .add(caller="fit_transform", callee="fit") + .add(caller="fit_transform", callee="transform") + .add(caller="fit_predict", callee="fit") + .add(caller="fit_predict", callee="transform") + ) + + ( + # handling sampler if the fit_* stage + method_mapping.add(caller="fit", callee="fit_resample") + .add(caller="fit_transform", callee="fit_resample") + .add(caller="fit_predict", callee="fit_resample") + ) + ( + method_mapping.add(caller="predict", callee="transform") + .add(caller="predict", callee="transform") + .add(caller="predict_proba", callee="transform") + .add(caller="decision_function", callee="transform") + .add(caller="predict_log_proba", callee="transform") + .add(caller="transform", callee="transform") + .add(caller="inverse_transform", callee="inverse_transform") + .add(caller="score", callee="transform") + .add(caller="fit_resample", callee="transform") + ) router.add(method_mapping=method_mapping, **{name: trans}) @@ -1207,23 +1218,24 @@ def get_metadata_routing(self): method_mapping = MethodMapping() if hasattr(final_est, "fit_transform"): method_mapping.add(caller="fit_transform", callee="fit_transform") - method_mapping.add(caller="fit_resample", callee="fit_transform") else: + ( + method_mapping.add(caller="fit", callee="fit").add( + caller="fit", callee="transform" + ) + ) + ( method_mapping.add(caller="fit", callee="fit") - method_mapping.add(caller="fit", callee="transform") - method_mapping.add(caller="fit_resample", callee="fit") - method_mapping.add(caller="fit_resample", callee="transform") - - method_mapping.add(caller="fit", callee="fit") - method_mapping.add(caller="predict", callee="predict") - method_mapping.add(caller="fit_predict", callee="fit_predict") - method_mapping.add(caller="predict_proba", callee="predict_proba") - method_mapping.add(caller="decision_function", callee="decision_function") - method_mapping.add(caller="predict_log_proba", callee="predict_log_proba") - method_mapping.add(caller="transform", callee="transform") - method_mapping.add(caller="inverse_transform", callee="inverse_transform") - method_mapping.add(caller="score", callee="score") - method_mapping.add(caller="fit_resample", callee="fit_resample") + .add(caller="predict", callee="predict") + .add(caller="fit_predict", callee="fit_predict") + .add(caller="predict_proba", callee="predict_proba") + .add(caller="decision_function", callee="decision_function") + .add(caller="predict_log_proba", callee="predict_log_proba") + .add(caller="transform", callee="transform") + .add(caller="inverse_transform", callee="inverse_transform") + .add(caller="score", callee="score") + .add(caller="fit_resample", callee="fit_resample") + ) router.add(method_mapping=method_mapping, **{final_name: final_est}) return router diff --git a/imblearn/tests/test_pipeline.py b/imblearn/tests/test_pipeline.py index dfe74442a..b9c6f9df3 100644 --- a/imblearn/tests/test_pipeline.py +++ b/imblearn/tests/test_pipeline.py @@ -34,6 +34,7 @@ ) from sklearn.utils.fixes import parse_version +from imblearn.base import BaseSampler from imblearn.datasets import make_imbalance from imblearn.pipeline import Pipeline, make_pipeline from imblearn.under_sampling import EditedNearestNeighbours as ENN @@ -1495,3 +1496,24 @@ def test_transform_input_sklearn_version(): # end of transform_input tests # ============================= + + +def test_metadata_routing_with_sampler(): + """Check that we can use a sampler with metadata routing.""" + X, y = make_classification() + cost_matrix = np.random.rand(X.shape[0], 2, 2) + + class CostSensitiveSampler(BaseSampler): + def fit_resample(self, X, y, cost_matrix=None): + return self._fit_resample(X, y, cost_matrix=cost_matrix) + + def _fit_resample(self, X, y, cost_matrix=None): + self.cost_matrix_ = cost_matrix + return X, y + + with config_context(enable_metadata_routing=True): + sampler = CostSensitiveSampler().set_fit_resample_request(cost_matrix=True) + pipeline = Pipeline([("sampler", sampler), ("model", LogisticRegression())]) + pipeline.fit(X, y, cost_matrix=cost_matrix) + + assert_allclose(pipeline[0].cost_matrix_, cost_matrix)