Εʟɪᴀᴜᴋ 2023-04-09 20:34 采纳率: 50%
浏览 24
已结题

基于CNN识别算法的问题

关于#神经网络#的问题:RuntimeError: mat1 and mat2 shapes cannot be multiplied (32x4096 and 3136x512);
在导入图片训练集的时候出现了问题,大致知道是如下代码设置有问题:

class CNN(nn.Module):
    def __init__(self, num_classes=100):#100种汉字
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)#,输入图像通道数,卷积产生通道数,卷积核尺寸,卷积步长
        self.bn1 = nn.BatchNorm2d(32)#归一化处理
        self.relu1 = nn.ReLU()#激活函数
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)#池化层参数
        self.conv2 = nn.Conv2d(32, 64, kernel_size=1, stride=1, padding=1)#mat1
        self.bn2 = nn.BatchNorm2d(64)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(64 * 7 * 7, 512)#mat2
        self.relu3 = nn.ReLU()
        self.fc2 = nn.Linear(512, num_classes)
        
    def forward(self, x):
        
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)
        x = self.pool1(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu2(x)
        x = self.pool2(x)
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = self.relu3(x)
        x = self.fc2(x)
        return x

# 定义训练和测试数据的转换
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.Resize((28, 28)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
])
'''
已经尝试过将dataloader中需要设置参数drop_last=True。即丢弃最后一个不足batchSize的样本,但是还是不行;
希望各位大佬帮我看看;

  • 写回答

2条回答 默认 最新

  • m0_61899108 2023-04-09 21:11
    关注

    这一句写错了,self.conv2 = nn.Conv2d(32, 64, kernel_size=1, stride=1, padding=1)
    应该是,self.conv2 = nn.Conv2d(32, 64, kernel_size=1, stride=1, padding=0)

    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论
查看更多回答(1条)

报告相同问题?

问题事件

  • 已结题 (查看结题原因) 4月14日
  • 已采纳回答 4月9日
  • 创建了问题 4月9日

悬赏问题

  • ¥15 linux驱动,linux应用,多线程
  • ¥20 我要一个分身加定位两个功能的安卓app
  • ¥15 基于FOC驱动器,如何实现卡丁车下坡无阻力的遛坡的效果
  • ¥15 IAR程序莫名变量多重定义
  • ¥15 (标签-UDP|关键词-client)
  • ¥15 关于库卡officelite无法与虚拟机通讯的问题
  • ¥15 目标检测项目无法读取视频
  • ¥15 GEO datasets中基因芯片数据仅仅提供了normalized signal如何进行差异分析
  • ¥100 求采集电商背景音乐的方法
  • ¥15 数学建模竞赛求指导帮助