-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlda_model.py
230 lines (188 loc) · 7.84 KB
/
lda_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
# Copyright 2021 Yifei Ma
# with references from "sklearn.decomposition.LatentDirichletAllocation"
# with the following original authors:
# * Chyi-Kwei Yau (the said scikit-learn implementation)
# * Matthew D. Hoffman (original onlineldavb implementation)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os, functools, warnings
import numpy as np, scipy as sp
import torch
import dgl
from dgl import function as fn
def _lbeta(alpha, axis):
return torch.lgamma(alpha).sum(axis) - torch.lgamma(alpha.sum(axis))
# Taken from scikit-learn. Worked better than uniform.
# Perhaps this is due to concentration around one.
_sklearn_random_init = torch.distributions.gamma.Gamma(100, 100)
def _edge_update(edges, step_size=1):
""" the Gibbs posterior distribution of z propto theta*beta.
As step_size -> infty, the result becomes MAP estimate on the dst nodes.
"""
q = edges.src['weight'] * edges.dst['weight']
marg = q.sum(axis=1, keepdims=True) + np.finfo(float).eps
p = q / marg
return {
'z': p * step_size,
'edge_elbo': marg.squeeze(1).log() * step_size,
}
def _weight_exp(z, ntype, prior):
"""Node weight is approximately normalized for VB along the ntype
direction.
"""
prior = prior + z * 0 # convert numpy to torch
gamma = prior + z
axis = 1 if ntype == 'doc' else 0 # word
Elog = torch.digamma(gamma) - torch.digamma(gamma.sum(axis, keepdims=True))
return Elog.exp()
def _node_update(nodes, prior):
return {
'z': nodes.data['z'],
'weight': _weight_exp(nodes.data['z'], nodes.ntype, prior)
}
def _update_all(G, ntype, prior, step_size, return_obj=False):
"""Follows Eq (5) of Hoffman et al., 2010.
"""
G_prop = G.reverse() if ntype == 'doc' else G # word
msg_fn = lambda edges: _edge_update(edges, step_size)
node_fn = lambda nodes: _node_update(nodes, prior)
G_prop.update_all(msg_fn, fn.sum('z','z'), node_fn)
if return_obj:
G_prop.update_all(msg_fn, fn.sum('edge_elbo', 'edge_elbo'))
G.nodes[ntype].data.update(G_prop.nodes[ntype].data)
return dict(G.nodes[ntype].data)
class LatentDirichletAllocation:
"""LDA model that works with a HeteroGraph with doc/word node types.
The model alters the attributes of G arbitrarily,
but always load word_z if needed.
This is inspired by [1] and its corresponding scikit-learn implementation.
References
---
[1] Matthew Hoffman, Francis Bach, David Blei. Online Learning for Latent
Dirichlet Allocation. Advances in Neural Information Processing Systems 23
(NIPS 2010).
[2] Reactive LDA Library blogpost by Yingjie Miao for a similar Gibbs model
"""
def __init__(
self, G, n_components,
prior=None,
step_size={'doc': 1, 'word': 1}, # use larger value to get MAP
word_rho=1, # use smaller value for online update
verbose=True,
):
self.n_components = n_components
if prior is None:
prior = {'doc': 1./n_components, 'word': 1./n_components}
self.prior = prior
self.step_size = step_size
self.word_rho = word_rho
self.word_z = self._load_or_init(G, 'word')['z']
self.verbose = verbose
def _load_or_init(self, G, ntype, z=None):
if z is None:
z = _sklearn_random_init.sample(
(G.num_nodes(ntype), self.n_components)
).to(G.device)
G.nodes[ntype].data['z'] = z
G.apply_nodes(
lambda nodes: _node_update(nodes, self.prior[ntype]),
ntype=ntype)
return dict(G.nodes[ntype].data)
def _e_step(self, G, reinit_doc=True, mean_change_tol=1e-3, max_iters=100):
"""_e_step implements doc data sampling until convergence or max_iters
"""
self._load_or_init(G, 'word', self.word_z)
if reinit_doc or ('weight' not in G.nodes['doc'].data):
self._load_or_init(G, 'doc')
for i in range(max_iters):
doc_z = dict(G.nodes['doc'].data)['z']
doc_data = _update_all(
G, 'doc', self.prior['doc'], self.step_size['doc'])
mean_change = (doc_data['z'] - doc_z).abs().mean()
if mean_change < mean_change_tol:
break
if self.verbose:
print(f'e-step num_iters={i+1} with mean_change={mean_change:.4f}')
return doc_data
transform = _e_step
def _m_step(self, G):
"""_m_step implements word data sampling and stores word_z stats
"""
# assume G.nodes['doc'].data has been up to date
word_data = _update_all(
G, 'word', self.prior['word'], self.step_size['word'])
# online update
self.word_z = (
(1-self.word_rho) * self.word_z
+self.word_rho * word_data['z']
)
return word_data
def partial_fit(self, G):
self._last_word_z = self.word_z
self._e_step(G)
self._m_step(G)
return self
def fit(self, G, mean_change_tol=1e-3, max_epochs=10):
for i in range(max_epochs):
self.partial_fit(G)
mean_change = (self.word_z - self._last_word_z).abs().mean()
if self.verbose:
print(f'epoch {i+1}, '
f'perplexity: {self.perplexity(G, False)}, '
f'mean_change: {mean_change:.4f}')
if mean_change < mean_change_tol:
break
return self
def perplexity(self, G, reinit_doc=True):
"""ppl = exp{-sum[log(p(w1,...,wn|d))] / n}
Follows Eq (15) in Hoffman et al., 2010.
"""
word_data = self._load_or_init(G, 'word', self.word_z)
if reinit_doc or ('weight' not in G.nodes['doc'].data):
self._e_step(G, reinit_doc)
doc_data = _update_all(
G, 'doc', self.prior['doc'], self.step_size['doc'],
return_obj=True)
# compute E[log p(docs | theta, beta)]
edge_elbo = (
doc_data['edge_elbo'].sum() / doc_data['z'].sum()
).cpu().numpy()
if self.verbose:
print(f'neg_elbo phi: {-edge_elbo:.3f}', end=' ')
# compute E[log p(theta | alpha) - log q(theta | gamma)]
doc_elbo = (
(-doc_data['z'] * doc_data['weight'].log()).sum(axis=1)
-_lbeta(self.prior['doc'] + doc_data['z'] * 0, axis=1)
+_lbeta(self.prior['doc'] + doc_data['z'], axis=1)
)
doc_elbo = (doc_elbo.sum() / doc_data['z'].sum()).cpu().numpy()
if self.verbose:
print(f'theta: {-doc_elbo:.3f}', end=' ')
# compute E[log p(beta | eta) - log q (beta | lambda)]
word_elbo = (
(-word_data['z'] * word_data['weight'].log()).sum(axis=0)
-_lbeta(self.prior['word'] + word_data['z'] * 0, axis=0)
+_lbeta(self.prior['word'] + word_data['z'], axis=0)
)
word_elbo = (word_elbo.sum() / word_data['z'].sum()).cpu().numpy()
if self.verbose:
print(f'beta: {-word_elbo:.3f}')
return np.exp(-edge_elbo - doc_elbo - word_elbo)
if __name__ == '__main__':
print('Testing LatentDirichletAllocation via task_example_test.sh ...')
tf_uv = np.array(np.nonzero(np.random.rand(20,10)<0.5)).T
G = dgl.heterograph({('doc','topic','word'): tf_uv.tolist()})
model = LatentDirichletAllocation(G, 5, verbose=False)
model.fit(G)
model.transform(G)
model.perplexity(G)