这是我写的权重初始化
下面是我写的生成器网络
class BasicBlock(nn.Module):
def __init__(self, in_channel,out_channel): #为输入的channel大小,BasicBlock输出等于输入channel大小
super(BasicBlock, self).__init__()
self.conv1 = nn.ConvTranspose2d(in_channel, out_channel//2, kernel_size=1,
stride=1, padding=0,bias=False)
self.bn1 =nn.GroupNorm(4,out_channel//2)
self.relu1 = nn.ReLU(inplace=True)
self.conv2 = nn.ConvTranspose2d(out_channel//2,out_channel, kernel_size=3,
stride=1, padding=1,bias=False)
self.bn2 =nn.GroupNorm(4,out_channel)
self.relu2 = nn.ReLU(inplace=True)
if in_channel!=out_channel:
self.extra =nn.Sequential(
nn.ConvTranspose2d(in_channel,out_channel,kernel_size=3,stride=1,padding=1,bias=False),
nn.GroupNorm(4, out_channel),
)
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.relu1(out)
out = self.conv2(out)
out = self.relu2(out)
residual=self.extra(residual)
out = out + residual
out=self.relu2(out)
return out
class netG(nn.Module):
def __init__(self):
super(netG, self).__init__()
self.ngf = 512
self.SEB1=BasicBlock(nz,self.ngf )
self.SEB2=BasicBlock(self.ngf , self.ngf//2)
self.SEB3=BasicBlock(self.ngf//2, self.ngf//4)
self.SEB4=BasicBlock(self.ngf//4, 3)
def forward(self, x):
x=self.SEB1(x)
x = self.SEB2(x)
x = self.SEB3(x)
x = self.SEB4(x)
return x