打豆豆1234 2023-01-22 15:20 采纳率: 100%
浏览 46
已结题

关于#pytorch#的问题:pytorch实现mnist手写数字识别

我正在用pytorch实现mnist手写数字识别,但是我的loss从一开始就不变,这是为什么? 以下是我的代码:


import gzip
import pickle
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn
from torch import optim
from torch.nn import functional as F
from torch.autograd import variable

#读取训练数据
f = gzip.open("./mnist.pkl.gz","rb")
train_data, val_data, test_data = pickle.load(f,encoding='latin1')
f.close()

# 将50000张训练图片分为250组,每组200张图片,图片大小
train_data_img = train_data[0].reshape(250,200,28,28)
train_data_ans = train_data[1].reshape(250,200)
#搭建网络
class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.fc1 = nn.Linear(784, 64)
        self.fc2 = nn.Linear(64, 32)
        self.fc3 = nn.Linear(32, 10)

    def forward(self,x):
        x = F.leaky_relu(self.fc1(x))
        x = F.leaky_relu(self.fc2(x))
        x = self.fc3(x)
        return x

net = Net()
optimizer = optim.SGD(net.parameters(),lr=0.01,momentum=0.0)

#迭代
losses = []
loss = 1
for epoch in range(10):
    for data,ans in zip(train_data_img,train_data_ans):
        out = net(torch.tensor(data.reshape(200,784)))
        ans = F.one_hot(torch.tensor(ans))
        loss = F.mse_loss(out, ans)
        optimizer.zero_grad()
        loss.backword()
        optimizer.step()
    losses.append(loss.item())

print(losses)
xlabel = np.linspace(0,len(losses),len(losses))
plt.plot(xlabel,losses)
plt.show()

  • 写回答

2条回答 默认 最新

  • 元气少女缘结神 2023-01-22 17:40
    关注

    可以将每一次的w、b、loss、dw、db都打印出来,看是否随机梯度下降没起作用

    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论
查看更多回答(1条)

报告相同问题?

问题事件

  • 系统已结题 5月19日
  • 已采纳回答 5月11日
  • 创建了问题 1月22日

悬赏问题

  • ¥20 Wpf Datarid单元格闪烁效果的实现
  • ¥15 图像分割、图像边缘提取
  • ¥15 sqlserver执行存储过程报错
  • ¥100 nuxt、uniapp、ruoyi-vue 相关发布问题
  • ¥15 浮窗和全屏应用同时存在,全屏应用输入法无法弹出
  • ¥100 matlab2009 32位一直初始化
  • ¥15 Expected type 'str | PathLike[str]…… bytes' instead
  • ¥15 三极管电路求解,已知电阻电压和三级关放大倍数
  • ¥15 ADS时域 连续相位观察方法
  • ¥15 Opencv配置出错