武倔 2017-11-28 09:57 采纳率: 0%
浏览 5192

(Tensorflow) 在读取文件后,如何将global_step变为0

这是cifar_10的代码,我想每次训练开始的时候global_step都是0,而不是从文件中读到的之前训练留下来的步数,对于MonitoredTrainingSession不是懂,有没有大佬教教我,感谢了!

def train(max_step, n):
  """Train CIFAR-10 for a number of steps."""
  """创建图"""
  with tf.Graph().as_default():
    global_step = tf.contrib.framework.get_or_create_global_step()

    # Get images and labels for CIFAR-10.
    # Force input pipeline to CPU:0 to avoid operations sometimes ending up on
    # GPU and resulting in a slow down.
    with tf.device('/cpu:0'):
        images, labels = cifar10.distorted_inputs(n)
        #使用cpu运行
    #到了cifar10里面的image以及FLAGS.batch_size=128
    #images=/tmp/cifar10_data/cifar-10-batches-bin
    #labels=128
    # Build a Graph that computes the logits predictions from the
    # inference model.
    file_output_path=FLAGS.train_dir
    file_output_path=os.path.join(file_output_path,"train_output.txt")
    file_output=open(file_output_path,"a")

    """初始化logits(预测值)以及训练"""
    logits = cifar10.inference(images)

    loss = cifar10.loss(logits, labels)

    train_op = cifar10.train(loss, global_step)

    # Build a Graph that trains the model with one batch of examples and
    # updates the model parameters.
    class _LoggerHook(tf.train.SessionRunHook):
      """Logs loss and runtime."""

      def begin(self):
        self._step = -1
        self._start_time = time.time()

      def before_run(self, run_context):
        self._step += 1
        return tf.train.SessionRunArgs(loss)  # Asks for loss value.

      def after_run(self, run_context, run_values):
        if self._step % FLAGS.log_frequency == 0:
          current_time = time.time()
          duration = current_time - self._start_time
          self._start_time = current_time

          loss_value = run_values.results
          examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
          sec_per_batch = float(duration / FLAGS.log_frequency)

          format_str = ('whichdata:%d %s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                        'sec/batch)')
          print (format_str % (n, datetime.now(), self._step, loss_value,
                               examples_per_sec, sec_per_batch))
          """
          /////////////////
          """
          file_output=open(file_output_path,"a")
          file_output.write(format_str % (n, datetime.now(), self._step, loss_value,
                               examples_per_sec, sec_per_batch)+'\n')
    with tf.train.MonitoredTrainingSession(
        checkpoint_dir=FLAGS.train_dir,
        hooks=[tf.train.StopAtStepHook(last_step=max_step),
               tf.train.NanTensorHook(loss),
               _LoggerHook()],
        config=tf.ConfigProto(
            log_device_placement=FLAGS.log_device_placement)) as mon_sess:
        while not mon_sess.should_stop():
            mon_sess.run(train_op)
  • 写回答

1条回答 默认 最新

  • Enzo洛 2018-01-09 01:20
    关注

    tf.train.Saver([varible]) 在这个参数中添加要读取的变量,其余变量用初始化OP运行。

    评论

报告相同问题?

悬赏问题

  • ¥15 程序不包含适用于入口点的静态Main方法
  • ¥15 素材场景中光线烘焙后灯光失效
  • ¥15 请教一下各位,为什么我这个没有实现模拟点击
  • ¥15 执行 virtuoso 命令后,界面没有,cadence 启动不起来
  • ¥50 comfyui下连接animatediff节点生成视频质量非常差的原因
  • ¥20 有关区间dp的问题求解
  • ¥15 多电路系统共用电源的串扰问题
  • ¥15 slam rangenet++配置
  • ¥15 有没有研究水声通信方面的帮我改俩matlab代码
  • ¥15 ubuntu子系统密码忘记