WSS_ 2023-04-16 21:35 采纳率: 37.5%
浏览 8
已结题

请问这个报错为什么会发生?

在应用CNN进行张量融合时报错:
RuntimeError: mat1 and mat2 shapes cannot be multiplied (10x18240 and 10x18240)
模型代码如下

class MyCNN(nn.Module):
    def __init__(self):
        super(MyCNN, self).__init__()
        self.conv1 = nn.Conv2d(6, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(10,18240)
        self.fc2 = nn.Linear(18240,1)

    def forward(self, xp, xc):
        batch_size = xp.size(0)
        x = torch.cat((xp.unsqueeze(1), xc), dim=1) #(10,6,20,456)
        print(x.shape)
        x = F.relu(self.conv1(x))
        print(x.shape)
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = x.view(batch_size, -1)
        print(x.shape)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x.squeeze()


为什么会出现这样的问题?要怎么修改?

  • 写回答

2条回答 默认 最新

  • CSDN-Ada助手 CSDN-AI 官方账号 2023-04-22 05:03
    关注
    不知道你这个问题是否已经解决, 如果还没有解决的话:
    • 这篇博客: PyTorch深度学习快速入门教程【小土堆】 学习笔记中的 RuntimeError: mat1 and mat2 shapes cannot be multiplied (64x1024 and 10240x64) 原因 部分也许能够解决你的问题, 你可以仔细阅读以下内容或者直接跳转源博客中阅读:

      在这里插入图片描述
      原因: 网络参数设置不合适!!

      from torch import nn
      from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
      import torch
      
      class Lyy(nn.Module):
          def __init__(self):
              super(Lyy, self).__init__()
              self.conv1 = Conv2d(3, 32, 5, padding=2)
              self.maxpool1 = MaxPool2d(2)
              self.conv2 = Conv2d(32, 32, 5, padding=2)
              self.maxpool2 = MaxPool2d(2)
              self.conv3 = Conv2d(32, 64, 5, padding=2)
              self.maxpool3 = MaxPool2d(2)
              self.flatten = Flatten()
              self.linear1 = Linear(1024, 64)
              self.linear2 = Linear(64, 10)
      
          def forward(self, x):
              x = self.conv1(x)
              x = self.maxpool1(x)
              x = self.conv2(x)
              x = self.maxpool2(x)
              x = self.conv3(x)
              x = self.maxpool3(x)
              x = self.flatten(x)
              x = self.linear1(x)
              x = self.linear2(x)
              return x
      
      lyy = Lyy()
      input = torch.ones((64,3,32,32))
      print(input.shape)
      output = lyy(input)
      print(output.shape)
      

      上述代码的,等价替换版本

      from torch import nn
      from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
      import torch
      
      class Lyy(nn.Module):
          def __init__(self):
              super(Lyy, self).__init__()
              self.model1 = Sequential(
                  Conv2d(3, 32, 5, padding=2),
                  MaxPool2d(2),
                  Conv2d(32, 32, 5, padding=2),
                  MaxPool2d(2),
                  Conv2d(32, 64, 5, padding=2),
                  MaxPool2d(2),
                  Flatten(),
                  Linear(1024, 64),
                  Linear(64, 10)
              )
      
          def forward(self, x):
              x = self.model1(x)
              return x
      
      lyy = Lyy()
      input = torch.ones((64,3,32,32))
      print(input.shape)
      output = lyy(input)
      print(output.shape)
      
    • 以下回答来自chatgpt:

      根据报错信息,mat1和mat2的大小分别为10x18240和10x18240,在矩阵相乘时不符合尺寸要求。可能原因是在进行张量融合时,数据维度和模型参数设置不符合。为了解决这个问题,需要对代码进行调整,确保数据和参数的维度匹配。

      具体的操作步骤如下:

      1.首先检查输入数据的维度是否和模型参数的要求一致,比如输入的数据维度是否为4维,卷积层的输入通道数是否和输入数据通道数一致等。

      2.如果前面的检查都没有问题,那么需要检查模型参数的设置是否符合实际应用需要。比如卷积核的大小、通道数、池化层的大小等是否设置得当。

      3.如果数据和参数都没有问题,那么建议将数据和参数的维度打印出来,检查究竟是哪一步出现了维度不匹配的问题。

      4.最后,如果还不能解决问题,建议通过查询文献或者请教更有经验的同行来寻求帮助。


    如果你已经解决了该问题, 非常希望你能够分享一下解决方案, 写成博客, 将相关链接放在评论区, 以帮助更多的人 ^-^
    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论
查看更多回答(1条)

报告相同问题?

问题事件

  • 系统已结题 11月1日
  • 已采纳回答 10月24日
  • 创建了问题 4月16日

悬赏问题

  • ¥15 神经网络预测均方误差很小 但是图像上看着差别太大
  • ¥15 Oracle中如何从clob类型截取特定字符串后面的字符
  • ¥15 想通过pywinauto自动电机应用程序按钮,但是找不到应用程序按钮信息
  • ¥15 如何在炒股软件中,爬到我想看的日k线
  • ¥15 seatunnel 怎么配置Elasticsearch
  • ¥15 PSCAD安装问题 ERROR: Visual Studio 2013, 2015, 2017 or 2019 is not found in the system.
  • ¥15 (标签-MATLAB|关键词-多址)
  • ¥15 关于#MATLAB#的问题,如何解决?(相关搜索:信噪比,系统容量)
  • ¥500 52810做蓝牙接受端
  • ¥15 基于PLC的三轴机械手程序