CaptainSxy 2018-09-28 02:22 采纳率: 0%
浏览 2241

tensorflow cifar10教程如何实现断点续训?我希望每次能从上次的结果继续训练

cifar10_train.py

FLAGS = tf.app.flags.FLAGS

tf.app.flags.DEFINE_string('train_dir', 'D:/tmp/cifar10_trainn',
"""Directory where to write event logs """
"""and checkpoint.""")
tf.app.flags.DEFINE_integer('max_steps', 100000,
"""Number of batches to run.""")
tf.app.flags.DEFINE_boolean('log_device_placement', False,
"""Whether to log device placement.""")
tf.app.flags.DEFINE_integer('log_frequency', 10,
"""How often to log results to the console.""")

def train():
"""Train CIFAR-10 for a number of steps."""
with tf.Graph().as_default():
global_step = tf.train.get_or_create_global_step()

# Get images and labels for CIFAR-10.
# Force input pipeline to CPU:0 to avoid operations sometimes ending up on
# GPU and resulting in a slow down.
with tf.device('/cpu:0'):
  images, labels = cifar10.distorted_inputs()

# Build a Graph that computes the logits predictions from the
# inference model.
logits = cifar10.inference(images)

# Calculate loss.
loss = cifar10.loss(logits, labels)

# Build a Graph that trains the model with one batch of examples and
# updates the model parameters.
train_op = cifar10.train(loss, global_step)

class _LoggerHook(tf.train.SessionRunHook):
  """Logs loss and runtime."""

  def begin(self):
    self._step = -1
    self._start_time = time.time()

  def before_run(self, run_context):
    self._step += 1
    return tf.train.SessionRunArgs(loss)  # Asks for loss value.

  def after_run(self, run_context, run_values):
    if self._step % FLAGS.log_frequency == 0:
      current_time = time.time()
      duration = current_time - self._start_time
      self._start_time = current_time

      loss_value = run_values.results
      examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
      sec_per_batch = float(duration / FLAGS.log_frequency)

      format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                    'sec/batch)')
      print (format_str % (datetime.now(), self._step, loss_value,
                           examples_per_sec, sec_per_batch))

saver = tf.train.Saver()
with tf.train.MonitoredTrainingSession(
    checkpoint_dir=FLAGS.train_dir,
    hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
           tf.train.NanTensorHook(loss),
           _LoggerHook()],
    config=tf.ConfigProto(
        log_device_placement=FLAGS.log_device_placement)) as mon_sess:
    while not mon_sess.should_stop():
        mon_sess.run(train_op)

def main(argv=None): # pylint: disable=unused-argument
cifar10.maybe_download_and_extract()
if tf.gfile.Exists(FLAGS.train_dir):
tf.gfile.DeleteRecursively(FLAGS.train_dir)
tf.gfile.MakeDirs(FLAGS.train_dir)
train()

if name == '__main__':
tf.app.run()

  • 写回答

1条回答 默认 最新

  • CSDN-Ada助手 CSDN-AI 官方账号 2022-10-25 19:24
    关注
    不知道你这个问题是否已经解决, 如果还没有解决的话:

    如果你已经解决了该问题, 非常希望你能够分享一下解决方案, 写成博客, 将相关链接放在评论区, 以帮助更多的人 ^-^
    评论

报告相同问题?

悬赏问题

  • ¥15 有赏,i卡绘世画不出
  • ¥15 如何用stata画出文献中常见的安慰剂检验图
  • ¥15 c语言链表结构体数据插入
  • ¥40 使用MATLAB解答线性代数问题
  • ¥15 COCOS的问题COCOS的问题
  • ¥15 FPGA-SRIO初始化失败
  • ¥15 MapReduce实现倒排索引失败
  • ¥15 ZABBIX6.0L连接数据库报错,如何解决?(操作系统-centos)
  • ¥15 找一位技术过硬的游戏pj程序员
  • ¥15 matlab生成电测深三层曲线模型代码