我正在用pytorch实现mnist手写数字识别,但是我的loss从一开始就不变,这是为什么? 以下是我的代码:
import gzip
import pickle
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn
from torch import optim
from torch.nn import functional as F
from torch.autograd import variable
#读取训练数据
f = gzip.open("./mnist.pkl.gz","rb")
train_data, val_data, test_data = pickle.load(f,encoding='latin1')
f.close()
# 将50000张训练图片分为250组,每组200张图片,图片大小
train_data_img = train_data[0].reshape(250,200,28,28)
train_data_ans = train_data[1].reshape(250,200)
#搭建网络
class Net(nn.Module):
def __init__(self):
super(Net,self).__init__()
self.fc1 = nn.Linear(784, 64)
self.fc2 = nn.Linear(64, 32)
self.fc3 = nn.Linear(32, 10)
def forward(self,x):
x = F.leaky_relu(self.fc1(x))
x = F.leaky_relu(self.fc2(x))
x = self.fc3(x)
return x
net = Net()
optimizer = optim.SGD(net.parameters(),lr=0.01,momentum=0.0)
#迭代
losses = []
loss = 1
for epoch in range(10):
for data,ans in zip(train_data_img,train_data_ans):
out = net(torch.tensor(data.reshape(200,784)))
ans = F.one_hot(torch.tensor(ans))
loss = F.mse_loss(out, ans)
optimizer.zero_grad()
loss.backword()
optimizer.step()
losses.append(loss.item())
print(losses)
xlabel = np.linspace(0,len(losses),len(losses))
plt.plot(xlabel,losses)
plt.show()