小瘪️ 2022-04-12 11:38 采纳率: 40%
浏览 31

有关深度学习去雨的问题

问题遇到的现象和发生背景

各位大佬好!我是刚接触深度学习的小白。我想复现一篇用于单幅图像去噪的非局部增强编解码网络(Non-locally Enhanced Encoder-Decoder Network for SingleImage De-raining)的论文。该论文没有给出train.py文件,我尝试自己编写遇到了一些问题。

问题相关代码

在我训练模型时,生成的图片非常的暗,并且效果也不是很好。我也不知道是什么原因。
其中左侧是带有噪声的图片,中间是我训练出的图片。右侧是原图。

img

img

img

以下是train.py


# ======================= Pytorch Lib =============================
import torch.nn as nn
import torch
from torchvision import transforms
# ======================= My Lib ===================================
from lib.NLEDN import NLEDN
from lib.data_loader_train import DataSet_Train
from lib.utils import calc_psnr, calc_ssim
# ======================= Config File ===============================
import config as cfg
# ======================= Origin Lib ================================
import os
import time
import torch.optim as optim

# ======================= Config ===================================
print('-' * 40)
print('cuda number:', cfg.CUDA_NUMBER, '\n')
print('train dir:', cfg.train_dir)

# ======================= DataSet ===================================
dataset_train = DataSet_Train(cfg)
train_batches = dataset_train.train_loader.__len__()
train_samples = dataset_train.train_dataset.__len__()

print('Train: %d batches, %d samples' % (train_batches, train_samples))
print('-' * 40 + '\n')

# ==================== Network ======================
net = NLEDN()



# ================== Network to GPU =========================
if torch.cuda.is_available():
    net.cuda(cfg.CUDA_NUMBER)

total_pnsr = 0
total_ssim = 0


# opt = torch.optim.Adam(net.parameters(), lr=0.0005)
optimizerG = optim.Adam(net.parameters(), lr = 0.0005, betas = (0.9, 0.999), weight_decay=0.0001)
loss_func = nn.L1Loss()



epoch_index = 0
start_time = time.time()
for epoch_index in range(10):
    
    total_train_step = 0

    # torch.set_grad_enabled(True)
    net.train()
    for batch_index, (img_batch, label_batch, name_list) in enumerate(dataset_train.train_loader):
        print('[%d/%d]' % (batch_index, train_batches), name_list[0])
        if torch.cuda.is_available():
            img_batch = img_batch.cuda(cfg.CUDA_NUMBER)
            label_batch = label_batch.cuda(cfg.CUDA_NUMBER)

        label_res_batch = img_batch - label_batch

        # ------------------------ Res Predict ------------------------
        prediction_res_batch = net(img_batch)
        

        loss = loss_func(prediction_res_batch, label_batch)

        prediction_batch = img_batch - prediction_res_batch
        prediction_batch = torch.clamp(prediction_batch, 0, 1)

        net.zero_grad()
        loss.backward()
        optimizerG.step()


        # ------------------------ Save Image And Calc Metric------------------------
        prediction_PIL = transforms.ToPILImage()(prediction_batch[0].cpu().data)
        label_batch_PIL = transforms.ToPILImage()(label_batch[0].cpu().data)

        pnsr = calc_psnr(prediction_PIL, label_batch_PIL)
        ssim = calc_ssim(prediction_PIL, label_batch_PIL)
        total_pnsr += pnsr
        total_ssim += ssim

        img = torch.cat([img_batch, prediction_batch, label_batch], dim=3)
        img = transforms.ToPILImage()(img[0].cpu().data)

        img.save(os.path.join(cfg.train_compare_results_dir, name_list[0]), format='png')
        prediction_PIL.save(os.path.join(cfg.train_results_dir, name_list[0]), format='png')

        total_train_step = total_train_step + 1
        # if total_train_step % 100 == 0:
        end_time = time.time()
        print(end_time - start_time)
        print("训练次数:{},Loss:{}".format(total_train_step, loss.item()))

    torch.save(net, "net_{}.pth".format(epoch_index))
    # torch.save(net.state_dic(), "/weights_myself/net_{}.pth".format(epoch_index))
    
    
mean_pnsr = total_pnsr / test_batches
mean_ssim = total_ssim / test_batches
print('PNSR:%.4f SSIM:%.4f' % (mean_pnsr, mean_ssim))



这是论文的源代码:https://github.com/AlexHex7/NLEDN

  • 写回答

1条回答 默认 最新

  • 不会长胖的斜杠 后端领域新星创作者 2022-04-12 11:43
    关注

    大概率是模型问题,训练得到的模型参数还有问题

    评论

报告相同问题?

问题事件

  • 创建了问题 4月12日

悬赏问题

  • ¥15 C语言使用vscode编码错误
  • ¥15 用KSV5转成本时,如何不生成那笔中间凭证
  • ¥20 ensp怎么配置让PC1和PC2通讯上
  • ¥50 有没有适合匹配类似图中的运动规律的图像处理算法
  • ¥15 dnat基础问题,本机发出,别人返回的包,不能命中
  • ¥15 请各位帮我看看是哪里出了问题
  • ¥15 vs2019的js智能提示
  • ¥15 关于#开发语言#的问题:FDTD建模问题图中代码没有报错,但是模型却变透明了
  • ¥15 uniapp的h5项目写一个抽奖动画
  • ¥15 hadoop中启动hive报错如下怎么解决