qq_51599850 2023-05-09 13:40 采纳率: 0%
浏览 13

PyTorch加载模型错误

我把在PyTorch框架下用torch.save()训练模型保存下来,但是我加载不了我保存下来的模型。
import torch
model=torch.load("xxx.pt)总是出现没有“model”的错误

img

  • 写回答

1条回答 默认 最新

  • CSDN-Ada助手 CSDN-AI 官方账号 2023-05-09 16:24
    关注
    • 帮你找了个相似的问题, 你可以看下: https://ask.csdn.net/questions/7596107
    • 我还给你找了一篇非常好的博客,你可以看看是否有帮助,链接:PyTorch模型保存torch.save()与加载torch.load()
    • 你还可以看下pytorch参考手册中的 pytorch torch.nn到底是什么?
    • 除此之外, 这篇博客: Pytorch——保存训练好的模型参数中的 2.torch.save(保存模型) 部分也许能够解决你的问题, 你可以仔细阅读以下内容或跳转源博客中阅读:
    • 首先,先搭建一个神经网络

      import torch
      from torch import nn
      import matplotlib.pyplot as plt
      torch.manual_seed(11)    # 使每次得到的随机数是固定的。但是如果不加上torch.manual_seed这个函数调用的话,打印出来的随机数每次都不一样
      
      
      x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)  # [100] -> [100,1]
      y = x.pow(2) + 0.5*torch.rand(x.size())  # y的形状与x一样
      
      
      def make_and_save_model():
          
          network = torch.nn.Sequential(
              torch.nn.Linear(1, 8),
              torch.nn.ReLU(),
              torch.nn.Linear(8, 1)
          )
          optimizer = torch.optim.SGD(network.parameters(), lr=0.3)   #优化器
          criterion = torch.nn.MSELoss()     #损失函数
      
          # 训练
          for i in range(200):
              prediction = network(x)      #数据放入模型后得到预测值
              loss = criterion(prediction, y)    #计算预测值与真实值之间的误差
              optimizer.zero_grad()       #清空梯度
              loss.backward()           #误差反向传播
              optimizer.step()          #更新参数
          torch.save(network, 'network.pth')  # 保存整个网络
          torch.save(network.state_dict(), 'network_params.pth')   # 只保存网络中的参数
          
          plt.figure(1, figsize = (10,3))
          plt.subplot(131)
          plt.title('network')
          plt.scatter(x.data.numpy(), y.data.numpy())
          plt.plot(x.data.numpy(), prediction.data.numpy(), 'yo' , lw = 5)
          plt.pause(1)
      
    • 以下回答由chatgpt基于相关博客总结生成:

      针对该问题,可能有以下几个原因导致无法成功加载模型: 1. 加载模型的路径错误; 2. 保存模型时出现了错误,导致模型文件已损坏; 3. 加载模型的代码存在问题,导致无法成功加载。

      针对这些可能的原因,可以采取以下方法逐一排查问题:

      1. 确认加载模型的路径是否正确,可以使用os.path.exists()方法检查文件是否存在;
      2. 确认保存模型时没有出现错误,可以查看保存模型时的输出信息或者捕获异常信息;
      3. 检查加载模型的代码是否正确,包括模型文件格式是否正确,是否使用了正确的模型加载方法等。

      具体的代码示例如下:

      import torch
      
      # 加载模型的路径
      MODEL_DIR = './model.pt'
      
      def load_model():
          if not os.path.exists(MODEL_DIR):
              print('模型文件不存在!')
              return None
          # 加载模型文件
          try:
              model = torch.load(MODEL_DIR)
          except Exception as e:
              print(f'加载模型失败,错误信息:{e}')
              return None
          return model
      
      # 在主函数中调用加载模型的方法
      if __name__ == '__main__':
          # 加载模型
          model = load_model()
          if model is None:
              print('无法加载模型!')
          else:
              # 对加载的模型进行测试等操作
              pass
      

      针对第三种可能的问题,如果确信模型文件格式正确且加载代码没有误,可以尝试使用其他的模型加载方法,例如按照state_dict方式加载:

      import torch
      
      # 加载模型的路径
      MODEL_DIR = './model.pth'
      
      def load_model():
          if not os.path.exists(MODEL_DIR):
              print('模型文件不存在!')
              return None
          # 加载模型文件
          try:
              # 使用state_dict方式加载
              model = torch.nn.Sequential(
                  torch.nn.Linear(1, 8),
                  torch.nn.ReLU(),
                  torch.nn.Linear(8, 1)
              )
              checkpoint = torch.load(MODEL_DIR)
              model.load_state_dict(checkpoint)
          except Exception as e:
              print(f'加载模型失败,错误信息:{e}')
              return None
          return model
      
      # 在主函数中调用加载模型的方法
      if __name__ == '__main__':
          # 加载模型
          model = load_model()
          if model is None:
              print('无法加载模型!')
          else:
              # 对加载的模型进行测试等操作
              pass
      
    评论

报告相同问题?

问题事件

  • 创建了问题 5月9日

悬赏问题

  • ¥15 R语言中lasso回归报错
  • ¥15 网站突然不能访问了,上午还好好的
  • ¥15 semrush,SEO,内嵌网站,api
  • ¥15 Stata:为什么reghdfe后的因变量没有被发现识别啊
  • ¥15 关于#c语言#的问题,请各位专家解答!
  • ¥15 这个如何解决详细步骤
  • ¥15 在微信h5支付申请中,别人给钱就能用我的软件,这个的所属行业是啥?
  • ¥30 靶向捕获探针设计软件包
  • ¥15 别人给钱就能用我的软件,这个的经营场景是啥?
  • ¥15 react-diff-viewer组件,如何解决数据量过大卡顿问题