We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 86b291c commit 0f4f40bCopy full SHA for 0f4f40b
src/qap_solvers/rrwm.py
@@ -20,7 +20,7 @@ def __init__(self, max_iter=50, sk_iter=20, alpha=0.2, beta=30):
20
self.max_iter = max_iter
21
self.alpha = alpha
22
self.beta = beta
23
- self.sk = Sinkhorn(max_iter=sk_iter,log_forward=False)
+ self.sk = Sinkhorn(max_iter=sk_iter)
24
25
def forward(self, M, num_src, ns_src, ns_tgt, v0=None):
26
d = M.sum(dim=2, keepdim=True)
@@ -45,7 +45,7 @@ def forward(self, M, num_src, ns_src, ns_tgt, v0=None):
45
s = v.view(batch_num, -1, num_src).transpose(1, 2)
46
s = torch.exp(self.beta * s / s.max(dim=1, keepdim=True).values.max(dim=2, keepdim=True).values)
47
48
- v = self.alpha * self.sk(s, ns_src, ns_tgt).transpose(1, 2).reshape(batch_num, mn, 1) + (1 - self.alpha) * v
+ v = self.alpha * self.sk(torch.log(s), ns_src, ns_tgt).transpose(1, 2).reshape(batch_num, mn, 1) + (1 - self.alpha) * v
49
n = torch.norm(v, p=1, dim=1, keepdim=True)
50
v = torch.matmul(v, 1 / n)
51
0 commit comments