最近刚接触深度学习,想要自己做一下模型训练,使用NIN网络训练一个手写字符的数据集,写完后不知道为什么LOSS一直掉不下去,可以说是根本没有下降,NIN网络以及batchsize和lr都是按照李沐花书参数来的,死活找不到问题所在,恳请各位指条明路
训练代码如下
import torchvision
import torch.nn as nn
from torch.utils.data import DataLoader
from model1 import *
from torch.nn import CrossEntropyLoss
from torch.utils.tensorboard import SummaryWriter
from torch import argmax
transform = torchvision.transforms.Compose([torchvision.transforms.Resize((224,224)),
torchvision.transforms.ToTensor()]
)
train_dataset = torchvision.datasets.MNIST("MNI_datasets",train=True,transform=transform,download=True)
test_dataset = torchvision.datasets.MNIST("MNI_datasets",train=False,transform=transform,download=True)
train_dataload = DataLoader(train_dataset,batch_size=32)
test_dataload = DataLoader(test_dataset,batch_size=32)
test_len = len(test_dataset)
train_len = len(train_dataset)
writer = SummaryWriter("logs")
nin = Nin()
nin = nin.cuda()
loss_fn = CrossEntropyLoss()
loss_fn = loss_fn.cuda()
lr = 0.1
opti = torch.optim.SGD(nin.parameters(),lr = lr)
train_step = 0
test_step = 0
epoch = 10
#开始训练
nin.train()
for i in range(epoch):
print("------第{}轮训练开始------".format(i+1))
train_total_loss = 0.0
test_total_loss = 0.0
total_accuracy = 0.0
for data in train_dataload:
imgs,targets = data
imgs = imgs.cuda()
targets = targets.cuda()
output = nin(imgs)
loss = loss_fn(output,targets)
opti.zero_grad()
loss.backward()
opti.step()
train_step+=1
train_total_loss = train_total_loss+loss
if train_step%100==0:
#writer.add_scalar("test_line",loss.item(),train_step)
print("训练次数为{},训练损失为{}".format(train_step,loss))
#测试集
with torch.no_grad():
for data in test_dataload:
imgs, targets = data
imgs = imgs.cuda()
targets = targets.cuda()
output = nin(imgs)
loss = loss_fn(output,targets)
test_total_loss = loss.item()+test_total_loss
accuracy = (output.argmax(1)==targets).sum()
total_accuracy = accuracy+total_accuracy
test_step+=1
if test_step%100==0:
#writer.add_scalar("test_line",loss.item(),test_step)
print("测试次数为{},测试损失为{}".format(test_step,loss))
print("本轮训练总损失为{}".format(train_total_loss))
print("本轮测试总损失为{}".format(test_total_loss))
print("本轮测试准确率为{}".format(total_accuracy/test_len))
#torch.save(nin,"nin_{}.pth".format(i+1))
#writer.close()
网络模型如下
import torch.nn as nn
import torch
from torch.nn import Conv2d,ReLU,MaxPool2d,Dropout,AdaptiveMaxPool2d,Flatten
def Nin_block(in_channels,out_channels,kernel_size,stride,padding):
return nn.Sequential(
Conv2d(in_channels,out_channels,kernel_size,padding=padding,stride=stride),
ReLU(),
Conv2d(out_channels,out_channels,kernel_size=1),
ReLU(),
Conv2d(out_channels,out_channels,kernel_size=1),
ReLU()
)
class Nin(nn.Module):
def __init__(self):
super(Nin,self).__init__()
self.model = nn.Sequential(
Nin_block(1,96,11,4,0),
MaxPool2d(3,2),
Nin_block(96,256,5,1,2),
MaxPool2d(3,2),
Nin_block(256,384,3,1,1),
MaxPool2d(3,2),
Dropout(0.5),
Nin_block(384,10,3,1,1),
AdaptiveMaxPool2d((1,1)),
Flatten()
)
def forward(self,input):
output = self.model(input)
return output
if __name__ =='__main__':
nin = Nin()