1
+ import numpy as np #导入numpy工具包
2
+ from os import listdir #使用listdir模块,用于访问本地文件
3
+ from sklearn import neighbors
4
+
5
+ def img2vector (fileName ):
6
+ retMat = np .zeros ([1024 ],int ) #定义返回的矩阵,大小为1*1024
7
+ fr = open (fileName ) #打开包含32*32大小的数字文件
8
+ lines = fr .readlines () #读取文件的所有行
9
+ for i in range (32 ): #遍历文件所有行
10
+ for j in range (32 ): #并将01数字存放在retMat中
11
+ retMat [i * 32 + j ] = lines [i ][j ]
12
+ return retMat
13
+
14
+ def readDataSet (path ):
15
+ fileList = listdir (path ) #获取文件夹下的所有文件
16
+ numFiles = len (fileList ) #统计需要读取的文件的数目
17
+ dataSet = np .zeros ([numFiles ,1024 ],int ) #用于存放所有的数字文件
18
+ hwLabels = np .zeros ([numFiles ])#用于存放对应的标签(与神经网络的不同)
19
+ for i in range (numFiles ): #遍历所有的文件
20
+ filePath = fileList [i ] #获取文件名称/路径
21
+ digit = int (filePath .split ('_' )[0 ]) #通过文件名获取标签
22
+ hwLabels [i ] = digit #直接存放数字,并非one-hot向量
23
+ dataSet [i ] = img2vector (path + '/' + filePath ) #读取文件内容
24
+ return dataSet ,hwLabels
25
+
26
+ #read dataSet
27
+ train_dataSet , train_hwLabels = readDataSet ('trainingDigits' )
28
+ knn = neighbors .KNeighborsClassifier (algorithm = 'kd_tree' , n_neighbors = 3 )
29
+ knn .fit (train_dataSet , train_hwLabels )
30
+
31
+ #read testing dataSet
32
+ dataSet ,hwLabels = readDataSet ('testDigits' )
33
+
34
+ res = knn .predict (dataSet ) #对测试集进行预测
35
+ error_num = np .sum (res != hwLabels ) #统计分类错误的数目
36
+ num = len (dataSet ) #测试集的数目
37
+ print ("Total num:" ,num ," Wrong num:" , \
38
+ error_num ," WrongRate:" ,error_num / float (num ))
0 commit comments