Skip to content

Commit a985254

Browse files
authored
Merge pull request #329 from lvgig/feature/tests_dtseries
refactored SeriesDtTransformer tests
2 parents b885fc7 + 06e553e commit a985254

File tree

4 files changed

+64
-122
lines changed

4 files changed

+64
-122
lines changed

CHANGELOG.rst

+1
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ Changed
4343
- fixed a bug in CappingTransformer which was preventing use of .get_params method `#311 <https://github.com/lvgig/tubular/issues/311>`_
4444
- Setup requirements for narwhals, remove python3.8 from our build pipelines as incompatible with polars
4545
- Refactored ToDatetimeTransformer tests in new format `#300 <https://github.com/lvgig/tubular/issues/300>`_
46+
- Refactors tests for SeriesDtMethodTransformer in new format. Changed column arg to columns to fit generic format. `#299 <https://github.com/lvgig/tubular/issues/299>_`
4647
- Refactored OrdinalEncoderTransformer tests in new format `#330 <https://github.com/lvgig/tubular/issues/330>`_
4748

4849

tests/conftest.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -250,9 +250,9 @@ def minimal_attribute_dict():
250250
"columns": ["a", "b"],
251251
},
252252
"SeriesDtMethodTransformer": {
253-
"new_column_name": "a",
253+
"new_column_name": "new_column",
254254
"pd_method_name": "month",
255-
"column": "b",
255+
"columns": "b",
256256
},
257257
"SeriesStrMethodTransformer": {
258258
"columns": ["b"],
+52-113
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,56 @@
1+
import re
2+
13
import numpy as np
24
import pytest
35
import test_aide as ta
46

57
import tests.test_data as d
8+
from tests.base_tests import (
9+
ColumnStrListInitTests,
10+
DropOriginalInitMixinTests,
11+
DropOriginalTransformMixinTests,
12+
GenericTransformTests,
13+
NewColumnNameInitMixintests,
14+
OtherBaseBehaviourTests,
15+
)
16+
from tests.dates.test_BaseDatetimeTransformer import (
17+
DatetimeMixinTransformTests,
18+
)
619
from tubular.dates import SeriesDtMethodTransformer
720

821

9-
class TestInit:
22+
class TestInit(
23+
ColumnStrListInitTests,
24+
DropOriginalInitMixinTests,
25+
NewColumnNameInitMixintests,
26+
):
1027
"""Tests for SeriesDtMethodTransformer.init()."""
1128

29+
@classmethod
30+
def setup_class(cls):
31+
cls.transformer_name = "SeriesDtMethodTransformer"
32+
1233
def test_invalid_input_type_errors(self):
1334
"""Test that an exceptions are raised for invalid input types."""
35+
bad_columns = ["b", "c"]
1436
with pytest.raises(
15-
TypeError,
16-
match=r"SeriesDtMethodTransformer: column should be a str but got \<class 'list'\>",
37+
ValueError,
38+
match=rf"SeriesDtMethodTransformer: column should be a str or list of len 1, got {re.escape(str(bad_columns))}",
1739
):
1840
SeriesDtMethodTransformer(
1941
new_column_name="a",
2042
pd_method_name=1,
21-
column=["b", "c"],
43+
columns=bad_columns,
2244
)
2345

2446
with pytest.raises(
2547
TypeError,
2648
match=r"SeriesDtMethodTransformer: unexpected type \(\<class 'int'\>\) for pd_method_name, expecting str",
27-
):
28-
SeriesDtMethodTransformer(new_column_name="a", pd_method_name=1, column="b")
29-
30-
with pytest.raises(
31-
TypeError,
32-
match=r"SeriesDtMethodTransformer: new_column_name should be str",
3349
):
3450
SeriesDtMethodTransformer(
35-
new_column_name=1.0,
36-
pd_method_name="year",
37-
column="b",
51+
new_column_name="a",
52+
pd_method_name=1,
53+
columns="b",
3854
)
3955

4056
with pytest.raises(
@@ -44,7 +60,7 @@ def test_invalid_input_type_errors(self):
4460
SeriesDtMethodTransformer(
4561
new_column_name="a",
4662
pd_method_name="year",
47-
column="b",
63+
columns="b",
4864
pd_method_kwargs=1,
4965
)
5066

@@ -55,7 +71,7 @@ def test_invalid_input_type_errors(self):
5571
SeriesDtMethodTransformer(
5672
new_column_name="a",
5773
pd_method_name="year",
58-
column="b",
74+
columns="b",
5975
pd_method_kwargs={"a": 1, 2: "b"},
6076
)
6177

@@ -68,52 +84,21 @@ def test_exception_raised_non_pandas_method_passed(self):
6884
SeriesDtMethodTransformer(
6985
new_column_name="a",
7086
pd_method_name="b",
71-
column="b",
87+
columns="b",
7288
)
7389

74-
def test_attributes_set(self):
75-
"""Test that the values passed for new_column_name, pd_method_name are saved to attributes on the object."""
76-
x = SeriesDtMethodTransformer(
77-
new_column_name="a",
78-
pd_method_name="year",
79-
column="b",
80-
pd_method_kwargs={"d": 1},
81-
)
82-
83-
ta.classes.test_object_attributes(
84-
obj=x,
85-
expected_attributes={
86-
"column": "b",
87-
"new_column_name": "a",
88-
"pd_method_name": "year",
89-
"pd_method_kwargs": {"d": 1},
90-
},
91-
msg="Attributes for SeriesDtMethodTransformer set in init",
92-
)
9390

94-
@pytest.mark.parametrize(
95-
("pd_method_name", "callable_attr"),
96-
[("year", False), ("to_period", True)],
97-
)
98-
def test_callable_attribute_set(self, pd_method_name, callable_attr):
99-
"""Test the _callable attribute is set to True if pd.Series.dt.pd_method_name is callable."""
100-
x = SeriesDtMethodTransformer(
101-
new_column_name="a",
102-
pd_method_name=pd_method_name,
103-
column="b",
104-
pd_method_kwargs={"d": 1},
105-
)
106-
107-
ta.classes.test_object_attributes(
108-
obj=x,
109-
expected_attributes={"_callable": callable_attr},
110-
msg="_callable attribute for SeriesDtMethodTransformer set in init",
111-
)
112-
113-
114-
class TestTransform:
91+
class TestTransform(
92+
DatetimeMixinTransformTests,
93+
DropOriginalTransformMixinTests,
94+
GenericTransformTests,
95+
):
11596
"""Tests for SeriesDtMethodTransformer.transform()."""
11697

98+
@classmethod
99+
def setup_class(cls):
100+
cls.transformer_name = "SeriesDtMethodTransformer"
101+
117102
def expected_df_1():
118103
"""Expected output of test_expected_output_no_overwrite."""
119104
df = d.create_datediff_test_df()
@@ -144,31 +129,6 @@ def expected_df_3():
144129

145130
return df
146131

147-
@pytest.mark.parametrize(
148-
("bad_column", "bad_type"),
149-
[
150-
("numeric_col", "int64"),
151-
("string_col", "object"),
152-
("bool_col", "bool"),
153-
("empty_col", "object"),
154-
("date_col", "date"),
155-
],
156-
)
157-
def test_input_data_check_column_errors(self, bad_column, bad_type):
158-
"""Check that errors are raised on a variety of different non datatypes"""
159-
x = SeriesDtMethodTransformer(
160-
new_column_name="a2",
161-
pd_method_name="year",
162-
column=bad_column,
163-
)
164-
165-
df = d.create_date_diff_incorrect_dtypes()
166-
167-
msg = rf"{x.classname()}: {x.columns[0]} type should be in \['datetime64'\] but got {bad_type}"
168-
169-
with pytest.raises(TypeError, match=msg):
170-
x.transform(df)
171-
172132
@pytest.mark.parametrize(
173133
("df", "expected"),
174134
ta.pandas.adjusted_dataframe_params(
@@ -181,7 +141,7 @@ def test_expected_output_no_overwrite(self, df, expected):
181141
x = SeriesDtMethodTransformer(
182142
new_column_name="a_year",
183143
pd_method_name="year",
184-
column="a",
144+
columns="a",
185145
pd_method_kwargs=None,
186146
)
187147

@@ -205,7 +165,7 @@ def test_expected_output_overwrite(self, df, expected):
205165
x = SeriesDtMethodTransformer(
206166
new_column_name="a",
207167
pd_method_name="year",
208-
column="a",
168+
columns="a",
209169
pd_method_kwargs=None,
210170
)
211171

@@ -229,7 +189,7 @@ def test_expected_output_callable(self, df, expected):
229189
x = SeriesDtMethodTransformer(
230190
new_column_name="b_new",
231191
pd_method_name="to_period",
232-
column="b",
192+
columns="b",
233193
pd_method_kwargs={"freq": "M"},
234194
)
235195

@@ -241,35 +201,14 @@ def test_expected_output_callable(self, df, expected):
241201
msg_tag="Unexpected values in SeriesDtMethodTransformer.transform with to_period",
242202
)
243203

244-
def test_attributes_unchanged_by_transform(self):
245-
"""Test that attributes set in init are unchanged by the transform method."""
246-
df = d.create_datediff_test_df()
247204

248-
x = SeriesDtMethodTransformer(
249-
new_column_name="b_new",
250-
pd_method_name="to_period",
251-
column="b",
252-
pd_method_kwargs={"freq": "M"},
253-
)
205+
class TestOtherBaseBehaviour(OtherBaseBehaviourTests):
206+
"""
207+
Class to run tests for BaseTransformerBehaviour outside the three standard methods.
254208
255-
x2 = SeriesDtMethodTransformer(
256-
new_column_name="b_new",
257-
pd_method_name="to_period",
258-
column="b",
259-
pd_method_kwargs={"freq": "M"},
260-
)
209+
May need to overwite specific tests in this class if the tested transformer modifies this behaviour.
210+
"""
261211

262-
x.transform(df)
263-
264-
assert (
265-
x.new_column_name == x2.new_column_name
266-
), "new_column_name changed by SeriesDtMethodTransformer.transform"
267-
assert (
268-
x.pd_method_name == x2.pd_method_name
269-
), "pd_method_name changed by SeriesDtMethodTransformer.transform"
270-
assert (
271-
x.columns == x2.columns
272-
), "columns changed by SeriesDtMethodTransformer.transform"
273-
assert (
274-
x.pd_method_kwargs == x2.pd_method_kwargs
275-
), "pd_method_kwargs changed by SeriesDtMethodTransformer.transform"
212+
@classmethod
213+
def setup_class(cls):
214+
cls.transformer_name = "SeriesDtMethodTransformer"

tubular/dates.py

+9-7
Original file line numberDiff line numberDiff line change
@@ -591,22 +591,24 @@ def __init__(
591591
self,
592592
new_column_name: str,
593593
pd_method_name: str,
594-
column: str,
594+
columns: list[str],
595595
pd_method_kwargs: dict[str, object] | None = None,
596596
drop_original: bool = False,
597597
**kwargs: dict[str, bool],
598598
) -> None:
599-
if type(column) is not str:
600-
msg = f"{self.classname()}: column should be a str but got {type(column)}"
601-
raise TypeError(msg)
602-
603599
super().__init__(
604-
columns=[column],
600+
columns=columns,
605601
new_column_name=new_column_name,
606602
drop_original=drop_original,
607603
**kwargs,
608604
)
609605

606+
if len(self.columns) > 1:
607+
msg = rf"{self.classname()}: column should be a str or list of len 1, got {self.columns}"
608+
raise ValueError(
609+
msg,
610+
)
611+
610612
if type(pd_method_name) is not str:
611613
msg = f"{self.classname()}: unexpected type ({type(pd_method_name)}) for pd_method_name, expecting str"
612614
raise TypeError(msg)
@@ -644,7 +646,7 @@ def __init__(
644646

645647
# This attribute is not for use in any method, use 'columns' instead.
646648
# Here only as a fix to allow string representation of transformer.
647-
self.column = column
649+
self.column = self.columns[0]
648650

649651
def transform(self, X: pd.DataFrame) -> pd.DataFrame:
650652
"""Transform specific column on input pandas.DataFrame (X) using the given pandas.Series.dt method and

0 commit comments

Comments
 (0)