我知道是网络层发生的问题,查看了网络层,但是实在不懂到底是哪里多了3个channels,急!
def init(self, ngf2=64, num_res_blocks=16, num_upsample=2):
super(Generator, self).init()
self.conv0 = nn.Conv2d(in_channels=3,out_channels=ngf2, kernel_size=3, padding=1)
# First layer
self.conv1 = nn.Conv2d(in_channels=ngf2,out_channels=ngf2, kernel_size=3, padding=1, stride=1)
# Residual blocks
self.res_blocks = nn.Sequential(*[ResidualInResidualDenseBlock(ngf2) for _ in range(num_res_blocks)])
# Second conv layer post residual blocks
self.conv2 = nn.Conv2d(in_channels=ngf2, out_channels = ngf2, kernel_size=3, stride=1, padding=1)
# Upsampling layers
upsample_layers = []
for _ in range(num_upsample):
upsample_layers += [
nn.Conv2d(ngf2, ngf2 * 4, kernel_size=1, stride=2, padding=1),
nn.LeakyReLU(),
nn.PixelShuffle(upscale_factor=2),
]
self.upsampling = nn.Sequential(*upsample_layers)
# Final output block
self.conv3 = nn.Sequential(
nn.Conv2d(in_channels=ngf2, out_channels = ngf2, kernel_size=5, stride=1, padding=1),
nn.LeakyReLU(),
nn.Conv2d(in_channels=ngf2, out_channels =64, kernel_size=5, padding=1, stride=1),
)
#self.conv4 = nn.Conv2d(out_channels=64, kernel_size=3, padding=1, in_channels=6)
def forward(self, x):
out0 = self.conv0(x)
out1 = self.conv1(out0)
out = self.res_blocks(out1)
out2 = self.conv2(out)
out = torch.add(out1, out2)
out = self.upsampling(out)
out = self.conv3(out)
# out = self.conv4(out)
return out