Skip to content

Commit 14e62a7

Browse files
committed
Added __getstate__ more appropriately, tests for pickling objects, and special tests for transform equivalence, new assert statement that makes the pickle a dict (not the most robust... but it works), added PickleError,
1 parent aea9128 commit 14e62a7

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+420
-156
lines changed

patsy/categorical.py

+12-17
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@
4747
pandas_Categorical_codes,
4848
safe_issubdtype,
4949
no_pickling, assert_no_pickling, check_pickle_version)
50-
from patsy.state import StatefulTransform
5150

5251
if have_pandas:
5352
import pandas
@@ -65,17 +64,14 @@ def __getstate__(self):
6564
data = getattr(self, 'data')
6665
contrast = getattr(self, 'contrast')
6766
levels = getattr(self, 'levels')
68-
return (0, data, contrast, levels)
67+
return {'version': 0, 'data': data, 'contrast': contrast,
68+
'levels': levels}
6969

7070
def __setstate__(self, pickle):
71-
version, data, contrast, levels = pickle
72-
check_pickle_version(version, 0, name=self.__class__.__name__)
73-
self.data = data
74-
self.contrast = contrast
75-
self.levels = levels
76-
77-
def __eq__(self, other):
78-
return self.__dict__ == other.__dict__
71+
check_pickle_version(pickle['version'], 0, self.__class__.__name__)
72+
self.data = pickle['data']
73+
self.contrast = pickle['contrast']
74+
self.levels = pickle['levels']
7975

8076

8177
def C(data, contrast=None, levels=None):
@@ -137,19 +133,18 @@ def test_C():
137133
assert c4.contrast == "NEW CONTRAST"
138134
assert c4.levels == "LEVELS"
139135

140-
# assert_no_pickling(c4)
141-
142136

143137
def test_C_pickle():
144138
from six.moves import cPickle as pickle
139+
from patsy.util import assert_pickled_equals
145140
c1 = C("asdf")
146-
assert c1 == pickle.loads(pickle.dumps(c1))
141+
assert_pickled_equals(c1, pickle.loads(pickle.dumps(c1)))
147142
c2 = C("DATA", "CONTRAST", "LEVELS")
148-
assert c2 == pickle.loads(pickle.dumps(c2))
143+
assert_pickled_equals(c2, pickle.loads(pickle.dumps(c2)))
149144
c3 = C(c2, levels="NEW LEVELS")
150-
assert c3 == pickle.loads(pickle.dumps(c3))
145+
assert_pickled_equals(c3, pickle.loads(pickle.dumps(c3)))
151146
c4 = C(c2, "NEW CONTRAST")
152-
assert c4 == pickle.loads(pickle.dumps(c4))
147+
assert_pickled_equals(c4, pickle.loads(pickle.dumps(c4)))
153148

154149

155150
def guess_categorical(data):
@@ -247,7 +242,7 @@ def sniff(self, data):
247242
# would be too. Otherwise we need to keep looking.
248243
return self._level_set == set([True, False])
249244

250-
# __getstate__ = no_pickling
245+
__getstate__ = no_pickling
251246

252247
def test_CategoricalSniffer():
253248
from patsy.missing import NAAction

patsy/constraint.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from patsy.util import (atleast_2d_column_default,
1919
repr_pretty_delegate, repr_pretty_impl,
2020
SortAnythingKey,
21-
no_pickling, assert_no_pickling)
21+
no_pickling, assert_no_pickling, check_pickle_version)
2222
from patsy.infix_parser import Token, Operator, ParseNode, infix_parse
2323

2424
class LinearConstraint(object):
@@ -65,10 +65,16 @@ def _repr_pretty_(self, p, cycle):
6565
return repr_pretty_impl(p, self,
6666
[self.variable_names, self.coefs, self.constants])
6767

68-
def __eq__(self, other):
69-
return self.__dict__ == other.__dict__
68+
def __getstate__(self):
69+
return {'version': 0, 'variable_names': self.variable_names,
70+
'coefs': self.coefs, 'constants': self.constants}
71+
72+
def __setstate__(self, pickle):
73+
check_pickle_version(pickle['version'], 0, self.__class__.__name__)
74+
self.variable_names = pickle['variable_names']
75+
self.coefs = pickle['coefs']
76+
self.constants = pickle['constants']
7077

71-
# __getstate__ = no_pickling
7278

7379
@classmethod
7480
def combine(cls, constraints):
@@ -121,8 +127,6 @@ def test_LinearConstraint():
121127
assert_raises(ValueError, LinearConstraint, ["a", "b"],
122128
np.zeros((0, 2)))
123129

124-
# assert_no_pickling(lc)
125-
126130
def test_LinearConstraint_combine():
127131
comb = LinearConstraint.combine([LinearConstraint(["a", "b"], [1, 0]),
128132
LinearConstraint(["a", "b"], [0, 1], [1])])

patsy/contrasts.py

-1
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@ def test_ContrastMatrix():
7575
from nose.tools import assert_raises
7676
assert_raises(PatsyError, ContrastMatrix, [[1], [0]], ["a", "b"])
7777

78-
# assert_no_pickling(cm)
7978

8079
# This always produces an object of the type that Python calls 'str' (whether
8180
# that be a Python 2 string-of-bytes or a Python 3 string-of-unicode). It does

patsy/desc.py

+25-20
Original file line numberDiff line numberDiff line change
@@ -66,14 +66,12 @@ def name(self):
6666
return "Intercept"
6767

6868
def __getstate__(self):
69-
return (0, self.factors)
69+
return {'version': 0, 'factors': self.factors}
7070

7171
def __setstate__(self, pickle):
72-
version, factors = pickle
73-
check_pickle_version(version, 0, name=self.__class__.__name__)
74-
self.factors = factors
72+
check_pickle_version(pickle['version'], 0, self.__class__.__name__)
73+
self.factors = pickle['factors']
7574

76-
# __getstate__ = no_pickling
7775

7876
INTERCEPT = Term([])
7977

@@ -85,12 +83,6 @@ def __init__(self, name):
8583
def name(self):
8684
return self._name
8785

88-
def __eq__(self, other):
89-
return self.__dict__ == other.__dict__
90-
91-
def __hash__(self):
92-
return hash((_MockFactor, str(self._name)))
93-
9486

9587
def test_Term():
9688
assert Term([1, 2, 1]).factors == (1, 2)
@@ -102,11 +94,12 @@ def test_Term():
10294
assert Term([f2, f1]).name() == "b:a"
10395
assert Term([]).name() == "Intercept"
10496

105-
# assert_no_pickling(Term([]))
106-
10797
from six.moves import cPickle as pickle
98+
from patsy.util import assert_pickled_equals
10899
t = Term([f1, f2])
109-
assert t == pickle.loads(pickle.dumps(t, pickle.HIGHEST_PROTOCOL))
100+
t2 = pickle.loads(pickle.dumps(t, pickle.HIGHEST_PROTOCOL))
101+
assert_pickled_equals(t, t2)
102+
110103

111104
class ModelDesc(object):
112105
"""A simple container representing the termlists parsed from a formula.
@@ -168,7 +161,7 @@ def term_code(term):
168161
if term != INTERCEPT]
169162
result += " + ".join(term_names)
170163
return result
171-
164+
172165
@classmethod
173166
def from_formula(cls, tree_or_string):
174167
"""Construct a :class:`ModelDesc` from a formula string.
@@ -186,10 +179,15 @@ def from_formula(cls, tree_or_string):
186179
assert isinstance(value, cls)
187180
return value
188181

189-
def __eq__(self, other):
190-
return self.__dict__ == other.__dict__
182+
def __getstate__(self):
183+
return {'version': 0, 'lhs_termlist': self.lhs_termlist,
184+
'rhs_termlist': self.rhs_termlist}
185+
186+
def __setstate__(self, pickle):
187+
check_pickle_version(pickle['version'], 0, self.__class__.__name__)
188+
self.lhs_termlist = pickle['lhs_termlist']
189+
self.rhs_termlist = pickle['rhs_termlist']
191190

192-
# __getstate__ = no_pickling
193191

194192
def test_ModelDesc():
195193
f1 = _MockFactor("a")
@@ -202,7 +200,9 @@ def test_ModelDesc():
202200

203201
# assert_no_pickling(m)
204202
from six.moves import cPickle as pickle
205-
assert m == pickle.loads(pickle.dumps(m, pickle.HIGHEST_PROTOCOL))
203+
from patsy.util import assert_pickled_equals
204+
m2 = pickle.loads(pickle.dumps(m, pickle.HIGHEST_PROTOCOL))
205+
assert_pickled_equals(m, m2)
206206

207207
assert ModelDesc([], []).describe() == "~ 0"
208208
assert ModelDesc([INTERCEPT], []).describe() == "1 ~ 0"
@@ -234,7 +234,12 @@ def _pretty_repr_(self, p, cycle): # pragma: no cover
234234
[self.intercept, self.intercept_origin,
235235
self.intercept_removed, self.terms])
236236

237-
# __getstate__ = no_pickling
237+
__getstate__ = no_pickling
238+
239+
240+
def test_IntermediateExpr_smoke():
241+
assert_no_pickling(IntermediateExpr(False, None, True, []))
242+
238243

239244
def _maybe_add_intercept(doit, terms):
240245
if doit:

patsy/design_info.py

+39-28
Original file line numberDiff line numberDiff line change
@@ -121,16 +121,18 @@ def __repr__(self):
121121
kwlist.append(("categories", self.categories))
122122
repr_pretty_impl(p, self, [], kwlist)
123123

124-
def __eq__(self, other):
125-
return self.__dict__ == other.__dict__
124+
def __getstate__(self):
125+
return {'version': 0, 'factor': self.factor, 'type': self.type,
126+
'state': self.state, 'num_columns': self.num_columns,
127+
'categories': self.categories}
126128

127-
def __hash__(self):
128-
if not self.categories:
129-
categories = 'NoCategories'
130-
else:
131-
categories = frozenset(self.categories)
132-
return hash((FactorInfo, str(self.factor), str(self.type),
133-
str(self.state), str(self.num_columns), categories))
129+
def __setstate__(self, pickle):
130+
check_pickle_version(pickle['version'], 0, self.__class__.__name__)
131+
self.factor = pickle['factor']
132+
self.type = pickle['type']
133+
self.state = pickle['state']
134+
self.num_columns = pickle['num_columns']
135+
self.categories = pickle['categories']
134136

135137

136138
def test_FactorInfo():
@@ -245,10 +247,17 @@ def _repr_pretty_(self, p, cycle):
245247
("contrast_matrices", self.contrast_matrices),
246248
("num_columns", self.num_columns)])
247249

248-
def __eq__(self, other):
249-
return self.__dict__ == other.__dict__
250+
def __getstate__(self):
251+
return {'version': 0, 'factors': self.factors,
252+
'contrast_matrices': self.contrast_matrices,
253+
'num_columns': self.num_columns}
254+
255+
def __setstate__(self, pickle):
256+
check_pickle_version(pickle['version'], 0, self.__class__.__name__)
257+
self.factors = pickle['factors']
258+
self.contrast_matrices = pickle['contrast_matrices']
259+
self.num_columns = pickle['num_columns']
250260

251-
# __getstate__ = no_pickling
252261

253262
def test_SubtermInfo():
254263
cm = ContrastMatrix(np.ones((2, 2)), ["[1]", "[2]"])
@@ -706,21 +715,19 @@ def from_array(cls, array_like, default_column_prefix="column"):
706715
return DesignInfo(column_names)
707716

708717
def __getstate__(self):
709-
return (0, self.column_name_indexes, self.factor_infos,
710-
self.term_codings, self.term_slices, self.term_name_slices)
718+
return {'version': 0, 'column_name_indexes': self.column_name_indexes,
719+
'factor_infos': self.factor_infos,
720+
'term_codings': self.term_codings,
721+
'term_slices': self.term_slices,
722+
'term_name_slices': self.term_name_slices}
711723

712724
def __setstate__(self, pickle):
713-
(version, column_name_indexes, factor_infos, term_codings,
714-
term_slices, term_name_slices) = pickle
715-
check_pickle_version(version, 0, self.__class__.__name__)
716-
self.column_name_indexes = column_name_indexes
717-
self.factor_infos = factor_infos
718-
self.term_codings = term_codings
719-
self.term_slices = term_slices
720-
self.term_name_slices = term_name_slices
721-
722-
def __eq__(self, other):
723-
return self.__dict__ == other.__dict__
725+
check_pickle_version(pickle['version'], 0, self.__class__.__name__)
726+
self.column_name_indexes = pickle['column_name_indexes']
727+
self.factor_infos = pickle['factor_infos']
728+
self.term_codings = pickle['term_codings']
729+
self.term_slices = pickle['term_slices']
730+
self.term_name_slices = pickle['term_name_slices']
724731

725732

726733
class _MockFactor(object):
@@ -772,9 +779,12 @@ def test_DesignInfo():
772779

773780
# smoke test
774781
repr(di)
775-
from six.moves import cPickle as pickle
776782

777-
assert di == pickle.loads(pickle.dumps(di, pickle.HIGHEST_PROTOCOL))
783+
# Pickling check
784+
from six.moves import cPickle as pickle
785+
from patsy.util import assert_pickled_equals
786+
di2 = pickle.loads(pickle.dumps(di, pickle.HIGHEST_PROTOCOL))
787+
assert_pickled_equals(di, di2)
778788

779789
# One without term objects
780790
di = DesignInfo(["a1", "a2", "a3", "b"])
@@ -795,7 +805,8 @@ def test_DesignInfo():
795805
assert di.slice("a3") == slice(2, 3)
796806
assert di.slice("b") == slice(3, 4)
797807

798-
assert di == pickle.loads(pickle.dumps(di, pickle.HIGHEST_PROTOCOL))
808+
di2 = pickle.loads(pickle.dumps(di, pickle.HIGHEST_PROTOCOL))
809+
assert_pickled_equals(di, di2)
799810

800811
# Check intercept handling in describe()
801812
assert DesignInfo(["Intercept", "a", "b"]).describe() == "1 + a + b"

patsy/eval.py

+9-15
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
# for __future__ flags!
1111

1212
# These are made available in the patsy.* namespace
13-
__all__ = ["EvalEnvironment", "EvalFactor", "VarLookupDict"]
13+
__all__ = ["EvalEnvironment", "EvalFactor"]
1414

1515
import sys
1616
import __future__
@@ -62,9 +62,6 @@ def __contains__(self, key):
6262
else:
6363
return True
6464

65-
def __eq__(self, other):
66-
return self.__dict__ == other.__dict__
67-
6865
def get(self, key, default=None):
6966
try:
7067
return self[key]
@@ -98,7 +95,6 @@ def test_VarLookupDict():
9895
assert ds.get("c") is None
9996
assert isinstance(repr(ds), six.string_types)
10097

101-
# assert_no_pickling(ds)
10298

10399
def ast_names(code):
104100
"""Iterator that yields all the (ast) names in a Python expression.
@@ -255,8 +251,7 @@ def _namespace_ids(self):
255251
def __eq__(self, other):
256252
return (isinstance(other, EvalEnvironment)
257253
and self.flags == other.flags
258-
and self.namespace == other.namespace)
259-
# and self._namespace_ids() == other._namespace_ids())
254+
and self._namespace_ids() == other._namespace_ids())
260255

261256
def __ne__(self, other):
262257
return not self == other
@@ -382,7 +377,6 @@ def test_EvalEnvironment_capture_namespace():
382377

383378
assert_raises(TypeError, EvalEnvironment.capture, 1.2)
384379

385-
# assert_no_pickling(EvalEnvironment.capture())
386380

387381
def test_EvalEnvironment_capture_flags():
388382
if sys.version_info >= (3,):
@@ -649,15 +643,15 @@ def eval(self, memorize_state, data):
649643
data)
650644

651645
def __getstate__(self):
652-
return (0, self.code, self.origin)
646+
return {'version': 0, 'code': self.code, 'origin': self.origin}
653647

654-
def __setstate__(self, state):
655-
(version, code, origin) = state
656-
check_pickle_version(version, 0, self.__class__.__name__)
657-
self.code = code
658-
self.origin = origin
648+
def __setstate__(self, pickle):
649+
check_pickle_version(pickle['version'], 0, self.__class__.__name__)
650+
self.code = pickle['code']
651+
self.origin = pickle['origin']
659652

660653
def test_EvalFactor_pickle_saves_origin():
654+
from patsy.util import assert_pickled_equals
661655
# The pickling tests use object equality before and after pickling
662656
# to test that pickling worked correctly. But EvalFactor's origin field
663657
# is not used in equality comparisons, so we need a separate test to
@@ -667,7 +661,7 @@ def test_EvalFactor_pickle_saves_origin():
667661
new_f = pickle.loads(pickle.dumps(f))
668662

669663
assert f.origin is not None
670-
assert f.origin == new_f.origin
664+
assert_pickled_equals(f, new_f)
671665

672666
def test_EvalFactor_basics():
673667
e = EvalFactor("a+b")

0 commit comments

Comments
 (0)