yunfanxiang
2019-09-23 10:11
采纳率: 0%
浏览 1.7k

pytorch测试集看每类准确率遇到了一点bug

报错如下:
class_correct[label] +=( c[i].item())
IndexError: invalid index of a 0-dim tensor. Use tensor.item() to convert a 0-dim tensor to a Python number

代码原本就用的.item(),不知道为何依然报这个错。。试着改了几次都不行

源代码就是网上常用的

    N_CLASSES=6;
    BATCH_SIZE=16

    classes = ('Sun', 'Rain', 'SmallFog', 'MidFog','BigFog','Snow')
    class_correct = list(0. for i in range(N_CLASSES))  
    class_total = list(0. for i in range(N_CLASSES))        
    with torch.no_grad():
        for data in val_loader:
            images, labels = data
            images, labels = images.to(device1), labels.to(device1)
            outputs = Incep(images)
            _, predicted = torch.max(outputs, 1)
            c = (predicted == labels).squeeze()
            print(c.size())
            for i in range(BATCH_SIZE):
                label = labels[i]
                class_correct[label] += c[i].item()
                class_total[label] += 1
        for i in range(N_CLASSES):
            print('Accuracy of %5s : %2d %%' % (
            classes[i], 100 * class_correct[i] / class_total[i]))

3条回答 默认 最新

相关推荐 更多相似问题