8
8
from mpl_toolkits .mplot3d import Axes3D
9
9
10
10
11
-
12
11
def generate_Swissroll (n ):
13
- t = (3 * np .pi ) / 2 * (1 + 2 * tf .random .uniform ([1 , n ], minval = 0 , maxval = 1 , dtype = tf .float32 ))
12
+ t = (
13
+ (3 * np .pi )
14
+ / 2
15
+ * (1 + 2 * tf .random .uniform ([1 , n ], minval = 0 , maxval = 1 , dtype = tf .float32 ))
16
+ )
14
17
h = 20 * tf .random .uniform ([1 , n ], minval = 0 , maxval = 1 , dtype = tf .float32 )
15
18
a1 = tf .constant (t * tf .cos (t )) ##映射第一个轴
16
19
a3 = tf .constant (t * tf .sin (t )) ##映射第三个轴 ,第二个轴是h
17
20
X = tf .concat ([a1 , h , a3 ], axis = 0 ) ##组成数据样本
18
21
return X .numpy ().T
19
22
23
+
20
24
def whole_remove (dataset , knn , unit_hop , ratio , percentage ):
21
25
print ("reading" )
22
26
edist = dra .eucli_distance_all (dataset )
@@ -29,15 +33,18 @@ def whole_remove(dataset, knn, unit_hop, ratio, percentage):
29
33
print ("dict, ave_egr" )
30
34
path_index = dra .bone_path (path_dict , gdist )
31
35
weight = dra .bone_weight (path_dict , path_index )
32
- #print(weight)
33
- remove_tag = dra .dataset_compression_index (ave_egr , path_dict , gdist , unit_hop , ratio , path_index , weight )
36
+ # print(weight)
37
+ remove_tag = dra .dataset_compression_index (
38
+ ave_egr , path_dict , gdist , unit_hop , ratio , path_index , weight
39
+ )
34
40
print ("remove_tag" )
35
- #print(remove_tag)
41
+ # print(remove_tag)
36
42
rsub_data , rem_data = dra .dataset_compress (dataset , remove_tag , percentage )
37
43
print ("data" )
38
44
print (rsub_data .shape )
39
45
return rsub_data , rem_data
40
46
47
+
41
48
def whole_augment (dataset , knn , unit_hop , ratio , percentage ):
42
49
print ("reading" )
43
50
edist = dra .eucli_distance_all (dataset )
@@ -50,42 +57,48 @@ def whole_augment(dataset, knn, unit_hop, ratio, percentage):
50
57
print ("dict, ave_egr" )
51
58
path_index = dra .bone_path (path_dict , gdist )
52
59
weight = dra .bone_weight (path_dict , path_index )
53
- add_tag = dra .dataset_augment_index (ave_egr , path_dict , gdist , unit_hop , ratio , path_index , weight )
60
+ add_tag = dra .dataset_augment_index (
61
+ ave_egr , path_dict , gdist , unit_hop , ratio , path_index , weight
62
+ )
54
63
print ("add_tag" )
55
- asub_data , add_data = dra .dataset_augment (dataset , add_tag , percentage , edist , path_dict )
64
+ asub_data , add_data = dra .dataset_augment (
65
+ dataset , add_tag , percentage , edist , path_dict
66
+ )
56
67
print (np .shape (asub_data ))
57
68
return asub_data , add_data
58
69
70
+
59
71
def polt_swissroll (data , change ):
60
72
plt .figure ()
61
73
x , y , z = list (data .T [0 ]), list (data .T [1 ]), list (data .T [2 ])
62
74
x1 , y1 , z1 = list (change .T [0 ]), list (change .T [1 ]), list (change .T [2 ])
63
- ax = plt .subplot (111 , projection = '3d' )
64
- ax .scatter (x , y , z , s = 10 , alpha = 0.3 , c = 'r' )
65
- ax .scatter (x1 , y1 , z1 ,s = 10 , alpha = 0.8 , c = 'b' )
66
- ax .set_zlabel ('Z' ) # 坐标轴
67
- ax .set_ylabel ('Y' )
68
- ax .set_xlabel ('X' )
75
+ ax = plt .subplot (111 , projection = "3d" )
76
+ ax .scatter (x , y , z , s = 10 , alpha = 0.3 , c = "r" )
77
+ ax .scatter (x1 , y1 , z1 , s = 10 , alpha = 0.8 , c = "b" )
78
+ ax .set_zlabel ("Z" ) # 坐标轴
79
+ ax .set_ylabel ("Y" )
80
+ ax .set_xlabel ("X" )
69
81
plt .show ()
70
82
71
- #dataset = generate_Swissroll(500)
72
- dataset , t = skl .make_swiss_roll (n_samples = 1000 , noise = 0.1 )
73
- x ,y ,z = list (dataset .T [0 ]), list (dataset .T [1 ]), list (dataset .T [2 ])
74
- ax = plt .subplot (111 , projection = '3d' )
75
- ax .scatter (x , y , z , s = 10 , alpha = 0.3 , c = 'r' )
76
- ax .set_zlabel ('Z' ) # 坐标轴
77
- ax .set_ylabel ('Y' )
78
- ax .set_xlabel ('X' )
83
+
84
+ # dataset = generate_Swissroll(500)
85
+ dataset , t = skl .make_swiss_roll (n_samples = 1000 , noise = 0.1 )
86
+ x , y , z = list (dataset .T [0 ]), list (dataset .T [1 ]), list (dataset .T [2 ])
87
+ ax = plt .subplot (111 , projection = "3d" )
88
+ ax .scatter (x , y , z , s = 10 , alpha = 0.3 , c = "r" )
89
+ ax .set_zlabel ("Z" ) # 坐标轴
90
+ ax .set_ylabel ("Y" )
91
+ ax .set_xlabel ("X" )
79
92
plt .show ()
80
- #polt_swissroll(dataset)
93
+ # polt_swissroll(dataset)
81
94
82
95
dataset_cafter , sub_data = whole_remove (dataset , 5 , 0.3 , 0.9 , 0.1 )
83
96
print (np .shape (sub_data ))
84
97
polt_swissroll (dataset_cafter , np .array (sub_data ))
85
98
86
- dataset_aafter , add_data = whole_augment (dataset ,5 , 0.9 , 0.9 , 0.1 )
99
+ dataset_aafter , add_data = whole_augment (dataset , 5 , 0.9 , 0.9 , 0.1 )
87
100
polt_swissroll (dataset , np .array (add_data ))
88
101
89
- np .savetxt ("../swiss_roll/dataset.txt" , dataset , fmt = '%f' , delimiter = ',' )
90
- np .savetxt ("../swiss_roll/dataset_rem.txt" , dataset_cafter , fmt = '%f' , delimiter = ',' )
91
- np .savetxt ("../swiss_roll/dataset_add.txt" , dataset_aafter , fmt = '%f' , delimiter = ',' )
102
+ np .savetxt ("../swiss_roll/dataset.txt" , dataset , fmt = "%f" , delimiter = "," )
103
+ np .savetxt ("../swiss_roll/dataset_rem.txt" , dataset_cafter , fmt = "%f" , delimiter = "," )
104
+ np .savetxt ("../swiss_roll/dataset_add.txt" , dataset_aafter , fmt = "%f" , delimiter = "," )
0 commit comments