利威levi 2023-05-06 20:36 采纳率: 0%
浏览 5

训练的时候LOSS始终不会下降

最近刚接触深度学习,想要自己做一下模型训练,使用NIN网络训练一个手写字符的数据集,写完后不知道为什么LOSS一直掉不下去,可以说是根本没有下降,NIN网络以及batchsize和lr都是按照李沐花书参数来的,死活找不到问题所在,恳请各位指条明路

训练代码如下


import torchvision
import torch.nn as nn
from torch.utils.data import DataLoader
from model1 import *

from torch.nn import CrossEntropyLoss
from torch.utils.tensorboard import SummaryWriter
from torch import argmax
transform = torchvision.transforms.Compose([torchvision.transforms.Resize((224,224)),
                                            torchvision.transforms.ToTensor()]
                                           )

train_dataset = torchvision.datasets.MNIST("MNI_datasets",train=True,transform=transform,download=True)
test_dataset = torchvision.datasets.MNIST("MNI_datasets",train=False,transform=transform,download=True)

train_dataload = DataLoader(train_dataset,batch_size=32)
test_dataload = DataLoader(test_dataset,batch_size=32)

test_len = len(test_dataset)
train_len = len(train_dataset)

writer = SummaryWriter("logs")
nin = Nin()
nin = nin.cuda()
loss_fn = CrossEntropyLoss()
loss_fn = loss_fn.cuda()
lr = 0.1
opti = torch.optim.SGD(nin.parameters(),lr = lr)

train_step = 0
test_step = 0
epoch = 10
#开始训练
nin.train()
for i in range(epoch):
    print("------第{}轮训练开始------".format(i+1))
    train_total_loss = 0.0
    test_total_loss = 0.0
    total_accuracy = 0.0
    for data in train_dataload:
        imgs,targets = data
        imgs = imgs.cuda()
        targets = targets.cuda()
        output = nin(imgs)
        loss = loss_fn(output,targets)
        opti.zero_grad()
        loss.backward()
        opti.step()
        train_step+=1
        train_total_loss = train_total_loss+loss
        
        if train_step%100==0:
            #writer.add_scalar("test_line",loss.item(),train_step)
            print("训练次数为{},训练损失为{}".format(train_step,loss))
    #测试集
    with torch.no_grad():
        for data in test_dataload:
            imgs, targets = data
            imgs = imgs.cuda()
            targets = targets.cuda()
            output = nin(imgs)
            loss = loss_fn(output,targets)
            test_total_loss = loss.item()+test_total_loss
            accuracy = (output.argmax(1)==targets).sum()
            
            total_accuracy = accuracy+total_accuracy
            test_step+=1
            if test_step%100==0:
                #writer.add_scalar("test_line",loss.item(),test_step)
                print("测试次数为{},测试损失为{}".format(test_step,loss))
    print("本轮训练总损失为{}".format(train_total_loss))
    print("本轮测试总损失为{}".format(test_total_loss))
    print("本轮测试准确率为{}".format(total_accuracy/test_len))

    #torch.save(nin,"nin_{}.pth".format(i+1))

#writer.close()

  

网络模型如下

import torch.nn as nn
import torch
from torch.nn import Conv2d,ReLU,MaxPool2d,Dropout,AdaptiveMaxPool2d,Flatten

def Nin_block(in_channels,out_channels,kernel_size,stride,padding):
    return nn.Sequential(
        Conv2d(in_channels,out_channels,kernel_size,padding=padding,stride=stride),
        ReLU(),
        Conv2d(out_channels,out_channels,kernel_size=1),
        ReLU(),
        Conv2d(out_channels,out_channels,kernel_size=1),
        ReLU()
    )

class Nin(nn.Module):
    def __init__(self):
        super(Nin,self).__init__()
        self.model = nn.Sequential(
            Nin_block(1,96,11,4,0),
            MaxPool2d(3,2),
            Nin_block(96,256,5,1,2),
            MaxPool2d(3,2),
            Nin_block(256,384,3,1,1),
            MaxPool2d(3,2),
            Dropout(0.5),
            Nin_block(384,10,3,1,1),
            AdaptiveMaxPool2d((1,1)),
            Flatten()
        )

    def forward(self,input):
        output = self.model(input)
        return output

if __name__ =='__main__':
    nin = Nin()
   
  • 写回答

1条回答 默认 最新

  • Zouia Gail(修行中) 2023-05-06 22:16
    关注

    可以适当增加层数,看着你这里只有一种,可以尝试添加全链接层

    评论

报告相同问题?

问题事件

  • 创建了问题 5月6日

悬赏问题

  • ¥15 c语言怎么用printf(“\b \b”)与getch()实现黑框里写入与删除?
  • ¥20 怎么用dlib库的算法识别小麦病虫害
  • ¥15 华为ensp模拟器中S5700交换机在配置过程中老是反复重启
  • ¥15 java写代码遇到问题,求帮助
  • ¥15 uniapp uview http 如何实现统一的请求异常信息提示?
  • ¥15 有了解d3和topogram.js库的吗?有偿请教
  • ¥100 任意维数的K均值聚类
  • ¥15 stamps做sbas-insar,时序沉降图怎么画
  • ¥15 买了个传感器,根据商家发的代码和步骤使用但是代码报错了不会改,有没有人可以看看
  • ¥15 关于#Java#的问题,如何解决?