-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtriest_impr.py
65 lines (50 loc) · 1.76 KB
/
triest_impr.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import random
from edge_sample import EdgeSample
from collections import defaultdict
class TriestImpr:
def __init__(self,M):
self._M = M
self._sample = EdgeSample()
self._globalT = 0
self._localT = {}
self._t = 0
def sample_edge(self,u,v):
if self._t <= self._M:
return True
elif self.flip_biased_coin():
u_dash, v_dash = self._sample.remove_random_edge()
return True
return False
def update_counters(self,u,v,op):
common_neighborhood = self._sample.get_intersection_neighborhood(u,v)
if not common_neighborhood:
return
increment_t = max(1, int(((self._t-1)*(self._t-2))/(self._M * (self._M - 1))))
for c in common_neighborhood:
if op == '+':
self._globalT += increment_t
if c in self._localT:
self._localT[c] += increment_t
else:
self._localT[c] = increment_t
if u in self._localT:
self._localT[u] += increment_t
else:
self._localT[u] = increment_t
if v in self._localT:
self._localT[v] += increment_t
else:
self._localT[v] = increment_t
def flip_biased_coin(self):
head_prob = random.random()
if head_prob <= self._M/self._t:
return True
else:
return False
def return_counters(self):
return {'global':self._globalT,'local':self._localT}
def run(self,u,v):
self._t += 1
self.update_counters(u,v,'+')
if self.sample_edge(u,v):
self._sample.add_edge(u,v)