Skip to content

Commit 97f7c0f

Browse files
Merge pull request #5 from parsing-science/Fix-some-errors-in-v2
Use bytes with joblib and remove output variable from create_model
2 parents a98e1ed + 4c96f1c commit 97f7c0f

File tree

4 files changed

+9
-7
lines changed

4 files changed

+9
-7
lines changed

.gitignore

+2
Original file line numberDiff line numberDiff line change
@@ -89,3 +89,5 @@ ENV/
8989
# Rope project settings
9090
.ropeproject
9191

92+
# HLM models
93+
HLM_jar/

ps_toolkit/pymc3_models/HLM.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def create_model(self):
2020
"""
2121
Creates and returns the PyMC3 model.
2222
23-
Returns the model and the output variable. The latter is for use in ADVI minibatch.
23+
Returns the model.
2424
"""
2525
model_input = theano.shared(np.zeros([1, self.num_pred]))
2626

@@ -54,7 +54,7 @@ def create_model(self):
5454

5555
o = pm.Bernoulli('o', p, observed=model_output)
5656

57-
return model, o
57+
return model
5858

5959
def fit(self, X, y, cats, n=200000, batch_size=100):
6060
"""
@@ -76,7 +76,7 @@ def fit(self, X, y, cats, n=200000, batch_size=100):
7676
num_samples, self.num_pred = X.shape
7777

7878
if self.cached_model is None:
79-
self.cached_model, o = self.create_model()
79+
self.cached_model = self.create_model()
8080

8181
with self.cached_model:
8282

@@ -109,7 +109,7 @@ def predict_proba(self, X, cats, return_std=False):
109109
num_samples = X.shape[0]
110110

111111
if self.cached_model is None:
112-
self.cached_model, o = self.create_model()
112+
self.cached_model = self.create_model()
113113

114114
self._set_shared_vars({'model_input': X, 'model_output': np.zeros(num_samples), 'model_cats': cats})
115115

ps_toolkit/pymc3_models/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -102,12 +102,12 @@ def save(self, file_prefix, custom_params=None):
102102
103103
custom_params: Dictionary of custom parameters to save. Defaults to None
104104
"""
105-
fileObject = open(file_prefix + 'advi_trace.pickle', 'w')
105+
fileObject = open(file_prefix + 'advi_trace.pickle', 'wb')
106106
joblib.dump(self.advi_trace, fileObject)
107107
fileObject.close()
108108

109109
if custom_params:
110-
fileObject = open(file_prefix + 'params.pickle', 'w')
110+
fileObject = open(file_prefix + 'params.pickle', 'wb')
111111
joblib.dump(custom_params, fileObject)
112112
fileObject.close()
113113

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
setup(
55
name='PS_Toolkit',
6-
version='2.0.0',
6+
version='2.0.1',
77
packages=find_packages(),
88
include_package_data=False,
99
zip_safe=False,

0 commit comments

Comments
 (0)