Skip to content

Commit 0f4f40b

Browse files
committed
fix RRWM sovler with the new Sinkhorn implementation
1 parent 86b291c commit 0f4f40b

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/qap_solvers/rrwm.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def __init__(self, max_iter=50, sk_iter=20, alpha=0.2, beta=30):
2020
self.max_iter = max_iter
2121
self.alpha = alpha
2222
self.beta = beta
23-
self.sk = Sinkhorn(max_iter=sk_iter,log_forward=False)
23+
self.sk = Sinkhorn(max_iter=sk_iter)
2424

2525
def forward(self, M, num_src, ns_src, ns_tgt, v0=None):
2626
d = M.sum(dim=2, keepdim=True)
@@ -45,7 +45,7 @@ def forward(self, M, num_src, ns_src, ns_tgt, v0=None):
4545
s = v.view(batch_num, -1, num_src).transpose(1, 2)
4646
s = torch.exp(self.beta * s / s.max(dim=1, keepdim=True).values.max(dim=2, keepdim=True).values)
4747

48-
v = self.alpha * self.sk(s, ns_src, ns_tgt).transpose(1, 2).reshape(batch_num, mn, 1) + (1 - self.alpha) * v
48+
v = self.alpha * self.sk(torch.log(s), ns_src, ns_tgt).transpose(1, 2).reshape(batch_num, mn, 1) + (1 - self.alpha) * v
4949
n = torch.norm(v, p=1, dim=1, keepdim=True)
5050
v = torch.matmul(v, 1 / n)
5151

0 commit comments

Comments
 (0)