用pytorch复刻Resnet 18的时候遇到了一个问题:
import matplotlib.pyplot as plt
import torch
import torchvision
import torch.nn as nn
from torch.utils.data import DataLoader
trans = torchvision.transforms.ToTensor()
train_set = torchvision.datasets.CIFAR10(root="./CIFAR10",download=True,transform=trans,train=True)
test_set = torchvision.datasets.CIFAR10(root="./CIFAR10",download=True,transform=trans,train=False)
train_dataloader = DataLoader(train_set,shuffle=True,batch_size=128)
test_dataloader = DataLoader(test_set,shuffle=False,batch_size=128)
class my_BasicBlock(nn.Module):
def __init__(self,in_channal,out_channal,stride):
super(my_BasicBlock,self).__init__()
self.net = nn.Sequential(
nn.Conv2d(in_channal,out_channal,kernel_size=3,stride=stride[0],padding=1),
nn.BatchNorm2d(out_channal),
nn.ReLU(inplace=True),
nn.Conv2d(out_channal,out_channal,kernel_size=3,stride=stride[1],padding=1),
nn.BatchNorm2d(out_channal)
)
self.shortcut = nn.Sequential()
if stride[0] != 1 or in_channal != out_channal:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channal,out_channal,kernel_size=1,stride=stride[0]),
nn.BatchNorm2d(out_channal)
)
def forward(self,x):
y = self.net(x)
y += self.shortcut(x)
y = nn.ReLU(y)
return y
class my_resnet18(nn.Module):
def __init__(self,my_BasicBlock,numclasser=10):
super(my_resnet18,self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(3,64,7,2,3),
nn.BatchNorm2d(64),
nn.MaxPool2d(kernel_size=3,stride=2,padding=1)
)
self.conv2 = nn.Sequential(
my_BasicBlock(64,64,[1,1]),
my_BasicBlock(64,64,[1,1])
)
self.conv3 = nn.Sequential(
my_BasicBlock(64,128,[2,1]),
my_BasicBlock(128,128,[1,1])
)
def forward(self,x):
y = self.conv1(x)
y = self.conv2(y)
y = self.conv3(y)
return y
resnet = my_resnet18(my_BasicBlock)
print(resnet)
for data,targer in train_dataloader:
output =resnet(data)
print(output.size())
break
在这里,我创建了一个BasicBlock,然后在my_resnet中引用它,(Resnet不全,我只是想看卷积输出的参数格式)但是这样会报错:
TypeError: conv2d() received an invalid combination of arguments - got (ReLU, Parameter, Parameter, tuple, tuple, tuple, int), but expected one of:
* (Tensor input, Tensor weight, Tensor bias, tuple of ints stride, tuple of ints padding, tuple of ints dilation, int groups)
didn't match because some of the arguments have invalid types: (!ReLU!, !Parameter!, !Parameter!, !tuple of (int, int)!, !tuple of (int, int)!, !tuple of (int, int)!, int)
提示我输入到卷积层的参数有误,但是我打印出resnet模型,发现他和正确模型是一样的,我想知道这样写代码究竟哪里出问题了