2401_82357774 2025-05-04 15:03 采纳率: 0%
浏览 7

pix2pix模型遇到了一个问题

模型遇到了一个问题

因为训练集和验证集都有着白色背景,生成器直接出来个纯白色图片

class Generator(nn.Module):
def init(self):
super().init()
self.conv1=Sequential(
nn.Conv2d(3,96,kernel_size=3,padding=1,stride=1),
nn.InstanceNorm2d(96),
nn.LeakyReLU(0.2,inplace=True),
nn.Conv2d(96,96,kernel_size=3,padding=1,stride=1),
nn.InstanceNorm2d(96),
nn.LeakyReLU(0.2,inplace=True),
nn.Conv2d(96,96,kernel_size=3,padding=1,stride=1),
nn.InstanceNorm2d(96),
nn.LeakyReLU(0.2,inplace=True)
)
self.down1=Sequential(
nn.Conv2d(96,192,kernel_size=3,padding=1,stride=1),
nn.InstanceNorm2d(192),
nn.LeakyReLU(0.2,inplace=True),
nn.Conv2d(192,192,kernel_size=3,padding=1,stride=1),
nn.InstanceNorm2d(192),
nn.LeakyReLU(0.2,inplace=True),
nn.Conv2d(192,192,kernel_size=3,padding=1,stride=2),
nn.InstanceNorm2d(192),
nn.LeakyReLU(0.2,inplace=True)
)
self.down2=Sequential(
nn.Conv2d(192,384,kernel_size=3,padding=1,stride=1),
nn.InstanceNorm2d(384),
nn.LeakyReLU(0.2,inplace=True),
nn.Conv2d(384,384,kernel_size=3,padding=1,stride=1),
nn.InstanceNorm2d(384),
nn.LeakyReLU(0.2,inplace=True),
nn.Conv2d(384,384,kernel_size=3,padding=1,stride=2),
nn.InstanceNorm2d(384),
nn.LeakyReLU(0.2,inplace=True)
)
self.down3=Sequential(
nn.Conv2d(384,768,kernel_size=3,padding=1,stride=1),
nn.InstanceNorm2d(768),
nn.LeakyReLU(0.2,inplace=True),
nn.Conv2d(768,768,kernel_size=3,padding=1,stride=1),
nn.InstanceNorm2d(768),
nn.LeakyReLU(0.2,inplace=True),
nn.Conv2d(768,768,kernel_size=3,padding=1,stride=2),
nn.InstanceNorm2d(768),
nn.LeakyReLU(0.2,inplace=True)
)
self.down_and_up=Sequential(
nn.Conv2d(768,1536,kernel_size=3,padding=1,stride=2),
nn.InstanceNorm2d(1536),
nn.LeakyReLU(0.2,inplace=True),
nn.Conv2d(1536,1536,kernel_size=3,padding=1,stride=1),
nn.InstanceNorm2d(1536),
nn.LeakyReLU(0.2,inplace=True),
nn.ConvTranspose2d(1536,768,kernel_size=2,padding=0,stride=2),
nn.InstanceNorm2d(768),
nn.LeakyReLU(0.2,inplace=True)
)
self.up1=Sequential(
nn.Conv2d(1536,768,kernel_size=3,padding=1,stride=1),
nn.InstanceNorm2d(768),
nn.LeakyReLU(0.2,inplace=True),
nn.Conv2d(768,768,kernel_size=3,padding=1,stride=1),
nn.InstanceNorm2d(768),
nn.LeakyReLU(0.2,inplace=True),
nn.ConvTranspose2d(768,384,kernel_size=2,padding=0,stride=2),
nn.InstanceNorm2d(384),
nn.LeakyReLU(0.2,inplace=True)
)
self.up2=Sequential(
nn.Conv2d(768,384,kernel_size=3,padding=1,stride=1),
nn.InstanceNorm2d(384),
nn.LeakyReLU(0.2,inplace=True),
nn.Conv2d(384,384,kernel_size=3,padding=1,stride=1),
nn.InstanceNorm2d(384),
nn.LeakyReLU(0.2,inplace=True),
nn.ConvTranspose2d(384,192,kernel_size=2,padding=0,stride=2),
nn.InstanceNorm2d(192),
nn.LeakyReLU(0.2,inplace=True)
)
self.up3=Sequential(
nn.Conv2d(384,192,kernel_size=3,padding=1,stride=1),
nn.InstanceNorm2d(192),
nn.LeakyReLU(0.2,inplace=True),
nn.Conv2d(192,192,kernel_size=3,padding=1,stride=1),
nn.InstanceNorm2d(192),
nn.LeakyReLU(0.2,inplace=True),
nn.ConvTranspose2d(192,96,kernel_size=2,padding=0,stride=2),
nn.InstanceNorm2d(96),
nn.LeakyReLU(0.2,inplace=True)
)
self.conv2=Sequential(
nn.Conv2d(192,96,kernel_size=3,padding=1,stride=1),
nn.InstanceNorm2d(96),
nn.LeakyReLU(0.2,inplace=True),
nn.Conv2d(96,96,kernel_size=3,padding=1,stride=1),
nn.InstanceNorm2d(96),
nn.LeakyReLU(0.2,inplace=True),
nn.Conv2d(96,96,kernel_size=3,padding=1,stride=1),
nn.InstanceNorm2d(96),
nn.LeakyReLU(0.2,inplace=True)
)
self.output=nn.Sequential(
nn.Conv2d(96,3,kernel_size=1,stride=1),
nn.Tanh()
)
def forward(self,x):
x1=self.conv1(x)
x2=self.down1(x1)
x3=self.down2(x2)
x4=self.down3(x3)
x=self.down_and_up(x4)
x=torch.cat([x,x4],dim=1)
x=self.up1(x)
x=torch.cat([x,x3],dim=1)
x=self.up2(x)
x=torch.cat([x,x2],dim=1)
x=self.up3(x)
x=torch.cat([x,x1],dim=1)
x=self.conv2(x)
x=self.output(x)
return x

class Distinguish(nn.Module):
def init(self):
super().init()
self.model=nn.Sequential(
nn.Conv2d(6,64,kernel_size=11,padding=5,stride=2),
nn.LeakyReLU(0.2,inplace=True),
nn.Dropout(0.2),
nn.Conv2d(64,128,kernel_size=5,padding=2),
nn.InstanceNorm2d(128),
nn.LeakyReLU(0.2,inplace=True),
nn.Conv2d(128,128,kernel_size=5,padding=2),
nn.InstanceNorm2d(128),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(128,256,kernel_size=5,padding=2,stride=2),
nn.InstanceNorm2d(256),
nn.LeakyReLU(0.2,inplace=True),
nn.Conv2d(256,512,kernel_size=5,padding=2),
nn.InstanceNorm2d(512),
nn.LeakyReLU(0.2,inplace=True),
nn.Conv2d(512,512,kernel_size=5,padding=2),
nn.InstanceNorm2d(512),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(512,512,kernel_size=5,padding=2,stride=2),
nn.InstanceNorm2d(512),
nn.LeakyReLU(0.2,inplace=True),
nn.Conv2d(512,512,kernel_size=5,padding=1),
nn.InstanceNorm2d(512),
nn.LeakyReLU(0.2,inplace=True),
nn.Conv2d(512,1,kernel_size=3,padding=1),
nn.Flatten(),
)
def forward(self,I,O):
x=torch.cat([I,O],dim=1)
x=self.model(x)
return x

D=Distinguish()
try:
D.load_state_dict(torch.load("D.pth"))
except:
pass
G=Generator().cuda()
try:
G.load_state_dict(torch.load("G.pth"))
except:
pass

totensor=transforms.Compose([
ToTensor(),
Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
])

class Pairdata():
def init(self):
...
def getitem(self):
...
def len(self):
return len(self.pair_img_path)
pairdate_loader=Pairdata()

D_optim=torch.optim.RMSprop(D.parameters(),lr=0.02)
G_optim=torch.optim.RMSprop(G.parameters(),lr=0.02)
l1loss=nn.L1Loss()
BCEloss=nn.BCEWithLogitsLoss()

for i in range(5):
D_optim.zero_grad()
G_optim.zero_grad()
for j in range(4):
input_img,output_img=pairdate_loader.getitem()
with torch.no_grad():
fake_img=G(input_img.cuda())
pred_real=D(input_img,output_img)
loss_real=-torch.mean(pred_real)
pred_fake=D(input_img,fake_img.cpu())
loss_fake=torch.mean(pred_fake)
current_D_loss=(loss_real+loss_fake)
(current_D_loss).backward()
print('D:')
print(current_D_loss.item())

    fake_img=G(input_img.cuda())
    pred_fake=D(input_img,fake_img.cpu())
    advloss=-torch.mean(pred_fake)
    l1_loss=l1loss(fake_img,output_img.cuda())
    current_G_loss=advloss+l1_loss*2
    (current_G_loss).backward()
    print('G:')
    print(current_G_loss.item())
D_optim.step()
G_optim.step()
clip_value=0.02
for param in D.parameters():
    param.data.clamp_(-clip_value,clip_value)
print('返回梯度')
torch.save(D.state_dict(),"D.pth")
torch.save(G.state_dict(),"G.pth")

unloader=transforms.ToPILImage()
def tensor_to_PIL(tensor):
image=tensor.cpu().clone()
image=image.squeeze(0)
image=(image+1)/2.0
image=unloader(image)
return image

tensor_to_PIL(G(pairdate_loader.getitem()[0].cuda())).show()
while True:
pass

  • 写回答

4条回答 默认 最新

  • 紫雾凌寒 新星创作者: 人工智能技术领域 2025-05-04 15:06
    关注

    让 “紫雾凌寒” 来帮你回答,本回答结合 DeepSeek

    如果还有问题,请评论或点击头像私信,如果有帮助到您, 请关注、采纳、点赞

    1. 关键点分析:
      - 训练集和验证集都有白色背景,导致生成器直接输出纯白色图片。这表明生成器在面对这种特定背景时,未能学习到有效的特征来生成有意义的图像内容。
      - 给出的代码片段中init方法不完整,无法准确判断其功能和是否存在问题,但这可能与生成器初始化不当有关。
    2. 解决方案:
      - 数据预处理
      • 对于有白色背景的图片,可以尝试通过图像分割等技术,将白色背景去除,使得训练数据具有更丰富的前景内容,有助于生成器学习到更有意义的特征。例如,可以使用OpenCV的图像分割方法,如GrabCut算法。
      • 对图像进行归一化处理,使得图像数据在一个合适的范围内,便于模型训练。假设使用PyTorch,可以这样进行归一化:
    import torchvision.transforms as transforms
    
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    
    • 模型改进
      • 检查生成器网络结构,确保其能够有效地捕捉图像特征。例如,增加卷积层的数量或调整卷积核大小等超参数,以增强模型的表示能力。
      • 在生成器中添加注意力机制,如Self - Attention模块,帮助模型更好地关注图像的不同区域,从而避免生成过于简单的纯白色图片。
        ```python
        import torch
        import torch.nn as nn

    class SelfAttention(nn.Module):
    def init(self, in_channels):
    super(SelfAttention, self).init()
    self.query_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
    self.key_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
    self.value_conv = nn.Conv2d(in_channels, in_channels, kernel_size=1)
    self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        batch_size, C, H, W = x.size()
        proj_query = self.query_conv(x).view(batch_size, -1, H * W).permute(0, 2, 1)
        proj_key = self.key_conv(x).view(batch_size, -1, H * W)
        energy = torch.bmm(proj_query, proj_key)
        attention = nn.functional.softmax(energy, dim=-1)
        proj_value = self.value_conv(x).view(batch_size, -1, H * W)
        out = torch.bmm(proj_value, attention.permute(0, 2, 1)).view(batch_size, C, H, W)
        out = self.gamma * out + x
        return out
    

    class Generator(nn.Module):
    def init(self):
    super().init()
    self.attention = SelfAttention(3)
    # 假设这里是完整的生成器网络结构,例如U - Net结构
    self.conv1 = nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1)
    self.conv2 = nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1)
    self.conv3 = nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1)
    self.conv4 = nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1)
    self.conv5 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv6 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv7 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv8 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv9 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv10 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv11 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv12 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv13 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv14 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv15 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv16 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv17 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv18 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv19 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv20 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv21 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv22 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv23 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv24 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv25 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv26 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv27 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv28 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv29 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv30 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv31 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv32 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv33 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv34 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv35 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv36 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv37 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv38 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv39 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv40 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv41 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv42 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv43 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv44 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv45 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv46 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv47 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv48 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv49 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv50 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv51 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv52 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv53 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv54 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv55 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv56 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv57 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv58 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv59 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv60 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv61 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv62 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv63 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv64 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv65 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv66 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv67 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv68 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv69 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv70 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv71 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv72 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv73 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv74 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv75 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv76 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv77 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv78 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv79 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv80 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv81 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv82 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv83 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv84 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv85 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv86 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv87 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv88 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv89 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv90 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv91 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv92 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv93 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv94 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
    self.conv95 = nn.Conv2d(512, 5

    希望以上解答对您有所帮助。如果您有任何疑问,欢迎在评论区提出。

    评论

报告相同问题?

问题事件

  • 创建了问题 5月4日