St_Louis 2019-03-18 17:45 采纳率: 0%
浏览 462

KNN问题:两段代码几乎相同,对同一个文件进行测试,一个错误率%1,一个80% ?

先上%1的

from numpy import *
import operator
from os import listdir

#---------------------------------------------
#分类模块
#@params
#   inX:输入向量、手写体识别的测试向量
#    dataSet:训练集样本、手写体识别的训练集向量
#    labels:训练集对应的标签向量
#    k:最近邻居数目、本实验为3
#---------------------------------------------
def classify0(inX, dataSet, labels, k):
    dataSetSize = dataSet.shape[0]     #手写体样本集容量
    #(以下三行)距离计算
    diffMat = tile(inX, (dataSetSize,1)) - dataSet
    sqDiffMat = diffMat**2    
    sqDistances = sqDiffMat.sum(axis=1)
    distances = sqDistances**0.5   #欧氏距离开平方
    sortedDistIndicies = distances.argsort()  #距离排序的索引排序
    classCount = {}    
    #(以下两行)选择距离最小的k个点
    for i in range(k):
        voteIlabel = labels[sortedDistIndicies[i]]
        classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1
    sortedClassCount = sorted(classCount.items(),
    #排序
    key = operator.itemgetter(1), reverse = True)
    return sortedClassCount[0][0]

#手写识别的测试代码
def handwritingClassTest():
    hwLabels = []
    trainingFileList = listdir(path='trainingDigits')  #获取目录内容
    m = len(trainingFileList)
    trainingMat = zeros((m,1024))
    for i in range(m):
        #一下三行,从文件名解析分类数字
        fileNameStr = trainingFileList[i]
        fileStr = fileNameStr.split('.')[0]
        classNumStr = int(fileStr.split('_')[0])

        hwLabels.append(classNumStr)
        trainingMat[i,:] = img2vector('trainingDigits/%s'%fileNameStr)
    testFileList = listdir(path='testDigits')

    errorCount = 0.0  #错误个数计数器
    mTest = len(testFileList)

    #从测试数据中提取数据
    for i in range(mTest):
        fileNameStr = testFileList[i]
        fileStr = fileNameStr.split('.')[0]

        classNumStr = int(fileStr.split('_')[0])
        vectorUnderTest = img2vector('testDigits/%s'% fileNameStr)
        classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)

        print("the classifier came back with:%d,the real answer is:%d"%(classifierResult,classNumStr))
        if(classifierResult != classNumStr):
            errorCount += 1.0
    print("\nthe total number of errors is:%d"%errorCount)
    print("\nthe total error rate is: %f"%(errorCount/float(mTest)))

#识别手写字体模块-图像转向量32x32 to 1x1024
def img2vector(filename):
    returnVect = zeros((1,1024))
    fr = open(filename)
    for i in range(32):
        lineStr = fr.readline()
        for j in range(32):
            returnVect[0,32*i+j] = int(lineStr[j])
    return returnVect



运行结果:
the total number of errors is:10

the total error rate is: 0.010571

然后上80%的

'''
Created on Sep 16, 2010
kNN: k Nearest Neighbors

Input:      inX: vector to compare to existing dataset (1xN)
            dataSet: size m data set of known vectors (NxM)
            labels: data set labels (1xM vector)
            k: number of neighbors to use for comparison (should be an odd number)有多少属性

Output:     the most popular class label

@author: pbharrin
'''

from numpy import *
import operator
from os import listdir

def creatDataSet():
    group = array([[1.0,1.1],
                   [1.0,1.0],
                   [0  ,0  ],
                   [0  ,0.1]])
    labels = ['A','A','B','B']
    return group,labels

def classify0(inX, dataSet, labels, k):      #(坐标,测试向量组,标签/属性,迭代次数)
    dataSetSize = dataSet.shape[0]       #dataSetSize 返回值为 4  因为group有四行
    diffMat = tile(inX, (dataSetSize,1)) - dataSet   #x2-x1  y2-y1
    sqDiffMat = diffMat**2                           #x**2 y**2
    sqDistances = sqDiffMat.sum(axis = 1)            #x**2 + y**2
    distances = sqDistances**0.5                     #根号下x**2+y**2
    sortedDistIndicies = distances.argsort()          #距离排大小
    classCount = {}                                    
    for i in range(k):
        voteIlabel = labels[sortedDistIndicies[i]]
        classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1
    sortedClassCount = sorted(classCount.items(),key = operator.itemgetter(1), reverse = True)
    return sortedClassCount[0][0]

def file2matrix(filename):
    fr = open(filename)
    numberOfLines = len(fr.readlines())         #get the number of lines in the file
    returnMat = zeros((numberOfLines,3))        #prepare matrix to return
    classLabelVector = []                       #prepare labels return   
    fr = open(filename)
    index = 0
    for line in fr.readlines():
        line = line.strip()
        listFromLine = line.split('\t')
        returnMat[index,:] = listFromLine[0:3]
        classLabelVector.append(int(listFromLine[-1]))
        index += 1
    return returnMat,classLabelVector

def autoNorm(dataSet):
    minVals = dataSet.min(0)
    maxVals = dataSet.max(0)
    ranges = maxVals - minVals
    normDataSet = zeros(shape(dataSet))
    m = dataSet.shape[0]
    normDataSet = dataSet - tile(minVals,(m,1))
    normDataSet = normDataSet/tile(ranges,(m,1))  #元素分割
    return normDataSet, ranges, minVals

def datingClassTest():
    hoRadio = 0.10
    datingDataMat, datingLabels = file2matrix('datingTestSet2.txt')
    normMat, ranges, minVals = autoNorm(datingDataMat)
    m = normMat.shape[0]
    numTestVecs = int(m*hoRadio)
    errorCount = 0.0
    for i in range(numTestVecs):
        classifierResult = classify0(normMat[i,:],normMat[numTestVecs:m,:],\
            datingLabels[numTestVecs:m],3)
        print("\n测试结果为 %d,正确的结果应该是 %d"%(classifierResult,datingLabels[i]))
        if(classifierResult != datingLabels[i]):
            errorCount += 1.0
    print("\n该分类器错误率为%f"%(errorCount/float(numTestVecs)))

# def datingClassTest():
#     hoRatio = 0.10      #hold out 10%
#     datingDataMat,datingLabels = file2matrix('datingTestSet2.txt')       #load data setfrom file
#     normMat, ranges, minVals = autoNorm(datingDataMat)
#     m = normMat.shape[0]
#     numTestVecs = int(m*hoRatio)
#     errorCount = 0.0
#     for i in range(numTestVecs):
#         classifierResult = classify0(normMat[i,:],normMat[numTestVecs:m,:],datingLabels[numTestVecs:m],3)
#         print "the classifier came back with: %d, the real answer is: %d" % (classifierResult, datingLabels[i])
#         if (classifierResult != datingLabels[i]): errorCount += 1.0
#     print "the total error rate is: %f" % (errorCount/float(numTestVecs))
#     print errorCount

def classifyPerson() :
    resultList = ['不喜欢','有点喜欢','特别喜欢']
    percentTats = float(input("每周玩多久游戏? :"))
    ffMiles = float(input("每年飞行多少英里? :"))
    iceCream = float(input("每周吃多少冰淇淋? :"))
    datingDataMat, datingLabels = file2matrix('datingTestSet2.txt')
    normMat, ranges, minVals = autoNorm(datingDataMat)
    inArr = array([ffMiles,percentTats,iceCream])
    result = classify0((inArr-minVals)/ranges,normMat,datingLabels,3)
    print("你将",resultList[result-1],"这个人")

def img2vector(filename):
    returnVect = zeros((1,1024))
    fr = open(filename)
    for i in range(32):
        lineStr = fr.readline()
        for j in range(32):
            returnVect[0,32*i+j] = int(lineStr[j])
        return returnVect

def handwritingClassTest():
    hwLabels = []
    trainingFileList = listdir(path='trainingDigits')
    m = len(trainingFileList)
    trainingMat = zeros((m,1024))   
    for i in range(m):
        fileNameStr = trainingFileList[i]
        fileStr = fileNameStr.split('.')[0]
        classNumStr = int(fileStr.split('_')[0])
        hwLabels.append(classNumStr)
        trainingMat[i,:] = img2vector('trainingDigits/%s'%fileNameStr)
    testFileList = listdir(path='testDigits')
    errorCount = 0.0
    mTest = len(testFileList)
    for i in range(mTest):
        fileNameStr = testFileList[i]
        fileStr = fileNameStr.split('.')[0]
        classNumStr = int(fileStr.split('_')[0])
        vectorUnderTest = img2vector('testDigits/%s' %fileNameStr)
        classifyResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)
        print("测试结果为:%d,正确的结果为:%d" % (classifyResult,classNumStr))
        if(classifyResult!=classNumStr):
            errorCount += 1.0
    print("\n错误结果总数为:%d" % errorCount)
    print("\n错误率为:%f" % (errorCount/float(mTest)))


运行结果:
错误结果总数为:777

错误率为:0.821353

  • 写回答

1条回答 默认 最新

  • huibinwei 2019-03-18 19:32
    关注

    应该过拟合了吧,最好标一下哪里改过,试一下其他算法,或者组合算法

    评论

报告相同问题?

悬赏问题

  • ¥15 (关键词-电路设计)
  • ¥15 如何解决MIPS计算是否溢出
  • ¥15 vue中我代理了iframe,iframe却走的是路由,没有显示该显示的网站,这个该如何处理
  • ¥15 操作系统相关算法中while();的含义
  • ¥15 CNVcaller安装后无法找到文件
  • ¥15 visual studio2022中文乱码无法解决
  • ¥15 关于华为5g模块mh5000-31接线问题
  • ¥15 keil L6007U报错
  • ¥15 webapi 发布到iis后无法访问
  • ¥15 初学者如何快速上手学习stm32?