yunfanxiang 2019-09-23 10:11 采纳率: 0%
浏览 1732

pytorch测试集看每类准确率遇到了一点bug

报错如下:
class_correct[label] +=( c[i].item())
IndexError: invalid index of a 0-dim tensor. Use tensor.item() to convert a 0-dim tensor to a Python number

代码原本就用的.item(),不知道为何依然报这个错。。试着改了几次都不行

源代码就是网上常用的

    N_CLASSES=6;
    BATCH_SIZE=16

    classes = ('Sun', 'Rain', 'SmallFog', 'MidFog','BigFog','Snow')
    class_correct = list(0. for i in range(N_CLASSES))  
    class_total = list(0. for i in range(N_CLASSES))        
    with torch.no_grad():
        for data in val_loader:
            images, labels = data
            images, labels = images.to(device1), labels.to(device1)
            outputs = Incep(images)
            _, predicted = torch.max(outputs, 1)
            c = (predicted == labels).squeeze()
            print(c.size())
            for i in range(BATCH_SIZE):
                label = labels[i]
                class_correct[label] += c[i].item()
                class_total[label] += 1
        for i in range(N_CLASSES):
            print('Accuracy of %5s : %2d %%' % (
            classes[i], 100 * class_correct[i] / class_total[i]))

  • 写回答

3条回答 默认 最新

  • threenewbee 2019-09-23 10:40
    关注

    class_correct[label] +=( c.item())

    评论

报告相同问题?

悬赏问题

  • ¥15 如何在scanpy上做差异基因和通路富集?
  • ¥20 关于#硬件工程#的问题,请各位专家解答!
  • ¥15 关于#matlab#的问题:期望的系统闭环传递函数为G(s)=wn^2/s^2+2¢wn+wn^2阻尼系数¢=0.707,使系统具有较小的超调量
  • ¥15 FLUENT如何实现在堆积颗粒的上表面加载高斯热源
  • ¥30 截图中的mathematics程序转换成matlab
  • ¥15 动力学代码报错,维度不匹配
  • ¥15 Power query添加列问题
  • ¥50 Kubernetes&Fission&Eleasticsearch
  • ¥15 報錯:Person is not mapped,如何解決?
  • ¥15 c++头文件不能识别CDialog