tensorflow.GraphDef was modified concurrently during serialization

图片说明

Create a saver object which will save all the variables

        saver = tf.train.Saver(keep_checkpoint_every_n_hours=1.0)
        if FLAGS.pre_trained_checkpoint:
            train_utils.restore_fn(FLAGS)

        start_epoch = 0
        # Get the number of training/validation steps per epoch
        tr_batches = int(MODELNET_TRAIN_DATA_SIZE / FLAGS.batch_size)
        if MODELNET_TRAIN_DATA_SIZE % FLAGS.batch_size > 0:
            tr_batches += 1
        val_batches = int(MODELNET_VALIDATE_DATA_SIZE / FLAGS.batch_size)
        if MODELNET_VALIDATE_DATA_SIZE % FLAGS.batch_size > 0:
            val_batches += 1

        # The filenames argument to the TFRecordDataset initializer can either be a string,
        # a list of strings, or a tf.Tensor of strings.
        training_filenames = os.path.join(FLAGS.dataset_dir, 'train.record')
        validate_filenames = os.path.join(FLAGS.dataset_dir, 'validate.record')
        ##################
        # Training loop.
        ##################
        for training_epoch in range(start_epoch, FLAGS.how_many_training_epochs):
            print("-------------------------------------")
            print(" Epoch {} ".format(training_epoch))
            print("-------------------------------------")

            sess.run(iterator.initializer, feed_dict={filenames: training_filenames})
            for step in range(tr_batches):
                # Pull the image batch we'll use for training.
                train_batch_xs, train_batch_ys = sess.run(next_batch)

                handle = sess.partial_run_setup([d_scores, final_desc, learning_rate, summary_op,
                                                 accuracy, total_loss, grad_summ_op, train_op],
                                                [X, final_X, ground_truth,
                                                 grouping_scheme, grouping_weight, is_training,
                                                 is_training2, dropout_keep_prob])

                scores, final = sess.partial_run(handle,
                                                 [d_scores, final_desc],
                                                 feed_dict={
                                                    X: train_batch_xs,
                                                    is_training: True}
                                                 )
                schemes = gvcnn.grouping_scheme(scores, NUM_GROUP, FLAGS.num_views)
                weights = gvcnn.grouping_weight(scores, schemes)

                # Run the graph with this batch of training data.
                lr, train_summary, train_accuracy, train_loss, grad_vals, _ = \
                    sess.partial_run(handle,
                                     [learning_rate, summary_op, accuracy, total_loss, grad_summ_op, train_op],
                                     feed_dict={
                                         final_X: final,
                                         ground_truth: train_batch_ys,
                                         grouping_scheme: schemes,
                                         grouping_weight: weights,
                                         is_training2: True,
                                         dropout_keep_prob: 0.8}
                                     )

                train_writer.add_summary(train_summary, training_epoch)
                train_writer.add_summary(grad_vals, training_epoch)
                tf.logging.info('Epoch #%d, Step #%d, rate %.10f, accuracy %.1f%%, loss %f' %
                                (training_epoch, step, lr, train_accuracy * 100, train_loss))



            # Save the model checkpoint periodically.
            if (training_epoch <= FLAGS.how_many_training_epochs-1):
                checkpoint_path = os.path.join(FLAGS.train_logdir, FLAGS.ckpt_name_to_save)
                tf.logging.info('Saving to "%s-%d"', checkpoint_path, training_epoch)
                saver.save(sess, checkpoint_path, global_step=training_epoch)

2个回答

用CheckpointSaverHook来做模型保存,不要自己写,session run的行为有可能是异步优化后并发的,不是依照python代码的串行关系执行的。
所以一般在一个循环内不会串行调用session run。如果需要运行多个OP一般是拼接成dict或者tuple传递给一次session.run。

Jay_Zhou_XMU
Jay_Zhou_XMU 可能类似,但是他的方法 解决不了我的。还请指教!
5 个月之前 回复
Csdn user default icon
上传中...
上传图片
插入图片
抄袭、复制答案,以达到刷声望分或其他目的的行为,在CSDN问答是严格禁止的,一经发现立刻封号。是时候展现真正的技术了!
立即提问