qq_45781819 2023-04-11 10:10 采纳率: 36.4%
浏览 33
已结题

RuntimeError: Given groups=1, weight of size [32, 3, 5, 5]

python白 训练模型出现错误

RuntimeError: Given groups=1, weight of size [32, 3, 5, 5], expected input[1, 32, 16, 16] to have 3 channels, but got 32 channels instead

但是我输出图片格式显示

print(imgs.shape)

torch.Size([1, 3, 32, 32])

代码如下:

import torch
import torchvision
from torch import nn
from torch.nn import Sequential, Conv2d, MaxPool2d, Flatten, Linear
from torch.utils.data import DataLoader

dataset = torchvision.datasets.CIFAR10("./data",train=False,
                                       transform=torchvision.transforms.ToTensor(),download=True)
dataloader = DataLoader(dataset,batch_size=1)

class Seq(nn.Module):
    def __init__(self):
        super(Seq,self).__init__()
        self.model = Sequential(
            Conv2d(3, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(3, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(3, 32, 5, padding=2),
            MaxPool2d(2),
            Flatten(),
            Linear(1024, 64),
            Linear(64, 10)
        )
    def forward(self,x):
        x = self.model(x)
        return x

#loss = nn.CrossEntropyLoss()
seq = Seq()
for data in dataloader:
    imgs,targets = data
    #print(imgs.shape)
    output = seq(imgs)
    #result = loss(outputs,target)


  • 写回答

2条回答 默认 最新

  • 瞬间的未来式 2023-04-11 10:19
    关注

    chatgpt:参考一下,不喜勿喷
    在构建Sequential模型时,第2、4、6层的输入通道数应该为32而不是3,因为输入图片是CIFAR-10数据集中的彩色图片,通道数是3。修改代码如下:

    class Seq(nn.Module):
        def __init__(self):
            super(Seq,self).__init__()
            self.model = Sequential(
                Conv2d(3, 32, 5, padding=2),
                MaxPool2d(2),
                Conv2d(32, 32, 5, padding=2),
                MaxPool2d(2),
                Conv2d(32, 32, 5, padding=2),
                MaxPool2d(2),
                Flatten(),
                Linear(1024, 64),
                Linear(64, 10)
            )
        def forward(self,x):
            x = self.model(x)
            return x
    
    
    
    

    这样修改后,第2、4、6层的输入通道数就是32,与imgs的通道数相匹配了,就不会出现输入输出通道数不匹配的错误了。

    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论
查看更多回答(1条)

报告相同问题?

问题事件

  • 系统已结题 4月22日
  • 已采纳回答 4月14日
  • 修改了问题 4月11日
  • 创建了问题 4月11日

悬赏问题

  • ¥15 微信会员卡等级和折扣规则
  • ¥15 微信公众平台自制会员卡可以通过收款码收款码收款进行自动积分吗
  • ¥15 随身WiFi网络灯亮但是没有网络,如何解决?
  • ¥15 gdf格式的脑电数据如何处理matlab
  • ¥20 重新写的代码替换了之后运行hbuliderx就这样了
  • ¥100 监控抖音用户作品更新可以微信公众号提醒
  • ¥15 UE5 如何可以不渲染HDRIBackdrop背景
  • ¥70 2048小游戏毕设项目
  • ¥20 mysql架构,按照姓名分表
  • ¥15 MATLAB实现区间[a,b]上的Gauss-Legendre积分