超超不写代码 2023-04-03 21:57 采纳率: 100%
浏览 63
已结题

CNN的输入输出大小对应不上

#在复现一篇CNN的论文的时候,发现原文的给的CNN结构的参数一直对不上,
代码文件:https://github.com/yofchio/CNNQS
主程序是main文件,数据在data文件夹,CNN结构代码在CNNmodel,
PDf是论文,CNN细节在论文的49,50页,图在论文的第10页

img

输入的图片大小是64X60的矩阵(channel为1),然后第一层的卷积层,论文说是垂直步长为3,扩张概率为2,但是我算了下参数对不上,最后和全连接层的输入参数40680对不上,恳请大家看看这篇帖子,怎么修改模型参数,才能把这个模型的shape对应上,把代码跑通。

  • 写回答

6条回答 默认 最新

  • qq_54517157 2023-04-06 21:01
    关注

    更改了,结构,现在是欧克的了

    import torch.nn as nn
    num_classes = 2  # 类别数
    batch_size = 128  # 批次大小
    class CNN(nn.Module):
        def __init__(self):
            super().__init__()
            self.layer1 = nn.Sequential(
                nn.Conv2d(1, 64, kernel_size=(5, 3), stride=(3, 1), dilation=(2, 1), padding=(12, 1)),
                nn.BatchNorm2d(64),
                nn.LeakyReLU(negative_slope=0.01, inplace=True),
                nn.MaxPool2d((2, 1), stride=(2, 1)),
            )
            self.layer2 = nn.Sequential(
                nn.Conv2d(64, 128, kernel_size=(5, 3),padding='same'),
                nn.BatchNorm2d(128),
                nn.LeakyReLU(negative_slope=0.01, inplace=True),
                nn.MaxPool2d((2, 1), stride=(2, 1)),
            )
            self.layer3 = nn.Sequential(
                nn.Conv2d(128, 256, kernel_size=(5, 3),padding='same'),
                nn.BatchNorm2d(256),
                nn.LeakyReLU(negative_slope=0.01, inplace=True),
                nn.MaxPool2d((2, 1), stride=(2, 1)),
            )
            self.fc1 = nn.Sequential(
                nn.Dropout(p=0.5),
                nn.Linear(46080, 2),
            )
            self.softmax = nn.Softmax(dim=1)
    
            for m in self.modules():
                if isinstance(m, nn.Conv2d):
                    nn.init.xavier_uniform_(m.weight)
                elif isinstance(m, nn.Linear):
                    nn.init.xavier_uniform_(m.weight)
    
        def forward(self, x):
    
            x = self.layer1(x)
            x = self.layer2(x)
            x = self.layer3(x)
            x = x.view(batch_size, -1)
            x = self.fc1(x)
            x = self.softmax(x)
            return x
    
    
    
    
    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论
查看更多回答(5条)

报告相同问题?

问题事件

  • 系统已结题 4月14日
  • 已采纳回答 4月6日
  • 创建了问题 4月3日

悬赏问题

  • ¥15 下图接收小电路,谁知道原理
  • ¥15 装 pytorch 的时候出了好多问题,遇到这种情况怎么处理?
  • ¥20 IOS游览器某宝手机网页版自动立即购买JavaScript脚本
  • ¥15 手机接入宽带网线,如何释放宽带全部速度
  • ¥30 关于#r语言#的问题:如何对R语言中mfgarch包中构建的garch-midas模型进行样本内长期波动率预测和样本外长期波动率预测
  • ¥15 ETLCloud 处理json多层级问题
  • ¥15 matlab中使用gurobi时报错
  • ¥15 这个主板怎么能扩出一两个sata口
  • ¥15 不是,这到底错哪儿了😭
  • ¥15 2020长安杯与连接网探