Skip to content

Commit eed4b5b

Browse files
committed
Merge remote-tracking branch 'origin/master'
2 parents dc4be85 + 7940d11 commit eed4b5b

File tree

2 files changed

+105
-14
lines changed

2 files changed

+105
-14
lines changed

hdbscan.py

+104-14
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import matplotlib.pyplot as plt
33
import seaborn as sns
44
import sklearn.datasets as data
5+
from scipy.cluster import hierarchy
6+
from collections import deque
57

68

79
class Utils:
@@ -10,21 +12,29 @@ def __init__(self):
1012

1113
@staticmethod
1214
def get_data():
13-
moons, _ = data.make_moons(n_samples=50, noise=0.05)
15+
moons, _ = data.make_moons(n_samples=50, noise=0.1)
1416
blobs, _ = data.make_blobs(n_samples=50, centers=[(-0.75, 2.25), (1.0, 2.0)], cluster_std=0.25)
1517
test_data = np.vstack([moons, blobs])
1618

1719
return test_data
1820

1921
@staticmethod
20-
def plot_data(test_data, color='b'):
22+
def plot_data(test_data, edges, transformed_distance, color='b'):
2123
sns.set_context('poster')
2224
sns.set_style('white')
2325
sns.set_color_codes()
2426
plot_kwds = {'alpha': 0.5, 's': 80, 'linewidths': 0}
2527

2628
plt.scatter(test_data.T[0], test_data.T[1], color=color, **plot_kwds)
27-
# plt.show()
29+
30+
for i in range(test_data.shape[0]):
31+
for j in range(test_data.shape[0]):
32+
if edges[i, j]:
33+
plt.plot([test_data[i, 0], test_data[j, 0]], [test_data[i, 1], test_data[j, 1]], 'k-',
34+
linewidth=transformed_distance[i, j])
35+
plt.show()
36+
37+
return True
2838

2939
@staticmethod
3040
def get_dist(test_data):
@@ -39,18 +49,19 @@ def get_dist(test_data):
3949
def core_dist(dist_mat, k=5):
4050
core_mat = np.zeros(shape=dist_mat.shape[0])
4151
for i in range(dist_mat.shape[0]):
42-
dist = np.sort(dist_mat[i, :])[::-1]
43-
core_mat[i] = dist[k-1]
52+
distance = np.sort(dist_mat[i, :])
53+
core_mat[i] = distance[k-1]
4454

4555
return core_mat
4656

4757
@staticmethod
4858
def transform_space(dist_mat, core_mat):
59+
transformed_dist_mat = np.zeros(shape=(dist_mat.shape[0], dist_mat.shape[0]))
4960
for i in range(dist_mat.shape[0]):
5061
for j in range(dist_mat.shape[0]):
51-
dist_mat[i, j] = np.amax([core_mat[i], core_mat[j], dist_mat[i, j]])
62+
transformed_dist_mat[i, j] = np.amax([core_mat[i], core_mat[j], dist_mat[i, j]])
5263

53-
return dist_mat
64+
return transformed_dist_mat
5465

5566
@staticmethod
5667
def prims_algorithm(dist_mat):
@@ -60,23 +71,102 @@ def prims_algorithm(dist_mat):
6071
visited = [0]
6172

6273
while len(not_visited):
63-
vertex = np.argmin(dist_mat[visited, not_visited])
64-
q, r = divmod(int(vertex), len(visited))
65-
if len(visited) == 1:
66-
r, q = q, r
74+
vertex = np.argmin(dist_mat[visited, :][:, not_visited])
75+
q, r = divmod(int(vertex), len(not_visited))
6776
q, r = visited[q], not_visited[r]
68-
edges[q, r] = 1
77+
edges[q, r] = dist_mat[q, r]
6978
visited.append(r)
7079
not_visited.remove(r)
7180

7281
return edges
7382

83+
@staticmethod
84+
def cluster_hierarchy(edges, num_points):
85+
out = np.zeros(shape=(num_points-1, 4))
86+
e_sorted = np.argsort(edges.flatten())[-(num_points - 1):]
87+
hierarchy_nodes = {i: [i, 1] for i in range(num_points)}
88+
hierarchy_clusters = {i: [i] for i in range(num_points)}
89+
hierarchy_tree = dict()
90+
curr_hierarchy = num_points - 1
91+
for i in range(e_sorted.shape[0]):
92+
q, r = divmod(e_sorted[i], num_points)
93+
s, t = hierarchy_nodes[q][1], hierarchy_nodes[r][1]
94+
out[i, 0], out[i, 1] = hierarchy_nodes[q][0], hierarchy_nodes[r][0]
95+
out[i, 2] = np.amax(edges[hierarchy_clusters[out[i, 0]], :][:, hierarchy_clusters[out[i, 1]]])
96+
out[i, 3] = s + t
97+
curr_hierarchy += 1
98+
for j in hierarchy_clusters[out[i, 0]]:
99+
hierarchy_nodes[j] = [curr_hierarchy, out[i, 3]]
100+
for k in hierarchy_clusters[out[i, 1]]:
101+
hierarchy_nodes[k] = [curr_hierarchy, out[i, 3]]
102+
hierarchy_clusters.update({curr_hierarchy: hierarchy_clusters[out[i, 0]] + hierarchy_clusters[out[i, 1]]})
103+
hierarchy_tree.update({curr_hierarchy: [[out[i, 0], s], [out[i, 1], t], out[i, 2]]})
104+
return out, hierarchy_tree, hierarchy_clusters
105+
106+
@staticmethod
107+
def plot_dendrogram(out):
108+
dendrogram = hierarchy.dendrogram(out)
109+
plt.show()
110+
111+
return dendrogram
112+
113+
@staticmethod
114+
def condense_cluster_tree(hierarchy_tree, hierarchy_clusters, min_cluster_size, num_points):
115+
clusters_stabilities = dict()
116+
117+
start = (num_points*2) - 2
118+
119+
clusters_stack = deque()
120+
clusters_stack.append(start)
121+
122+
while len(clusters_stack):
123+
i = clusters_stack.pop()
124+
v = hierarchy_tree[i]
125+
cluster_birth = 1/v[2]
126+
points_lambda = []
127+
128+
while v[0][1] < min_cluster_size or v[1][1] < min_cluster_size:
129+
next_v = None
130+
if v[0][1] < min_cluster_size:
131+
for _ in hierarchy_clusters[v[0][0]]:
132+
points_lambda.append(1/v[2])
133+
else:
134+
next_v = v[0][0]
135+
if v[1][1] < min_cluster_size:
136+
for _ in hierarchy_clusters[v[1][0]]:
137+
points_lambda.append(1/v[2])
138+
else:
139+
next_v = v[1][0]
140+
v = hierarchy_tree[next_v]
141+
142+
cluster_death_points_fall = len(hierarchy_clusters[i]) - len(points_lambda)
143+
cluster_death = 1/v[2]
144+
sum_stabilities = (cluster_birth - cluster_death) * cluster_death_points_fall
145+
sum_stabilities += -np.sum(np.array(points_lambda) - cluster_birth)
146+
clusters_stabilities.update({i: sum_stabilities})
147+
148+
clusters_stack.append(v[0][0])
149+
clusters_stack.append(v[1][0])
150+
151+
return clusters_stabilities
152+
153+
@staticmethod
154+
def extract_clusters(hierarchy_clusters, clusters_stabilities):
155+
156+
for k, v in hierarchy_clusters.items()[:-1]:
157+
pass
158+
159+
return True
160+
74161

75162
if __name__ == '__main__':
76163
data = Utils.get_data()
77-
Utils.plot_data(data)
78164
dist = Utils.get_dist(data)
79165
core_dist = Utils.core_dist(dist)
80166
transformed_dist = Utils.transform_space(dist, core_dist)
81167
e = Utils.prims_algorithm(transformed_dist)
82-
pass
168+
plot = Utils.plot_data(data, e, transformed_dist)
169+
Z, tree, clusters = Utils.cluster_hierarchy(e, num_points=transformed_dist.shape[0])
170+
# dn = Utils.plot_dendrogram(Z)
171+
c_stabilities = Utils.condense_cluster_tree(tree, clusters, min_cluster_size=5, num_points=transformed_dist.shape[0])
172+
clusters = Utils.extract_clusters(clusters, c_stabilities)

requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
scipy
12
sklearn
23
seaborn
34
tensorflow

0 commit comments

Comments
 (0)