Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue with hessian computation in vLEM for mp_srlds #171

Open
XiaoliangWang2001 opened this issue Jan 9, 2025 · 0 comments
Open

Issue with hessian computation in vLEM for mp_srlds #171

XiaoliangWang2001 opened this issue Jan 9, 2025 · 0 comments

Comments

@XiaoliangWang2001
Copy link

Hi, thank you for the great package.

I am working with the transition module of mp_srlds and came across the code:

Ez = np.sum(expected_joints, axis=2) # marginal over z from T=1 to T-1
for k1 in range(self.K):
for k2 in range(self.K):
vtilde = vtildes[:,k1,k2][:,None] # SWAP?
#Sticky terms
if k1==k2:
Rv = vtilde@self.Ss[k2:k2+1,:]
hess += Ez[k1,k2] * \
( np.einsum('tn, ni, nj ->tij', -vtilde, self.Ss[k2:k2+1,:], self.Ss[k2:k2+1,:]) \
+ np.einsum('ti, tj -> tij', Rv, Rv))
#Switching terms
else:
Rv = vtilde@self.Rs[k2:k2+1,:]
hess += Ez[k1,k2] * \
( np.einsum('tn, ni, nj ->tij', -vtilde, self.Rs[k2:k2+1,:], self.Rs[k2:k2+1,:]) \
+ np.einsum('ti, tj -> tij', Rv, Rv))

where on line 89 Ez was indexed by k1 and k2. However on line 82:
Ez = np.sum(expected_joints, axis=2) # marginal over z from T=1 to T-1
and after checking the dimensions of expected_joints:

ssm/ssm/messages.py

Lines 186 to 198 in 6c856ad

# Compute E[z_t, z_{t+1}] for t = 1, ..., T-1
# Note that this is an array of size T*K*K, which can be quite large.
# To be a bit more frugal with memory, first check if the given log_Ps
# are TxKxK. If so, instantiate the full expected joints as well, since
# we will need them for the M-step. However, if log_Ps is 1xKxK then we
# know that the transition matrix is stationary, and all we need for the
# M-step is the sum of the expected joints.
stationary = (Ps.shape[0] == 1)
if not stationary:
expected_joints = alphas[:-1,:,None] + betas[1:,None,:] + ll[1:,None,:] + log_Ps
expected_joints -= expected_joints.max((1,2))[:,None, None]
expected_joints = np.exp(expected_joints)
expected_joints /= expected_joints.sum((1,2))[:,None,None]

I believe it should have dimensions (T-1, K, K). As a result Ez would have dimensions (T-1, K), but as shown above the time dimension was actually indexed using k1, which is a bit confusing to me.

Could you clarify if this behavior is intentional, or if there might be a mistake in how Ez is used? I may be missing something here, so I’d appreciate your insight. Thanks for your time and support!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant