Skip to content
This repository was archived by the owner on Jan 15, 2025. It is now read-only.

Commit 286da45

Browse files
committed
1. modified geo.py, as beta, not complete
2. dataset package 3. umap, swissroll 4. readme, req-hhw_code.txt modified
1 parent 98ac51f commit 286da45

25 files changed

+2003
-401
lines changed

README.md

+68-26
Original file line numberDiff line numberDiff line change
@@ -30,17 +30,22 @@ conda install -n geo ipykernel --update-deps --force-reinstall
3030

3131
- `/code/ ``/code_new/`:原始代码code以及简单修改后的代码code_new
3232
- `/hhw_code/`:基于pytorch重构的代码
33-
- `main.ipynb`:notebook运行测试样例
34-
- `run_test.py`:脚本文件运行二测试样例
35-
- `geo.py`:实现数据集几何特征的分析,多进程处理
36-
- `utils.py`:通用工具函数
37-
- `network.py`:用于图像分类的模型、训练函数
38-
- `data_utils.py`:用于数据集处理、加载的相关函数
39-
- `app_utils.py`:应用数据集几何特征分析的结果,如数据集压缩或增强
40-
- `dim_reduce.py`:用于数据降维等预处理
41-
- `test.py`:用于测试项目中各个函数、网络模型
42-
- `results/`:存有数据集几何特征分析的结果,使用pickle存储
43-
- `pics/`:训练与测试分类网络时的图片
33+
- `main.ipynb`:notebook运行测试样例;
34+
- ~~`run_test.py`:脚本文件运行测试样例;~~
35+
- `geo.py`:实现数据集几何特征的分析,多进程处理;
36+
- `utils.py`:通用工具函数;
37+
- `network.py`:用于图像分类的模型、训练函数;
38+
- ~~`data_utils.py`:用于数据集处理、加载的相关函数;~~
39+
- `app_utils.py`:应用数据集几何特征分析的结果,如数据集压缩或增强;
40+
- `dim_reduce.py`:用于数据降维等预处理;
41+
- `test_all.py`:用于测试项目中各个函数、网络模型;
42+
- `datasets/`:数据集自写库
43+
- `data_utils.py`:仅有 load_data_mnist,暂无用处,下一步调整或删除;
44+
- `dataset.py`:数据集基类;
45+
- `mnist.py`:MNIST 数据集类,有待进一步扩充;
46+
- `results/`:存有数据集几何特征分析的结果,使用pickle存储;
47+
- `pics/`:训练与测试分类网络时的图片;
48+
- 其他,暂未列出。
4449

4550
### 运行
4651

@@ -80,21 +85,58 @@ req-hhw_code.txt:
8085
* `torchvision`:pytorch项目的一部分,由常用数据集、模型架构和计算机视觉常用图像转换组成。项目中用于加载与处理数据集。
8186
* `matplotlib`:用于在 Python 中创建静态、动画和交互式可视化的综合库。用于可视化,也就是绘图。
8287
* `seaborn`:基于 matplotlib 的 Python 数据可视化库。与matplotlib相比,其提供了更高级的封装,能更方便地绘制美观且信息丰富的统计图形。用于可视化,也就是绘图。
88+
* `umap-learn`:UMAP降维库。
89+
* `cProfile`:python 性能分析库。
90+
* `idx2numpy`:从下载的数据集文件中提取。
91+
* `pickle`:存储与加载 Python 对象。
8392

8493
#### 其他
8594

86-
- 编写 `/hhw_code/test.py` 中对 `get_class_geo_feature` 进行测试时发现,计算“平均欧式-测地距离比值” (`ave_egr` 时发现,按论文中描述的计算方法,无需使用 “k-近邻测地线距离” (`geo_dist`),可以极大地化简算法,有待进一步分析确认。
87-
88-
- 需要对算法中间所得的数据进行进一步分析,可视化、分析其分布,从而深入理解算法,产生新想法。
89-
90-
- 论文中提出用 “骨干路径” (bone_path) 来避免过多地考虑子路径。但在测试时发现, “骨干路径” 非常多,非常短,当选取 200 张 mnist 0 图像,k = 5 时,骨干路径占比超过 80%, 路径长度集中在 5-6 个结点(包括起始和目的结点)。由此感到 “骨干路径” 的特征描述能力较弱,能否提出更强更有效的特征描述指标呢?
91-
92-
- 分析发现中间数据矩阵稀疏性较高,如何利用稀疏矩阵来更有效地计算和存储数据呢?
93-
94-
- 对于高维,数值较大,噪音较大的数据,欧氏距离 L2 范数容易受极端值的影响,是否能用其他范数,如 L1 范数?
95-
96-
- 降维 / 特征提取 后再进行“几何”特征提取?
97-
98-
- 只考虑局部状态,却要进行全局计算,开销较大,如何解决?
99-
100-
- 如何分析其他模态的数据?如携带时序信息的文本。
95+
23.8.13
96+
97+
- [X] 编写 `/hhw_code/test.py` 中对 `get_class_geo_feature` 进行测试时发现,计算“平均欧式-测地距离比值” (`ave_egr` 时发现,按论文中描述的计算方法,无需使用 “k-近邻测地线距离” (`geo_dist`),可以极大地化简算法,有待进一步分析确认。
98+
- [X] 需要对算法中间所得的数据进行进一步分析,可视化、分析其分布,从而深入理解算法,产生新想法。
99+
- [ ] 论文中提出用 “骨干路径” (bone_path) 来避免过多地考虑子路径。但在测试时发现, “骨干路径” 非常多,非常短,当选取 200 张 mnist 0 图像,k = 5 时,骨干路径占比超过 80%, 路径长度集中在 5-6 个结点(包括起始和目的结点)。由此感到 “骨干路径” 的特征描述能力较弱,能否提出更强更有效的特征描述指标呢?
100+
- [ ] 分析发现中间数据矩阵稀疏性较高,如何利用稀疏矩阵来更有效地计算和存储数据呢?
101+
- [ ] 对于高维,数值较大,噪音较大的数据,欧氏距离 L2 范数容易受极端值的影响,是否能用其他范数,如 L1 范数?
102+
- [ ] 降维 / 特征提取 后再进行“几何”特征提取?
103+
- [ ] 只考虑局部状态,却要进行全局计算,开销较大,如何解决?
104+
- [ ] 如何分析其他模态的数据?如携带时序信息的文本。
105+
106+
23.8.15
107+
108+
* 8.13中无需使用 “k-近邻测地线距离”,进一步修改调整代码,砍去了很多不必要的数据,做了若干优化,还未完全融合,改进版函数暂时以 xx_beta 为函数名。
109+
* [ ] TODO:感觉可以将 feature 封装成“对象”,采用面向对象的设计理念,简化程序复杂的逻辑和函数调用。
110+
111+
- 在 MNIST 数字 0 上进行了测试,提取了几何特征,然后数据压缩。将压缩后的数据与去除的数据用 UMAP 降维到 2 维,绘图如下:
112+
`visualize.ipynb`
113+
![mnist_num0_umap_compress](image/README/mnist_num0_umap_compress.png)
114+
发现程混合状,降维后导致几何特征消失?
115+
从剩余的数据和被去除的数据中选取部分样本,绘图如下:
116+
剩余的数据样本:
117+
![mnist_tsz10000_k5_resdemo](image/README/mnist_tsz10000_k5_resdemo.png)
118+
被去除的数据样本:
119+
![mnist_tsz10000_k5_removed_demo](image/README/mnist_tsz10000_k5_removed_demo.png)
120+
- 在 SwissRoll 人造小数据集上测试,下面是 aegr 最大的若干测地线路径,`test_all.py test_swissroll`
121+
![swiss_roll_aegr](image/README/swiss_roll_aegr.png)
122+
- 在 hhw_code 中 data_compress 时,目前尚未利用 udist 即单位结点对应的测地线距离长度,而路径越长,aegr越接近 1 ,导致偏差。
123+
- **问题**:直接使用欧氏距离来评判两个样本是否“靠近”,无法反映出图片数据所具有的“平移不变性”,即两张相同的图片,其中一张略微位移一下,就会导致欧氏距离非常大。
124+
也就是说,“欧氏距离”无法反映图片数据的“空间信息” 。如果这样的话,基于“欧氏距离”的“数据集合特征”,就难以反映出数据的“语义”信息。
125+
- 对 minst 0 的 geo feature 数据进行分析,绘制成柱状图:
126+
![img](image/README/mnit_0_k15_sz1001_euc&geo.png)
127+
![img](image/README/mnit_0_k15_sz1001_weight&len.png)
128+
![img](image/README/mnit_0_k15_sz1001_aegr&udist.png)
129+
- TODO:对 swissroll 的geo feature 数据进行分析,绘制成柱状图:
130+
(初步发现欧几里得性质较好的 swissroll 数据,aegr 普遍接近于 1)
131+
- UMAP 对 MNIST 进行降维可视化:`umap.ipynb`
132+
![mnist_umap](image/README/mnist_umap.png)
133+
- 用压缩后的 MNIST 数据选出 1w 张进行分类测试,测试集为原 1w 测试数据,网络采用略修改的 LeNet,结果如下:
134+
135+
![](image/README/train_compare.png)
136+
137+
- - $(a)$:在 1w 数据上提取几何特征,做数据压缩,压缩率为 0.8, 然后在压缩后的数据上训练。
138+
- $(b)$:直接在 1w 数据上进行训练。
139+
- $(c)$:在 1w 数据上随机选出 0.8 * 1w 的数据进行训练。
140+
- 收敛更快、更稳定一些,test acc 后期更稳定,过拟合现象有所缓解。
141+
- 需要进一步地集中测试和分析。
142+
- 感觉”几何特征“消耗大,效果并不显著,几何特征与语义无法找到很好的映射关系,不太好解释(;へ:)

code/bone_swissroll.py

+39-26
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,19 @@
88
from mpl_toolkits.mplot3d import Axes3D
99

1010

11-
1211
def generate_Swissroll(n):
13-
t = (3 * np.pi) / 2 * (1 + 2 * tf.random.uniform([1, n], minval=0, maxval=1, dtype=tf.float32))
12+
t = (
13+
(3 * np.pi)
14+
/ 2
15+
* (1 + 2 * tf.random.uniform([1, n], minval=0, maxval=1, dtype=tf.float32))
16+
)
1417
h = 20 * tf.random.uniform([1, n], minval=0, maxval=1, dtype=tf.float32)
1518
a1 = tf.constant(t * tf.cos(t)) ##映射第一个轴
1619
a3 = tf.constant(t * tf.sin(t)) ##映射第三个轴 ,第二个轴是h
1720
X = tf.concat([a1, h, a3], axis=0) ##组成数据样本
1821
return X.numpy().T
1922

23+
2024
def whole_remove(dataset, knn, unit_hop, ratio, percentage):
2125
print("reading")
2226
edist = dra.eucli_distance_all(dataset)
@@ -29,15 +33,18 @@ def whole_remove(dataset, knn, unit_hop, ratio, percentage):
2933
print("dict, ave_egr")
3034
path_index = dra.bone_path(path_dict, gdist)
3135
weight = dra.bone_weight(path_dict, path_index)
32-
#print(weight)
33-
remove_tag = dra.dataset_compression_index(ave_egr, path_dict, gdist, unit_hop, ratio, path_index, weight)
36+
# print(weight)
37+
remove_tag = dra.dataset_compression_index(
38+
ave_egr, path_dict, gdist, unit_hop, ratio, path_index, weight
39+
)
3440
print("remove_tag")
35-
#print(remove_tag)
41+
# print(remove_tag)
3642
rsub_data, rem_data = dra.dataset_compress(dataset, remove_tag, percentage)
3743
print("data")
3844
print(rsub_data.shape)
3945
return rsub_data, rem_data
4046

47+
4148
def whole_augment(dataset, knn, unit_hop, ratio, percentage):
4249
print("reading")
4350
edist = dra.eucli_distance_all(dataset)
@@ -50,42 +57,48 @@ def whole_augment(dataset, knn, unit_hop, ratio, percentage):
5057
print("dict, ave_egr")
5158
path_index = dra.bone_path(path_dict, gdist)
5259
weight = dra.bone_weight(path_dict, path_index)
53-
add_tag = dra.dataset_augment_index(ave_egr, path_dict, gdist, unit_hop, ratio, path_index, weight)
60+
add_tag = dra.dataset_augment_index(
61+
ave_egr, path_dict, gdist, unit_hop, ratio, path_index, weight
62+
)
5463
print("add_tag")
55-
asub_data, add_data = dra.dataset_augment(dataset, add_tag, percentage, edist, path_dict)
64+
asub_data, add_data = dra.dataset_augment(
65+
dataset, add_tag, percentage, edist, path_dict
66+
)
5667
print(np.shape(asub_data))
5768
return asub_data, add_data
5869

70+
5971
def polt_swissroll(data, change):
6072
plt.figure()
6173
x, y, z = list(data.T[0]), list(data.T[1]), list(data.T[2])
6274
x1, y1, z1 = list(change.T[0]), list(change.T[1]), list(change.T[2])
63-
ax = plt.subplot(111, projection='3d')
64-
ax.scatter(x, y, z, s=10, alpha=0.3, c='r')
65-
ax.scatter(x1, y1, z1,s=10, alpha=0.8, c='b')
66-
ax.set_zlabel('Z') # 坐标轴
67-
ax.set_ylabel('Y')
68-
ax.set_xlabel('X')
75+
ax = plt.subplot(111, projection="3d")
76+
ax.scatter(x, y, z, s=10, alpha=0.3, c="r")
77+
ax.scatter(x1, y1, z1, s=10, alpha=0.8, c="b")
78+
ax.set_zlabel("Z") # 坐标轴
79+
ax.set_ylabel("Y")
80+
ax.set_xlabel("X")
6981
plt.show()
7082

71-
#dataset = generate_Swissroll(500)
72-
dataset, t = skl.make_swiss_roll(n_samples = 1000, noise = 0.1)
73-
x,y,z = list(dataset.T[0]), list(dataset.T[1]), list(dataset.T[2])
74-
ax = plt.subplot(111, projection='3d')
75-
ax.scatter(x, y, z, s=10, alpha=0.3, c='r')
76-
ax.set_zlabel('Z') # 坐标轴
77-
ax.set_ylabel('Y')
78-
ax.set_xlabel('X')
83+
84+
# dataset = generate_Swissroll(500)
85+
dataset, t = skl.make_swiss_roll(n_samples=1000, noise=0.1)
86+
x, y, z = list(dataset.T[0]), list(dataset.T[1]), list(dataset.T[2])
87+
ax = plt.subplot(111, projection="3d")
88+
ax.scatter(x, y, z, s=10, alpha=0.3, c="r")
89+
ax.set_zlabel("Z") # 坐标轴
90+
ax.set_ylabel("Y")
91+
ax.set_xlabel("X")
7992
plt.show()
80-
#polt_swissroll(dataset)
93+
# polt_swissroll(dataset)
8194

8295
dataset_cafter, sub_data = whole_remove(dataset, 5, 0.3, 0.9, 0.1)
8396
print(np.shape(sub_data))
8497
polt_swissroll(dataset_cafter, np.array(sub_data))
8598

86-
dataset_aafter, add_data = whole_augment(dataset,5, 0.9, 0.9, 0.1)
99+
dataset_aafter, add_data = whole_augment(dataset, 5, 0.9, 0.9, 0.1)
87100
polt_swissroll(dataset, np.array(add_data))
88101

89-
np.savetxt("../swiss_roll/dataset.txt", dataset, fmt='%f',delimiter=',')
90-
np.savetxt("../swiss_roll/dataset_rem.txt", dataset_cafter, fmt='%f',delimiter=',')
91-
np.savetxt("../swiss_roll/dataset_add.txt", dataset_aafter, fmt='%f',delimiter=',')
102+
np.savetxt("../swiss_roll/dataset.txt", dataset, fmt="%f", delimiter=",")
103+
np.savetxt("../swiss_roll/dataset_rem.txt", dataset_cafter, fmt="%f", delimiter=",")
104+
np.savetxt("../swiss_roll/dataset_add.txt", dataset_aafter, fmt="%f", delimiter=",")

0 commit comments

Comments
 (0)