想给键盘放天假 2020-06-29 00:25 采纳率: 100%
浏览 592
已采纳

pytorch的MNIST代码中loss输出的疑问

MNIST的训练集一共**60000**个,我设置**mini-batch=128**,分成**469批**

train_dataset = mnist.MNIST('./data', train=True, transform=transform, download=False)

test_dataset = mnist.MNIST('./data', train=False, transform=transform)
print(train_dataset)
train_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False)[/code]

我看教程给的代码,里面loss的计算如下

        out = model(img)
        loss = criterion(out, label)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, pred = out.max(1)
        num_correct = (pred == label).sum().item()
        acc = num_correct / img.shape[0]
        train_acc += acc

最后输出

    print('epoch:{},Train Loss:{:.4f}, Train Acc: {:.4f},Test Loss: {:.4f}, Test Acc: {:.4f}'.format(epoch, train_loss / len(train_loader), train_acc / len(train_loader),eval_loss / len(test_loader), eval_acc / len(test_loader))

其中len(train_loader)=**469**,损失函数使用的是交叉熵函数,每个循环的loss都是当前批次的loss总和,没有除以N。我想请问为什么最后输出的loss不是除以6000/128,而是len(train_loader),毕竟6000不能被128整除,所以最后一个批次的数量是达不到128的,那么这一批的loss肯定也只是6000-128*468=**96**个数据的和,这样再除以469的话难免会有一些误差。acc也是同理

我特地找了下pytorch官方的案例,里面是把所有数据的loss加起来,最后除以总数,如下

        test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
        pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
        correct += pred.eq(target.view_as(pred)).sum().item()

test_loss /= len(test_loader.dataset)

print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

我这里的代码是来自《Python深度学习:基于Pytorch》,是不是在最后的计算方法上不够严谨呢,还是说这点误差可以忽略不计

大家都是怎么计算loss和acc的呢

  • 写回答

1条回答 默认 最新

  • 关注
    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论

报告相同问题?

悬赏问题

  • ¥15 安装svn网络有问题怎么办
  • ¥15 Python爬取指定微博话题下的内容,保存为txt
  • ¥15 vue2登录调用后端接口如何实现
  • ¥65 永磁型步进电机PID算法
  • ¥15 sqlite 附加(attach database)加密数据库时,返回26是什么原因呢?
  • ¥88 找成都本地经验丰富懂小程序开发的技术大咖
  • ¥15 如何处理复杂数据表格的除法运算
  • ¥15 如何用stc8h1k08的片子做485数据透传的功能?(关键词-串口)
  • ¥15 有兄弟姐妹会用word插图功能制作类似citespace的图片吗?
  • ¥15 latex怎么处理论文引理引用参考文献