m0_61524481 2022-09-30 18:14 采纳率: 66.7%
浏览 104
已结题

提升knn算法的准确率

不使用 sklearn写的knn算法,识别mnist数据集,准确率只有百分之六十, 如何进一步提高识别的准确率
已经尝试过使用不同的k值和对图片进行归一化处理



def load_mnist():
    X_train = np.fromfile('mnist_data/train-images-idx3-ubyte', dtype=np.uint8, offset=16)
    X_train = X_train.reshape(int(6e4), 28, 28)
    X_test = np.fromfile('mnist_data/t10k-images-idx3-ubyte', dtype=np.uint8, offset=16)
    X_test = X_test.reshape(int(1e4), 28, 28)
    y_train = np.fromfile('mnist_data/train-labels-idx1-ubyte', dtype=np.uint8, offset=8)
    y_train = y_train.reshape(int(6e4))
    y_test = np.fromfile('mnist_data/t10k-labels-idx1-ubyte', dtype=np.uint8, offset=8)
    y_test = y_test.reshape(int(1e4))


class Knn(object):

    def __init__(self, k=3):
        self.k = k

    def fit(self, X, y):
        self.X = X
        self.y = y

    def predict(self, X):
        dataset = self.X
        labels = self.y
        k = self.k
        predict_labels = []
        X = np.reshape(X, (X.shape[0], -1))
        dataset = np.reshape(dataset, (dataset.shape[0], -1))

        scalar = MaxAbsScaler()
        scalar.fit(dataset)
        dataset = scalar.transform(dataset)
        X = scalar.transform(X)

        print(dataset[0])

        dataset_size = dataset.shape[0]
        for i in tqdm(range(X.shape[0])):
            diff_mat = np.tile(X[i], (dataset_size, 1)) - dataset
            sq_diff_mat = diff_mat ** 2
            sq_distances = sq_diff_mat.sum(axis=1)
            distances = sq_distances ** 0.5
            sorted_dist_indicies = distances.argsort()
            class_count = {}
            for j in range(k):
                vote_label = labels[sorted_dist_indicies[i]]
                class_count[vote_label] = class_count.get(vote_label, 0) + 1
            sorted_class_count = sorted(class_count.items(), key=operator.itemgetter(1), reverse=True)
            predict_labels.append(sorted_class_count[0][0])
        predict_labels = np.array(predict_labels)
        return predict_labels


  • 写回答

5条回答 默认 最新

  • youcans_ 人工智能领域优质创作者 2022-10-03 23:15
    关注

    首先,手写识别的关键是特征描述,如果这一步没有做好,用什么方法,怎么调参,也不会有好的结果。
    将图像像素值直接作为输入向量,原则上是不适当的。
    推荐实现方法如下:
    (1)首先,样本均匀,标准化,归一化,这些必要的准备工作就不说了,
    (2)特征提取,或者说特征向量构造,将字符图像转换为特征向量作为模型的输入,
    (3)KNN,可以选择不同的K值,2~5之间有些影响,5 以上没必要。
    关于特征构造,推荐两种方法:
    1,HOG,方向梯度直方图
    2,小波特征,例如Haar
    我查了一下以前的程序,检验集识别准确率大约 80~90%。
    给出一段 HOG 特征描述符的构造例程,这类似于SIFT的特征描述符,效果不错。

    import cv2 as cv
    
        # (2) 构造 HOG 描述符
        # HOGDescriptor
        winSize = (20, 20)
        blockSize = (10, 10)
        blockStride = (5, 5)
        cellSize = (5, 5)
        nbins = 8
        derivAperture = 1
        winSigma = -1.
        histogramNormType = 0
        L2HysThreshold = 0.2
        gammaCorrection = 1
        nlevels = 16
        signedGradients = True
        hog = cv.HOGDescriptor(winSize, blockSize, blockStride, cellSize, nbins,
                               derivAperture, winSigma, histogramNormType,
                               L2HysThreshold, gammaCorrection, nlevels)
        p = (1+(20-10)//5)*(1+(20-10)//5)*(10//5)*(10//5)*8  # 特征描述符长度,288
    

    参考结果:

    Recognition of handwritten digits by KNN-HOG
    k=2, correct=938, accuracy=93.80%
    k=3, correct=939, accuracy=93.90%
    k=4, correct=940, accuracy=94.00%
    k=5, correct=938, accuracy=93.80%
    
    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论 编辑记录
查看更多回答(4条)

报告相同问题?

问题事件

  • 系统已结题 10月13日
  • 已采纳回答 10月5日
  • 创建了问题 9月30日

悬赏问题

  • ¥15 前置放大电路与功率放大电路相连放大倍数出现问题
  • ¥30 关于<main>标签页面跳转的问题
  • ¥80 部署运行web自动化项目
  • ¥15 腾讯云如何建立同一个项目中物模型之间的联系
  • ¥30 VMware 云桌面水印如何添加
  • ¥15 用ns3仿真出5G核心网网元
  • ¥15 matlab答疑 关于海上风电的爬坡事件检测
  • ¥88 python部署量化回测异常问题
  • ¥30 酬劳2w元求合作写文章
  • ¥15 在现有系统基础上增加功能