m0_62112526 2022-08-22 10:22 采纳率: 50%
浏览 114
已结题

使用resnet152模型训练报错

在使用resnet152模型训练fashion_mnist时报错

具体代码如下

model = torchvision.models.resnet152(pretrained=True)
fashion_mnist = keras.datasets.fashion_mnist
data,label = fashion_mnist.load_data()

for param in model.parameters():
    param.requires_gradq = False
model.fc = torch.nn.Linear(2048, 28)
loss_fn = torch.nn.MSELoss(reduction='sum')
opt = torch.optim.SGD(model.fc.parameters(), lr=0.001)

然而在下面的代码,也就是训练模型的时候遇到报错

for i in range(200):
    y = model(data)
    loss = loss_fn(label, y)
    opt.zero_grad()
    loss.backward()
    opt.step()
    print(loss)

报错内容 :

TypeError                                 Traceback (most recent call last)
<ipython-input-20-042e5c6d56e7> in <module>
      1 for i in range(200):
----> 2     y = model(data)
      3     loss = loss_fn(label, y)
      4     opt.zero_grad()
      5     loss.backward()

~/yes/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1100         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102             return forward_call(*input, **kwargs)
   1103         # Do not call functions when jit is used
   1104         full_backward_hooks, non_full_backward_hooks = [], []

~/yes/lib/python3.8/site-packages/torchvision/models/resnet.py in forward(self, x)
    247 
    248     def forward(self, x: Tensor) -> Tensor:
--> 249         return self._forward_impl(x)
    250 
    251 

~/yes/lib/python3.8/site-packages/torchvision/models/resnet.py in _forward_impl(self, x)
    230     def _forward_impl(self, x: Tensor) -> Tensor:
    231         # See note [TorchScript super()]
--> 232         x = self.conv1(x)
    233         x = self.bn1(x)
    234         x = self.relu(x)

~/yes/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1100         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102             return forward_call(*input, **kwargs)
   1103         # Do not call functions when jit is used
   1104         full_backward_hooks, non_full_backward_hooks = [], []

~/yes/lib/python3.8/site-packages/torch/nn/modules/conv.py in forward(self, input)
    444 
    445     def forward(self, input: Tensor) -> Tensor:
--> 446         return self._conv_forward(input, self.weight, self.bias)
    447 
    448 class Conv3d(_ConvNd):

~/yes/lib/python3.8/site-packages/torch/nn/modules/conv.py in _conv_forward(self, input, weight, bias)
    440                             weight, bias, self.stride,
    441                             _pair(0), self.dilation, self.groups)
--> 442         return F.conv2d(input, weight, bias, self.stride,
    443                         self.padding, self.dilation, self.groups)
    444 

TypeError: conv2d() received an invalid combination of arguments - got (tuple, Parameter, NoneType, tuple, tuple, tuple, int), but expected one of:
 * (Tensor input, Tensor weight, Tensor bias, tuple of ints stride, tuple of ints padding, tuple of ints dilation, int groups)
      didn't match because some of the arguments have invalid types: (!tuple!, !Parameter!, !NoneType!, !tuple!, !tuple!, !tuple!, int)
 * (Tensor input, Tensor weight, Tensor bias, tuple of ints stride, str padding, tuple of ints dilation, int groups)
      didn't match because some of the arguments have invalid types: (!tuple!, !Parameter!, !NoneType!, !tuple!, !tuple!, !tuple!, int)

我尝试将data和label转torch.tensor

data = torch.tensor(data)

可是依旧报错

TypeError                                 Traceback (most recent call last)
<ipython-input-32-42ca4109f755> in <module>
----> 1 data = torch.tensor(data)

TypeError: not a sequence
求指导!
提前致谢!
  • 写回答

2条回答 默认 最新

  • 爱晚乏客游 2022-08-22 14:36
    关注

    fashion_mnist.load_data()这个返回值应该有4个把,为啥你才两个?
    x_train,y_train,x_test,y_test=fashion.load_data()?

    评论

报告相同问题?

问题事件

  • 已结题 (查看结题原因) 8月23日
  • 赞助了问题酬金5元 8月22日
  • 创建了问题 8月22日

悬赏问题

  • ¥20 如何通过代码传输视频到亚马逊平台
  • ¥15 php查询mysql数据库并显示至下拉列表中
  • ¥15 freertos下使用外部中断失效
  • ¥15 输入的char字符转为int类型,不是对应的ascall码,如何才能使之转换为对应ascall码?或者使输入的char字符可以正常与其他字符比较?
  • ¥15 devserver配置完 启动服务 无法访问static上的资源
  • ¥15 解决websocket跟c#客户端通信
  • ¥30 Python调用dll文件输出Nan重置dll状态
  • ¥15 浮动div的高度控制问题。
  • ¥66 换电脑后应用程序报错
  • ¥50 array数据同步问题