Skip to content

Commit 11fbba6

Browse files
Merge pull request #438 from bmreiniger/fix_437
Fix for sklearn-pandas-out and refit
2 parents 06e46db + 5d153da commit 11fbba6

File tree

3 files changed

+32
-2
lines changed

3 files changed

+32
-2
lines changed

category_encoders/utils.py

+1
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,7 @@ def fit(self, X, y=None, **kwargs):
320320
self._fit(X, y, **kwargs)
321321

322322
# for finding invariant columns transform without y (as is done on the test set)
323+
self.feature_names_out_ = None # Issue#437
323324
X_transformed = self.transform(X, override_return_df=True)
324325
self.feature_names_out_ = X_transformed.columns.tolist()
325326

requirements-dev.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
sphinx
22
sphinx_rtd_theme
33
pytest
4-
numpydoc
4+
numpydoc
5+
packaging

tests/test_utils.py

+29-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
from unittest import TestCase # or `from unittest import ...` if on Python 3.4+
2-
from category_encoders.utils import convert_input_vector, convert_inputs, get_categorical_cols
2+
import pytest
3+
from category_encoders.utils import convert_input_vector, convert_inputs, get_categorical_cols, BaseEncoder
4+
5+
from sklearn.base import BaseEstimator, TransformerMixin
6+
from sklearn import __version__ as skl_version
7+
from packaging.version import Version
38
import pandas as pd
49
import numpy as np
510

@@ -120,3 +125,26 @@ def test_get_categorical_cols(self):
120125
self.assertEqual(get_categorical_cols(df.astype("object")), ["col"])
121126
self.assertEqual(get_categorical_cols(df.astype("category")), ["col"])
122127
self.assertEqual(get_categorical_cols(df.astype("string")), ["col"])
128+
129+
130+
class TestBaseEncoder(TestCase):
131+
def setUp(self):
132+
class DummyEncoder(BaseEncoder, BaseEstimator, TransformerMixin):
133+
def _fit(self, X, y=None):
134+
return self
135+
136+
def transform(self, X, y=None, override_return_df=False):
137+
return X
138+
139+
self.encoder = DummyEncoder()
140+
141+
@pytest.mark.skipif(Version(skl_version) < Version('1.2'), reason="requires sklean > 1.2")
142+
def test_sklearn_pandas_out_refit(self):
143+
# Thanks to Issue#437
144+
df = pd.DataFrame({"C1": ["a", "a"], "C2": ["c", "d"]})
145+
self.encoder.set_output(transform="pandas")
146+
self.encoder.fit_transform(df.iloc[:1])
147+
out = self.encoder.fit_transform(
148+
df.rename(columns={'C1': 'X1', 'C2': 'X2'})
149+
)
150+
self.assertTrue(list(out.columns) == ['X1', 'X2'])

0 commit comments

Comments
 (0)