西窗虫 2024-04-24 20:50 采纳率: 0%
浏览 2

pytorch搭建Resnet遇到问题

用pytorch复刻Resnet 18的时候遇到了一个问题:

import matplotlib.pyplot as plt
import torch
import torchvision
import torch.nn as nn
from torch.utils.data import DataLoader

trans = torchvision.transforms.ToTensor()
train_set = torchvision.datasets.CIFAR10(root="./CIFAR10",download=True,transform=trans,train=True)
test_set = torchvision.datasets.CIFAR10(root="./CIFAR10",download=True,transform=trans,train=False)

train_dataloader = DataLoader(train_set,shuffle=True,batch_size=128)
test_dataloader = DataLoader(test_set,shuffle=False,batch_size=128)

class my_BasicBlock(nn.Module):
    def __init__(self,in_channal,out_channal,stride):
        super(my_BasicBlock,self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channal,out_channal,kernel_size=3,stride=stride[0],padding=1),
            nn.BatchNorm2d(out_channal),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channal,out_channal,kernel_size=3,stride=stride[1],padding=1),
            nn.BatchNorm2d(out_channal)
        )

        self.shortcut = nn.Sequential()
        if stride[0] != 1 or in_channal != out_channal:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channal,out_channal,kernel_size=1,stride=stride[0]),
                nn.BatchNorm2d(out_channal)
            )
    def forward(self,x):
        y = self.net(x)
        y += self.shortcut(x)
        y = nn.ReLU(y)
        return y


class my_resnet18(nn.Module):
    def __init__(self,my_BasicBlock,numclasser=10):
        super(my_resnet18,self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(3,64,7,2,3),
            nn.BatchNorm2d(64),
            nn.MaxPool2d(kernel_size=3,stride=2,padding=1)
        )
        self.conv2 = nn.Sequential(
            my_BasicBlock(64,64,[1,1]),
            my_BasicBlock(64,64,[1,1])
        )
        self.conv3 = nn.Sequential(
            my_BasicBlock(64,128,[2,1]),
            my_BasicBlock(128,128,[1,1])
        )

    def forward(self,x):
        y = self.conv1(x)
        y = self.conv2(y)
        y = self.conv3(y)

        return y

resnet = my_resnet18(my_BasicBlock)
print(resnet)
for data,targer in train_dataloader:
    output =resnet(data)
    print(output.size())
    break

在这里,我创建了一个BasicBlock,然后在my_resnet中引用它,(Resnet不全,我只是想看卷积输出的参数格式)但是这样会报错:


TypeError: conv2d() received an invalid combination of arguments - got (ReLU, Parameter, Parameter, tuple, tuple, tuple, int), but expected one of:
 * (Tensor input, Tensor weight, Tensor bias, tuple of ints stride, tuple of ints padding, tuple of ints dilation, int groups)
      didn't match because some of the arguments have invalid types: (!ReLU!, !Parameter!, !Parameter!, !tuple of (int, int)!, !tuple of (int, int)!, !tuple of (int, int)!, int)

提示我输入到卷积层的参数有误,但是我打印出resnet模型,发现他和正确模型是一样的,我想知道这样写代码究竟哪里出问题了

  • 写回答

4条回答 默认 最新

  • 专家-赤兔[在线] 优质创作者: 编程框架技术领域 2024-04-24 20:50
    关注

    引自免费微信小程序:皆我百晓生

    看起来你在定义forward()函数时对ReLU操作符的位置有些误解。在你的代码中,你在添加残差连接之后立即应用了ReLU激活函数,而正确的做法是在整个block的所有卷积操作完成后统一应用ReLU。

    为了修复这个问题,你需要将ReLU操作从my_BasicBlock类中的forward()函数中移除,并将其放在y += self.shortcut(x)之后,确保在整个block的计算结束后再进行非线性变换。下面是修改后的代码:

    class my_BasicBlock(nn.Module):
        def __init__(self, in_channal, out_channal, stride):
            super(my_BasicBlock, self).__init__()
            self.net = nn.Sequential(
                nn.Conv2d(in_channal, out_channal, kernel_size=3, stride=stride[0], padding=1),
                nn.BatchNorm2d(out_channal),
                nn.ReLU(inplace=True),
                nn.Conv2d(out_channal, out_channal, kernel_size=3, stride=stride[1], padding=1),
                nn.BatchNorm2d(out_channal)
            )
    
            self.shortcut = nn.Sequential()
            if stride[0] != 1 or in_channal != out_channal:
                self.shortcut = nn.Sequential(
                    nn.Conv2d(in_channal, out_channal, kernel_size=1, stride=stride[0]),
                    nn.BatchNorm2d(out_channal)
                )
    
        def forward(self, x):
            y = self.net(x)
            y += self.shortcut(x)
            
            # 移动ReLU到这里,在整个block计算后应用
            y = nn.ReLU()(y)
            return y
    

    现在,你的代码应该可以正常运行并避免TypeError。如果你还有其他问题或错误,请随时告诉我。

    评论 编辑记录

报告相同问题?

问题事件

  • 创建了问题 4月24日