weixin_41514309 2021-05-06 09:16 采纳率: 0%
浏览 38

pytorch 中写的深度学习模型在测试集验证时出错?

如果测试集的dataloader中的参数batch_size设置不是整个测试集的个数时模型在测试集上表现很差相当与胡乱预测,当batch_size设置成整个测试集个数时表现正常。代码检查没有什么问题,猜测是pytorch出现了bug

  • 写回答

1条回答 默认 最新

  • GitCode 官方 企业官方账号 2021-05-06 09:48
    关注

    pytorch在测试集上评估模型准确率的时候要注意一个点,需要把模型从train状态转换成eval状态,因为pytorch的梯度是累积的,所以才会出现你说的batch_size如果是整个数据集大小的时候,表现正常,反之则不正常,这个不是bug

    下面是我自己之前封装的一个准确率评估代码片段:

    # 计算Acc
    def calc_acc(net, dataloader, device):
        net.eval()
        total = 0.0
        correct = 0.0
        with torch.no_grad():
            for data in dataloader:
                images, labels = data[0].to(device), data[1].to(device)
                outputs = net(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        # print('-----', correct)
        # print('-----', total)
        print('Accuracy of the network on the dataset: %.4f %%' % (
                100.0 * correct / total))
        return 100.0 * correct / total
    评论

报告相同问题?

悬赏问题

  • ¥15 github符合条件20分钟秒到账,github空投 提供github账号可兑换💰感兴趣的可以找我交流一下
  • ¥50 永磁型步进电机PID算法
  • ¥15 sqlite 附加(attach database)加密数据库时,返回26是什么原因呢?
  • ¥88 找成都本地经验丰富懂小程序开发的技术大咖
  • ¥15 如何处理复杂数据表格的除法运算
  • ¥15 如何用stc8h1k08的片子做485数据透传的功能?(关键词-串口)
  • ¥15 有兄弟姐妹会用word插图功能制作类似citespace的图片吗?
  • ¥200 uniapp长期运行卡死问题解决
  • ¥15 latex怎么处理论文引理引用参考文献
  • ¥15 请教:如何用postman调用本地虚拟机区块链接上的合约?