coding: utf-8
from future import print_function
import tensorflow as tf
from nets import nets_factory
from preprocessing import preprocessing_factory
import utils
import os
slim = tf.contrib.slim
def gram(layer):
shape = tf.shape(layer)
num_images = shape[0]
width = shape[1]
height = shape[2]
num_filters = shape[3]
filters = tf.reshape(layer, tf.stack([num_images, -1, num_filters]))
grams = tf.matmul(filters, filters, transpose_a=True) / tf.to_float(width * height * num_filters)
return grams
def get_style_features(FLAGS):
"""
For the "style_image", the preprocessing step is:
1. Resize the shorter side to FLAGS.image_size
2. Apply central crop
"""
with tf.Graph().as_default():
network_fn = nets_factory.get_network_fn(
FLAGS.loss_model,
num_classes=1,
is_training=False)
image_preprocessing_fn, image_unprocessing_fn = preprocessing_factory.get_preprocessing(
FLAGS.loss_model,
is_training=False)
# Get the style image data
size = FLAGS.image_size
img_bytes = tf.read_file(FLAGS.style_image)
if FLAGS.style_image.lower().endswith('png'):
image = tf.image.decode_png(img_bytes)
else:
image = tf.image.decode_jpeg(img_bytes)
# image = _aspect_preserving_resize(image, size)
# Add the batch dimension
images = tf.expand_dims(image_preprocessing_fn(image, size, size), 0)
# images = tf.stack([image_preprocessing_fn(image, size, size)])
_, endpoints_dict = network_fn(images, spatial_squeeze=False)
features = []
for layer in FLAGS.style_layers:
feature = endpoints_dict[layer]
feature = tf.squeeze(gram(feature), [0]) # remove the batch dimension
features.append(feature)
with tf.Session() as sess:
# Restore variables for loss network.
init_func = utils._get_init_fn(FLAGS)
init_func(sess)
# Make sure the 'generated' directory is exists.
if os.path.exists('generated') is False:
os.makedirs('generated')
# Indicate cropped style image path
save_file = 'generated/target_style_' + FLAGS.naming + '.jpg'
# Write preprocessed style image to indicated path
with open(save_file, 'wb') as f:
target_image = image_unprocessing_fn(images[0, :])
value = tf.image.encode_jpeg(tf.cast(target_image, tf.uint8))
f.write(sess.run(value))
tf.logging.info('Target style pattern is saved to: %s.' % save_file)
# Return the features those layers are use for measuring style loss.
return sess.run(features)
def style_loss(endpoints_dict, style_features_t, style_layers):
style_loss = 0
style_loss_summary = {}
for style_gram, layer in zip(style_features_t, style_layers):
generated_images, _ = tf.split(endpoints_dict[layer], 2, 0)
size = tf.size(generated_images)
layer_style_loss = tf.nn.l2_loss(gram(generated_images) - style_gram) * 2 / tf.to_float(size)
style_loss_summary[layer] = layer_style_loss
style_loss += layer_style_loss
return style_loss, style_loss_summary
def content_loss(endpoints_dict, content_layers):
content_loss = 0
for layer in content_layers:
generated_images, content_images = tf.split(endpoints_dict[layer], 2, 0)
size = tf.size(generated_images)
content_loss += tf.nn.l2_loss(generated_images - content_images) * 2 / tf.to_float(size) # remain the same as in the paper
return content_loss
def total_variation_loss(layer):
shape = tf.shape(layer)
height = shape[1]
width = shape[2]
y = tf.slice(layer, [0, 0, 0, 0], tf.stack([-1, height - 1, -1, -1])) - tf.slice(layer, [0, 1, 0, 0], [-1, -1, -1, -1])
x = tf.slice(layer, [0, 0, 0, 0], tf.stack([-1, -1, width - 1, -1])) - tf.slice(layer, [0, 0, 1, 0], [-1, -1, -1, -1])
loss = tf.nn.l2_loss(x) / tf.to_float(tf.size(x)) + tf.nn.l2_loss(y) / tf.to_float(tf.size(y))
return loss
train.py
from future import print_function
from future import division
import tensorflow as tf
from nets import nets_factory
from preprocessing import preprocessing_factory
import reader
import model
import time
import losses
import utils
import os
import argparse
slim = tf.contrib.slim
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('-c', '--conf', default='conf/mosaic.yml', help='the path to the conf file')
return parser.parse_args()
def main(FLAGS):
style_features_t = losses.get_style_features(FLAGS)
# Make sure the training path exists.
training_path = os.path.join(FLAGS.model_path, FLAGS.naming)
if not(os.path.exists(training_path)):
os.makedirs(training_path)
with tf.Graph().as_default():
with tf.Session() as sess:
"""Build Network"""
network_fn = nets_factory.get_network_fn(
FLAGS.loss_model,
num_classes=1,
is_training=False)
image_preprocessing_fn, image_unprocessing_fn = preprocessing_factory.get_preprocessing(
FLAGS.loss_model,
is_training=False)
processed_images = reader.image(FLAGS.batch_size, FLAGS.image_size, FLAGS.image_size,
'F:\Anaconda3\7\train2014/', image_preprocessing_fn, epochs=FLAGS.epoch)
generated = model.net(processed_images, training=True)
processed_generated = [image_preprocessing_fn(image, FLAGS.image_size, FLAGS.image_size)
for image in tf.unstack(generated, axis=0, num=FLAGS.batch_size)
]
processed_generated = tf.stack(processed_generated)
_, endpoints_dict = network_fn(tf.concat([processed_generated, processed_images], 0), spatial_squeeze=False)
# Log the structure of loss network
tf.logging.info('Loss network layers(You can define them in "content_layers" and "style_layers"):')
for key in endpoints_dict:
tf.logging.info(key)
"""Build Losses"""
content_loss = losses.content_loss(endpoints_dict, FLAGS.content_layers)
style_loss, style_loss_summary = losses.style_loss(endpoints_dict, style_features_t, FLAGS.style_layers)
tv_loss = losses.total_variation_loss(generated) # use the unprocessed image
loss = FLAGS.style_weight * style_loss + FLAGS.content_weight * content_loss + FLAGS.tv_weight * tv_loss
# Add Summary for visualization in tensorboard.
"""Add Summary"""
tf.summary.scalar('losses/content_loss', content_loss)
tf.summary.scalar('losses/style_loss', style_loss)
tf.summary.scalar('losses/regularizer_loss', tv_loss)
tf.summary.scalar('weighted_losses/weighted_content_loss', content_loss * FLAGS.content_weight)
tf.summary.scalar('weighted_losses/weighted_style_loss', style_loss * FLAGS.style_weight)
tf.summary.scalar('weighted_losses/weighted_regularizer_loss', tv_loss * FLAGS.tv_weight)
tf.summary.scalar('total_loss', loss)
for layer in FLAGS.style_layers:
tf.summary.scalar('style_losses/' + layer, style_loss_summary[layer])
tf.summary.image('generated', generated)
# tf.image_summary('processed_generated', processed_generated) # May be better?
tf.summary.image('origin', tf.stack([
image_unprocessing_fn(image) for image in tf.unstack(processed_images, axis=0, num=FLAGS.batch_size)
]))
summary = tf.summary.merge_all()
writer = tf.summary.FileWriter(training_path)
"""Prepare to Train"""
global_step = tf.Variable(0, name="global_step", trainable=False)
variable_to_train = []
for variable in tf.trainable_variables():
if not(variable.name.startswith(FLAGS.loss_model)):
variable_to_train.append(variable)
train_op = tf.train.AdamOptimizer(1e-3).minimize(loss, global_step=global_step, var_list=variable_to_train)
variables_to_restore = []
for v in tf.global_variables():
if not(v.name.startswith(FLAGS.loss_model)):
variables_to_restore.append(v)
saver = tf.train.Saver(variables_to_restore, write_version=tf.train.SaverDef.V1)
sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()])
# Restore variables for loss network.
init_func = utils._get_init_fn(FLAGS)
init_func(sess)
# Restore variables for training model if the checkpoint file exists.
last_file = tf.train.latest_checkpoint(training_path)
if last_file:
tf.logging.info('Restoring model from {}'.format(last_file))
saver.restore(sess, last_file)
"""Start Training"""
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
start_time = time.time()
try:
while not coord.should_stop():
_, loss_t, step = sess.run([train_op, loss, global_step])
elapsed_time = time.time() - start_time
start_time = time.time()
"""logging"""
# print(step)
if step % 10 == 0:
tf.logging.info('step: %d, total Loss %f, secs/step: %f' % (step, loss_t, elapsed_time))
"""summary"""
if step % 25 == 0:
tf.logging.info('adding summary...')
summary_str = sess.run(summary)
writer.add_summary(summary_str, step)
writer.flush()
"""checkpoint"""
if step % 1000 == 0:
saver.save(sess, os.path.join(training_path, 'fast-style-model.ckpt'), global_step=step)
except tf.errors.OutOfRangeError:
saver.save(sess, os.path.join(training_path, 'fast-style-model.ckpt-done'))
tf.logging.info('Done training -- epoch limit reached')
finally:
coord.request_stop()
coord.join(threads)
if name == '__main__':
tf.logging.set_verbosity(tf.logging.INFO)
args = parse_args()
FLAGS = utils.read_conf_file(args.conf)
main(FLAGS)
`
错误情况如图