这是cifar_10的代码,我想每次训练开始的时候global_step都是0,而不是从文件中读到的之前训练留下来的步数,对于MonitoredTrainingSession不是懂,有没有大佬教教我,感谢了!
def train(max_step, n):
"""Train CIFAR-10 for a number of steps."""
"""创建图"""
with tf.Graph().as_default():
global_step = tf.contrib.framework.get_or_create_global_step()
# Get images and labels for CIFAR-10.
# Force input pipeline to CPU:0 to avoid operations sometimes ending up on
# GPU and resulting in a slow down.
with tf.device('/cpu:0'):
images, labels = cifar10.distorted_inputs(n)
#使用cpu运行
#到了cifar10里面的image以及FLAGS.batch_size=128
#images=/tmp/cifar10_data/cifar-10-batches-bin
#labels=128
# Build a Graph that computes the logits predictions from the
# inference model.
file_output_path=FLAGS.train_dir
file_output_path=os.path.join(file_output_path,"train_output.txt")
file_output=open(file_output_path,"a")
"""初始化logits(预测值)以及训练"""
logits = cifar10.inference(images)
loss = cifar10.loss(logits, labels)
train_op = cifar10.train(loss, global_step)
# Build a Graph that trains the model with one batch of examples and
# updates the model parameters.
class _LoggerHook(tf.train.SessionRunHook):
"""Logs loss and runtime."""
def begin(self):
self._step = -1
self._start_time = time.time()
def before_run(self, run_context):
self._step += 1
return tf.train.SessionRunArgs(loss) # Asks for loss value.
def after_run(self, run_context, run_values):
if self._step % FLAGS.log_frequency == 0:
current_time = time.time()
duration = current_time - self._start_time
self._start_time = current_time
loss_value = run_values.results
examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
sec_per_batch = float(duration / FLAGS.log_frequency)
format_str = ('whichdata:%d %s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
'sec/batch)')
print (format_str % (n, datetime.now(), self._step, loss_value,
examples_per_sec, sec_per_batch))
"""
/////////////////
"""
file_output=open(file_output_path,"a")
file_output.write(format_str % (n, datetime.now(), self._step, loss_value,
examples_per_sec, sec_per_batch)+'\n')
with tf.train.MonitoredTrainingSession(
checkpoint_dir=FLAGS.train_dir,
hooks=[tf.train.StopAtStepHook(last_step=max_step),
tf.train.NanTensorHook(loss),
_LoggerHook()],
config=tf.ConfigProto(
log_device_placement=FLAGS.log_device_placement)) as mon_sess:
while not mon_sess.should_stop():
mon_sess.run(train_op)