本地磁盘A 2018-11-08 11:17 采纳率: 100%
浏览 4318
已结题

萌新求问,我神经网络书上的一段代码出错不知道为什么??

这是陈云的《深度学习框架Pytorch入门与实践》4.2.1节图像相关层的代码:

from PIL import Image
from torchvision.transforms import ToTensor,ToPILImage
import torch as t
from torch import nn
from torch.autograd import Variable as V
to_tensor=ToTensor()
to_pil=ToPILImage()
lena=Image.open('C:/Users/Desktop/lena.png')
input=to_tensor(lena).unsqueeze(0)

kernel=t.ones(3,3)/-9
kernel[1][1]=1
conv=nn.Conv2d(1,1,(3,3),1,bias=False)
conv.weight.data=kernel.view(1,1,3,3)

out=conv(V(input))
to_pil(out.data.squeeze(0))

但是结果报错:

File "D:\py\lib\site-packages\torch\nn\modules\conv.py", line 301, in forward
self.padding, self.dilation, self.groups)

RuntimeError: Given groups=1, weight of size [1, 1, 3, 3], expected input[1, 3, 300, 300] to have 1 channels, but got 3 channels instead

这怎么解决啊。。。。。

  • 写回答

2条回答 默认 最新

  • xwang71785 2019-03-10 19:58
    关注

    lena=Image.open('C:/Users/Desktop/lena.png')
    我猜你用的lena.png是彩色的(所以是3个channels)
    下载一个黑白的试试(只有1个channel)

    评论

报告相同问题?

悬赏问题

  • ¥15 CSS实现渐隐虚线边框
  • ¥15 thinkphp6配合social login单点登录问题
  • ¥15 HFSS 中的 H 场图与 MATLAB 中绘制的 B1 场 部分对应不上
  • ¥15 如何在scanpy上做差异基因和通路富集?
  • ¥20 关于#硬件工程#的问题,请各位专家解答!
  • ¥15 关于#matlab#的问题:期望的系统闭环传递函数为G(s)=wn^2/s^2+2¢wn+wn^2阻尼系数¢=0.707,使系统具有较小的超调量
  • ¥15 FLUENT如何实现在堆积颗粒的上表面加载高斯热源
  • ¥30 截图中的mathematics程序转换成matlab
  • ¥15 动力学代码报错,维度不匹配
  • ¥15 Power query添加列问题