-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtriest_base.py
87 lines (65 loc) · 2.36 KB
/
triest_base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import random
from edge_sample import EdgeSample
from collections import defaultdict
class TriestBase:
def __init__(self,M):
self._M = M
self._sample = EdgeSample()
self._globalT = 0
self._localT = {}
self._t = 0
def sample_edge(self,u,v):
if self._t <= self._M:
return True
elif self.flip_biased_coin():
u_dash, v_dash = self._sample.remove_random_edge()
self.update_counters(u_dash,v_dash,'-')
return True
return False
def update_counters(self,u,v,op):
common_neighborhood = self._sample.get_intersection_neighborhood(u,v)
if not common_neighborhood:
return
for c in common_neighborhood:
if op == '+':
self._globalT += 1
if c in self._localT:
self._localT[c] += 1
else:
self._localT[c] = 1
if u in self._localT:
self._localT[u] += 1
else:
self._localT[u] =1
if v in self._localT:
self._localT[v] += 1
else:
self._localT[v] = 1
elif op == '-':
self._globalT -= 1
self._localT[c] -= 1
if self._localT[c] == 0:
self._localT.pop(c)
self._localT[u] -= 1
if self._localT[u] == 0:
self._localT.pop(u)
self._localT[v] -= 1
if self._localT[v] == 0:
self._localT.pop(v)
def flip_biased_coin(self):
head_prob = random.random()
if head_prob <= self._M/self._t:
return True
else:
return False
def return_counters(self):
estimate = max(1, (self._t * (self._t - 1) * (self._t - 2))/(self._M * (self._M - 1) * (self._M - 2)))
global_estimate = int(estimate * self._globalT)
for key in self._localT:
self._localT[key] = int(self._localT[key] * estimate)
return {'global':global_estimate,'local':self._localT}
def run(self,u,v):
self._t += 1
if self.sample_edge(u,v):
self._sample.add_edge(u,v)
self.update_counters(u,v,'+')