Skip to content

Commit d656528

Browse files
Merge pull request #350 from lvgig/feature/remove_no_nulls_error_nearest_mean
adjusted behaviour of NearestMeanResponse transformer to be able to c…
2 parents 1792bbe + 1e474e2 commit d656528

File tree

3 files changed

+97
-29
lines changed

3 files changed

+97
-29
lines changed

CHANGELOG.rst

+15-5
Original file line numberDiff line numberDiff line change
@@ -16,22 +16,32 @@ Subsections for each version can be one of the following;
1616

1717
Each individual change should have a link to the pull request after the description of the change.
1818

19-
1.4.1 (unreleased)
19+
1.4.2 (unreleased)
2020
------------------
2121

2222
Changed
2323
^^^^^^^
2424

25-
- Refactored BaseImputer to utilise narwhals `#314 <https://github.com/lvgig/tubular/issues/314>_`
26-
- Converted test dfs to flexible pandas/polars setup
27-
- Converted BaseNominalTransformer to utilise narwhals `#334 <https://github.com/lvgig/tubular/issues/334>_`
28-
- narwhalified CheckNumericMixin `#336 <https://github.com/lvgig/tubular/issues/336>_`
2925
- placeholder
3026
- placeholder
3127
- placeholder
3228
- placeholder
3329
- placeholder
3430

31+
1.4.1 (02/12/2024)
32+
------------------
33+
34+
Changed
35+
^^^^^^^
36+
37+
- Refactored BaseImputer to utilise narwhals `#314 <https://github.com/lvgig/tubular/issues/314>_`
38+
- Converted test dfs to flexible pandas/polars setup
39+
- Converted BaseNominalTransformer to utilise narwhals `#334 <https://github.com/lvgig/tubular/issues/334>_`
40+
- narwhalified CheckNumericMixin `#336 <https://github.com/lvgig/tubular/issues/336>_`
41+
- Changed behaviour of NearestMeanResponseImputer so that if there are no nulls at fit,
42+
it warns and has no effect at transform, as opposed to erroring. The error was problematic for e.g.
43+
lightweight test runs where nulls are less likely to be present.
44+
3545
1.4.0 (2024-10-15)
3646
------------------
3747

tests/imputers/test_NearestMeanResponseImputer.py

+58-6
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import narwhals as nw
12
import numpy as np
23
import pytest
34

@@ -10,6 +11,7 @@
1011
from tests.imputers.test_BaseImputer import (
1112
GenericImputerTransformTests,
1213
)
14+
from tests.utils import assert_frame_equal_dispatch, dataframe_init_dispatch
1315
from tubular.imputers import NearestMeanResponseImputer
1416

1517

@@ -42,18 +44,23 @@ def test_null_values_in_response_error(self, library):
4244
transformer.fit(df, df["a"])
4345

4446
@pytest.mark.parametrize("library", ["pandas", "polars"])
45-
def test_columns_with_no_nulls_error(self, library):
46-
"""Test an error is raised if a non-response column contains no nulls."""
47+
def test_columns_with_no_nulls_warning(self, library):
48+
"""Test a warning is raised if a non-response column contains no nulls."""
4749
df = d.create_numeric_df_1(library=library)
4850

49-
transformer = NearestMeanResponseImputer(columns=["b", "c"])
51+
transformer = NearestMeanResponseImputer(columns=["c"])
5052

51-
with pytest.raises(
52-
ValueError,
53-
match="NearestMeanResponseImputer: Column b has no missing values, cannot use this transformer.",
53+
with pytest.warns(
54+
UserWarning,
55+
match="NearestMeanResponseImputer: Column c has no missing values, this transformer will have no effect for this column.",
5456
):
5557
transformer.fit(df, df["c"])
5658

59+
expected_impute_values = {"c": None}
60+
assert (
61+
transformer.impute_values_ == expected_impute_values
62+
), f"impute_values_ attr not as expected, expected {expected_impute_values} but got {transformer.impute_values_}"
63+
5764
@pytest.mark.parametrize("library", ["pandas", "polars"])
5865
def test_learnt_values(self, library):
5966
"""Test that the nearest response values learnt during fit are expected."""
@@ -78,3 +85,48 @@ class TestTransform(
7885
@classmethod
7986
def setup_class(cls):
8087
cls.transformer_name = "NearestMeanResponseImputer"
88+
89+
@pytest.mark.parametrize("library", ["pandas", "polars"])
90+
@pytest.mark.parametrize(
91+
("fit_col", "transform_col"),
92+
[
93+
# try a few types, with and without nulls in transform col
94+
([1, 2, 3], [1.0, np.nan, np.nan]),
95+
([4, 5, 6], [7, 8, 9]),
96+
(["a", "b", "c"], ["a", None, "d"]),
97+
(["c", "d", "e"], ["f", "g", "h"]),
98+
([4.0, 5.0, 6.0], [8.0, np.nan, 6.0]),
99+
([1.0, 2.0, 3.0], [4.0, 3.0, 2.0]),
100+
([True, False, False], [True, True, None]),
101+
([True, False, True], [True, False, True]),
102+
],
103+
)
104+
def test_no_effect_when_fit_on_null_free_col(self, fit_col, transform_col, library):
105+
"test that when transformer fits on a col with no nulls, transform has no effect"
106+
107+
df_fit_dict = {
108+
"a": fit_col,
109+
"b": [1] * len(fit_col),
110+
}
111+
112+
df_fit = dataframe_init_dispatch(df_fit_dict, library=library)
113+
114+
df_transform_dict = {
115+
"a": transform_col,
116+
}
117+
118+
df_transform = dataframe_init_dispatch(df_transform_dict, library=library)
119+
120+
transformer = NearestMeanResponseImputer(columns=["a"])
121+
122+
transformer.fit(df_fit, df_fit["b"])
123+
124+
df_transform = nw.from_native(df_transform)
125+
126+
expected_output = df_transform.clone().to_native()
127+
128+
df_transform = nw.to_native(df_transform)
129+
130+
output = transformer.transform(df_transform)
131+
132+
assert_frame_equal_dispatch(output, expected_output)

tubular/imputers.py

+24-18
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,10 @@ def transform(self, X: FrameT) -> FrameT:
5555
X = nw.from_native(super().transform(X))
5656

5757
new_col_expressions = [
58-
nw.col(c).fill_null(self.impute_values_[c]) for c in self.columns
58+
nw.col(c).fill_null(self.impute_values_[c])
59+
if self.impute_values_[c]
60+
else nw.col(c)
61+
for c in self.columns
5962
]
6063

6164
return X.with_columns(
@@ -424,7 +427,8 @@ class NearestMeanResponseImputer(BaseImputer):
424427
----------
425428
columns : None or str or list, default = None
426429
Columns to impute, if the default of None is supplied all columns in X are used
427-
when the transform method is called.
430+
when the transform method is called. If the column does not contain nulls at fit,
431+
a warning will be issues and this transformer will have no effect on that column.
428432
429433
Attributes
430434
----------
@@ -478,26 +482,28 @@ def fit(self, X: FrameT, y: nw.Series) -> FrameT:
478482
c_nulls = X.select(nw.col(c).is_null())[c]
479483

480484
if c_nulls.sum() == 0:
481-
msg = f"{self.classname()}: Column {c} has no missing values, cannot use this transformer."
482-
raise ValueError(msg)
485+
msg = f"{self.classname()}: Column {c} has no missing values, this transformer will have no effect for this column."
486+
warnings.warn(msg, stacklevel=2)
487+
self.impute_values_[c] = None
483488

484-
mean_response_by_levels = (
485-
X_y.filter(~c_nulls).group_by(c).agg(nw.col(response_column).mean())
486-
)
489+
else:
490+
mean_response_by_levels = (
491+
X_y.filter(~c_nulls).group_by(c).agg(nw.col(response_column).mean())
492+
)
487493

488-
mean_response_nulls = X_y.filter(c_nulls)[response_column].mean()
494+
mean_response_nulls = X_y.filter(c_nulls)[response_column].mean()
489495

490-
mean_response_by_levels = mean_response_by_levels.with_columns(
491-
(nw.col(response_column) - mean_response_nulls)
492-
.abs()
493-
.alias("abs_diff_response"),
494-
)
496+
mean_response_by_levels = mean_response_by_levels.with_columns(
497+
(nw.col(response_column) - mean_response_nulls)
498+
.abs()
499+
.alias("abs_diff_response"),
500+
)
495501

496-
# take first value having the minimum difference in terms of average response
497-
self.impute_values_[c] = mean_response_by_levels.filter(
498-
mean_response_by_levels["abs_diff_response"]
499-
== mean_response_by_levels["abs_diff_response"].min(),
500-
)[c].item(index=0)
502+
# take first value having the minimum difference in terms of average response
503+
self.impute_values_[c] = mean_response_by_levels.filter(
504+
mean_response_by_levels["abs_diff_response"]
505+
== mean_response_by_levels["abs_diff_response"].min(),
506+
)[c].item(index=0)
501507

502508
return self
503509

0 commit comments

Comments
 (0)