加油啊利酱! 2021-08-23 23:33 采纳率: 0%
浏览 163
已结题

SRCNN,训练的效果很差, 应该怎么改

我复现SRCNN的时候,不知道为什么model跑个200个epochs就没办法优化了,导致效果像抛硬币,有的图片提升了一丢丢,有的完全是变得更差了.
我用的训练图像是原作者提供的。

import numpy as np
from matplotlib import pyplot as plt
import sys
import keras
import cv2
import numpy
from keras.models import Sequential
from keras.layers import Conv2D
from keras.optimizers import Adam
from skimage.measure import compare_ssim as ssim
import cv2
import math
import random
from keras.callbacks import ModelCheckpoint
import os
from tensorflow.keras.callbacks import ModelCheckpoint


def psnr(target, ref):
    target_data = target.astype(float)
    ref_data = ref.astype(float)
    diff = ref_data - target_data
    diff = diff.flatten('C')
    rmse = math.sqrt(np.mean(diff ** 2.))

    return 20 * math.log10(255. / rmse)


def mse(target, ref):
    err = np.sum((target.astype('float') - ref.astype('float')) ** 2)
    err /= float(target.shape[0] * target.shape[1])
    return err


def compare_images(target, ref):
    scores = []
    scores.append(psnr(target, ref))
    scores.append(mse(target, ref))
    scores.append(ssim(target, ref, multichannel =True))
    
    return scores

def modcrop(img, scale):
    tmpsz = img.shape
    sz = tmpsz[0:2]
    
    # np.mod 是sz%scale
    sz = sz - np.mod(sz, scale)
    
    img = img[0:sz[0], 0:sz[1]]
    return img


def shave(image, border):
    "把周围去掉"
    img = image[border: -border, border: -border]
    return img


path = './Train'
deg=[]
ref=[]
# deg=np.array()
# ref = np.array()
count=0
for file in os.listdir('./Train'):
    if file != ".DS_Store":
        ref_e = cv2.imread(path+'/'+file)
        ref_e = cv2.cvtColor(ref_e, cv2.COLOR_BGR2YCrCb)

        ref_e=ref_e[:,:,1]
        ref_e=modcrop(ref_e,3)
#         size = ref_e.shape[0], ref_e.shape[1]
#         print(size)
        
        
        
        h = ref_e.shape[0]
        w = ref_e.shape[1]
        
        new_height = h // 2
        new_width = w // 2
        deg_e=cv2.resize(cv2.resize(ref_e,(new_width,new_height)),(w,h))
        
        temp1=np.zeros((32,32,1))
        temp2=np.zeros((20,20,1))
        
        for x in range(0,ref_e.shape[0]-33,14):
            for y in range(0,ref_e.shape[1]-33,14):
                temp1[:,:,0] = deg_e[x:x + 33 - 1, y:y + 33 - 1].astype(float) / 255;
                temp2[:,:,0] = ref_e[x + 6 : x + 6 + 21 - 1, y + 6 : y + 6 + 21 - 1].astype(float) / 255;
                deg.append(temp1)
                ref.append(temp2)
                
ref = np.array(ref)
deg = np.array(deg)


def model():

    SRCNN = Sequential()
    SRCNN.add(Conv2D(filters=64, kernel_size = (9, 9), activation='relu', padding='valid', use_bias=True, input_shape=(32, 32, 1)))
    SRCNN.add(Conv2D(filters=32, kernel_size = (3, 3), activation='relu', padding='same', use_bias=True))
    SRCNN.add(Conv2D(filters=1, kernel_size = (5, 5), padding='valid', use_bias=True))
    SRCNN.compile(optimizer='adam', loss='mean_squared_error', metrics=['mean_squared_error'])
    
    return SRCNN


checkpoint = ModelCheckpoint("SRCNN_check_1.h5", monitor='val_loss', verbose=1, save_best_only=True, save_weights_only=False, mode='min')
callbacks_list = [checkpoint]
history = srcnn.fit(deg,ref, validation_split=0.33, epochs=200, batch_size=128, verbose=1, callbacks=callbacks_list)


def predict(image_path):
    
    # load the degraded and reference images
    path, file = os.path.split(image_path)
    degraded = cv2.imread(image_path)
    ref = cv2.imread('Train/{}'.format(file))
    
    # preprocess the image with modcrop
    ref = modcrop(ref, 3)
    degraded = modcrop(degraded, 3)
    
    temp = cv2.cvtColor(degraded, cv2.COLOR_BGR2YCrCb)
    
    # create image slice and normalize  
    Y = numpy.zeros((1, temp.shape[0], temp.shape[1], 1), dtype=float)
    Y[0, :, :, 0] = temp[:, :, 0].astype(float) / 255
    
    srcnn.load_weights('SRCNN_check_1.h5')
    # perform super-resolution with srcnn
    pre = srcnn.predict(Y, batch_size=1)
    
    # post-process output
    pre *= 255
    pre[pre[:] > 255] = 255
    pre[pre[:] < 0] = 0
    pre = pre.astype(np.uint8)
    
    # copy Y channel back to image and convert to BGR
    temp = shave(temp, 6)
    temp[:, :, 0] = pre[0, :, :, 0]
    output = cv2.cvtColor(temp, cv2.COLOR_YCrCb2BGR)
    
    # remove border from reference and degraged image
    ref = shave(ref.astype(np.uint8), 6)
    degraded = shave(degraded.astype(np.uint8), 6)
    
    # image quality calculations
    scores = []
    scores.append(compare_images(degraded, ref))
    scores.append(compare_images(output, ref))
    
    # return images and scores
    return ref, degraded, output, scores


ref, degraded, output, scores= predict('train_lr/butterfly_GT.bmp')
print(degraded.shape)
print(output.shape)
print('Degraded Image: \nPSNR: {}\nMSE: {}\nSSIM: {}\n'.format(scores[0][0], scores[0][1], scores[0][2]))
print('Reconstructed Image: \nPSNR: {}\nMSE: {}\nSSIM: {}\n'.format(scores[1][0], scores[1][1], scores[1][2]))

fig, axs = plt.subplots(1, 3, figsize=(20, 8))
axs[0].imshow(cv2.cvtColor(degraded, cv2.COLOR_BGR2RGB))
axs[0].set_title('Degraded')
axs[1].imshow(cv2.cvtColor(output, cv2.COLOR_BGR2RGB))
axs[1].set_title('SRCNN')
axs[2].imshow(cv2.cvtColor(ref, cv2.COLOR_BGR2RGB))
axs[2].set_title('orginal')
# plt.savefig("test2.jpg")
# remove the x and y ticks
for ax in axs:
    ax.set_xticks([])
    ax.set_yticks([])
  • 写回答

1条回答 默认 最新

  • 爱晚乏客游 2021-08-24 01:22
    关注

    这种问题如果原作者的代码没有问题的话,你的设置也没有问题的话,那么就是你的数据量的问题了,数据量是否足够大,分布是否合理?
    还有一种你要看下原作者的效果能达到什么样子,有些时候就是网络的瓶颈在那里的,一旦是这种情况那么久需要你自己根据具体的问题修改一些参数了,这种最难了。

    评论

报告相同问题?

问题事件

  • 系统已结题 8月31日
  • 修改了问题 8月24日
  • 赞助了问题酬金 8月24日
  • 赞助了问题酬金 8月24日
  • 展开全部

悬赏问题

  • ¥15 Xsheii7我安装这个文件的时候跳出来另一个文件已锁定文件的无一部分进程无法访问。这个该怎么解决
  • ¥15 unity terrain打包后地形错位,跟建筑不在同一个位置,怎么办
  • ¥15 FileNotFoundError 解决方案
  • ¥15 uniapp实现如下图的图表功能
  • ¥15 u-subsection如何修改相邻两个节点样式
  • ¥30 vs2010开发 WFP(windows filtering platform)
  • ¥15 服务端控制goose报文控制块的发布问题
  • ¥15 学习指导与未来导向啊
  • ¥15 求多普勒频移瞬时表达式
  • ¥15 如果要做一个老年人平板有哪些需求