我使用GITHUB开源项目KAIR源码框架训练EDSR超分辨率模型,模型能够跑的通但是为什么最后超分辨率输出的图像会是这个样子
这两张图是两次不同迭代代数的效果
并且输出的训练日志中PSNR的数值波动的十分诡异,下面这些是我截图的训练日志
明显看到PSNR从10dB掉到1dB,又再次回到10dB
使用的是自制的数据集,低分辨率图像尺寸是6464 高分辨率是128128,模型是在GoogleColab环境下训练的
这是EDSR模型的源码
from models import common
import torch.nn as nn
url = {
'r16f64x2': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x2-1bc95232.pt',
'r16f64x3': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x3-abf2a44e.pt',
'r16f64x4': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x4-6b446fab.pt',
'r32f256x2': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x2-0edfb8a3.pt',
'r32f256x3': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x3-ea3ef2c6.pt',
'r32f256x4': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x4-4f62e9ef.pt'
}
def make_model(args, parent=False):
return EDSR(args)
class EDSR(nn.Module):
def __init__(self, args, conv=common.default_conv):
super(EDSR, self).__init__()
n_resblocks = args['n_resblocks']
n_feats = args['n_feats']
kernel_size = 3
scale = args['scale']
act = nn.ReLU(True)
url_name = 'r{}f{}x{}'.format(n_resblocks, n_feats, scale)
if url_name in url:
self.url = url[url_name]
else:
self.url = None
self.sub_mean = common.MeanShift(args['rgb_range'])
self.add_mean = common.MeanShift(args['rgb_range'], sign=1)
# define head module
m_head = [conv(args['n_colors'], n_feats, kernel_size)]
# define body module
m_body = [
common.ResBlock(
conv, n_feats, kernel_size, act=act, res_scale=args['res_scale']
) for _ in range(n_resblocks)
]
m_body.append(conv(n_feats, n_feats, kernel_size))
# define tail module
m_tail = [
common.Upsampler(conv, scale, n_feats, act=False),
conv(n_feats, args['n_colors'], kernel_size)
]
self.head = nn.Sequential(*m_head)
self.body = nn.Sequential(*m_body)
self.tail = nn.Sequential(*m_tail)
def forward(self, x):
x = self.sub_mean(x)
x = self.head(x)
res = self.body(x)
res += x
x = self.tail(res)
x = self.add_mean(x)
return x
def load_state_dict(self, state_dict, strict=True):
own_state = self.state_dict()
for name, param in state_dict.items():
if name in own_state:
if isinstance(param, nn.Parameter):
param = param.data
try:
own_state[name].copy_(param)
except Exception:
if name.find('tail') == -1:
raise RuntimeError('While copying the parameter named {}, '
'whose dimensions in the model are {} and '
'whose dimensions in the checkpoint are {}.'
.format(name, own_state[name].size(), param.size()))
elif strict:
if name.find('tail') == -1:
raise KeyError('unexpected key "{}" in state_dict'
.format(name))
这其中的arg是来自于json文件中的配置参数,我是按照KAIR中的框架结构自己添加的模型,也就是说本身KAIR是没有EDSR这个模型的