2
2
import matplotlib .pyplot as plt
3
3
import seaborn as sns
4
4
import sklearn .datasets as data
5
+ from scipy .cluster import hierarchy
6
+ from collections import deque
5
7
6
8
7
9
class Utils :
@@ -10,21 +12,29 @@ def __init__(self):
10
12
11
13
@staticmethod
12
14
def get_data ():
13
- moons , _ = data .make_moons (n_samples = 50 , noise = 0.05 )
15
+ moons , _ = data .make_moons (n_samples = 50 , noise = 0.1 )
14
16
blobs , _ = data .make_blobs (n_samples = 50 , centers = [(- 0.75 , 2.25 ), (1.0 , 2.0 )], cluster_std = 0.25 )
15
17
test_data = np .vstack ([moons , blobs ])
16
18
17
19
return test_data
18
20
19
21
@staticmethod
20
- def plot_data (test_data , color = 'b' ):
22
+ def plot_data (test_data , edges , transformed_distance , color = 'b' ):
21
23
sns .set_context ('poster' )
22
24
sns .set_style ('white' )
23
25
sns .set_color_codes ()
24
26
plot_kwds = {'alpha' : 0.5 , 's' : 80 , 'linewidths' : 0 }
25
27
26
28
plt .scatter (test_data .T [0 ], test_data .T [1 ], color = color , ** plot_kwds )
27
- # plt.show()
29
+
30
+ for i in range (test_data .shape [0 ]):
31
+ for j in range (test_data .shape [0 ]):
32
+ if edges [i , j ]:
33
+ plt .plot ([test_data [i , 0 ], test_data [j , 0 ]], [test_data [i , 1 ], test_data [j , 1 ]], 'k-' ,
34
+ linewidth = transformed_distance [i , j ])
35
+ plt .show ()
36
+
37
+ return True
28
38
29
39
@staticmethod
30
40
def get_dist (test_data ):
@@ -39,18 +49,19 @@ def get_dist(test_data):
39
49
def core_dist (dist_mat , k = 5 ):
40
50
core_mat = np .zeros (shape = dist_mat .shape [0 ])
41
51
for i in range (dist_mat .shape [0 ]):
42
- dist = np .sort (dist_mat [i , :])[:: - 1 ]
43
- core_mat [i ] = dist [k - 1 ]
52
+ distance = np .sort (dist_mat [i , :])
53
+ core_mat [i ] = distance [k - 1 ]
44
54
45
55
return core_mat
46
56
47
57
@staticmethod
48
58
def transform_space (dist_mat , core_mat ):
59
+ transformed_dist_mat = np .zeros (shape = (dist_mat .shape [0 ], dist_mat .shape [0 ]))
49
60
for i in range (dist_mat .shape [0 ]):
50
61
for j in range (dist_mat .shape [0 ]):
51
- dist_mat [i , j ] = np .amax ([core_mat [i ], core_mat [j ], dist_mat [i , j ]])
62
+ transformed_dist_mat [i , j ] = np .amax ([core_mat [i ], core_mat [j ], dist_mat [i , j ]])
52
63
53
- return dist_mat
64
+ return transformed_dist_mat
54
65
55
66
@staticmethod
56
67
def prims_algorithm (dist_mat ):
@@ -60,23 +71,102 @@ def prims_algorithm(dist_mat):
60
71
visited = [0 ]
61
72
62
73
while len (not_visited ):
63
- vertex = np .argmin (dist_mat [visited , not_visited ])
64
- q , r = divmod (int (vertex ), len (visited ))
65
- if len (visited ) == 1 :
66
- r , q = q , r
74
+ vertex = np .argmin (dist_mat [visited , :][:, not_visited ])
75
+ q , r = divmod (int (vertex ), len (not_visited ))
67
76
q , r = visited [q ], not_visited [r ]
68
- edges [q , r ] = 1
77
+ edges [q , r ] = dist_mat [ q , r ]
69
78
visited .append (r )
70
79
not_visited .remove (r )
71
80
72
81
return edges
73
82
83
+ @staticmethod
84
+ def cluster_hierarchy (edges , num_points ):
85
+ out = np .zeros (shape = (num_points - 1 , 4 ))
86
+ e_sorted = np .argsort (edges .flatten ())[- (num_points - 1 ):]
87
+ hierarchy_nodes = {i : [i , 1 ] for i in range (num_points )}
88
+ hierarchy_clusters = {i : [i ] for i in range (num_points )}
89
+ hierarchy_tree = dict ()
90
+ curr_hierarchy = num_points - 1
91
+ for i in range (e_sorted .shape [0 ]):
92
+ q , r = divmod (e_sorted [i ], num_points )
93
+ s , t = hierarchy_nodes [q ][1 ], hierarchy_nodes [r ][1 ]
94
+ out [i , 0 ], out [i , 1 ] = hierarchy_nodes [q ][0 ], hierarchy_nodes [r ][0 ]
95
+ out [i , 2 ] = np .amax (edges [hierarchy_clusters [out [i , 0 ]], :][:, hierarchy_clusters [out [i , 1 ]]])
96
+ out [i , 3 ] = s + t
97
+ curr_hierarchy += 1
98
+ for j in hierarchy_clusters [out [i , 0 ]]:
99
+ hierarchy_nodes [j ] = [curr_hierarchy , out [i , 3 ]]
100
+ for k in hierarchy_clusters [out [i , 1 ]]:
101
+ hierarchy_nodes [k ] = [curr_hierarchy , out [i , 3 ]]
102
+ hierarchy_clusters .update ({curr_hierarchy : hierarchy_clusters [out [i , 0 ]] + hierarchy_clusters [out [i , 1 ]]})
103
+ hierarchy_tree .update ({curr_hierarchy : [[out [i , 0 ], s ], [out [i , 1 ], t ], out [i , 2 ]]})
104
+ return out , hierarchy_tree , hierarchy_clusters
105
+
106
+ @staticmethod
107
+ def plot_dendrogram (out ):
108
+ dendrogram = hierarchy .dendrogram (out )
109
+ plt .show ()
110
+
111
+ return dendrogram
112
+
113
+ @staticmethod
114
+ def condense_cluster_tree (hierarchy_tree , hierarchy_clusters , min_cluster_size , num_points ):
115
+ clusters_stabilities = dict ()
116
+
117
+ start = (num_points * 2 ) - 2
118
+
119
+ clusters_stack = deque ()
120
+ clusters_stack .append (start )
121
+
122
+ while len (clusters_stack ):
123
+ i = clusters_stack .pop ()
124
+ v = hierarchy_tree [i ]
125
+ cluster_birth = 1 / v [2 ]
126
+ points_lambda = []
127
+
128
+ while v [0 ][1 ] < min_cluster_size or v [1 ][1 ] < min_cluster_size :
129
+ next_v = None
130
+ if v [0 ][1 ] < min_cluster_size :
131
+ for _ in hierarchy_clusters [v [0 ][0 ]]:
132
+ points_lambda .append (1 / v [2 ])
133
+ else :
134
+ next_v = v [0 ][0 ]
135
+ if v [1 ][1 ] < min_cluster_size :
136
+ for _ in hierarchy_clusters [v [1 ][0 ]]:
137
+ points_lambda .append (1 / v [2 ])
138
+ else :
139
+ next_v = v [1 ][0 ]
140
+ v = hierarchy_tree [next_v ]
141
+
142
+ cluster_death_points_fall = len (hierarchy_clusters [i ]) - len (points_lambda )
143
+ cluster_death = 1 / v [2 ]
144
+ sum_stabilities = (cluster_birth - cluster_death ) * cluster_death_points_fall
145
+ sum_stabilities += - np .sum (np .array (points_lambda ) - cluster_birth )
146
+ clusters_stabilities .update ({i : sum_stabilities })
147
+
148
+ clusters_stack .append (v [0 ][0 ])
149
+ clusters_stack .append (v [1 ][0 ])
150
+
151
+ return clusters_stabilities
152
+
153
+ @staticmethod
154
+ def extract_clusters (hierarchy_clusters , clusters_stabilities ):
155
+
156
+ for k , v in hierarchy_clusters .items ()[:- 1 ]:
157
+ pass
158
+
159
+ return True
160
+
74
161
75
162
if __name__ == '__main__' :
76
163
data = Utils .get_data ()
77
- Utils .plot_data (data )
78
164
dist = Utils .get_dist (data )
79
165
core_dist = Utils .core_dist (dist )
80
166
transformed_dist = Utils .transform_space (dist , core_dist )
81
167
e = Utils .prims_algorithm (transformed_dist )
82
- pass
168
+ plot = Utils .plot_data (data , e , transformed_dist )
169
+ Z , tree , clusters = Utils .cluster_hierarchy (e , num_points = transformed_dist .shape [0 ])
170
+ # dn = Utils.plot_dendrogram(Z)
171
+ c_stabilities = Utils .condense_cluster_tree (tree , clusters , min_cluster_size = 5 , num_points = transformed_dist .shape [0 ])
172
+ clusters = Utils .extract_clusters (clusters , c_stabilities )
0 commit comments