尼古拉斯富贵 2018-10-18 13:11 采纳率: 0%
浏览 2599

如何可视化tensorflow版的fater rcnn的训练过程?

小白一枚,github下的faster rcnn tf版在win10系统上,代码里没有输出训练日志的语句无log文件,
想问一下怎么添加语句可以通过tensorboard显示?下为train代码

import time
import tensorflow as tf
import numpy as np

from tensorflow.python import pywrap_tensorflow

import lib.config.config as cfg
from lib.datasets import roidb as rdl_roidb
from lib.datasets.factory import get_imdb
from lib.datasets.imdb import imdb as imdb2
from lib.layer_utils.roi_data_layer import RoIDataLayer
from lib.nets.vgg16 import vgg16
from lib.utils.timer import Timer

try:
import cPickle as pickle
except ImportError:
import pickle
import os

def get_training_roidb(imdb):
"""Returns a roidb (Region of Interest database) for use in training."""
if True:
print('Appending horizontally-flipped training examples...')
imdb.append_flipped_images()
print('done')

print('Preparing training data...')
rdl_roidb.prepare_roidb(imdb)
print('done')

return imdb.roidb

def combined_roidb(imdb_names):
"""
Combine multiple roidbs
"""

def get_roidb(imdb_name):
    imdb = get_imdb(imdb_name)
    print('Loaded dataset `{:s}` for training'.format(imdb.name))
    imdb.set_proposal_method("gt")
    print('Set proposal method: {:s}'.format("gt"))
    roidb = get_training_roidb(imdb)
    return roidb

roidbs = [get_roidb(s) for s in imdb_names.split('+')]
roidb = roidbs[0]
if len(roidbs) > 1:
    for r in roidbs[1:]:
        roidb.extend(r)
    tmp = get_imdb(imdb_names.split('+')[1])
    imdb = imdb2(imdb_names, tmp.classes)
else:
    imdb = get_imdb(imdb_names)
return imdb, roidb

class Train:
def init(self):

    # Create network
    if cfg.FLAGS.net == 'vgg16':
        self.net = vgg16(batch_size=cfg.FLAGS.ims_per_batch)
    else:
        raise NotImplementedError

    self.imdb, self.roidb = combined_roidb("voc_2007_trainval")

    self.data_layer = RoIDataLayer(self.roidb, self.imdb.num_classes)
    self.output_dir = cfg.get_output_dir(self.imdb, 'default')


def train(self):

    # Create session
    tfconfig = tf.ConfigProto(allow_soft_placement=True)
    tfconfig.gpu_options.allow_growth = True
    sess = tf.Session(config=tfconfig)

    with sess.graph.as_default():

        tf.set_random_seed(cfg.FLAGS.rng_seed)
        layers = self.net.create_architecture(sess, "TRAIN", self.imdb.num_classes, tag='default')
        loss = layers['total_loss']
        lr = tf.Variable(cfg.FLAGS.learning_rate, trainable=False)
        momentum = cfg.FLAGS.momentum
        optimizer = tf.train.MomentumOptimizer(lr, momentum)

        gvs = optimizer.compute_gradients(loss)

        # Double bias
        # Double the gradient of the bias if set
        if cfg.FLAGS.double_bias:
            final_gvs = []
            with tf.variable_scope('Gradient_Mult'):
                for grad, var in gvs:
                    scale = 1.
                    if cfg.FLAGS.double_bias and '/biases:' in var.name:
                        scale *= 2.
                    if not np.allclose(scale, 1.0):
                        grad = tf.multiply(grad, scale)
                    final_gvs.append((grad, var))
            train_op = optimizer.apply_gradients(final_gvs)
        else:
            train_op = optimizer.apply_gradients(gvs)

        # We will handle the snapshots ourselves
        self.saver = tf.train.Saver(max_to_keep=100000)
        # Write the train and validation information to tensorboard
        # writer = tf.summary.FileWriter(self.tbdir, sess.graph)
        # valwriter = tf.summary.FileWriter(self.tbvaldir)

    # Load weights
    # Fresh train directly from ImageNet weights
    print('Loading initial model weights from {:s}'.format(cfg.FLAGS.pretrained_model))
    variables = tf.global_variables()
    # Initialize all variables first
    sess.run(tf.variables_initializer(variables, name='init'))
    var_keep_dic = self.get_variables_in_checkpoint_file(cfg.FLAGS.pretrained_model)
    # Get the variables to restore, ignorizing the variables to fix
    variables_to_restore = self.net.get_variables_to_restore(variables, var_keep_dic)

    restorer = tf.train.Saver(variables_to_restore)
    restorer.restore(sess, cfg.FLAGS.pretrained_model)
    print('Loaded.')
    # Need to fix the variables before loading, so that the RGB weights are changed to BGR
    # For VGG16 it also changes the convolutional weights fc6 and fc7 to
    # fully connected weights
    self.net.fix_variables(sess, cfg.FLAGS.pretrained_model)
    print('Fixed.')
    sess.run(tf.assign(lr, cfg.FLAGS.learning_rate))
    last_snapshot_iter = 0

    timer = Timer()
    iter = last_snapshot_iter + 1
    last_summary_time = time.time()
    while iter < cfg.FLAGS.max_iters + 1:
        # Learning rate
        if iter == cfg.FLAGS.step_size + 1:
            # Add snapshot here before reducing the learning rate
            # self.snapshot(sess, iter)
            sess.run(tf.assign(lr, cfg.FLAGS.learning_rate * cfg.FLAGS.gamma))

        timer.tic()
        # Get training data, one batch at a time
        blobs = self.data_layer.forward()

        # Compute the graph without summary
        rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss = self.net.train_step(sess, blobs, train_op)
        timer.toc()
        iter += 1

        # Display training information
        if iter % (cfg.FLAGS.display) == 0:
            print('iter: %d / %d, total loss: %.6f\n >>> rpn_loss_cls: %.6f\n '
                  '>>> rpn_loss_box: %.6f\n >>> loss_cls: %.6f\n >>> loss_box: %.6f\n ' % \
                  (iter, cfg.FLAGS.max_iters, total_loss, rpn_loss_cls, rpn_loss_box, loss_cls, loss_box))
            print('speed: {:.3f}s / iter'.format(timer.average_time))

        if iter % cfg.FLAGS.snapshot_iterations == 0:
            self.snapshot(sess, iter )

def get_variables_in_checkpoint_file(self, file_name):
    try:
        reader = pywrap_tensorflow.NewCheckpointReader(file_name)
        var_to_shape_map = reader.get_variable_to_shape_map()
        return var_to_shape_map
    except Exception as e:  # pylint: disable=broad-except
        print(str(e))
        if "corrupted compressed block contents" in str(e):
            print("It's likely that your checkpoint file has been compressed "
                  "with SNAPPY.")

def snapshot(self, sess, iter):
    net = self.net

    if not os.path.exists(self.output_dir):
        os.makedirs(self.output_dir)

    # Store the model snapshot
    filename = 'vgg16_faster_rcnn_iter_{:d}'.format(iter) + '.ckpt'
    filename = os.path.join(self.output_dir, filename)
    self.saver.save(sess, filename)
    print('Wrote snapshot to: {:s}'.format(filename))

    # Also store some meta information, random state, etc.
    nfilename = 'vgg16_faster_rcnn_iter_{:d}'.format(iter) + '.pkl'
    nfilename = os.path.join(self.output_dir, nfilename)
    # current state of numpy random
    st0 = np.random.get_state()
    # current position in the database
    cur = self.data_layer._cur
    # current shuffled indeces of the database
    perm = self.data_layer._perm

    # Dump the meta info
    with open(nfilename, 'wb') as fid:
        pickle.dump(st0, fid, pickle.HIGHEST_PROTOCOL)
        pickle.dump(cur, fid, pickle.HIGHEST_PROTOCOL)
        pickle.dump(perm, fid, pickle.HIGHEST_PROTOCOL)
        pickle.dump(iter, fid, pickle.HIGHEST_PROTOCOL)

    return filename, nfilename

if name == '__main__':
train = Train()
train.train()

  • 写回答

1条回答 默认 最新

  • threenewbee 2018-10-18 13:43
    关注
    评论

报告相同问题?

悬赏问题

  • ¥20 怎么在stm32门禁成品上增加记录功能
  • ¥15 Source insight编写代码后使用CCS5.2版本import之后,代码跳到注释行里面
  • ¥50 NT4.0系统 STOP:0X0000007B
  • ¥15 想问一下stata17中这段代码哪里有问题呀
  • ¥15 flink cdc无法实时同步mysql数据
  • ¥100 有人会搭建GPT-J-6B框架吗?有偿
  • ¥15 求差集那个函数有问题,有无佬可以解决
  • ¥15 【提问】基于Invest的水源涵养
  • ¥20 微信网友居然可以通过vx号找到我绑的手机号
  • ¥15 解riccati方程组