Skip to content

Commit d69698c

Browse files
committed
KNN实现手写识别
1 parent 76748c1 commit d69698c

File tree

1 file changed

+38
-0
lines changed

1 file changed

+38
-0
lines changed

MachineLearning/B/use_neighbors2.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import numpy as np #导入numpy工具包
2+
from os import listdir #使用listdir模块,用于访问本地文件
3+
from sklearn import neighbors
4+
5+
def img2vector(fileName):
6+
retMat = np.zeros([1024],int) #定义返回的矩阵,大小为1*1024
7+
fr = open(fileName) #打开包含32*32大小的数字文件
8+
lines = fr.readlines() #读取文件的所有行
9+
for i in range(32): #遍历文件所有行
10+
for j in range(32): #并将01数字存放在retMat中
11+
retMat[i*32+j] = lines[i][j]
12+
return retMat
13+
14+
def readDataSet(path):
15+
fileList = listdir(path) #获取文件夹下的所有文件
16+
numFiles = len(fileList) #统计需要读取的文件的数目
17+
dataSet = np.zeros([numFiles,1024],int) #用于存放所有的数字文件
18+
hwLabels = np.zeros([numFiles])#用于存放对应的标签(与神经网络的不同)
19+
for i in range(numFiles): #遍历所有的文件
20+
filePath = fileList[i] #获取文件名称/路径
21+
digit = int(filePath.split('_')[0]) #通过文件名获取标签
22+
hwLabels[i] = digit #直接存放数字,并非one-hot向量
23+
dataSet[i] = img2vector(path +'/'+filePath) #读取文件内容
24+
return dataSet,hwLabels
25+
26+
#read dataSet
27+
train_dataSet, train_hwLabels = readDataSet('trainingDigits')
28+
knn = neighbors.KNeighborsClassifier(algorithm='kd_tree', n_neighbors=3)
29+
knn.fit(train_dataSet, train_hwLabels)
30+
31+
#read testing dataSet
32+
dataSet,hwLabels = readDataSet('testDigits')
33+
34+
res = knn.predict(dataSet) #对测试集进行预测
35+
error_num = np.sum(res != hwLabels) #统计分类错误的数目
36+
num = len(dataSet) #测试集的数目
37+
print("Total num:",num," Wrong num:", \
38+
error_num," WrongRate:",error_num / float(num))

0 commit comments

Comments
 (0)