Skip to content

Commit fa315f0

Browse files
committedFeb 25, 2016
update utils
1 parent 6766702 commit fa315f0

15 files changed

+234
-0
lines changed
 

‎dataset/MNIST/README.md

+72
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
## 介绍
2+
3+
在学习机器学习的时候,首当其冲的就是准备一份通用的数据集,方便与其他的算法进行比较。在这里,我写了一个用于加载MNIST数据集的方法,并将其进行封装,主要用于将MNIST数据集转换成numpy.array()格式的训练数据。直接下面看下面的代码吧!
4+
5+
MNIST数据集原网址:[http://yann.lecun.com/exdb/mnist/](http://yann.lecun.com/exdb/mnist/)
6+
7+
## 文件目录
8+
9+
- ../../utils/data_util.py 用于加载MNIST数据集方法文件
10+
- ../../utils/test.py 用于测试的文件,一个简单的KNN测试MNIST数据集
11+
- ./train-images.idx3-ubyte 训练集X
12+
- ./train-labels.idx1-ubyte 训练集y
13+
- ./t10k-images.idx3-ubyte 测试集X
14+
- ./t10k-labels.idx1-ubyte 测试集y
15+
16+
17+
18+
## 源码
19+
20+
[../../utils/data_util.py](../../utils/data_util.py)文件
21+
22+
23+
[../../utils/test.py](../../utils/test.py)文件:简单地测试了一下KNN算法,代码如下
24+
25+
```
26+
# -*- coding: utf-8 -*-
27+
"""
28+
Created on Thu Feb 25 16:09:58 2016
29+
Test MNIST dataset
30+
@author: liudiwei
31+
"""
32+
33+
from sklearn import neighbors
34+
from data_util import DataUtils
35+
import datetime
36+
37+
38+
def main():
39+
trainfile_X = '../dataset/MNIST/train-images.idx3-ubyte'
40+
trainfile_y = '../dataset/MNIST/train-labels.idx1-ubyte'
41+
testfile_X = '../dataset/MNIST/t10k-images.idx3-ubyte'
42+
testfile_y = '../dataset/MNIST/t10k-labels.idx1-ubyte'
43+
train_X = DataUtils(filename=trainfile_X).getImage()
44+
train_y = DataUtils(filename=trainfile_y).getLabel()
45+
test_X = DataUtils(testfile_X).getImage()
46+
test_y = DataUtils(testfile_y).getLabel()
47+
48+
return train_X, train_y, test_X, test_y
49+
50+
51+
def testKNN():
52+
train_X, train_y, test_X, test_y = main()
53+
startTime = datetime.datetime.now()
54+
knn = neighbors.KNeighborsClassifier(n_neighbors=3)
55+
knn.fit(train_X, train_y)
56+
match = 0;
57+
for i in xrange(len(test_y)):
58+
predictLabel = knn.predict(test_X[i])[0]
59+
if(predictLabel==test_y[i]):
60+
match += 1
61+
62+
endTime = datetime.datetime.now()
63+
print 'use time: '+str(endTime-startTime)
64+
print 'error rate: '+ str(1-(match*1.0/len(test_y)))
65+
66+
if __name__ == "__main__":
67+
testKNN()
68+
```
69+
70+
通过main方法,最后直接返回numpy.array()格式的数据:train_X, train_y, test_X, test_y。如果你需要,直接条用main方法即可!
71+
72+
---

‎dataset/MNIST/mnist.zip

28.8 MB
Binary file not shown.
1.57 MB
Binary file not shown.

‎dataset/MNIST/t10k-images.idx3-ubyte

7.48 MB
Binary file not shown.
4.44 KB
Binary file not shown.

‎dataset/MNIST/t10k-labels.idx1-ubyte

9.77 KB
Binary file not shown.
9.45 MB
Binary file not shown.

‎dataset/MNIST/train-images.idx3-ubyte

44.9 MB
Binary file not shown.
28.2 KB
Binary file not shown.

‎dataset/MNIST/train-labels.idx1-ubyte

58.6 KB
Binary file not shown.

‎doc/05_stata_boosting.pdf

644 KB
Binary file not shown.

‎utils/README.md

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
## 目录介绍
2+
3+
- data_util.py用于加载MNIST数据集
4+
- test.py 使用KNN算法测试MNIST数据集
5+
6+
## 依赖
7+
8+
- numpy
9+
- matplotlib
10+
- scikit-learn

‎utils/data_util.py

+110
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
Created on Thu Feb 25 14:40:06 2016
4+
load MNIST dataset
5+
@author: liudiwei
6+
"""
7+
import numpy as np
8+
import struct
9+
import matplotlib.pyplot as plt
10+
import os
11+
12+
class DataUtils(object):
13+
"""MNIST数据集加载
14+
输出格式为:numpy.array()
15+
16+
使用方法如下
17+
from data_util import DataUtils
18+
def main():
19+
trainfile_X = '../dataset/MNIST/train-images.idx3-ubyte'
20+
trainfile_y = '../dataset/MNIST/train-labels.idx1-ubyte'
21+
testfile_X = '../dataset/MNIST/t10k-images.idx3-ubyte'
22+
testfile_y = '../dataset/MNIST/t10k-labels.idx1-ubyte'
23+
24+
train_X = DataUtils(filename=trainfile_X).getImage()
25+
train_y = DataUtils(filename=trainfile_y).getLabel()
26+
test_X = DataUtils(testfile_X).getImage()
27+
test_y = DataUtils(testfile_y).getLabel()
28+
29+
#以下内容是将图像保存到本地文件中
30+
#path_trainset = "../dataset/MNIST/imgs_train"
31+
#path_testset = "../dataset/MNIST/imgs_test"
32+
#if not os.path.exists(path_trainset):
33+
# os.mkdir(path_trainset)
34+
#if not os.path.exists(path_testset):
35+
# os.mkdir(path_testset)
36+
#DataUtils(outpath=path_trainset).outImg(train_X, train_y)
37+
#DataUtils(outpath=path_testset).outImg(test_X, test_y)
38+
39+
return train_X, train_y, test_X, test_y
40+
"""
41+
42+
43+
def __init__(self, filename=None, outpath=None):
44+
self._filename = filename
45+
self._outpath = outpath
46+
47+
self._tag = '>'
48+
self._twoBytes = 'II'
49+
self._fourBytes = 'IIII'
50+
self._pictureBytes = '784B'
51+
self._labelByte = '1B'
52+
self._twoBytes2 = self._tag + self._twoBytes
53+
self._fourBytes2 = self._tag + self._fourBytes
54+
self._pictureBytes2 = self._tag + self._pictureBytes
55+
self._labelByte2 = self._tag + self._labelByte
56+
57+
def getImage(self):
58+
"""
59+
将MNIST的二进制文件转换成像素特征数据
60+
"""
61+
binfile = open(self._filename, 'rb') #以二进制方式打开
62+
buf = binfile.read()
63+
binfile.close()
64+
index = 0
65+
numMagic,numImgs,numRows,numCols=struct.unpack_from(self._fourBytes2,\
66+
buf,\
67+
index)
68+
index += struct.calcsize(self._fourBytes)
69+
images = []
70+
for i in range(numImgs):
71+
imgVal = struct.unpack_from(self._pictureBytes2, buf, index)
72+
index += struct.calcsize(self._pictureBytes2)
73+
imgVal = list(imgVal)
74+
for j in range(len(imgVal)):
75+
if imgVal[j] > 1:
76+
imgVal[j] = 1
77+
images.append(imgVal)
78+
return np.array(images)
79+
80+
def getLabel(self):
81+
"""
82+
将MNIST中label二进制文件转换成对应的label数字特征
83+
"""
84+
binFile = open(self._filename,'rb')
85+
buf = binFile.read()
86+
binFile.close()
87+
index = 0
88+
magic, numItems= struct.unpack_from(self._twoBytes2, buf,index)
89+
index += struct.calcsize(self._twoBytes2)
90+
labels = [];
91+
for x in range(numItems):
92+
im = struct.unpack_from(self._labelByte2,buf,index)
93+
index += struct.calcsize(self._labelByte2)
94+
labels.append(im[0])
95+
return np.array(labels)
96+
97+
def outImg(self, arrX, arrY):
98+
"""
99+
根据生成的特征和数字标号,输出png的图像
100+
"""
101+
m, n = np.shape(arrX)
102+
#每张图是28*28=784Byte,这里只显示第一张图
103+
for i in range(1):
104+
img = np.array(arrX[i])
105+
img = img.reshape(28,28)
106+
outfile = str(i) + "_" + str(arrY[i]) + ".png"
107+
plt.figure()
108+
plt.imshow(img, cmap = 'binary')#黑白显示
109+
plt.savefig(self._outpath + "/" + outfile)
110+

‎utils/data_util.pyc

4 KB
Binary file not shown.

‎utils/test.py

+42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
Created on Thu Feb 25 16:09:58 2016
4+
Test MNIST dataset
5+
@author: liudiwei
6+
"""
7+
8+
from sklearn import neighbors
9+
from data_util import DataUtils
10+
import datetime
11+
12+
13+
def main():
14+
trainfile_X = '../dataset/MNIST/train-images.idx3-ubyte'
15+
trainfile_y = '../dataset/MNIST/train-labels.idx1-ubyte'
16+
testfile_X = '../dataset/MNIST/t10k-images.idx3-ubyte'
17+
testfile_y = '../dataset/MNIST/t10k-labels.idx1-ubyte'
18+
train_X = DataUtils(filename=trainfile_X).getImage()
19+
train_y = DataUtils(filename=trainfile_y).getLabel()
20+
test_X = DataUtils(testfile_X).getImage()
21+
test_y = DataUtils(testfile_y).getLabel()
22+
23+
return train_X, train_y, test_X, test_y
24+
25+
26+
def testKNN():
27+
train_X, train_y, test_X, test_y = main()
28+
startTime = datetime.datetime.now()
29+
knn = neighbors.KNeighborsClassifier(n_neighbors=3)
30+
knn.fit(train_X, train_y)
31+
match = 0;
32+
for i in xrange(len(test_y)):
33+
predictLabel = knn.predict(test_X[i])[0]
34+
if(predictLabel==test_y[i]):
35+
match += 1
36+
37+
endTime = datetime.datetime.now()
38+
print 'use time: '+str(endTime-startTime)
39+
print 'error rate: '+ str(1-(match*1.0/len(test_y)))
40+
41+
if __name__ == "__main__":
42+
testKNN()

0 commit comments

Comments
 (0)
Please sign in to comment.