fengzhongyetu 2022-07-04 15:14 采纳率: 0%
浏览 64
已结题

convnext网络中添加混淆矩阵代码报错

**计算混淆矩阵出现报错 **


    @tf.function
    def val_step(val_images, val_labels):
        global output1

        output1 = model(val_images, training=False)
        loss = loss_object(val_labels, output1)

        val_loss(loss)
        val_accuracy(val_labels, output1)

    best_val_acc = 0.
    for epoch in range(epochs):
        train_loss.reset_states()  # clear history info
        train_accuracy.reset_states()  # clear history info
        val_loss.reset_states()  # clear history info
        val_accuracy.reset_states()  # clear history info

        # train
        train_bar = tqdm(train_ds, file=sys.stdout)
        for images, labels in train_bar:
            # update learning rate
            optimizer.learning_rate = next(scheduler)

            train_step(images, labels)

            # print train process
            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}, acc:{:.3f}, lr:{:.5f}".format(
                epoch + 1,
                epochs,
                train_loss.result(),
                train_accuracy.result(),
                optimizer.learning_rate.numpy()
            )

        # validate
        val_bar = tqdm(val_ds, file=sys.stdout)#tqdm是进度条
        for images, labels in val_bar:
            val_step(images, labels)
            # 计算混淆矩阵
            cm = confusion_matrix(labels, output1)
            cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
            plot_confusion_matrix(cm)
            print('每一类准确率:{}'.format(np.diagonal(cm)))
            with open('confusion_matrix.csv', 'w+') as f:
                for i in cm:
                    f.write(','.join(list(map(str, i))))
                    f.write('\n')

            # print val process
            val_bar.desc = "valid epoch[{}/{}] loss:{:.3f}, acc:{:.3f}".format(epoch + 1,
                                                                               epochs,
                                                                               val_loss.result(),
                                                                               val_accuracy.result())

Traceback (most recent call last):
  File "D:/ConvNeXT/train.py", line 238, in <module>
    main()
  File "D:/ConvNeXT/train.py", line 143, in main
    cm = confusion_matrix(labels, output1)
  File "C:\Users\Miaomiao\AppData\Roaming\Python\Python36\site-packages\sklearn\utils\validation.py", line 63, in inner_f
    return f(*args, **kwargs)
  File "C:\Users\Miaomiao\AppData\Roaming\Python\Python36\site-packages\sklearn\metrics\_classification.py", line 299, in confusion_matrix
    y_type, y_true, y_pred = _check_targets(y_true, y_pred)
  File "C:\Users\Miaomiao\AppData\Roaming\Python\Python36\site-packages\sklearn\metrics\_classification.py", line 85, in _check_targets
    type_pred = type_of_target(y_pred)
  File "C:\Users\Miaomiao\AppData\Roaming\Python\Python36\site-packages\sklearn\utils\multiclass.py", line 261, in type_of_target
    if is_multilabel(y):
  File "C:\Users\Miaomiao\AppData\Roaming\Python\Python36\site-packages\sklearn\utils\multiclass.py", line 147, in is_multilabel
    y = np.asarray(y)
  File "D:\anconda\envs\tf-2.4.0\lib\site-packages\numpy\core\_asarray.py", line 83, in asarray
    return array(a, dtype, copy=False, order=order)
  File "D:\anconda\envs\tf-2.4.0\lib\site-packages\tensorflow\python\framework\ops.py", line 855, in __array__
    " a NumPy call, which is not supported".format(self.name))
NotImplementedError: Cannot convert a symbolic Tensor (conv_ne_xt/head/BiasAdd:0) to a numpy array. This error may indicate that you're trying to pass a Tensor to a NumPy call, which is not supported

Process finished with exit code 1

尝试过更改numpy版本

可以实现矩阵效果

  • 写回答

0条回答 默认 最新

    报告相同问题?

    问题事件

    • 系统已结题 7月12日
    • 创建了问题 7月4日

    悬赏问题

    • ¥15 问题遇到的现象和发生背景 360导航页面千次ip是20元,但是我们是刷量的 超过100ip就不算量了,假量超过100就不算了 这是什么逻辑呢 有没有人能懂的 1000元红包感谢费
    • ¥30 计算机硬件实验报告寻代
    • ¥15 51单片机写代码,要求是图片上的要求,请大家积极参与,设计一个时钟,时间从12:00开始计时,液晶屏第一行显示time,第二行显示时间
    • ¥15 用C语言判断命题逻辑关系
    • ¥15 原子操作+O3编译,程序挂住
    • ¥15 使用STM32F103C6微控制器设计两个从0到F计数的一位数计数器(数字),同时,有一个控制按钮,可以选择哪个计数器工作:需要两个七段显示器和一个按钮。
    • ¥15 在yolo1到yolo11网络模型中,具体有哪些模型可以用作图像分类?
    • ¥15 AD9910输出波形向上偏移,波谷不为0V
    • ¥15 淘宝自动下单XPath自动点击插件无法点击特定<span>元素,如何解决?
    • ¥15 曙光1620-g30服务器安装硬盘后 看不到硬盘