Create a saver object which will save all the variables
saver = tf.train.Saver(keep_checkpoint_every_n_hours=1.0)
if FLAGS.pre_trained_checkpoint:
train_utils.restore_fn(FLAGS)
start_epoch = 0
# Get the number of training/validation steps per epoch
tr_batches = int(MODELNET_TRAIN_DATA_SIZE / FLAGS.batch_size)
if MODELNET_TRAIN_DATA_SIZE % FLAGS.batch_size > 0:
tr_batches += 1
val_batches = int(MODELNET_VALIDATE_DATA_SIZE / FLAGS.batch_size)
if MODELNET_VALIDATE_DATA_SIZE % FLAGS.batch_size > 0:
val_batches += 1
# The filenames argument to the TFRecordDataset initializer can either be a string,
# a list of strings, or a tf.Tensor of strings.
training_filenames = os.path.join(FLAGS.dataset_dir, 'train.record')
validate_filenames = os.path.join(FLAGS.dataset_dir, 'validate.record')
##################
# Training loop.
##################
for training_epoch in range(start_epoch, FLAGS.how_many_training_epochs):
print("-------------------------------------")
print(" Epoch {} ".format(training_epoch))
print("-------------------------------------")
sess.run(iterator.initializer, feed_dict={filenames: training_filenames})
for step in range(tr_batches):
# Pull the image batch we'll use for training.
train_batch_xs, train_batch_ys = sess.run(next_batch)
handle = sess.partial_run_setup([d_scores, final_desc, learning_rate, summary_op,
accuracy, total_loss, grad_summ_op, train_op],
[X, final_X, ground_truth,
grouping_scheme, grouping_weight, is_training,
is_training2, dropout_keep_prob])
scores, final = sess.partial_run(handle,
[d_scores, final_desc],
feed_dict={
X: train_batch_xs,
is_training: True}
)
schemes = gvcnn.grouping_scheme(scores, NUM_GROUP, FLAGS.num_views)
weights = gvcnn.grouping_weight(scores, schemes)
# Run the graph with this batch of training data.
lr, train_summary, train_accuracy, train_loss, grad_vals, _ = \
sess.partial_run(handle,
[learning_rate, summary_op, accuracy, total_loss, grad_summ_op, train_op],
feed_dict={
final_X: final,
ground_truth: train_batch_ys,
grouping_scheme: schemes,
grouping_weight: weights,
is_training2: True,
dropout_keep_prob: 0.8}
)
train_writer.add_summary(train_summary, training_epoch)
train_writer.add_summary(grad_vals, training_epoch)
tf.logging.info('Epoch #%d, Step #%d, rate %.10f, accuracy %.1f%%, loss %f' %
(training_epoch, step, lr, train_accuracy * 100, train_loss))
# Save the model checkpoint periodically.
if (training_epoch <= FLAGS.how_many_training_epochs-1):
checkpoint_path = os.path.join(FLAGS.train_logdir, FLAGS.ckpt_name_to_save)
tf.logging.info('Saving to "%s-%d"', checkpoint_path, training_epoch)
saver.save(sess, checkpoint_path, global_step=training_epoch)