我正在写一个ConditionGAN,但是学习率调整了很久生成器和判别器的loss一直不平衡,请问应该如何设置学习率比较合适?
我的代码如下:
criterion = nn.MSELoss()
G_optimizer = Adam(generator.parameters(), lr=5e-4, betas=(0.5, 0.999))
D_optimizer = Adam(discriminator.parameters(), lr=1e-5, betas=(0.5, 0.999))
G_scheduler = lr_scheduler.CosineAnnealingLR(G_optimizer, T_max=num_epochs)
D_scheduler = lr_scheduler.CosineAnnealingLR(D_optimizer, T_max=num_epochs)
G_train_losses = []
D_train_losses = []
lambda_gp = 10
num_G = 1
#训练
for epoch in range(num_epochs):
G_losses = 0
D_losses = 0
for i, (images, audios) in enumerate(data_loader):
images = images.to(device)
audios = audios.to(device)
#生成图像情感标签
with torch.no_grad():
image_emotion = emotion_classification(images)
image_emotion = torch.argmax(image_emotion, dim=1)
noise = torch.randn(audios.shape[0], 100).to(device)
#判别器训练
D_optimizer.zero_grad()
#真实图像
real_validity = discriminator(audios, image_emotion)
real_loss = criterion(real_validity, torch.ones_like(real_validity).to(device))
#假图像
z = torch.randn(audios.shape[0], 100).to(device)
fake_labels = torch.randint(0, 7, (audios.shape[0],)).to(device)
fake_images = generator(z, fake_labels)
fake_validity = discriminator(fake_images, fake_labels)
fake_loss = criterion(fake_validity, torch.zeros_like(fake_validity).to(device))
#梯度惩罚
gp = gradient_penalty(discriminator, audios, fake_images, image_emotion)
#总损失
D_loss = real_loss + fake_loss + lambda_gp * gp
D_losses += D_loss.item()
D_loss.backward()
D_optimizer.step()
for j in range(num_G):
#生成器训练
generator.train()
G_optimizer.zero_grad()
#生成假图像
z = torch.randn(audios.shape[0], 100).to(device)
fake_labels = torch.randint(0, 7, (audios.shape[0],)).to(device)
fake_images = generator(z, fake_labels)
#判别器判断
validity = discriminator(fake_images, fake_labels)
G_loss = criterion(validity, torch.ones_like(validity).to(device))
G_losses += G_loss.item()
G_loss.backward()
G_optimizer.step()
#打印进度条
total_batches = len(data_loader)
progress = (i + 1) / total_batches * 100
bar_length = 30
filled_length = int(bar_length * (i + 1) // total_batches)
bar = '=' * filled_length + '-' * (bar_length - filled_length)
print(f'\rEpoch {epoch+1}/{num_epochs} Training: {progress:3.0f}%|{bar}| {i+1}/{len(data_loader)}', end=' ', flush=True)
G_train_losses.append(G_losses / (num_G*len(data_loader)))
D_train_losses.append(D_losses / len(data_loader))
print(f"[epoch={epoch + 1:3d}] generator loss: {G_train_losses[epoch]:.4f} discriminator loss: {D_train_losses[epoch]:.4f}")
generator.eval()
z = torch.randn(7, 100).to(device)
labels = torch.tensor([0, 1, 2, 3, 4, 5, 6]).to(device)
sample_audio = generator(z, labels).squeeze().data.cpu()
#绘制生成的音频
plt.figure(figsize=(20, 16))
for i in range(7):
audio = sample_audio[i]
audio = (audio + 1) / 2 * (global_max - global_min) + global_min
audio = audio.expm1()
spec_db = T.AmplitudeToDB()(audio)
spec_db = spec_db.squeeze(0)
plt.subplot(1, 8, i+1)
plt.imshow(spec_db)
plt.tight_layout()
plt.show()
G_scheduler.step()
D_scheduler.step()