我复现SRCNN的时候,不知道为什么model跑个200个epochs就没办法优化了,导致效果像抛硬币,有的图片提升了一丢丢,有的完全是变得更差了.
我用的训练图像是原作者提供的。
import numpy as np
from matplotlib import pyplot as plt
import sys
import keras
import cv2
import numpy
from keras.models import Sequential
from keras.layers import Conv2D
from keras.optimizers import Adam
from skimage.measure import compare_ssim as ssim
import cv2
import math
import random
from keras.callbacks import ModelCheckpoint
import os
from tensorflow.keras.callbacks import ModelCheckpoint
def psnr(target, ref):
target_data = target.astype(float)
ref_data = ref.astype(float)
diff = ref_data - target_data
diff = diff.flatten('C')
rmse = math.sqrt(np.mean(diff ** 2.))
return 20 * math.log10(255. / rmse)
def mse(target, ref):
err = np.sum((target.astype('float') - ref.astype('float')) ** 2)
err /= float(target.shape[0] * target.shape[1])
return err
def compare_images(target, ref):
scores = []
scores.append(psnr(target, ref))
scores.append(mse(target, ref))
scores.append(ssim(target, ref, multichannel =True))
return scores
def modcrop(img, scale):
tmpsz = img.shape
sz = tmpsz[0:2]
# np.mod 是sz%scale
sz = sz - np.mod(sz, scale)
img = img[0:sz[0], 0:sz[1]]
return img
def shave(image, border):
"把周围去掉"
img = image[border: -border, border: -border]
return img
path = './Train'
deg=[]
ref=[]
# deg=np.array()
# ref = np.array()
count=0
for file in os.listdir('./Train'):
if file != ".DS_Store":
ref_e = cv2.imread(path+'/'+file)
ref_e = cv2.cvtColor(ref_e, cv2.COLOR_BGR2YCrCb)
ref_e=ref_e[:,:,1]
ref_e=modcrop(ref_e,3)
# size = ref_e.shape[0], ref_e.shape[1]
# print(size)
h = ref_e.shape[0]
w = ref_e.shape[1]
new_height = h // 2
new_width = w // 2
deg_e=cv2.resize(cv2.resize(ref_e,(new_width,new_height)),(w,h))
temp1=np.zeros((32,32,1))
temp2=np.zeros((20,20,1))
for x in range(0,ref_e.shape[0]-33,14):
for y in range(0,ref_e.shape[1]-33,14):
temp1[:,:,0] = deg_e[x:x + 33 - 1, y:y + 33 - 1].astype(float) / 255;
temp2[:,:,0] = ref_e[x + 6 : x + 6 + 21 - 1, y + 6 : y + 6 + 21 - 1].astype(float) / 255;
deg.append(temp1)
ref.append(temp2)
ref = np.array(ref)
deg = np.array(deg)
def model():
SRCNN = Sequential()
SRCNN.add(Conv2D(filters=64, kernel_size = (9, 9), activation='relu', padding='valid', use_bias=True, input_shape=(32, 32, 1)))
SRCNN.add(Conv2D(filters=32, kernel_size = (3, 3), activation='relu', padding='same', use_bias=True))
SRCNN.add(Conv2D(filters=1, kernel_size = (5, 5), padding='valid', use_bias=True))
SRCNN.compile(optimizer='adam', loss='mean_squared_error', metrics=['mean_squared_error'])
return SRCNN
checkpoint = ModelCheckpoint("SRCNN_check_1.h5", monitor='val_loss', verbose=1, save_best_only=True, save_weights_only=False, mode='min')
callbacks_list = [checkpoint]
history = srcnn.fit(deg,ref, validation_split=0.33, epochs=200, batch_size=128, verbose=1, callbacks=callbacks_list)
def predict(image_path):
# load the degraded and reference images
path, file = os.path.split(image_path)
degraded = cv2.imread(image_path)
ref = cv2.imread('Train/{}'.format(file))
# preprocess the image with modcrop
ref = modcrop(ref, 3)
degraded = modcrop(degraded, 3)
temp = cv2.cvtColor(degraded, cv2.COLOR_BGR2YCrCb)
# create image slice and normalize
Y = numpy.zeros((1, temp.shape[0], temp.shape[1], 1), dtype=float)
Y[0, :, :, 0] = temp[:, :, 0].astype(float) / 255
srcnn.load_weights('SRCNN_check_1.h5')
# perform super-resolution with srcnn
pre = srcnn.predict(Y, batch_size=1)
# post-process output
pre *= 255
pre[pre[:] > 255] = 255
pre[pre[:] < 0] = 0
pre = pre.astype(np.uint8)
# copy Y channel back to image and convert to BGR
temp = shave(temp, 6)
temp[:, :, 0] = pre[0, :, :, 0]
output = cv2.cvtColor(temp, cv2.COLOR_YCrCb2BGR)
# remove border from reference and degraged image
ref = shave(ref.astype(np.uint8), 6)
degraded = shave(degraded.astype(np.uint8), 6)
# image quality calculations
scores = []
scores.append(compare_images(degraded, ref))
scores.append(compare_images(output, ref))
# return images and scores
return ref, degraded, output, scores
ref, degraded, output, scores= predict('train_lr/butterfly_GT.bmp')
print(degraded.shape)
print(output.shape)
print('Degraded Image: \nPSNR: {}\nMSE: {}\nSSIM: {}\n'.format(scores[0][0], scores[0][1], scores[0][2]))
print('Reconstructed Image: \nPSNR: {}\nMSE: {}\nSSIM: {}\n'.format(scores[1][0], scores[1][1], scores[1][2]))
fig, axs = plt.subplots(1, 3, figsize=(20, 8))
axs[0].imshow(cv2.cvtColor(degraded, cv2.COLOR_BGR2RGB))
axs[0].set_title('Degraded')
axs[1].imshow(cv2.cvtColor(output, cv2.COLOR_BGR2RGB))
axs[1].set_title('SRCNN')
axs[2].imshow(cv2.cvtColor(ref, cv2.COLOR_BGR2RGB))
axs[2].set_title('orginal')
# plt.savefig("test2.jpg")
# remove the x and y ticks
for ax in axs:
ax.set_xticks([])
ax.set_yticks([])