qq_29780267 2023-03-30 16:26 采纳率: 0%
浏览 10

对抗自编码器训练出错

问题遇到的现象和发生背景

这是一个ADAE的训练过程,但是编码器的损失一直在0.5到0.6之间,判别器的损失在0.6左右,生成器损失接近0,请问是哪里存在问题?


 for i, (data,label) in enumerate(dataloader):
            P.zero_grad()
            Q.zero_grad()
            D.zero_grad()
            data = np.array(data)
            data = torch.tensor(data)
            #训练自编码器
            noise = torch.rand(data.shape).to(device)
            data_noise = data + noise * 0.1
            data_noise = torch.clamp(data_noise,0.,1.)

            code = Q(data_noise)  # 将真实图片放入判别器中
            decode=P(code)
            loss = F.binary_cross_entropy(decode+ EPS, data+ EPS)
            #loss=nn.MSELoss(decode.float(), data.float())
            # ae_optimizer.zero_grad()  # 在反向传播之前,先将梯度归0
            loss.backward()  # 将误差反向传播
            optim_P.step()
            optim_Q_enc.step()
            
            #训练判别器
            Q.eval()
            batch_size = data.shape[0]
            output_real = D(data)
            
            fake_data = Q(data_noise)
            output_fake = D(fake_data.detach())
            
            D_loss = -torch.mean(torch.log(output_real + EPS) + torch.log(1 - output_fake + EPS))

            D_loss.backward()
            optimizerD.step()
            
            #训练生成器
            Q.train()
            z_fake= Q(data_noise)
            D_fake = D(z_fake)
    
            errAE = -torch.mean(torch.log(1-D_fake+EPS))
            errAE.backward()
            optim_Q_gen.step()  

            # try:
            if i%100==0:
                    print('[%d/%d][%d/%d] loss: %.4f D_loss: %.4f '
                           'errAE: %.6f '
                          % (epoch, nepoch, i, len(dataloader),
                             loss,
                             D_loss,
                             errAE))


遇到的现象和发生背景,请写出第一个错误信息
用代码块功能插入代码,请勿粘贴截图。 不用代码块回答率下降 50%
运行结果及详细报错内容
[994/1000][0/10] loss: 0.5886 D_loss: 0.6867 errAE: 0.000097 
[995/1000][0/10] loss: 0.5867 D_loss: 0.6893 errAE: 0.000543 
[996/1000][0/10] loss: 0.5960 D_loss: 0.6840 errAE: 0.000071 
[997/1000][0/10] loss: 0.5951 D_loss: 0.7006 errAE: 0.000915 
[998/1000][0/10] loss: 0.5907 D_loss: 0.6920 errAE: 0.000074 
[999/1000][0/10] loss: 0.5930 D_loss: 0.6931 errAE: 0.000026 

我的解答思路和尝试过的方法,不写自己思路的,回答率下降 60%
我想要达到的结果,如果你需要快速回答,请尝试 “付费悬赏”
  • 写回答

1条回答 默认 最新

  • CSDN-Ada助手 CSDN-AI 官方账号 2023-03-30 19:18
    关注
    评论

报告相同问题?

问题事件

  • 创建了问题 3月30日

悬赏问题

  • ¥15 fluent里模拟降膜反应的UDF编写
  • ¥15 MYSQL 多表拼接link
  • ¥15 关于某款2.13寸墨水屏的问题
  • ¥15 obsidian的中文层级自动编号
  • ¥15 同一个网口一个电脑连接有网,另一个电脑连接没网
  • ¥15 神经网络模型一直不能上GPU
  • ¥15 pyqt怎么把滑块和输入框相互绑定,求解决!
  • ¥20 wpf datagrid单元闪烁效果失灵
  • ¥15 券商软件上市公司信息获取问题
  • ¥100 ensp启动设备蓝屏,代码clock_watchdog_timeout