SyusukeVae 2020-03-08 17:10 采纳率: 0%
浏览 569
已结题

训练风格迁移模型时遇到一些无法解决的错误

coding: utf-8

from future import print_function
import tensorflow as tf
from nets import nets_factory
from preprocessing import preprocessing_factory
import utils
import os

slim = tf.contrib.slim

def gram(layer):
shape = tf.shape(layer)
num_images = shape[0]
width = shape[1]
height = shape[2]
num_filters = shape[3]
filters = tf.reshape(layer, tf.stack([num_images, -1, num_filters]))
grams = tf.matmul(filters, filters, transpose_a=True) / tf.to_float(width * height * num_filters)

return grams

def get_style_features(FLAGS):
"""
For the "style_image", the preprocessing step is:
1. Resize the shorter side to FLAGS.image_size
2. Apply central crop
"""
with tf.Graph().as_default():
network_fn = nets_factory.get_network_fn(
FLAGS.loss_model,
num_classes=1,
is_training=False)
image_preprocessing_fn, image_unprocessing_fn = preprocessing_factory.get_preprocessing(
FLAGS.loss_model,
is_training=False)

    # Get the style image data
    size = FLAGS.image_size
    img_bytes = tf.read_file(FLAGS.style_image)
    if FLAGS.style_image.lower().endswith('png'):
        image = tf.image.decode_png(img_bytes)
    else:
        image = tf.image.decode_jpeg(img_bytes)
    # image = _aspect_preserving_resize(image, size)

    # Add the batch dimension
    images = tf.expand_dims(image_preprocessing_fn(image, size, size), 0)
    # images = tf.stack([image_preprocessing_fn(image, size, size)])

    _, endpoints_dict = network_fn(images, spatial_squeeze=False)
    features = []
    for layer in FLAGS.style_layers:
        feature = endpoints_dict[layer]
        feature = tf.squeeze(gram(feature), [0])  # remove the batch dimension
        features.append(feature)

    with tf.Session() as sess:
        # Restore variables for loss network.
        init_func = utils._get_init_fn(FLAGS)
        init_func(sess)

        # Make sure the 'generated' directory is exists.
        if os.path.exists('generated') is False:
            os.makedirs('generated')
        # Indicate cropped style image path
        save_file = 'generated/target_style_' + FLAGS.naming + '.jpg'
        # Write preprocessed style image to indicated path
        with open(save_file, 'wb') as f:
            target_image = image_unprocessing_fn(images[0, :])
            value = tf.image.encode_jpeg(tf.cast(target_image, tf.uint8))
            f.write(sess.run(value))
            tf.logging.info('Target style pattern is saved to: %s.' % save_file)

        # Return the features those layers are use for measuring style loss.
        return sess.run(features)

def style_loss(endpoints_dict, style_features_t, style_layers):
style_loss = 0
style_loss_summary = {}
for style_gram, layer in zip(style_features_t, style_layers):
generated_images, _ = tf.split(endpoints_dict[layer], 2, 0)
size = tf.size(generated_images)
layer_style_loss = tf.nn.l2_loss(gram(generated_images) - style_gram) * 2 / tf.to_float(size)
style_loss_summary[layer] = layer_style_loss
style_loss += layer_style_loss
return style_loss, style_loss_summary

def content_loss(endpoints_dict, content_layers):
content_loss = 0
for layer in content_layers:
generated_images, content_images = tf.split(endpoints_dict[layer], 2, 0)
size = tf.size(generated_images)
content_loss += tf.nn.l2_loss(generated_images - content_images) * 2 / tf.to_float(size) # remain the same as in the paper
return content_loss

def total_variation_loss(layer):
shape = tf.shape(layer)
height = shape[1]
width = shape[2]
y = tf.slice(layer, [0, 0, 0, 0], tf.stack([-1, height - 1, -1, -1])) - tf.slice(layer, [0, 1, 0, 0], [-1, -1, -1, -1])
x = tf.slice(layer, [0, 0, 0, 0], tf.stack([-1, -1, width - 1, -1])) - tf.slice(layer, [0, 0, 1, 0], [-1, -1, -1, -1])
loss = tf.nn.l2_loss(x) / tf.to_float(tf.size(x)) + tf.nn.l2_loss(y) / tf.to_float(tf.size(y))
return loss

train.py

from future import print_function
from future import division

import tensorflow as tf
from nets import nets_factory
from preprocessing import preprocessing_factory
import reader
import model
import time
import losses
import utils
import os
import argparse

slim = tf.contrib.slim

def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('-c', '--conf', default='conf/mosaic.yml', help='the path to the conf file')
return parser.parse_args()

def main(FLAGS):
style_features_t = losses.get_style_features(FLAGS)

# Make sure the training path exists.
training_path = os.path.join(FLAGS.model_path, FLAGS.naming)
if not(os.path.exists(training_path)):
    os.makedirs(training_path)

with tf.Graph().as_default():
    with tf.Session() as sess:
        """Build Network"""
        network_fn = nets_factory.get_network_fn(
            FLAGS.loss_model,
            num_classes=1,
            is_training=False)

        image_preprocessing_fn, image_unprocessing_fn = preprocessing_factory.get_preprocessing(
            FLAGS.loss_model,
            is_training=False)
        processed_images = reader.image(FLAGS.batch_size, FLAGS.image_size, FLAGS.image_size,
                                        'F:\Anaconda3\7\train2014/', image_preprocessing_fn, epochs=FLAGS.epoch)
        generated = model.net(processed_images, training=True)
        processed_generated = [image_preprocessing_fn(image, FLAGS.image_size, FLAGS.image_size)
                               for image in tf.unstack(generated, axis=0, num=FLAGS.batch_size)
                               ]
        processed_generated = tf.stack(processed_generated)
        _, endpoints_dict = network_fn(tf.concat([processed_generated, processed_images], 0), spatial_squeeze=False)

        # Log the structure of loss network
        tf.logging.info('Loss network layers(You can define them in "content_layers" and "style_layers"):')
        for key in endpoints_dict:
            tf.logging.info(key)

        """Build Losses"""
        content_loss = losses.content_loss(endpoints_dict, FLAGS.content_layers)
        style_loss, style_loss_summary = losses.style_loss(endpoints_dict, style_features_t, FLAGS.style_layers)
        tv_loss = losses.total_variation_loss(generated)  # use the unprocessed image

        loss = FLAGS.style_weight * style_loss + FLAGS.content_weight * content_loss + FLAGS.tv_weight * tv_loss

        # Add Summary for visualization in tensorboard.
        """Add Summary"""
        tf.summary.scalar('losses/content_loss', content_loss)
        tf.summary.scalar('losses/style_loss', style_loss)
        tf.summary.scalar('losses/regularizer_loss', tv_loss)

        tf.summary.scalar('weighted_losses/weighted_content_loss', content_loss * FLAGS.content_weight)
        tf.summary.scalar('weighted_losses/weighted_style_loss', style_loss * FLAGS.style_weight)
        tf.summary.scalar('weighted_losses/weighted_regularizer_loss', tv_loss * FLAGS.tv_weight)
        tf.summary.scalar('total_loss', loss)

        for layer in FLAGS.style_layers:
            tf.summary.scalar('style_losses/' + layer, style_loss_summary[layer])
        tf.summary.image('generated', generated)
        # tf.image_summary('processed_generated', processed_generated)  # May be better?
        tf.summary.image('origin', tf.stack([
            image_unprocessing_fn(image) for image in tf.unstack(processed_images, axis=0, num=FLAGS.batch_size)
        ]))
        summary = tf.summary.merge_all()
        writer = tf.summary.FileWriter(training_path)

        """Prepare to Train"""
        global_step = tf.Variable(0, name="global_step", trainable=False)

        variable_to_train = []
        for variable in tf.trainable_variables():
            if not(variable.name.startswith(FLAGS.loss_model)):
                variable_to_train.append(variable)
        train_op = tf.train.AdamOptimizer(1e-3).minimize(loss, global_step=global_step, var_list=variable_to_train)

        variables_to_restore = []
        for v in tf.global_variables():
            if not(v.name.startswith(FLAGS.loss_model)):
                variables_to_restore.append(v)
        saver = tf.train.Saver(variables_to_restore, write_version=tf.train.SaverDef.V1)

        sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()])

        # Restore variables for loss network.
        init_func = utils._get_init_fn(FLAGS)
        init_func(sess)

        # Restore variables for training model if the checkpoint file exists.
        last_file = tf.train.latest_checkpoint(training_path)
        if last_file:
            tf.logging.info('Restoring model from {}'.format(last_file))
            saver.restore(sess, last_file)

        """Start Training"""
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        start_time = time.time()
        try:
            while not coord.should_stop():
                _, loss_t, step = sess.run([train_op, loss, global_step])
                elapsed_time = time.time() - start_time
                start_time = time.time()
                """logging"""
                # print(step)
                if step % 10 == 0:
                    tf.logging.info('step: %d,  total Loss %f, secs/step: %f' % (step, loss_t, elapsed_time))
                """summary"""
                if step % 25 == 0:
                    tf.logging.info('adding summary...')
                    summary_str = sess.run(summary)
                    writer.add_summary(summary_str, step)
                    writer.flush()
                """checkpoint"""
                if step % 1000 == 0:
                    saver.save(sess, os.path.join(training_path, 'fast-style-model.ckpt'), global_step=step)
        except tf.errors.OutOfRangeError:
            saver.save(sess, os.path.join(training_path, 'fast-style-model.ckpt-done'))
            tf.logging.info('Done training -- epoch limit reached')
        finally:
            coord.request_stop()
        coord.join(threads)

if name == '__main__':
tf.logging.set_verbosity(tf.logging.INFO)
args = parse_args()
FLAGS = utils.read_conf_file(args.conf)
main(FLAGS)

`

图片说明

错误情况如图

  • 写回答

1条回答 默认 最新

  • dabocaiqq 2020-03-08 21:44
    关注
    评论

报告相同问题?

悬赏问题

  • ¥15 教务系统账号被盗号如何追溯设备
  • ¥20 delta降尺度方法,未来数据怎么降尺度
  • ¥15 c# 使用NPOI快速将datatable数据导入excel中指定sheet,要求快速高效
  • ¥15 再不同版本的系统上,TCP传输速度不一致
  • ¥15 高德地图点聚合中Marker的位置无法实时更新
  • ¥15 DIFY API Endpoint 问题。
  • ¥20 sub地址DHCP问题
  • ¥15 delta降尺度计算的一些细节,有偿
  • ¥15 Arduino红外遥控代码有问题
  • ¥15 数值计算离散正交多项式