tjdnbj 2024-01-28 12:24 采纳率: 41.2%
浏览 6

深度学习softmax回归提问

以下是我的代码,想问一下运行后为何出现如图所示错误?(NotImplementedError: Module [FlattenLayer] is missing the required "forward" function)该如何修改?

img

img


import torchvision
import torchvision.transforms as transforms
import torch
from torch import nn
from torch.nn import init
import numpy as np
import sys
sys.path.append("C:/Users/zyx20/Desktop/深度学习编程/pythonProject")
import d2lzh_pytorch as d2l
batch_size = 256
if sys.platform.startswith('win'):
    num_workers = 0  # 0表示不用额外的进程来加速读取数据
else:
    num_workers = 4

mnist_train = torchvision.datasets.FashionMNIST(root='C:/Users/zyx20/Desktop/深度学习编程/MNIST/raw', train=True, download=True, transform=transforms.ToTensor())
mnist_test = torchvision.datasets.FashionMNIST(root='C:/Users/zyx20/Desktop/深度学习编程/MNIST/raw', train=False, download=True, transform=transforms.ToTensor())
train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)
test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)

num_inputs=784
num_outputs=10
class LinearNet(nn.Module):
    def __init__(self,num_inputs,num_outputs):
        super(LinearNet,self).__init__()
        self.linear=nn.Linear(num_inputs,num_outputs)
    def forward(self,x):
        y=self.linear(x.view(x.shape[0],-1))
        return y

net=LinearNet(num_inputs,num_outputs)

class FlattenLayer(nn.Module):
    def __init__(self):
        super(FlattenLayer,self).__init__()
    def foward(self,x):
        return x.view(x.shape[0],-1)
from collections import OrderedDict
net=nn.Sequential(
    #FlattenLayer(),
    #nn.Linear(num_inputs,num_outputs)
    OrderedDict([('flatten',FlattenLayer()),
                 ('linear',nn.Linear(num_inputs,num_outputs))])
)
init.normal_(net.linear.weight,mean=0,std=0.01)
init.constant_(net.linear.bias,val=0)
#定义交叉熵损失函数
loss=nn.CrossEntropyLoss()
#定义优化算法
optimizer=torch.optim.SGD(net.parameters(),lr=0.1)
#训练模型
num_epochs=5
def train_ch3(net,train_iter,test_iter,loss,num_epochs,batch_size,params=None,lr=None,optimizer=None):
    for epoch in range(num_epochs):
        train_l_sum,train_acc_sum,n=0.0,0.0,0
        for X,y in train_iter:
            y_hat=net(X)
            l=loss(y_hat,y).sum()

            #梯度清零
            if optimizer is not None:
                optimizer.zero_grad()
            elif params is not None and params[0].grad is not None:
                for param in params:
                    param.grad.data.zero_()

            l.backward()
            if optimizer is None:
                d2l.sgd(params,lr,batch_size)
            else:
                optimizer.step()

            train_l_sum+=l.item()
            train_acc_sum+=(y_hat.argmax(dim=1)==y).sum().item()
            n+=y.shape[0]
        test_acc=evaluate_accuracy(test_iter,net)
        print('epoch %d,loss %.4f,train acc %.3f,test acc %.3f'%(epoch+1,train_l_sum/n,train_acc_sum/n,test_acc))
train_ch3(net,train_iter,test_iter,loss,num_epochs,batch_size,None,None)
  • 写回答

1条回答 默认 最新

  • CSDN-Ada助手 CSDN-AI 官方账号 2024-01-28 14:06
    关注

    【相关推荐】



    • 这篇文章:深度学习:softmax激活实现多元分类 也许能够解决你的问题,你可以看下
    • 您还可以看一下 李立宗老师的讲给入门者的深度学习课程中的 softmax函数小节, 巩固相关知识点
    • 除此之外, 这篇博客: 动手学深度学习 图像分类数据集(二) softmax回归的从零开始实现中的 计算分类准确率 部分也许能够解决你的问题, 你可以仔细阅读以下内容或跳转源博客中阅读:

      解析:
      给定一个类别的预测概率分布y_hat ,我们把预测概率最大的类别作为输出类别。如果它与真实类别y 一致,说明这次预测是正确的。分类准确率即正确预测数量与总预测数量之比。

      def accuracy(y_hat, y):
          return (y_hat.argmax(dim=1) == y).float().mean().item()
      

      举例说明: 假设对于一个三分类问题,其预测值 y_hat如下 真实值y如下
      [0.1000, 0.3000, 0.6000] 代表对于第一个样本, 每个类别的概率值
      在这里插入图片描述
      argmax(dim=1) 函数的作用是返回每一行最大值的索引
      在这里 刚好标签对应的就是索引, 最大值对应的是最大概率 所以这个所以就是我们预测的标签值
      y_hat.argmax(dim=1) == y判断预测值是否与真实值相等
      在这里插入图片描述
      最终的结果计算出了准确率

      放到本题的模型中,计算分类准确率

      def evaluate_accuracy(data_iter, net):
          acc_sum, n = 0.0, 0
          for X, y in data_iter:
              acc_sum += (net(X).argmax(dim=1) == y).float().sum().item()
              n += y.shape[0]
          return acc_sum / n
      
      print(evaluate_accuracy(test_iter, net)) 
      

      0.0647


    如果你已经解决了该问题, 非常希望你能够分享一下解决方案, 写成博客, 将相关链接放在评论区, 以帮助更多的人 ^-^
    评论

报告相同问题?

问题事件

  • 修改了问题 1月29日
  • 修改了问题 1月29日
  • 创建了问题 1月28日

悬赏问题

  • ¥15 vscode编译ros找不到头文件,cmake.list文件出问题,如何解决?(语言-c++|操作系统-linux)
  • ¥15 通过AT指令控制esp8266发送信息
  • ¥15 有哪些AI工具提供可以通过代码上传EXCEL文件的API接口,并反馈分析结果
  • ¥15 二维装箱算法、矩形排列算法(相关搜索:二维装箱)
  • ¥20 nrf2401上电之后执行特定任务概率性一直处于最大重发状态
  • ¥15 二分图中俩集合中节点数与连边概率的关系
  • ¥20 wordpress如何限制ip访问频率
  • ¥15 自研小游戏,需要后台服务器存储用户数据关卡配置等数据
  • ¥15 请求解答odoo17外发加工某工序的实操方法
  • ¥20 IDEA ssm项目 跳转页面报错500