抑郁症少年 2022-11-21 23:30 采纳率: 31.3%
浏览 129

头歌机器学习KNN算法

问题遇到的现象和发生背景
用代码块功能插入代码,请勿粘贴截图

import numpy as np
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
np.random.seed(20)


def predict(x_test, x_train, y_train, k = 3):
    """
    :param x_test: 测试集特征值
    :param x_train:  训练集特征值
    :param y_train: 训练集目标值
    :param k: k邻居数,请根据具体情况调节k值的大小
    :return: 返回预测结果,类型为numpy数组
    """
    # 请根据注释在下面补充你的代码实现knn算法的过程
    # ********** Begin ********** #
    result = np.array(np.zeros(x_train.shape[0]).astype('int64'))
    # 对ndarray数组进行遍历,每次取数组中的一行。
    for rowData in x_train:
        # 对于测试集中的每一个样本,依次与训练集中的所有样本求欧几里得距离。
        dis = np.sqrt(np.sum((x_test - rowData)**2,axis=1))
        # 返回数组排序后,每个元素在原数组(排序之前的数组)中的索引。并进行截断,只取前k个元素。【取距离最近的k个元素的索引】
        index = dis.argsort()[:k]
        # 返回数组中每个元素出现的次数。元素必须是非负的整数。【使用weights考虑权重,权重为距离的倒数。】
        count = np.bincount(y_train[index],weights=1/dis[index])
        # 返回ndarray数组中,值最大的元素对应的索引。该索引就是我们判定的类别。
        # 最大元素索引,就是出现次数最多的元素。
        result = count.argmax()
    # ********** End ********** #
    return result


if __name__ == '__main__':
    iris = load_iris()
    x_train, x_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.2)
    result = predict(x_test, x_train, y_train)
    score = (np.sum(result == y_test) / len(result))
    if score >= 0.9:
        print("测试通过")
    else:
        print("测试失败")

运行结果及报错内容
score = (np.sum(result == y_test) / len(result))
                                        ^^^^^^^^^^^
TypeError: object of type 'numpy.int64' has no len()

  • 写回答

1条回答 默认 最新

  • Love And Program 人工智能领域新星创作者 2022-11-22 00:03
    关注

    说明你这个result是numpy.int64类型啊,这个又不是列表啥的,不能用len看长度 你可以打印此数,然后强制转换成列表或numpy,即可使用len

    评论

报告相同问题?

问题事件

  • 创建了问题 11月21日

悬赏问题

  • ¥15 任意一个散点图自己下载其js脚本文件并做成独立的案例页面,不要作在线的,要离线状态。
  • ¥15 各位 帮我看看如何写代码,打出来的图形要和如下图呈现的一样,急
  • ¥30 c#打开word开启修订并实时显示批注
  • ¥15 如何解决ldsc的这条报错/index error
  • ¥15 VS2022+WDK驱动开发环境
  • ¥30 关于#java#的问题,请各位专家解答!
  • ¥30 vue+element根据数据循环生成多个table,如何实现最后一列 平均分合并
  • ¥20 pcf8563时钟芯片不启振
  • ¥20 pip2.40更新pip2.43时报错
  • ¥15 换yum源但仍然用不了httpd