在python上使用knn算法识别mnist。正确率只有27%。求查错,自己看了好几天都找不出来哪出问题了
# -*- coding: UTF-8 -*-
from __future__ import division
import os
import struct
import numpy as np
import data
import heapq
'''knn 求距离公式'''
def euc(vec1, vec2):
npvec1, npvec2 = np.array(vec1), np.array(vec2)
return ((npvec1-npvec2)**2).sum()
'''data.image_data是mnist数据集,b是将这个数据集分成60000份'''
a=np.array([data.image_data])
b=a.reshape((60000,784))
'''data.image_test_data是mnist测试集,d是将这个数据集分成10000份'''
c=np.array([data.image_test_data])
d=c.reshape((10000,784))
'''i是测试次数,y是正确的次数'''
i=0
y=0
while i < 10000:
list1=[]
list2=[]
'''计算距离,并放入list1'''
for x in b:
list1.append(euc(d[i],x))
'''从list1里选3个最小的'''
result = map(list1.index, heapq.nsmallest(11, list1))
result.sort()
for x in result:
x1=data.label_data[x]
list2.append(x1)
if data.label_test_data[i]==max(set(list2), key=list2.count):
'''用百分比显示出正确率'''
y=y+1
print("correct",i+1,"%.4f%%" % (y/(i+1)*100))
else:
print("not correct",i+1)
i=i+1