%KT% 2023-10-07 21:33 采纳率: 0%
浏览 7
已结题

GAN反向传播存在问题

在进行深度学习模型训练时,报错:

img

之前自己也花了大量时间查找原因,大概分析出是因为pytorch版本不符造成的。pytorch1.4版本之前的反向传播过程和1.4版本之后的有些许不同。但配置低版本虚拟环境过程过于繁琐,暂时不考虑该方案。
这是原始代码:

    def process(self, image):
        # process_outputs
        seg_mask, image_rec = self(image)
        """
        G and D process, this package is reusable
        """
        # zero optimizers
        self.optimizer_G.zero_grad()
        self.optimizer_D.zero_grad()
        gen_loss = 0
        dis_loss = 0
        real_B = image
        fake_B = image_rec

        # discriminator loss
        dis_input_real = real_B
        dis_input_fake = fake_B.detach()
        dis_real, dis_real_feat = self.model_D(dis_input_real)
        dis_fake, dis_fake_feat = self.model_D(dis_input_fake)
        dis_real_loss = self.adversarial_loss(dis_real, True, True)
        dis_fake_loss = self.adversarial_loss(dis_fake, False, True)
        dis_loss += (dis_real_loss + dis_fake_loss) / 2

        # generator adversarial loss
        gen_input_fake = fake_B
        gen_fake, gen_fake_feat = self.model_D(gen_input_fake)
        gen_gan_loss = self.adversarial_loss(gen_fake, True, False) * self.args.lamd_gen
        gen_loss += gen_gan_loss
        # generator feature matching loss
        gen_fm_loss = 0
        for i in range(len(dis_real_feat)):
            gen_fm_loss += self.l1_loss(gen_fake_feat[i], dis_real_feat[i].detach())
        gen_fm_loss = gen_fm_loss * self.args.lamd_fm
        gen_loss += gen_fm_loss
        # generator l1 loss
        gen_l1_loss = self.l1_loss(fake_B, real_B) * self.args.lamd_p
        gen_loss += gen_l1_loss

        # Backward and optimize discriminator
        dis_loss.backward()
        self.optimizer_D.step()
        # Backward and optimize generator
        gen_loss.backward()
        self.optimizer_G.step()
        # create logs
        logs = dict(
            gen_gan_loss=gen_gan_loss,
            gen_fm_loss=gen_fm_loss,
            gen_l1_loss=gen_l1_loss,
            # gen_content_loss=gen_content_loss,
            # gen_style_loss=gen_style_loss,
        )
        return seg_mask, fake_B, gen_loss, dis_loss, logs

目前已经实验过的方法有:
1、将dis_loss.backward()调到gen_loss.backward()前面,此方法虽能解决报错,但理论上仍存在错误,模型无法收敛
2、添加retain_graph=True操作,无法解决问题

希望有人能给出代码的完整修改方案。
https://github.com/pytorch/pytorch/issues/39141 此链接是pytorch开发人员针对该问题给出的方案,但本人刚学习神经网络,不懂得如何照着修改,可参考修改。

  • 写回答

13条回答 默认 最新

  • 玥轩_521 2023-10-07 22:34
    关注
    获得0.45元问题酬金

    援引通义千问:在进行深度学习模型训练时,如果使用的是pytorch1.4版本之前的反向传播过程,可能会导致报错。解决这个问题的方法是配置低版本的虚拟环境,但这会比较繁琐。目前你已经尝试过的方法包括将dis_loss.backward()调到gen_loss.backward()前面,以及添加retain_graph=True操作,但这些方法都无法解决问题。
    建议你可以参考pytorch开发人员给出的方案,修改代码以解决这个问题。如果你不熟悉神经网络,可能需要找一个熟悉神经网络的人帮助你。你也可以在网上查找关于如何修改pytorch代码的相关教程和资料,以帮助你解决问题。

    评论

报告相同问题?

问题事件

  • 系统已结题 10月15日
  • 修改了问题 10月7日
  • 赞助了问题酬金15元 10月7日
  • 创建了问题 10月7日

悬赏问题

  • ¥15 关于#c语言#的问题,请各位专家解答!
  • ¥15 这个如何解决详细步骤
  • ¥15 在微信h5支付申请中,别人给钱就能用我的软件,这个的所属行业是啥?
  • ¥30 靶向捕获探针设计软件包
  • ¥15 别人给钱就能用我的软件,这个的经营场景是啥?
  • ¥15 react-diff-viewer组件,如何解决数据量过大卡顿问题
  • ¥20 遥感植被物候指数空间分布图制作
  • ¥15 安装了xlrd库但是import不了…
  • ¥20 Github上传代码没有contribution和activity记录
  • ¥20 SNETCracker