qq_34326616 2025-04-20 00:08 采纳率: 0%
浏览 28

ConditionGAN的学习率一般设置多少比较合适

我正在写一个ConditionGAN,但是学习率调整了很久生成器和判别器的loss一直不平衡,请问应该如何设置学习率比较合适?
我的代码如下:

criterion = nn.MSELoss()
G_optimizer = Adam(generator.parameters(), lr=5e-4, betas=(0.5, 0.999))
D_optimizer = Adam(discriminator.parameters(), lr=1e-5, betas=(0.5, 0.999))
G_scheduler = lr_scheduler.CosineAnnealingLR(G_optimizer, T_max=num_epochs)   
D_scheduler = lr_scheduler.CosineAnnealingLR(D_optimizer, T_max=num_epochs) 


G_train_losses = []
D_train_losses = []
lambda_gp = 10
num_G = 1

#训练
for epoch in range(num_epochs):
    G_losses = 0
    D_losses = 0

    for i, (images, audios) in enumerate(data_loader):

        images = images.to(device)
        audios = audios.to(device)
        #生成图像情感标签
        with torch.no_grad():
            image_emotion = emotion_classification(images)
            image_emotion = torch.argmax(image_emotion, dim=1)
        noise = torch.randn(audios.shape[0], 100).to(device)

        
        #判别器训练
        D_optimizer.zero_grad()

        #真实图像
        real_validity = discriminator(audios, image_emotion)
        real_loss = criterion(real_validity, torch.ones_like(real_validity).to(device))

        #假图像
        z = torch.randn(audios.shape[0], 100).to(device)
        fake_labels = torch.randint(0, 7, (audios.shape[0],)).to(device)
        fake_images = generator(z, fake_labels)
        fake_validity = discriminator(fake_images, fake_labels)
        fake_loss = criterion(fake_validity, torch.zeros_like(fake_validity).to(device))

        #梯度惩罚
        gp = gradient_penalty(discriminator, audios, fake_images, image_emotion)

        #总损失
        D_loss = real_loss + fake_loss + lambda_gp * gp
        D_losses += D_loss.item()
        D_loss.backward()
        D_optimizer.step()

        for j in range(num_G):
            #生成器训练
            generator.train()
            G_optimizer.zero_grad()

            #生成假图像
            z = torch.randn(audios.shape[0], 100).to(device)
            fake_labels = torch.randint(0, 7, (audios.shape[0],)).to(device)
            fake_images = generator(z, fake_labels)

            #判别器判断
            validity = discriminator(fake_images, fake_labels)
            G_loss = criterion(validity, torch.ones_like(validity).to(device))
            G_losses += G_loss.item()
            G_loss.backward()
            G_optimizer.step()


        #打印进度条
        total_batches = len(data_loader)
        progress = (i + 1) / total_batches * 100
        bar_length = 30
        filled_length = int(bar_length * (i + 1) // total_batches)
        bar = '=' * filled_length + '-' * (bar_length - filled_length)
        print(f'\rEpoch {epoch+1}/{num_epochs}       Training: {progress:3.0f}%|{bar}| {i+1}/{len(data_loader)}', end='  ', flush=True)


    G_train_losses.append(G_losses / (num_G*len(data_loader)))
    D_train_losses.append(D_losses / len(data_loader))
    print(f"[epoch={epoch + 1:3d}]  generator loss: {G_train_losses[epoch]:.4f}  discriminator loss: {D_train_losses[epoch]:.4f}")

    generator.eval()
    z = torch.randn(7, 100).to(device)
    labels = torch.tensor([0, 1, 2, 3, 4, 5, 6]).to(device)
    sample_audio = generator(z, labels).squeeze().data.cpu()

    #绘制生成的音频
    plt.figure(figsize=(20, 16))
    for i in range(7):
        audio = sample_audio[i]
        audio = (audio + 1) / 2 * (global_max - global_min) + global_min
        audio = audio.expm1()
        spec_db = T.AmplitudeToDB()(audio)
        spec_db = spec_db.squeeze(0)
        plt.subplot(1, 8, i+1)
        plt.imshow(spec_db)
    plt.tight_layout()
    plt.show()

    G_scheduler.step()
    D_scheduler.step()

  • 写回答

4条回答 默认 最新

  • 阿里嘎多学长 2025-04-20 00:08
    关注

    阿里嘎多学长整理AIGC生成,因移动端显示问题导致当前答案未能完全显示,请使用PC端查看更加详细的解答过程

    ConditionGAN学习率设置

    ConditionGAN是一个生成对抗网络(GAN),学习率的设置对模型的收敛和loss的平衡非常重要。一般来说,学习率太高可能会导致模型的loss不稳定,太低可能会导致模型的收敛速度太慢。

    在你的情况下,你可以尝试以下几种方法来调整学习率:

    1. 使用 Cyclical Learning Rate(CLR):CLR是一种学习率调整策略,它可以帮助模型在不同的学习率下训练,从而避免了学习率太高或太低的问题。
    import torch.optim as optim
    
    criterion = nn.MSELoss()
    G_optimizer = optim.Adam(G.parameters(), lr=0.001)
    D_optimizer = optim.Adam(D.parameters(), lr=0.001)
    
    for epoch in range(num_epochs):
        for i, batch in enumerate(train_loader):
            # ...
            G_optimizer.zero_grad()
            D_optimizer.zero_grad()
            # ...
            G_optimizer.step()
            D_optimizer.step()
    
    1. 使用 Learning Rate Scheduler:Learning Rate Scheduler是一种策略,它可以根据模型的loss和学习率来调整学习率。
    import torch.optim as optim
    
    criterion = nn.MSELoss()
    G_optimizer = optim.Adam(G.parameters(), lr=0.001)
    D_optimizer = optim.Adam(D.parameters(), lr=0.001)
    
    scheduler = optim.lr_scheduler.StepLR(G_optimizer, step_size=5, gamma=0.5)
    scheduler = optim.lr_scheduler.StepLR(D_optimizer, step_size=5, gamma=0.5)
    
    for epoch in range(num_epochs):
        for i, batch in enumerate(train_loader):
            # ...
            G_optimizer.zero_grad()
            D_optimizer.zero_grad()
            # ...
            G_optimizer.step()
            D_optimizer.step()
        scheduler.step()
    
    1. 使用 Grid Search:Grid Search是一种简单的方法,可以尝试不同的学习率组合来找到最合适的学习率。
    import torch.optim as optim
    
    criterion = nn.MSELoss()
    G_optimizer = optim.Adam(G.parameters(), lr=0.001)
    D_optimizer = optim.Adam(D.parameters(), lr=0.001)
    
    lr_values = [0.001, 0.005, 0.01]
    for lr in lr_values:
        G_optimizer = optim.Adam(G.parameters(), lr=lr)
        D_optimizer = optim.Adam(D.parameters(), lr=lr)
        for epoch in range(num_epochs):
            for i, batch in enumerate(train_loader):
                # ...
                G_optimizer.zero_grad()
                D_optimizer.zero_grad()
                # ...
                G_optimizer.step()
                D_optimizer.step()
        # 评估模型的loss和性能
        # ...
    

    这些方法可以帮助你找到合适的学习率,但是需要注意的是,每个模型的学习率设置都需要根据实际情况进行调整。

    评论

报告相同问题?

问题事件

  • 创建了问题 4月20日