|
1 | 1 | from unittest import TestCase # or `from unittest import ...` if on Python 3.4+
|
2 |
| -from category_encoders.utils import convert_input_vector, convert_inputs, get_categorical_cols |
| 2 | +import pytest |
| 3 | +from category_encoders.utils import convert_input_vector, convert_inputs, get_categorical_cols, BaseEncoder |
| 4 | + |
| 5 | +from sklearn.base import BaseEstimator, TransformerMixin |
| 6 | +from sklearn import __version__ as skl_version |
| 7 | +from packaging.version import Version |
3 | 8 | import pandas as pd
|
4 | 9 | import numpy as np
|
5 | 10 |
|
@@ -120,3 +125,26 @@ def test_get_categorical_cols(self):
|
120 | 125 | self.assertEqual(get_categorical_cols(df.astype("object")), ["col"])
|
121 | 126 | self.assertEqual(get_categorical_cols(df.astype("category")), ["col"])
|
122 | 127 | self.assertEqual(get_categorical_cols(df.astype("string")), ["col"])
|
| 128 | + |
| 129 | + |
| 130 | +class TestBaseEncoder(TestCase): |
| 131 | + def setUp(self): |
| 132 | + class DummyEncoder(BaseEncoder, BaseEstimator, TransformerMixin): |
| 133 | + def _fit(self, X, y=None): |
| 134 | + return self |
| 135 | + |
| 136 | + def transform(self, X, y=None, override_return_df=False): |
| 137 | + return X |
| 138 | + |
| 139 | + self.encoder = DummyEncoder() |
| 140 | + |
| 141 | + @pytest.mark.skipif(Version(skl_version) < Version('1.2'), reason="requires sklean > 1.2") |
| 142 | + def test_sklearn_pandas_out_refit(self): |
| 143 | + # Thanks to Issue#437 |
| 144 | + df = pd.DataFrame({"C1": ["a", "a"], "C2": ["c", "d"]}) |
| 145 | + self.encoder.set_output(transform="pandas") |
| 146 | + self.encoder.fit_transform(df.iloc[:1]) |
| 147 | + out = self.encoder.fit_transform( |
| 148 | + df.rename(columns={'C1': 'X1', 'C2': 'X2'}) |
| 149 | + ) |
| 150 | + self.assertTrue(list(out.columns) == ['X1', 'X2']) |
0 commit comments